Une version AVX512 de Karatsuba

La version 128 bits de Karatsuba précédemment détaillée (cf. Application à Karatsuba), nécessite de calculer 3 produits de polynômes de degré au plus 63 : \(A_0(x)B_0(x), A_1(x)B_1(x)\) et \((A_1(x)+A_0(x))(B_1(x)+B_0(x))\) . Ces 3 produits sont effectués via 3 appels à la fonction _mm_clmulepi64_si128. On peut réduire ces 3 appels en 1 seul appel de la fonction _mm512_clmulepi64_epi128 (c’est une sous utilisation de cette dernière puisqu’elle permet d’effectuer 4 produits, mais encore une fois, il s’agit ici que d’une simple introduction à l’utilisation de ce jeu d’instructions).

Le code source correspondant est donné ci-après.

Avertissement

Même si vous disposez d’un processeur AVX512, il se peut que le code fourni ne puisse pas être compilé. Il faut en effet que votre processeur dispose de plus de l’extension VPCLMULQDQ.

     % cat /proc/cpuinfo
     flags        : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36
     clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc
     art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq
     pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2
     x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch
     cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept
     vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx
     smap avx512ifma clflushopt intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves
     dtherm arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke
     avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm
     md_clear flush_l1d arch_capabilities

Le code :

#include <stdint.h>
#include <string.h>
#include <stdio.h>
#include <immintrin.h>

/******************************************************************************************/
/*                                                                                        */
/* Programme pour réaliser la multiplication de 2 polynômes de degré au plus 127          */
/* en utilisant la méthode de Karatsuba. Un polynôme est codé sur un registre de 128 bits */
/* le résultat est renvoyé dans 1 registre de 256 bits.                                   */  
/* Le code tire partie de l'extension VPCLMULQDQ  disponible sur certains processeurs     */
/* AVX512                                                                                 */
/*                                                                                        */
/******************************************************************************************/
 
typedef union
{
  __m128i vec128;
  uint64_t vec64[2];
  uint32_t vec32[4];
  uint16_t vec16[8];
  uint8_t vec8[16];
} __reg128;

typedef union
{
  __m256i vec256;
  __m128i vec128[2];
  uint64_t vec64[4];
  uint32_t vec32[8];
  uint16_t vec16[16];
  uint8_t vec8[32];
} __reg256;

typedef union
{
  __m512i vec512;
  __m256i vec256[2];
  __m128i vec128[4];
  uint64_t vec64[8];
  uint32_t vec32[16];
  uint16_t vec16[32];
  uint8_t vec8[64];
} __reg512;

void print128(char *s, int mode, __reg128 A)
{
  switch(mode)
  {
      case 8 :
            printf("%s=(",s);
            for(int i =0; i < 15; i++) printf("'%2x',",A.vec8[i]);
            printf("'%2x')\n",A.vec8[15]);
            break;
      case 16 :
            printf("%s=(",s);
            for(int i =0; i < 7; i++) printf("'%4x',",A.vec16[i]);
            printf("'%4x')\n",A.vec16[7]);
            break;
      case 32 :
            printf("%s=(",s);
            for(int i =0; i < 3; i++) printf("'%8x',",A.vec32[i]);
            printf("'%8x')\n",A.vec32[3]);
            break;
      case 64 :
           printf("%s=('0x%lx','0x%lx')\n",s,A.vec64[0],A.vec64[1]);

  }
}

void karat_mult_F2(__reg256 *C, __reg128 A, __reg128 B) {

__reg512 aux1,aux2;
__reg128 middle;

 // A(x) = A[1]x^(64) + A[0]
 // B(x) = B[1]x^(64) + B[0]
 // C(X) = A[1]B[1]x^(128) + ((A[1]+A[0])(B[1]+B[0])-A[1]B[1]-A[0]B[0])x^(64) + A[0]B[0]
 //      = C[1]x^(128)+C[0]
 //
 //  255                   191                  127                  63                  0  
 //  |--------------------|--------------------|--------------------|--------------------|
 //  |____________________|____________________|____________________|____________________|
 //                     192                  128                   64 
 //  |-----------------------------------------||----------------------------------------|
 //                   A[1]B[1]                                   A[0][B0]                 
 //                        |----------------------------------------|
 //                         (A[1]+A[0])(B[1]+B[0])-A[1]B[1]-A[0]B[0]
 //
 // On initialise les registres aux1 et aux2 en plaçant A0^A1, A1 , A0 dans aux1
 // et B0^B1, B1, B0 dans aux2 de la façon suivante :
 // 
 //  511       447       383       319       255       191       127       63        0  
 //  |---------|---------|---------|---------|---------|---------|---------|---------|
 //  |____0____|____0____|____0____|__A0^A1__|___0_____|___A1____|____0____|___A0____|
 //          448       384       320       256       192       128        64 
 //
 //  511       447       383       319       255       191       127       63        0  
 //  |---------|---------|---------|---------|---------|---------|---------|---------|
 //  |____0____|____0____|____0____|__B0^B1__|___0_____|___B1____|____0____|___B0____|
 //          448       384       320       256       192       128        64 
 // 
aux1.vec512 = _mm512_set_epi64(0,0,0,A.vec64[0]^A.vec64[1], 0, A.vec64[1],0, A.vec64[0]);
aux2.vec512 = _mm512_set_epi64(0,0,0,B.vec64[0]^B.vec64[1], 0, B.vec64[1],0, B.vec64[0]); 
//
// On utilise l'instruction clmulepi64_epi128 avec la constante Imm8 égale à 0
// pour spécifier que les polynômes à multiplier sont dans les 64 bits de poids
// faible de chaque paquet de 128 bits.
// On effectue les 3 multiplications de Karatsuba en une seule instruction.
aux1.vec512 = _mm512_clmulepi64_epi128(aux1.vec512, aux2.vec512, 0);
 // 
 // aux1 est de la forme suivante
 //  511                 383                 255                 127                 0  
 //  |-------------------|-------------------|-------------------|-------------------|
 //  |_________0_________|__(A0^A1)(B0^B1)___|________A1B1_______|________A0B0_______|
 //                    384                 256                 128         
 
 // Calcul de (A[1]+A[0])(B[1]+B[0])-A[1]B[1]-A[0]B[0]
 middle.vec128 = _mm_xor_si128(aux1.vec128[2],aux1.vec128[1]);
 middle.vec128 = _mm_xor_si128(middle.vec128,aux1.vec128[0]);

 // On récupère facilement les 64 bits de poids faible de middle en utilisant la structure 
 // __reg128. Idem pour les 64 bits de poids fort de A0B0 en utilisant la structure __reg512.
 aux1.vec64[1]^=middle.vec64[0];
 C->vec128[0] = aux1.vec128[0];
 // idem pour additionner les 64 bits de poids fort de middle aux 64 bits de poids faible
 // de A1B1.
 aux1.vec64[2]^=middle.vec64[1];
 C->vec128[1] = aux1.vec128[1];        
 }
 
 int main(void)
 {
  __reg128 A, B;
  __reg256 *C = calloc(1,sizeof(__reg256));
  
   // on initialise les registres 128 bits
   // {64 bits poids fort, 64 bits de poids faible}
   A.vec128 = _mm_set_epi64x(0xfffabfffeeffffff,0xffffaa1256ee1234);
   B.vec128 = _mm_set_epi64x(0xbfeefffdffffffff,0xea0d362010800099);

   printf("Données : \n");
   printf("A0=0x%lx\n", A.vec64[0]);
   printf("A1=0x%lx\n", A.vec64[1]);
   printf("B0=0x%lx\n", B.vec64[0]);
   printf("B1=0x%lx\n\n", B.vec64[1]);
   karat_mult_F2(C, A, B);
   printf("Resultat : \n");
   print128("C0",64,(__reg128)(C->vec128[0]));
   print128("C1",64,(__reg128)(C->vec128[1]));
}