visualising ymm registers of VINSERTF128 transpose kernel


we take in two 32-byte aligned float array pointers, from and to representing column-major matrices, alongside lda as stride length of from and ldb stride length of to:


inline void t_load_8x8_ps(const float* from, float* to, int lda, int ldb) {

  __m256 t0, t1, t2, t3, t4, t5, t6, t7,
         r0, r1, r2, r3, r4, r5, r6, r7;

   r0 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[0*lda+0])),
                              _mm_load_ps(&from[4*lda+0]) , 1);
   r1 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[1*lda+0])), 
                              _mm_load_ps(&from[5*lda+0]), 1);
   r2 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[2*lda+0])), 
                              _mm_load_ps(&from[6*lda+0]), 1);
   r3 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[3*lda+0])), 
                              _mm_load_ps(&from[7*lda+0]), 1);

   r4 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[0*lda+4])),
                              _mm_load_ps(&from[4*lda+4]), 1);
   r5 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[1*lda+4])),
                              _mm_load_ps(&from[5*lda+4]), 1);
   r6 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[2*lda+4])), 
                              _mm_load_ps(&from[6*lda+4]), 1);
   r7 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[3*lda+4])),
                              _mm_load_ps(&from[7*lda+4]), 1);
      

_mm256_insertf128_ps takes in 3 parameters, a __m256 ymm register, a __m128 xmm register and an integer that represents the 128-bit offset.

for the first 4 calls, the __m256 argument loads the first 4 elements (128 bits) of the first 4 rows into an xmm register with _mm_load_ps then casts it into a ymm using _mm256_castps128_ps256. the second argument loads the first 4 elements of the next 4 rows (total of 8) into an xmm.

_mm256_insertf128_ps first copies all elements of the xmm register into a new ymm register at the specified offset (1 in out case, pointing to the latter 128 bits), then fills the rest of the register with the corresponding values from the ymm argument register.

the latter 4 calls do the same for the latter 128 bits of the registers.

considering a 1024 square matrix where each element's value represents it's index, the elements handled in the first call are:

next, _mm256_unpacklo_ps and _mm256_unpackhi_ps moves the lower or upper two values of two 128 bit halves of the regiters into a new ymm register and interlaces them:


   t0 = _mm256_unpacklo_ps(r0, r1);
   t1 = _mm256_unpackhi_ps(r0, r1);
   t2 = _mm256_unpacklo_ps(r2, r3);
   t3 = _mm256_unpackhi_ps(r2, r3);
   t4 = _mm256_unpacklo_ps(r4, r5);
   t5 = _mm256_unpackhi_ps(r4, r5);
   t6 = _mm256_unpacklo_ps(r6, r7);
   t7 = _mm256_unpackhi_ps(r6, r7); 

in the case of r0 and r1, there are two possible unpacks:

r0 0 1 2 3 4096 4097 4098 4099
r1 1024 1025 1026 1027 5120 5121 5122 5123
_mm256_unpacklo_ps 0 1024 1 1025 4096 5120 4097 5121
_mm256_unpackhi_ps 2 1026 3 1027 4098 5122 4099 5123

the registers are:

next, we shuffle them in the right place with _mm256_shuffle_ps. the first and second arguments are __m256 registers, the third is an imm8 value used to index into the given registers. each two bits of this value give the index of the transferred elements from the two 128 bit halves of the register, from 0 to 4.


   r0 = _mm256_shuffle_ps(t0, t2, 0x44);
   r1 = _mm256_shuffle_ps(t0, t2, 0xee);
   r2 = _mm256_shuffle_ps(t1, t3, 0x44);
   r3 = _mm256_shuffle_ps(t1, t3, 0xee);
   r4 = _mm256_shuffle_ps(t4, t6, 0x44);
   r5 = _mm256_shuffle_ps(t4, t6, 0xee);
   r6 = _mm256_shuffle_ps(t5, t7, 0x44);
   r7 = _mm256_shuffle_ps(t5, t7, 0xee);
      

you can use _MM_SHUFFLE() to create the imm8, or give numbers with known binary values, here we use 0x44 (0b1000100) and 0xee (0b11101110), equivalent to a _MM_SHUFFLE(2, 0, 2, 0) and _MM_SHUFFLE(3, 2, 3, 2).

then, we store the values into to:


   _mm256_store_ps( &to[0*ldb], r0); 
   _mm256_store_ps( &to[1*ldb], r1); 
   _mm256_store_ps( &to[2*ldb], r2); 
   _mm256_store_ps( &to[3*ldb], r3); 
   _mm256_store_ps( &to[4*ldb], r4); 
   _mm256_store_ps( &to[5*ldb], r5); 
   _mm256_store_ps( &to[6*ldb], r6); 
   _mm256_store_ps( &to[7*ldb], r7); 


benchmarks

TODO





full code


inline void t_load_8x8_ps(const float* from, float* to, int lda, int ldb) {

  __m256 t0, t1, t2, t3, t4, t5, t6, t7,
         r0, r1, r2, r3, r4, r5, r6, r7;

   r0 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[0*lda+0])),
                              _mm_load_ps(&from[4*lda+0]) , 1);
   r1 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[1*lda+0])), 
                              _mm_load_ps(&from[5*lda+0]), 1);
   r2 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[2*lda+0])), 
                              _mm_load_ps(&from[6*lda+0]), 1);
   r3 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[3*lda+0])), 
                              _mm_load_ps(&from[7*lda+0]), 1);
   r4 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[0*lda+4])),
                              _mm_load_ps(&from[4*lda+4]), 1);
   r5 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[1*lda+4])),
                              _mm_load_ps(&from[5*lda+4]), 1);
   r6 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[2*lda+4])), 
                              _mm_load_ps(&from[6*lda+4]), 1);
   r7 = _mm256_insertf128_ps( _mm256_castps128_ps256(_mm_load_ps(&from[3*lda+4])),
                              _mm_load_ps(&from[7*lda+4]), 1);

   t0 = _mm256_unpacklo_ps(r0, r1);
   t1 = _mm256_unpackhi_ps(r0, r1);
   t2 = _mm256_unpacklo_ps(r2, r3);
   t3 = _mm256_unpackhi_ps(r2, r3);
   t4 = _mm256_unpacklo_ps(r4, r5);
   t5 = _mm256_unpackhi_ps(r4, r5);
   t6 = _mm256_unpacklo_ps(r6, r7);
   t7 = _mm256_unpackhi_ps(r6, r7);

   r0 = _mm256_shuffle_ps(t0, t2, 0x44);
   r1 = _mm256_shuffle_ps(t0, t2, 0xee);
   r2 = _mm256_shuffle_ps(t1, t3, 0x44);
   r3 = _mm256_shuffle_ps(t1, t3, 0xee);
   r4 = _mm256_shuffle_ps(t4, t6, 0x44);
   r5 = _mm256_shuffle_ps(t4, t6, 0xee);
   r6 = _mm256_shuffle_ps(t5, t7, 0x44);
   r7 = _mm256_shuffle_ps(t5, t7, 0xee);
   
   _mm256_store_ps( &to[0*ldb], r0); 
   _mm256_store_ps( &to[1*ldb], r1); 
   _mm256_store_ps( &to[2*ldb], r2); 
   _mm256_store_ps( &to[3*ldb], r3); 
   _mm256_store_ps( &to[4*ldb], r4); 
   _mm256_store_ps( &to[5*ldb], r5); 
   _mm256_store_ps( &to[6*ldb], r6); 
   _mm256_store_ps( &to[7*ldb], r7);
 }

to call it, create a wrapper blocking function:


template< int block, int n, int m >
inline void t_ps_load128(float* a, float* b, int lda, int ldb) {
  auto start = std::chrono::high_resolution_clock::now();
  #pragma omp parallel for shared(a, b, lda, ldb) default(none) collapse(2) num_threads(12)
  for(int i=0; i<n; i+=block) {
    for(int j=0; j<m; j+=block) {
      int mk = std::min(i+block, n);
      int ml = std::min(j+block, m);
      for(int k=i; k<mk; k+=8) {
        for(int l=j; l<ml; l+=8) {
          t_load_8x8_ps(&a[k*lda+l], &b[l*ldb+k], lda, ldb);
        }
      }
    }
  }
  auto end = std::chrono::high_resolution_clock::now();
  std::chrono::duration ms_double = end - start;
  std::cout << "transpose runtime: " << ms_double.count() << " ms " << std::endl; 
}