Home > Software design >  numpy einsum/tensordot with shared non-contracted axis
numpy einsum/tensordot with shared non-contracted axis

Time:01-31

Suppose I have two arrays:

import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)

and want to sum over the last 3 axes, and keep the shared axis. The output dimension should be (32,6,6,20,128). Notice here the axis with 20 is shared in both a and b. Let's call this axis the "group" axis.

I have two methods for this task:
The first one is just a simple einsum:

def method1(a, b):
    return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True)  # output shape:(32,6,6,20,128)

In the second method I loop through group dimension and use einsum/tensordot to compute the result for each group dimension, then stack the results:

def method2(a, b):
    result = []
    for g in range(b.shape[0]): # loop through each group dimension
        # result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
        result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True))  # output shape:(32,6,6,128)
    return np.stack(result, axis=-2)  # output shape:(32,6,6,20,128)

here's the timing for both methods in my jupyter notebook:
enter image description here
we can see the second method with a loop is faster than the first method.

My question is:

  1. How come method1 is that much slower? It doesn't compute more things.
  2. Is there a more efficient way without using loops? (I'm a bit reluctant to use loops because they are slow in python)

Thanks for any help!

CodePudding user response:

As pointed out by @Murali in the comments, method1 is not very efficient because it does not succeed to use a BLAS calls as opposed to method2 which does. In fact, np.einsum is quite good in method1 since it compute the result sequentially while method2 mostly runs in parallel thanks to OpenBLAS (used by Numpy on most machines). That being said, method2 is sub-optimal since it does not fully use the available cores (parts of the computation are done sequentially). On my 6-core machine, it barely use 50% of all the cores.


Faster implementation

One solution to speed up this computation is to write an highly-optimized Numba parallel code for this.

First of all, a semi-naive implementation is to use many for loops to compute the Einstein summation and reshape the input/output arrays so Numba can better optimize the code (eg. unrolling, use of SIMD instructions). Here is the result:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)

    for NHW in range(sN*sH*sW):
        for g in range(sg):
            for o in range(so):
                s = 0.0

                # Reduction
                for ihw in range(si*sh*sw):
                    s  = ra[NHW, g, ihw] * rb[g, o, ihw]

                out[NHW, g, o] = s

    return out.reshape((sN, sH, sW, sg, so))

Note that the input array are assumed to be contiguous. If this is not the case, please consider performing a copy (which is cheap compared to the computation).

While the above code works, it is far from being efficient. Here are some improvements that can be performed:

  • run the outermost NHW loop in parallel;
  • use the Numba flag fastmath=True. This flag is unsafe if the input data contains special values like NaN or inf/-inf. However, this flag help compiler to generate a much faster code using SIMD instructions (this is not possible otherwise since IEEE-754 floating-point operations are not associative);
  • swap the NHW-based loop and g-based loop results in better performance since it improves cache-locality (rb is more likely to fit in the last-level cache of mainstream CPUs whereas it would likely in fetched from the RAM otherwise);
  • make use of register blocking so to saturate better SIMD computing units of the processor and reduce the pressure on the memory hierarchy;
  • make use of tiling by splitting the o-based loop so rb can almost fully be read from lower-level caches (eg. L1 or L2).

All these improvements except the last one are implemented in the following code:

@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
    sN, sH, sW, sg, si, sh, sw = a.shape
    so = b.shape[1]
    assert b.shape == (sg, so, si, sh, sw)

    ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
    rb = b.reshape(sg, so, si*sh*sw)
    out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)

    for g in range(sg):
        for k in nb.prange((sN*sH*sW)//2):
            NHW = k*2
            so_vect_max = (so // 4) * 4

            for o in range(0, so_vect_max, 4):
                s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0

                # Optimized reduction using register blocking
                for ihw in range(si*sh*sw):
                    ra_0 = ra[NHW 0, g, ihw]
                    ra_1 = ra[NHW 1, g, ihw]
                    rb_0 = rb[g, o 0, ihw]
                    rb_1 = rb[g, o 1, ihw]
                    rb_2 = rb[g, o 2, ihw]
                    rb_3 = rb[g, o 3, ihw]
                    s00  = ra_0 * rb_0; s01  = ra_0 * rb_1
                    s02  = ra_0 * rb_2; s03  = ra_0 * rb_3
                    s10  = ra_1 * rb_0; s11  = ra_1 * rb_1
                    s12  = ra_1 * rb_2; s13  = ra_1 * rb_3

                out[NHW 0, g, o 0] = s00; out[NHW 0, g, o 1] = s01
                out[NHW 0, g, o 2] = s02; out[NHW 0, g, o 3] = s03
                out[NHW 1, g, o 0] = s10; out[NHW 1, g, o 1] = s11
                out[NHW 1, g, o 2] = s12; out[NHW 1, g, o 3] = s13

            # Remaining part for `o`
            for o in range(so_vect_max, so):
                for ihw in range(si*sh*sw):
                    out[NHW, g, o]  = ra[NHW, g, ihw] * rb[g, o, ihw]

        # Remaining part for `k`
        if (sN*sH*sW) % 2 == 1:
            k = sN*sH*sW - 1
            for o in range(so):
                for ihw in range(si*sh*sw):
                    out[k, g, o]  = ra[k, g, ihw] * rb[g, o, ihw]


    return out.reshape((sN, sH, sW, sg, so))

This code is much more complex and uglier but also far more efficient. I did not implemented the tiling optimization since it would make the code even less readable. However, it should results in a significantly faster code on many-core processors (especially the ones with a small L2/L3 cache).


Performance results

Here are performance results on my i5-9600KF 6-core processor:

method1:              816 ms
method2:              104 ms
method3:               71 ms
Theoretical optimal:   33 ms

The code is about 1.5 faster than method2. There is not much room for improvements since the optimal time is only twice better than method3 although implementing the tiling strategy should help to be significantly closer to the optimal time.

If you want a faster implementation, then you need to use a faster hardware or a more dedicated one. Server-based GPUs (ie. not the one of personal computers) not should be able to speed up a lot such a computation since your input is small, clearly compute-bound and massively makes use of FMA floating-point operations. A first start is to try cupy.einsum.


Under the hood: low-level analysis

In order to understand why method1 is not faster, I checked the executed code. Here is the main loop:

1a0:┌─→; Part of the reduction (see below)
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1000]
    │  
    │  ; Decrement the number of loop cycle
    │  sub        r9,0x8 
    │  
    │  ; Prefetch items so to reduce the impact 
    │  ; of the latency of reading from the RAM.
    │  prefetcht0 BYTE PTR [r8]
    │  prefetcht0 BYTE PTR [rdi]
    │  
    │  ; Part of the reduction (see below)
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1000]
    │  
    │  ; Increment iterator for the two arrays
    │  add        rdi,0x40 
    │  add        r8,0x40 
    │  
    │  ; Main computational part: 
    │  ; reduction using add mul SSE2 instructions
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1030]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1030]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  movapd     xmm0,XMMWORD PTR [rdi-0x1020]
    │  mulpd      xmm0,XMMWORD PTR [r8-0x1020]
    │  addpd      xmm0,xmm1                     <--- Slow
    │  movapd     xmm1,XMMWORD PTR [rdi-0x1010]
    │  mulpd      xmm1,XMMWORD PTR [r8-0x1010]
    │  addpd      xmm1,xmm0                     <--- Slow
    │  
    │  ; Is the loop over? 
    │  ; If not, jump to the beginning of the loop.
    ├──cmp        r9,0x7 
    └──jg         1a0

It turns out that Numpy use the SSE2 instruction set (which is available on all x86-64 processors). However, my machine, like almost all relatively recent processor support the AVX instruction set which can compute twice more items at once per instruction. My machine also support fuse-multiply add instructions (FMA) that are twice faster in this case. Moreover, the loop is clearly bounded by the addpd which accumulate the result in mostly the same register. The processor cannot execute them efficiently since an addpd takes few cycle of latency and up to two can be executed at the same time on modern x86-64 processors (which is not possible here since only 1 intruction can perform the accumulation in xmm1 at a time).

Here is the executed code of the main computational part of method2 (dgemm call of OpenBLAS):

6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi-0x80]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi-0x60]
     │  vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
     │  vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi-0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi-0x20]
     │  vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
     │  vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
     │  vfmadd231pd  ymm4,ymm0,YMMWORD PTR [rdi]
     │  vfmadd231pd  ymm5,ymm1,YMMWORD PTR [rdi 0x20]
     │  vfmadd231pd  ymm6,ymm2,YMMWORD PTR [rdi 0x40]
     │  vfmadd231pd  ymm7,ymm3,YMMWORD PTR [rdi 0x60]
     │  add          rsi,0x40
     │  add          rdi,0x100
     ├──dec          rax
     └──jne          6a40

This loop is far more optimized: it makes use of the AVX instruction set as well as the FMA one (ie. vfmadd231pd instructions). Futhermore, the loop is better unrolled and there is not latency/dependency issue like in the Numpy code. However, while this loop is highly-efficient, the cores are not efficiently used due to some sequential checks done in Numpy and a sequential copy performed in OpenBLAS.

Here is the executed code of the main computational part of method3:

.LBB8_7:┌─→leaq (%r15,%r13), %r14
        │  vmovsd    (%rbp,%r14,8), %xmm1
        │  leaq (%r8,%r13), %rdi
        │  vmovsd    (%rbp,%rdi,8), %xmm2
        │  addq %rcx, %r13
        │  leaq (%rbx,%r13), %rdi
        │  vmovsd    (%r11,%rdi,8), %xmm3
        │  leaq (%r12,%r13), %rdi
        │  leaq (%r9,%r13), %rdx
        │  addq %rax, %r13
        │  cmpq $1, %r10
        │  leaq -1(%r10), %r10
        │  vmovsd    (%r11,%rdi,8), %xmm4
        │  vmovsd    (%r11,%rdx,8), %xmm5
        │  vmovsd    (%r11,%r13,8), %xmm6
        │  vfmadd231sd    %xmm1, %xmm3, %xmm0
        │  vfmadd231sd    %xmm1, %xmm4, %xmm7
        │  vfmadd231sd    %xmm1, %xmm5, %xmm14
        │  vfmadd231sd    %xmm1, %xmm6, %xmm13
        │  vfmadd231sd    %xmm3, %xmm2, %xmm12
        │  vfmadd231sd    %xmm4, %xmm2, %xmm11
        │  vfmadd231sd    %xmm5, %xmm2, %xmm10
        │  vfmadd231sd    %xmm2, %xmm6, %xmm9
        │  movl $0, %r13d
        │  cmovgq    %rsi, %r13
        │  movq 144(%rsp), %rdx
        │  addq %rsi, %rdx
        │  incq %rdx
        │  incq %rsi
        ├──cmpq $1, %rdx
        └──jne  .LBB8_7

The reduction part of the loop is very efficient (while not perfect). However, the rest of the loop is not efficient (especially the leaq part). On my machine, the loop seems to be bounded by the number of instruction the processor can execute and such unwanted instructions should be responsible for 40-45% of the execution time. I guess this comes from the Numpy array indexing made by Numba. This is sad since the loop would be very close to the optimal time (assuming the cache would not be an issue which appear not to be the case on my machine).

I also found out that the BLAS code is faster on Linux than Windows on my machine (with default packages coming from PIP and the same Numpy at version 1.20.3). Thus, the gap is closer between method2 and method3 but the later is still a bit faster.

  •  Tags:  
  • Related