Home > Software engineering >  How to decompose 2D (m*n, m*n) matrix into 4D (m, m, n, n) matrix in Python?
How to decompose 2D (m*n, m*n) matrix into 4D (m, m, n, n) matrix in Python?

Time:01-18

To start let's consider a black box where m and n are two variables (where m is a multiple of n) it outputs a 2D matrix of shape (m*n, m*n). Now, it is required to transform this 2D matrix into a 4D matrix with the shape (m, m, n, n). I am unsure the best way to describe this in writing, but the way that the data is structured is that within the 2D (m*n)x(m*n) matrix there exists m lots of (nxn) "tiles" in each direction. Consider an example array a, in this case we have m = 3 and n = 2 so the incoming 2D matrix is 6x6:

print(a)
[[0  1  4  5  8  9 ]
 [2  3  6  7  10 11]
 [12 13 16 17 20 21]
 [14 15 18 19 22 23]
 [24 25 28 29 32 33]
 [26 27 30 31 34 35]]

this is then passed into some function:

b = some_func(a)

to which the required output would be the 4D array:

print(b)
[[[[ 0  1]
   [ 2  3]]

  [[ 4  5]
   [ 6  7]]

  [[ 8  9]
   [10 11]]]


 [[[12 13]
   [14 15]]

  [[16 17]
   [18 19]]

  [[20 21]
   [22 23]]]


 [[[24 25]
   [26 27]]

  [[28 29]
   [30 31]]

  [[32 33]
   [34 35]]]]

To put into words, we need to separate out the "nxn" tiles within the larger 2D array. The actual meaning of this situation is that we have mxm matrix, where each entry is actually an nxn matrix, creating a 4D matrix which we can then do following work. This is a highly simplified example for demonstrative purposes in what is a much more complicated system with a lot more going on. In my case there is also an extra axis, m = 256, the entries in the matrix are complex (64-bit) and we are highly concerned about performance however these details are irrelevant to the issue. If it helps at all the case of n = 2 is the only case we are concerned with, however I would hope that there is a more general solution.

I can reasonably conceive of a solution that uses for loops, indexing, modulo arithmetic, etc however this would be drastically inefficient in Python.

Potential Solutions?

  1. The mind instantly jumps to something like np.reshape(), however we cannot use simply use a.reshape(m, m, n, n) as the correct order is not preserved due to the way np.reshape() first ravels the array, as is outlined in a Matrix operation in question

    CodePudding user response:

    I'll try to illustrate the issues discussed in the comments.

    A starting array - a reshape of a 1d arange:

    In [160]: arr = np.arange(16).reshape(4,4)
    In [161]: arr
    Out[161]: 
    array([[ 0,  1,  2,  3],
           [ 4,  5,  6,  7],
           [ 8,  9, 10, 11],
           [12, 13, 14, 15]])
    In [162]: arr.ravel()
    Out[162]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
    In [163]: arr.strides
    Out[163]: (32, 8)
    

    Further reshape to 4d. Note the ravel is the same. I could also use arr2.__array_interface__ to show the data buffer id.

    In [164]: arr1 = arr.reshape(2,2,2,2)
    In [165]: arr1.ravel()
    Out[165]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
    In [166]: arr1.strides            
    Out[166]: (64, 32, 16, 8)
    

    It's would be a good idea to test your understanding of the change in strides with change in shape.

    Now swap:

    In [167]: arr2 = arr1.swapaxes(1,2)
    In [168]: arr2
    Out[168]: 
    array([[[[ 0,  1],
             [ 4,  5]],
    
            [[ 2,  3],
             [ 6,  7]]],
    
    
           [[[ 8,  9],
             [12, 13]],
    
            [[10, 11],
             [14, 15]]]])
    In [169]: arr2.strides
    Out[169]: (64, 16, 32, 8)
    

    Still a (2,2,2,2) but strides has changed. This too is a view. But a reshape of this (including a ravel) will make a copy. The elements have been reordered:

    In [170]: arr2.ravel()
    Out[170]: array([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])
    In [171]: arr3 = arr2.reshape(4,4)
    In [172]: arr3
    Out[172]: 
    array([[ 0,  1,  4,  5],
           [ 2,  3,  6,  7],
           [ 8,  9, 12, 13],
           [10, 11, 14, 15]])
    In [173]: arr3.ravel()
    Out[173]: array([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])
    

    We see the same change in strides in a simpler 2d transpose:

    In [174]: arr4 = arr.T
    In [175]: arr4.strides
    Out[175]: (8, 32)
    In [176]: arr4.ravel()
    Out[176]: array([ 0,  4,  8, 12,  1,  5,  9, 13,  2,  6, 10, 14,  3,  7, 11, 15])
    

    We can make a view ravel by specifying the 'F' column order. Though that may not help with the understanding. Order does not readily extend to higher dimensions, but strides does.

    In [177]: arr4.ravel(order='F')
    Out[177]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
    

    CodePudding user response:

    I am not sure if this is better/faster than hpaulj answer, I guess will be similar in performance. Please check yourself on big arrays

    a.reshape((m, m * n * n)).reshape(m, n, m * n).transpose(0, 2, 1).reshape(m, m, n, n).transpose(0, 1, 3, 2)
    
  •  Tags:  
  • Related