#!/usr/bin/env python3 import sys I = 32 vectorunit = 'avx2' if len(sys.argv) > 1: I = int(sys.argv[1]) assert I in (32,64) if len(sys.argv) > 2: vectorunit = sys.argv[2] assert vectorunit in ('neon','sse42','avx2') # XXX: further porting requires various tweaks below # e.g. arm32 does not have vclt_s64 even with neon # e.g. sse2 does not have blendv_epi8 vectorbits = {'neon':128,'sse2':128,'sse42':128,'avx2':256}[vectorunit] vectorwords = vectorbits//I # XXX: do more searches for optimal parameters here if vectorwords == 8: unrolledsort = (8,16),(16,32),(32,64),(64,128) unrolledsortxor = 64,128 unrolledVsort = (8,16),(16,32),(32,64) unrolledVsortxor = 32,64,128 threestages_specialize = 32, pads = (161,191),(193,255) # (449,511) is also interesting favorxor = True elif vectorwords == 4: if I == 32: unrolledsort = (8,16),(16,32) else: unrolledsort = (8,16),(16,32),(32,64) unrolledsortxor = 32,64 unrolledVsort = (8,16),(16,32),(32,64) unrolledVsortxor = 16,32,64 threestages_specialize = 16, pads = (81,95),(97,127) # (225,255) is also interesting favorxor = True elif vectorwords == 2: unrolledsort = (8,16),(16,32), unrolledsortxor = 32, unrolledVsort = (8,16),(16,32),(32,64) unrolledVsortxor = 16,32,64 threestages_specialize = 16, pads = () favorxor = True if vectorunit != 'avx2': pads = () # for simplicity in the absence of speed study sort2_min = min(unrolledsortxor) sort2 = f'sort_2poweratleast{sort2_min}' Vsort2_min = min(unrolledVsortxor) Vsort2 = f'V_sort_2poweratleast{Vsort2_min}' allowmaskload = False # fast on Intel, and on AMD starting with Zen 2, _but_ AMD documentation claims that it can fail allowmaskstore = False # fast on Intel but not on AMD usepartialmem = True # irrelevant unless allowmaskload or allowmaskstore partiallimit = 64 # XXX: could allow 128 for non-mask versions without extra operations assert vectorwords&(vectorwords-1) == 0 intI = f'int{I}' intIvec = f'{intI}x{vectorwords}' int8vec = f'int8x{(I*vectorwords)//8}' int32vec = f'int32x{(I*vectorwords)//32}' def preamble(): print('/* WARNING: auto-generated (by autogen/sort); do not edit */') print('') if vectorunit == 'neon': print('#include ') else: print('#include ') print('') print(fr'''#include "{intI}_sort.h" #define {intI} {intI}_t #define {intI}_largest {hex((1<<(I-1))-1)} #include "crypto_{intI}.h" #define {intI}_min crypto_{intI}_min #define {intI}_MINMAX(a,b) crypto_{intI}_minmax(&(a),&(b)) #define NOINLINE __attribute__((noinline)) ''') if vectorunit == 'neon': print(f'''#include "crypto_int8.h" #define int8 crypto_int8 #define int8_min crypto_int8_min #define int8x16 int8x16_t #include "crypto_uint8.h" #define uint8 crypto_uint8 #define uint8x16 uint8x16_t ''') if I != 32: print(f'''#include "crypto_int32.h" #define int32 crypto_int32 #define int32_min crypto_int32_min #define int32x4 int32x4_t ''') print(f'#define {intIvec} {intIvec}_t') print(f'#define u{intIvec} u{intIvec}_t') # XXX: should also look for ways to use ld2 etc print(fr'''#define {intIvec}_load vld1q_s{I} #define {intIvec}_store vst1q_s{I} #define {intIvec}_ifthenelse vbslq_s{I} ''') if I == 32: print(fr'''#define {intIvec}_smaller_umask vcltq_s{I} #define {intIvec}_min vminq_s{I} #define {intIvec}_max vmaxq_s{I} #define {intIvec}_MINMAX(a,b) \ do {{ \ {intIvec} c = {intIvec}_min(a,b); \ b = {intIvec}_max(a,b); \ a = c; \ }} while(0) ''') else: print(fr'''#define {int32vec}_smaller_umask vcltq_s32 #define {intIvec}_smaller_umask vcltq_s{I} #define {intIvec}_MINMAX(a,b) \ do {{ \ u{intIvec} t = {intIvec}_smaller_umask(a,b); \ {intIvec} c = {intIvec}_ifthenelse(t,a,b); \ b = {intIvec}_ifthenelse(t,b,a); \ a = c; \ }} while(0) ''') # XXX: tweak varextract name to reflect differences in out-of-range handling # XXX: also use tbx for infty etc print(f'''#define {int8vec}_load vld1q_s8 #define {int8vec}_varextract vqtbl1q_s8 #define {int8vec}_add vaddq_s8 #define {int8vec}_sub vsubq_s8 #define {int8vec}_broadcast vdupq_n_s8 #define u{int8vec}_load vld1q_u8 #define u{int8vec}_add vaddq_u8 #define u{int8vec}_sub vsubq_u8 #define u{int8vec}_broadcast vdupq_n_u8 #define {int8vec}_from_{intIvec} vreinterpretq_s8_s{I} #define u{intIvec}_from_{intIvec} vreinterpretq_u{I}_s{I} #define {intIvec}_from_u{intIvec} vreinterpretq_s{I}_u{I} #define {intIvec}_from_{int8vec} vreinterpretq_s{I}_s8 #define {int32vec}_load vld1q_s32 #define {int32vec}_add vaddq_s32 #define {int32vec}_sub vsubq_s32 #define {intIvec}_broadcast vdupq_n_s{I} static inline u{int8vec} u{int8vec}_set(uint8 x0,uint8 x1,uint8 x2,uint8 x3,uint8 x4,uint8 x5,uint8 x6,uint8 x7,uint8 x8,uint8 x9,uint8 x10,uint8 x11,uint8 x12,uint8 x13,uint8 x14,uint8 x15) {{ uint8 x[16] = {{x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15}}; return u{int8vec}_load(x); }} ''') if I == 32: print(f'''static inline {int32vec} {int32vec}_set(int32 x0,int32 x1,int32 x2,int32 x3) {{ int32 x[4] = {{x0,x1,x2,x3}}; return {int32vec}_load(x); }} ''') if I == 64: print(f'''static inline {intIvec} {intIvec}_set(int{I} x0,int{I} x1) {{ int{I} x[2] = {{x0,x1}}; return {intIvec}_load(x); }} ''') if I == 32: # XXX: maybe better to use vswp for 2301 print(f'''#define {intIvec}_1032 vrev64q_s32 #define {intIvec}_2301(v) vextq_s32(v,v,2) #define {intIvec}_3210(v) {intIvec}_1032({intIvec}_2301(v)) #define {intIvec}_a0b0a2b2 vtrn1q_s32 #define {intIvec}_a1b1a3b3 vtrn2q_s32 #define {intIvec}_a0b0a1b1 vzip1q_s32 #define {intIvec}_a2b2a3b3 vzip2q_s32 #define {intIvec}_leftleft(a,b) vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(a),vreinterpretq_s64_s32(b))) #define {intIvec}_rightright(a,b) vreinterpretq_s32_s64(vzip2q_s64(vreinterpretq_s64_s32(a),vreinterpretq_s64_s32(b))) ''') else: print(f'''#define {intIvec}_10(v) vextq_s64(v,v,1) #define {intIvec}_leftleft vzip1q_s64 #define {intIvec}_rightright vzip2q_s64 ''') else: mmprefix = {'sse2':'_mm_','sse42':'_mm_','avx2':'_mm256_'}[vectorunit] print(fr'''typedef __m{vectorbits}i {intIvec}; #define {intIvec}_load(z) {mmprefix}loadu_si{vectorbits}((__m{vectorbits}i *) (z)) #define {intIvec}_store(z,i) {mmprefix}storeu_si{vectorbits}((__m{vectorbits}i *) (z),(i)) #define {intIvec}_smaller_mask(a,b) {mmprefix}cmpgt_epi{I}(b,a) #define {intIvec}_add {mmprefix}add_epi{I} #define {intIvec}_sub {mmprefix}sub_epi{I}''') print(f'#define {int8vec}_iftopthenelse(c,t,e) {mmprefix}blendv_epi8(e,t,c)') if vectorunit == 'sse42': print(fr'''#define {intIvec}_leftleft(a,b) {mmprefix}unpacklo_epi64(a,b) #define {intIvec}_rightright(a,b) {mmprefix}unpackhi_epi64(a,b)''') if I == 32: print(fr'''#define {intIvec}_a0b0a1b1(a,b) {mmprefix}unpacklo_epi32(a,b) #define {intIvec}_a2b2a3b3(a,b) {mmprefix}unpackhi_epi32(a,b)''') if vectorunit == 'avx2': print(fr'''#define {intIvec}_leftleft(a,b) {mmprefix}permute2x128_si256(a,b,0x20) #define {intIvec}_rightright(a,b) {mmprefix}permute2x128_si256(a,b,0x31)''') if I == 32: print(fr''' #define {intIvec}_MINMAX(a,b) \ do {{ \ {intIvec} c = {intIvec}_min(a,b); \ b = {intIvec}_max(a,b); \ a = c; \ }} while(0) ''') if I == 64: print(fr''' #define {intIvec}_MINMAX(a,b) \ do {{ \ {intIvec} t = {intIvec}_smaller_mask(a,b); \ {intIvec} c = {int8vec}_iftopthenelse(t,a,b); \ b = {int8vec}_iftopthenelse(t,b,a); \ a = c; \ }} while(0) ''') if I == 32: print(fr'''#define {intIvec}_min {mmprefix}min_epi{I} #define {intIvec}_max {mmprefix}max_epi{I} #define {intIvec}_set {mmprefix}setr_epi{I} #define {intIvec}_broadcast {mmprefix}set1_epi{I}''') if (I,vectorwords) == (32,4): print(fr'''#define {int8vec}_add {mmprefix}add_epi8 #define {int8vec}_sub {mmprefix}sub_epi8 #define {int8vec}_set {mmprefix}setr_epi8 #define {int8vec}_broadcast {mmprefix}set1_epi8 #define {int8vec}_varextract {mmprefix}shuffle_epi8 #define {int32vec}_add {mmprefix}add_epi32 #define {int32vec}_sub {mmprefix}sub_epi32 #define {int32vec}_set {mmprefix}setr_epi32 #define {int32vec}_broadcast {mmprefix}set1_epi32 #define {intIvec}_extract(v,p0,p1,p2,p3) {mmprefix}shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0)) #define {intIvec}_constextract_ab0ab1ab2ab3(a,b,p0,p1,p2,p3) {mmprefix}castps_si128({mmprefix}blend_ps({mmprefix}castsi128_ps(a),{mmprefix}castsi128_ps(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3))) #define {intIvec}_1032(v) {intIvec}_extract(v,1,0,3,2) #define {intIvec}_2301(v) {intIvec}_extract(v,2,3,0,1) #define {intIvec}_3210(v) {intIvec}_extract(v,3,2,1,0) #include "crypto_int8.h" #define int8_min crypto_int8_min ''') if (I,vectorwords) == (64,2): print(fr'''#define {int8vec}_add {mmprefix}add_epi8 #define {int8vec}_sub {mmprefix}sub_epi8 #define {int8vec}_set {mmprefix}setr_epi8 #define {int8vec}_broadcast {mmprefix}set1_epi8 #define {int8vec}_varextract {mmprefix}shuffle_epi8 #define {int32vec}_add {mmprefix}add_epi32 #define {int32vec}_sub {mmprefix}sub_epi32 #define {int32vec}_set {mmprefix}setr_epi32 #define {int32vec}_broadcast {mmprefix}set1_epi32 #define {intIvec}_extract(v,p0,p1) {mmprefix}shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0))) #define {intIvec}_set(a,b) {mmprefix}set_epi64x(b,a) #define {intIvec}_broadcast {mmprefix}set1_epi64x #define {intIvec}_10(v) {intIvec}_extract(v,1,0) #include "crypto_int8.h" #define int8_min crypto_int8_min #include "crypto_int32.h" #define int32_min crypto_int32_min ''') if (I,vectorwords) == (32,8): print(fr'''#define {intIvec}_varextract {mmprefix}permutevar8x32_epi32 #define {intIvec}_extract(v,p0,p1,p2,p3,p4,p5,p6,p7) {intIvec}_varextract(v,{mmprefix}setr_epi{I}(p0,p1,p2,p3,p4,p5,p6,p7)) #define {intIvec}_constextract_eachside(v,p0,p1,p2,p3) {mmprefix}shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0)) #define {intIvec}_constextract_aabb_eachside(a,b,p0,p1,p2,p3) {mmprefix}castps_si256({mmprefix}shuffle_ps({mmprefix}castsi256_ps(a),{mmprefix}castsi256_ps(b),_MM_SHUFFLE(p3,p2,p1,p0))) #define {intIvec}_ifconstthenelse(c0,c1,c2,c3,c4,c5,c6,c7,t,e) {mmprefix}blend_epi{I}(e,t,(c0)|((c1)<<1)|((c2)<<2)|((c3)<<3)|((c4)<<4)|((c5)<<5)|((c6)<<6)|((c7)<<7)) ''') if (I,vectorwords) == (64,4): print(fr'''#define {int32vec}_add {mmprefix}add_epi32 #define {int32vec}_sub {mmprefix}sub_epi32 #define {int32vec}_set {mmprefix}setr_epi32 #define {int32vec}_broadcast {mmprefix}set1_epi32 #define {int32vec}_varextract {mmprefix}permutevar8x32_epi32 #define {intIvec}_set {mmprefix}setr_epi64x #define {intIvec}_broadcast {mmprefix}set1_epi64x #define {intIvec}_extract(v,p0,p1,p2,p3) {mmprefix}permute4x64_epi64(v,_MM_SHUFFLE(p3,p2,p1,p0)) #define {intIvec}_constextract_eachside(v,p0,p1) {mmprefix}shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0))) #define {intIvec}_constextract_a01b01a23b23(a,b,p0,p1,p2,p3) {mmprefix}castpd_si256({mmprefix}shuffle_pd({mmprefix}castsi256_pd(a),{mmprefix}castsi256_pd(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3))) #define {intIvec}_ifconstthenelse(c0,c1,c2,c3,t,e) {mmprefix}blend_epi32(e,t,(c0)|((c0)<<1)|((c1)<<2)|((c1)<<3)|((c2)<<4)|((c2)<<5)|((c3)<<6)|((c3)<<7)) #define {intIvec}_1032(v) {intIvec}_extract(v,1,0,3,2) #define {intIvec}_2301(v) {intIvec}_extract(v,2,3,0,1) #define {intIvec}_3210(v) {intIvec}_extract(v,3,2,1,0) #include "crypto_int32.h" #define int32_min crypto_int32_min ''') # XXX: can skip some of the macros above if allowmaskload and allowmaskstore if usepartialmem and (allowmaskload or allowmaskstore): print(fr'''#define partialmem (partialmem_storage+64) static const {intI} partialmem_storage[] __attribute__((aligned(128))) = {{ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, }} ; #define {intIvec}_partialload(p,z) {mmprefix}maskload_epi{I}((void *) (z),p) #define {intIvec}_partialstore(p,z,i) {mmprefix}maskstore_epi{I}((void *) (z),p,(i)) ''') # ===== serial fixed-size networks def smallsize(): print(fr'''NOINLINE static void {intI}_sort_3through7({intI} *x,long long n) {{ if (n >= 4) {{ {intI} x0 = x[0]; {intI} x1 = x[1]; {intI} x2 = x[2]; {intI} x3 = x[3]; {intI}_MINMAX(x0,x1); {intI}_MINMAX(x2,x3); {intI}_MINMAX(x0,x2); {intI}_MINMAX(x1,x3); {intI}_MINMAX(x1,x2); if (n >= 5) {{ if (n == 5) {{ {intI} x4 = x[4]; {intI}_MINMAX(x0,x4); {intI}_MINMAX(x2,x4); {intI}_MINMAX(x1,x2); {intI}_MINMAX(x3,x4); x[4] = x4; }} else {{ {intI} x4 = x[4]; {intI} x5 = x[5]; {intI}_MINMAX(x4,x5); if (n == 6) {{ {intI}_MINMAX(x0,x4); {intI}_MINMAX(x2,x4); {intI}_MINMAX(x1,x5); {intI}_MINMAX(x3,x5); }} else {{ {intI} x6 = x[6]; {intI}_MINMAX(x4,x6); {intI}_MINMAX(x5,x6); {intI}_MINMAX(x0,x4); {intI}_MINMAX(x2,x6); {intI}_MINMAX(x2,x4); {intI}_MINMAX(x1,x5); {intI}_MINMAX(x3,x5); {intI}_MINMAX(x5,x6); x[6] = x6; }} {intI}_MINMAX(x1,x2); {intI}_MINMAX(x3,x4); x[4] = x4; x[5] = x5; }} }} x[0] = x0; x[1] = x1; x[2] = x2; x[3] = x3; }} else {{ {intI} x0 = x[0]; {intI} x1 = x[1]; {intI} x2 = x[2]; {intI}_MINMAX(x0,x1); {intI}_MINMAX(x0,x2); {intI}_MINMAX(x1,x2); x[0] = x0; x[1] = x1; x[2] = x2; }} }} ''') # ===== vectorized fixed-size networks class vectorsortingnetwork: def __init__(self,layout,xor=True): layout = [list(x) for x in layout] for x in layout: assert len(x) == vectorwords assert len(layout)%2 == 0 # XXX: drop this restriction? self.layout = layout self.phys = [f'x{r}' for r in range(len(layout))] # arch register r is stored in self.phys[r] self.operations = [] self.usedregs = set(self.phys) self.usedregssmallindex = set() self.usexor = xor self.useinfty = False if xor: self.assign('vecxor',f'{intIvec}_broadcast(xor)') def print(self): print('{') if len(self.usedregssmallindex) > 0: print(f' int32_t {",".join(sorted(self.usedregssmallindex))};') if len(self.usedregs) > 0: print(f' {intIvec} {",".join(sorted(self.usedregs))};') for op in self.operations: if op[0] == 'assign': phys,result,newlayout = op[1:] if newlayout is None: print(f' {phys} = {result};') else: print(f' {phys} = {result}; // {" ".join(map(str,newlayout))}') elif op[0] == 'assignsmallindex': phys,result = op[1:] print(f' {phys} = {result};') elif op[0] == 'store': result, = op[1:] print(f' {result};') elif op[0] == 'comment': c, = op[1:] print(f' // {c}') else: raise Exception(f'unrecognized operation {op}') print('}') print('') def allocate(self,r): '''Assign a new free physical register to arch register r. OK for caller to also use old register until calling allocate again.''' self.phys[r] = {'x':'y','y':'x'}[self.phys[r][0]]+self.phys[r][1:] # XXX: for generating low-level asm would instead want to cycle between fewer regs # but, for generating C (or qhasm), trusting register allocator makes generated assignments a bit more readable # (at least for V_sort) def comment(self,c): self.operations.append(('comment',c)) def assign(self,phys,result,newlayout=None): self.usedregs.add(phys) self.operations.append(('assign',phys,result,newlayout)) def assignsmallindex(self,phys,result): self.usedregssmallindex.add(phys) self.operations.append(('assignsmallindex',phys,result)) def createinfty(self): if self.useinfty: return self.useinfty = True self.assign('infty',f'{intIvec}_broadcast({intI}_largest)') def partialmask(self,offset): if usepartialmem: return f'{intIvec}_load(&partialmem[{offset}-n])' # XXX: also do version with offset in constants rather than n? rangevectorwords = ','.join(map(str,range(vectorwords))) return f'{intIvec}_smaller_mask({intIvec}_set({rangevectorwords}),{intIvec}_broadcast(n-{offset}))' def load(self,r,partial=False): rphys = self.phys[r] xor = 'vecxor^' if self.usexor else '' if not partial: self.assign(rphys,f'{xor}{intIvec}_load(x+{vectorwords*r})',self.layout[r]) return # uint instead of int would slightly streamline infty usage self.createinfty() if allowmaskload: self.assign(f'partial{r}',f'{self.partialmask(r*vectorwords)}') self.assign(rphys,f'{xor}{intIvec}_partialload(partial{r},x+{vectorwords*r})',self.layout[r]) self.assign(rphys,f'{int8vec}_iftopthenelse(partial{r},{rphys},infty)') return self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorwords},n)') xdata = f'{intIvec}_load(x+pos{r}-{vectorwords})' if vectorbits == 128: if vectorunit == 'neon': mplus = ','.join(map(str,range(0,vectorbits//8))) rotated = f'{intIvec}_from_{int8vec}({int8vec}_varextract({int8vec}_from_{intIvec}({xdata}),u{int8vec}_add(u{int8vec}_set({mplus}),u{int8vec}_broadcast({I//8}*((-pos{r})&{(vectorbits//I-1)})))))' else: mplus = ','.join(map(str,range(vectorbits//8,2*vectorbits//8))) rotated = f'{int8vec}_varextract({xdata},{int8vec}_sub({int8vec}_set({mplus}),{int8vec}_broadcast({I//8}*(pos{r}&{(vectorbits//I-1)}))))' mplus = ','.join(map(str,range(r*vectorwords,(r+1)*vectorwords))) if vectorunit == 'neon': control = f'{intIvec}_smaller_umask({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))' else: control = f'{intIvec}_smaller_mask({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))' elif (I,vectorwords) == (32,8): mplus = ','.join(map(str,range(r*vectorwords,(r+1)*vectorwords))) self.assign(f'diff{r}',f'{intIvec}_sub({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))') rotated = f'{intIvec}_varextract({xdata},diff{r})' control = f'diff{r}' elif (I,vectorwords) == (64,4): mplus = ','.join(map(str,range(2*r*vectorwords,2*(r+1)*vectorwords))) self.assign(f'diff{r}',f'{int32vec}_sub({int32vec}_set({mplus}),{int32vec}_broadcast(2*pos{r}))') rotated = f'{int32vec}_varextract({xdata},diff{r})' control = f'diff{r}' else: raise Exception('unhandled partial load') if vectorunit == 'neon': self.assign(rphys,f'{intIvec}_ifthenelse({control},{rotated},infty)',self.layout[r]) else: self.assign(rphys,f'{int8vec}_iftopthenelse({control},{rotated},infty)',self.layout[r]) # warning: must do store in reverse order def store(self,r,partial=False): rphys = self.phys[r] xor = 'vecxor^' if self.usexor else '' if not partial: self.operations.append(('store',f'{intIvec}_store(x+{vectorwords*r},{xor}{rphys})')) return if allowmaskstore: if not allowmaskload: self.assign(f'partial{r}',f'{self.partialmask(r*vectorwords)}') self.operations.append(('store',f'{intIvec}_partialstore(partial{r},x+{vectorwords*r},{xor}{rphys})')) return if allowmaskload: self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorwords},n)') # this is why store has to be in reverse order if vectorbits == 128: if vectorunit == 'neon': mplus = ','.join(map(str,range(0,vectorbits//8))) storeshift = f'u{int8vec}_sub(u{int8vec}_set({mplus}),u{int8vec}_broadcast({I//8}*((-pos{r})&{(vectorbits//I-1)})))' xdata = f'{intIvec}_from_{int8vec}({int8vec}_varextract({int8vec}_from_{intIvec}({xor}{rphys}),{storeshift}))' else: mplus = ','.join(map(str,range(vectorbits//8,2*vectorbits//8))) storeshift = f'{int8vec}_add({int8vec}_set({mplus}),{int8vec}_broadcast({I//8}*(pos{r}&{(vectorbits//I-1)})))' xdata = f'{int8vec}_varextract({xor}{rphys},{storeshift})' elif (I,vectorwords) == (32,8): mplus = ','.join(map(str,range(vectorwords))) xdata = f'{intIvec}_varextract({xor}{rphys},{intIvec}_add({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r})))' elif (I,vectorwords) == (64,4): mplus = ','.join(map(str,range(2*vectorwords))) xdata = f'{int32vec}_varextract({xor}{rphys},{int32vec}_add({int32vec}_set({mplus}),{int32vec}_broadcast(2*pos{r})))' else: raise Exception('unhandled partial store') self.operations.append(('store',f'{intIvec}_store(x+pos{r}-{vectorwords},{xdata})')) def vecswap(self,r,s): assert r != s self.phys[r],self.phys[s] = self.phys[s],self.phys[r] self.layout[r],self.layout[s] = self.layout[s],self.layout[r] def minmax(self,r,s): assert r != s rphys = self.phys[r] sphys = self.phys[s] newlayout0 = [min(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])] newlayout1 = [max(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])] self.layout[r],self.layout[s] = newlayout0,newlayout1 self.allocate(r) rphysnew = self.phys[r] if I == 32: self.assign(rphysnew,f'{intIvec}_min({rphys},{sphys})',newlayout0) self.assign(sphys,f'{intIvec}_max({rphys},{sphys})',newlayout1) else: if vectorunit == 'neon': self.assign('t',f'{intIvec}_from_u{intIvec}({intIvec}_smaller_umask({rphys},{sphys}))') self.assign(rphysnew,f'{intIvec}_ifthenelse(u{intIvec}_from_{intIvec}(t),{rphys},{sphys})',newlayout0) self.assign(sphys,f'{intIvec}_ifthenelse(u{intIvec}_from_{intIvec}(t),{sphys},{rphys})',newlayout1) else: # XXX: avx-512 has min_epi64 but avx2 does not self.assign('t',f'{intIvec}_smaller_mask({rphys},{sphys})') self.assign(rphysnew,f'{int8vec}_iftopthenelse(t,{rphys},{sphys})',newlayout0) self.assign(sphys,f'{int8vec}_iftopthenelse(t,{sphys},{rphys})',newlayout1) def shuffle1(self,r,L): r'''Rearrange layout of vector r to match L.''' L = list(L) oldL = self.layout[r] perm = [oldL.index(a) for a in L] rphys = self.phys[r] if vectorwords == 2 and perm == [1,0]: self.assign(rphys,f'{intIvec}_10({rphys})',L) elif vectorwords == 4 and perm == [1,0,3,2]: self.assign(rphys,f'{intIvec}_1032({rphys})',L) elif vectorwords == 4 and perm == [2,3,0,1]: self.assign(rphys,f'{intIvec}_2301({rphys})',L) elif vectorwords == 4 and perm == [3,2,1,0]: self.assign(rphys,f'{intIvec}_3210({rphys})',L) elif (I,vectorwords) == (32,8) and perm[4:] == [perm[0]+4,perm[1]+4,perm[2]+4,perm[3]+4]: self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L) elif (I,vectorwords) == (64,4) and perm[2:] == [perm[0]+2,perm[1]+2]: self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]})',L) elif vectorwords == 8 and vectorunit != 'neon': self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]},{perm[4]},{perm[5]},{perm[6]},{perm[7]})',L) elif vectorwords == 4 and vectorunit != 'neon': self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L) elif vectorwords == 2 and vectorunit != 'neon': self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]})',L) else: raise Exception(f'unhandled permutation from {oldL} to {L}') self.layout[r] = L def shuffle2(self,r,s,L,M,exact=False): oldL = self.layout[r] oldM = self.layout[s] rphys = self.phys[r] sphys = self.phys[s] try: assert (I,vectorwords) == (64,4) p0 = oldL[:2].index(L[0]) p1 = oldM[:2].index(L[1]) p2 = oldL[2:].index(L[2]) p3 = oldM[2:].index(L[3]) q0 = oldL[:2].index(M[0]) q1 = oldM[:2].index(M[1]) q2 = oldL[2:].index(M[2]) q3 = oldM[2:].index(M[3]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{p0},{p1},{p2},{p3})',L) self.assign(sphys,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{q0},{q1},{q2},{q3})',M) return except: pass try: assert (I,vectorwords) == (32,8) p0 = oldL[:4].index(L[0]) p1 = oldL[:4].index(L[1]) p2 = oldM[:4].index(L[2]) p3 = oldM[:4].index(L[3]) assert p0 == oldL[4:].index(L[4]) assert p1 == oldL[4:].index(L[5]) assert p2 == oldM[4:].index(L[6]) assert p3 == oldM[4:].index(L[7]) q0 = oldL[:4].index(M[0]) q1 = oldL[:4].index(M[1]) q2 = oldM[:4].index(M[2]) q3 = oldM[:4].index(M[3]) assert q0 == oldL[4:].index(M[4]) assert q1 == oldL[4:].index(M[5]) assert q2 == oldM[4:].index(M[6]) assert q3 == oldM[4:].index(M[7]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',L) self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',M) return except: pass try: assert not exact assert (I,vectorwords) == (32,8) p0 = oldL[:4].index(L[0]) p1 = oldL[:4].index(L[2]) p2 = oldM[:4].index(L[1]) p3 = oldM[:4].index(L[3]) assert p0 == oldL[4:].index(L[4]) assert p1 == oldL[4:].index(L[6]) assert p2 == oldM[4:].index(L[5]) assert p3 == oldM[4:].index(L[7]) q0 = oldL[:4].index(M[0]) q1 = oldL[:4].index(M[2]) q2 = oldM[:4].index(M[1]) q3 = oldM[:4].index(M[3]) assert q0 == oldL[4:].index(M[4]) assert q1 == oldL[4:].index(M[6]) assert q2 == oldM[4:].index(M[5]) assert q3 == oldM[4:].index(M[7]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = [L[0],L[2],L[1],L[3],L[4],L[6],L[5],L[7]] self.layout[s] = [M[0],M[2],M[1],M[3],M[4],M[6],M[5],M[7]] self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',self.layout[r]) self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',self.layout[s]) return except: pass try: half = vectorwords//2 assert oldL[:half] == L[:half] assert oldL[half:] == M[:half] assert oldM[:half] == L[half:] assert oldM[half:] == M[half:] self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L) self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M) return except: pass try: assert oldL == [L[2],L[0],M[2],M[0]] assert oldM == [L[3],L[1],M[3],M[1]] self.assign('t',f'{intIvec}_a0b0a1b1({rphys},{sphys})',[L[2],L[3],L[0],L[1]]) self.assign('u',f'{intIvec}_a2b2a3b3({rphys},{sphys})',[M[2],M[3],M[0],M[1]]) self.layout[r] = L self.layout[s] = M self.assign(rphys,f'{intIvec}_2301(t)',L) self.assign(sphys,f'{intIvec}_2301(u)',M) return except: pass try: assert vectorunit == 'neon' assert oldL == [L[0],M[0],L[2],M[2]] assert oldM == [L[1],M[1],L[3],M[3]] self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_a0b0a2b2({rphys},{sphys})',L) self.assign(sphys,f'{intIvec}_a1b1a3b3({rphys},{sphys})',M) return except: pass try: assert vectorunit == 'neon' assert oldL == [L[0],L[2],M[0],M[2]] assert oldM == [L[1],L[3],M[1],M[3]] self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_a0b0a1b1({rphys},{sphys})',L) self.assign(sphys,f'{intIvec}_a2b2a3b3({rphys},{sphys})',M) return except: pass try: assert vectorunit != 'neon' assert not exact # XXX: generalize this assert oldL == [L[0],M[0],L[2],M[2]] assert oldM == [L[1],M[1],L[3],M[3]] self.assign('t',f'{intIvec}_1032({rphys})',[M[0],L[0],M[2],L[2]]) self.layout[r] = [L[1],L[0],L[3],L[2]] self.layout[s] = M self.assign(rphys,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},1,0,1,0)',self.layout[r]) self.assign(sphys,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},0,1,0,1)',self.layout[s]) return except: pass try: assert vectorunit != 'neon' # XXX: generalize this assert oldL == [M[0],L[0],M[2],L[2]] assert oldM == [M[1],L[1],M[3],L[3]] self.assign('t',f'{intIvec}_1032({rphys})',[L[0],M[0],L[2],M[2]]) self.assign('u',f'{intIvec}_1032({sphys})',[L[1],M[1],L[2],M[3]]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},0,1,0,1)',L) self.assign(sphys,f'{intIvec}_constextract_ab0ab1ab2ab3(u,{rphys},1,0,1,0)',M) return except: pass for bigflip in False,True: try: assert vectorbits == 256 if bigflip: half = vectorwords//2 z0 = L[:half]+M[:half] z1 = L[half:]+M[half:] else: z0 = L z1 = M blend = [] shuf0 = [None]*vectorwords shuf1 = [None]*vectorwords for i in range(vectorwords): if oldL[i] == z0[i]: blend.append(0) shuf0[i] = oldL.index(z1[i]) shuf1[i] = 0 if i < vectorwords//2: assert shuf0[i] < vectorwords//2 else: assert shuf0[i] == shuf0[i-vectorwords//2]+vectorwords//2 else: blend.append(1) assert oldM[i] == z1[i] shuf0[i] = 0 shuf1[i] = oldM.index(z0[i]) if i < vectorwords//2: assert shuf1[i] < vectorwords//2 else: assert shuf1[i] == shuf1[i-vectorwords//2]+vectorwords//2 blend = ','.join(map(str,blend)) s0 = ','.join(map(str,shuf0[:vectorwords//2])) s1 = ','.join(map(str,shuf1[:vectorwords//2])) # XXX: encapsulate temporaries better self.assign('u',f'{intIvec}_constextract_eachside({sphys},{s1})') self.assign('t',f'{intIvec}_constextract_eachside({rphys},{s0})') self.assign(rphys,f'{intIvec}_ifconstthenelse({blend},u,{rphys})',L) self.assign(sphys,f'{intIvec}_ifconstthenelse({blend},{sphys},t)',M) if bigflip: self.allocate(r) rphysnew = self.phys[r] self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L) self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M) self.layout[r] = L self.layout[s] = M return except: pass raise Exception(f'unhandled permutation from {oldL},{oldM} to {L},{M}') def rearrange_onestep(self,comparators): numvectors = len(self.layout) for k in range(0,numvectors,2): if all(comparators[a] == a for a in self.layout[k]): if all(comparators[a] == a for a in self.layout[k+1]): continue collected = set(self.layout[k]+self.layout[k+1]) if not all(comparators[a] in collected for a in collected): for j in range(numvectors): if all(comparators[a] in self.layout[j] for a in self.layout[k]): self.vecswap(j,k+1) return True if any(comparators[a] in self.layout[k] for a in self.layout[k]): newlayout0 = list(self.layout[k]) newlayout1 = list(self.layout[k+1]) for i in range(vectorwords): a = newlayout0[i] b = comparators[a] if b == newlayout1[i]: continue if b in newlayout0: j = newlayout0.index(b) assert j > i newlayout1[i],newlayout0[j] = newlayout0[j],newlayout1[i] else: j = newlayout1.index(b) assert j > i newlayout1[i],newlayout1[j] = newlayout1[j],newlayout1[i] self.shuffle2(k,k+1,newlayout0,newlayout1) return True if [comparators[a] for a in self.layout[k]] != self.layout[k+1]: newlayout = [comparators[a] for a in self.layout[k]] self.shuffle1(k+1,newlayout) return True return False def rearrange(self,comparators): comparators = dict(comparators) while self.rearrange_onestep(comparators): pass def fixedsize(nlow,nhigh,xor=False,V=False): nlow = int(nlow) nhigh = int(nhigh) assert 0 <= nlow assert nlow <= nhigh assert nhigh-partiallimit <= nlow assert nhigh%vectorwords == 0 assert nlow%vectorwords == 0 lgn = 4 while 2**lgn < nhigh: lgn += 1 if nhigh < 2*vectorwords: raise Exception(f'unable to handle sizes below {2*vectorwords}') if V: funname = f'{intI}_V_sort_' else: funname = f'{intI}_sort_' funargs = f'{intI} *x' if nlow < nhigh: funname += f'{nlow}through' funargs += ',long long n' funname += f'{nhigh}' if xor: funname += '_xor' funargs += f',{intI} xor' print('NOINLINE') print(f'static void {funname}({funargs})') # ===== decide on initial layout of nodes numvectors = nhigh//vectorwords layout = {} if V: for k in range(numvectors//2): layout[k] = list(reversed(range(vectorwords*(numvectors//2-1-k),vectorwords*(numvectors//2-k)))) for k in range(numvectors//2,numvectors): layout[k] = list(range(vectorwords*k,vectorwords*(k+1))) else: for k in range(0,numvectors,nhigh//vectorwords): for offset in range(nhigh//vectorwords): layout[k+offset] = list(range(vectorwords*k+offset,vectorwords*(k+nhigh//vectorwords),nhigh//vectorwords)) layout = [layout[k] for k in range(len(layout))] # ===== build network S = vectorsortingnetwork(layout,xor=xor) for k in range(numvectors): S.load(k,partial=k*vectorwords>=nlow) for lgsubsort in range(1,lgn+1): if V and lgsubsort < lgn: continue for stage in reversed(range(lgsubsort)): if nhigh >= 2*vectorwords and (lgsubsort,stage) == (1,0): comparators = {a:a^1 for a in range(nhigh)} elif nhigh >= 4*vectorwords and (lgsubsort,stage) == (2,1): comparators = {a:a^2 for a in range(nhigh)} elif nhigh >= 4*vectorwords and (lgsubsort,stage) == (2,0): comparators = {a:a^(3*(1&((a>>0)^(a>>1)))) for a in range(nhigh)} elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,2): comparators = {a:a^4 for a in range(nhigh)} elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,1): comparators = {a:a+[0,0,2,2,-2,-2,0,0][a%8] for a in range(nhigh)} elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,0): comparators = {a:a+[0,1,-1,1,-1,1,-1,0][a%8] for a in range(nhigh)} elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,3): comparators = {a:a^8 for a in range(nhigh)} elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,2): comparators = {a:a+[0,0,0,0,4,4,4,4,-4,-4,-4,-4,0,0,0,0][a%16] for a in range(nhigh)} elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,1): comparators = {a:a+[0,0,2,2,-2,-2,2,2,-2,-2,2,2,-2,-2,0,0][a%16] for a in range(nhigh)} elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,0): comparators = {a:a+[0,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,0][a%16] for a in range(nhigh)} else: if stage == lgsubsort-1: stagemask = (2<>2)^(k>>1))&1)*6)] for k in range(vectorwords)] # XXX S.shuffle1(k,newlayout) strcomparators = ' '.join(f'{a}:{comparators[a]}' for a in range(nhigh) if comparators[a] > a) S.comment(f'stage ({lgsubsort},{stage}) {strcomparators}') S.rearrange(comparators) for k in range(0,numvectors,2): if all(comparators[a] == a for a in S.layout[k]): if all(comparators[a] == a for a in S.layout[k+1]): continue S.minmax(k,k+1) for k in range(0,numvectors,2): for offset in 0,1: for i in range(numvectors): if k*vectorwords+offset in S.layout[i]: if i != k+offset: S.vecswap(i,k+offset) break for k in range(0,numvectors,2): y0 = list(range(k*vectorwords,(k+1)*vectorwords)) y1 = list(range((k+1)*vectorwords,(k+2)*vectorwords)) S.shuffle2(k,k+1,y0,y1,exact=True) for k in range(numvectors) if allowmaskstore or nlow == nhigh else reversed(range(numvectors)): S.store(k,partial=k*vectorwords>=nlow) S.print() # ===== V_sort def threestages(k,down=False,p=None,atleast=None): assert k in (4,5,6,7,8) print('NOINLINE') updown = 'down' if down else 'up' if p is not None: print(f'static void {intI}_threestages_{k}_{updown}_{p}({intI} *x)') elif atleast is not None: print(f'static void {intI}_threestages_{k}_{updown}_atleast{atleast}({intI} *x,long long p)') else: print(f'static void {intI}_threestages_{k}_{updown}({intI} *x,long long p,long long n)') print('{') print(' long long i;') if p is not None: print(f' long long p = {p};') if p is not None or atleast is not None: print(f' long long n = p;') for vector in True,False: # must be this order if p is not None and p%vectorwords == 0 and not vector: break if atleast is not None and atleast%vectorwords == 0 and not vector: break xtype = intIvec if vector else intI if vector: print(f' for (i = 0;i+{vectorwords} <= n;i += {vectorwords}) {{') else: print(' for (;i < n;++i) {') for j in range(k): addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i' if vector: print(f' {xtype} x{j} = {xtype}_load(&x[{addr}]);') else: print(f' {xtype} x{j} = x[{addr}];') for i,j in (0,4),(1,5),(2,6),(3,7),(0,2),(1,3),(4,6),(5,7),(0,1),(2,3),(4,5),(6,7): if j >= k: continue if down: i,j = j,i print(f' {xtype}_MINMAX(x{i},x{j});') for j in range(k): addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i' if vector: print(f' {xtype}_store(&x[{addr}],x{j});') else: print(f' x[{addr}] = x{j};') print(' }') print('}') print('') def V_sort(): for nlow,nhigh in unrolledVsort: fixedsize(nlow,nhigh,V=True) for n in unrolledVsortxor: fixedsize(n,n,xor=True,V=True) threestages(8) threestages(7) threestages(6) threestages(5) threestages(4) threestages_min = Vsort2_min for p in threestages_specialize: assert p == threestages_min threestages_min *= 2 threestages(8,p=p) threestages(8,down=True,p=p) threestages(8,down=True,atleast=threestages_min) threestages(6,down=True) print(f'''// XXX: currently xor must be 0 or -1 NOINLINE static void {intI}_{Vsort2}_xor({intI} *x,long long n,{intI} xor) {{''') for n in unrolledVsortxor: print(f' if (n == {n}) {{ {intI}_V_sort_{n}_xor(x,xor); return; }}') assert unrolledVsortxor[:3] == (Vsort2_min,Vsort2_min*2,Vsort2_min*4) # so n is at least Vsort2_min*8, justifying the following recursive calls for p in threestages_specialize: assert p in unrolledVsortxor print(f''' if (n == {p*8}) {{ if (xor) {intI}_threestages_8_down_{p}(x); else {intI}_threestages_8_up_{p}(x); for (long long i = 0;i < 8;++i) {intI}_V_sort_{p}_xor(x+{p}*i,xor); return; }}''') print(f''' if (xor) {intI}_threestages_8_down_atleast{threestages_min}(x,n>>3); else {intI}_threestages_8_up(x,n>>3,n>>3); for (long long i = 0;i < 8;++i) {intI}_{Vsort2}_xor(x+(n>>3)*i,n>>3,xor); }} /* q is power of 2; want only merge stages q,q/2,q/4,...,1 */ // XXX: assuming 8 <= q < n <= 2q; q is a power of 2 NOINLINE static void {intI}_V_sort({intI} *x,long long q,long long n) {{''') assert any(nhigh == Vsort2_min for nlow,nhigh in unrolledVsort) for nlow,nhigh in unrolledVsort: if nhigh == Vsort2_min and favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''') print(f' if (n <= {nhigh}) {{ {intI}_V_sort_{nlow}through{nhigh}(x,n); return; }}') if not favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''') assert any(nhigh == 64 for nlow,nhigh in unrolledVsort) print(f''' // 64 <= q < n < 2q q >>= 2; // 64 <= 4q < n < 8q if (7*q < n) {{ {intI}_threestages_8_up(x,q,n-7*q); {intI}_threestages_7_up(x+n-7*q,q,8*q-n); }} else if (6*q < n) {{ {intI}_threestages_7_up(x,q,n-6*q); {intI}_threestages_6_up(x+n-6*q,q,7*q-n); }} else if (5*q < n) {{ {intI}_threestages_6_up(x,q,n-5*q); {intI}_threestages_5_up(x+n-5*q,q,6*q-n); }} else {{ {intI}_threestages_5_up(x,q,n-4*q); {intI}_threestages_4_up(x+n-4*q,q,5*q-n); }} // now want to handle each batch of q entries separately {intI}_V_sort(x,q>>1,q); {intI}_V_sort(x+q,q>>1,q); {intI}_V_sort(x+2*q,q>>1,q); {intI}_V_sort(x+3*q,q>>1,q); x += 4*q; n -= 4*q; while (n >= q) {{ {intI}_V_sort(x,q>>1,q); x += q; n -= q; }} // have n entries left in last batch, with 0 <= n < q if (n <= 1) return; while (q >= n) q >>= 1; // empty merge stage // now 1 <= q < n <= 2q if (q >= 8) {{ {intI}_V_sort(x,q,n); return; }} if (n == 8) {{ {intI}_MINMAX(x[0],x[4]); {intI}_MINMAX(x[1],x[5]); {intI}_MINMAX(x[2],x[6]); {intI}_MINMAX(x[3],x[7]); {intI}_MINMAX(x[0],x[2]); {intI}_MINMAX(x[1],x[3]); {intI}_MINMAX(x[0],x[1]); {intI}_MINMAX(x[2],x[3]); {intI}_MINMAX(x[4],x[6]); {intI}_MINMAX(x[5],x[7]); {intI}_MINMAX(x[4],x[5]); {intI}_MINMAX(x[6],x[7]); return; }} if (4 <= n) {{ for (long long i = 0;i < n-4;++i) {intI}_MINMAX(x[i],x[4+i]); {intI}_MINMAX(x[0],x[2]); {intI}_MINMAX(x[1],x[3]); {intI}_MINMAX(x[0],x[1]); {intI}_MINMAX(x[2],x[3]); n -= 4; x += 4; }} if (3 <= n) {intI}_MINMAX(x[0],x[2]); if (2 <= n) {intI}_MINMAX(x[0],x[1]); }} ''') # ===== main sort def main_sort_prep(): smallsize() for nlow,nhigh in unrolledsort: fixedsize(nlow,nhigh) for n in unrolledsortxor: fixedsize(n,n,xor=True) def main_sort(): print('// XXX: currently xor must be 0 or -1') print('NOINLINE') print(f'static void {intI}_{sort2}_xor({intI} *x,long long n,{intI} xor)') print('{') for n in unrolledsortxor: print(f' if (n == {n}) {{ {intI}_sort_{n}_xor(x,xor); return; }}') print(f' {intI}_{sort2}_xor(x,n>>1,~xor);') print(f' {intI}_{sort2}_xor(x+(n>>1),n>>1,xor);') print(f' {intI}_{Vsort2}_xor(x,n,xor);') print('}') print(fr''' void {intI}_sort({intI} *x,long long n) {{ long long q; if (n <= 1) return; if (n == 2) {{ {intI}_MINMAX(x[0],x[1]); return; }} if (n <= 7) {{ {intI}_sort_3through7(x,n); return; }}''') # XXX: n cutoff here should be another variable to optimize nmin = 8 # invariant: n in program is at least nmin assert any(nhigh == sort2_min for nlow,nhigh in unrolledsort) for nlow,nhigh in unrolledsort: if nhigh == sort2_min and favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''') if nlow <= nmin: print(f' if (n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}') while nlow <= nmin and nmin <= nhigh: nmin += 1 else: print(f' if ({nlow} <= n && n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}') if not favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''') qmin = 1 while qmin < nmin-qmin: qmin += qmin assert sort2_min <= qmin for padlow,padhigh in pads: padlowdown = padlow while padlowdown%vectorwords: padlowdown -= 1 padhighup = padhigh while padhighup%vectorwords: padhighup += 1 print(fr''' if ({padlow} <= n && n <= {padhigh}) {{ {intI} buf[{padhighup}]; for (long long i = {padlowdown};i < {padhighup};++i) buf[i] = {intI}_largest; for (long long i = 0;i < n;++i) buf[i] = x[i]; {intI}_sort(buf,{padhighup}); for (long long i = 0;i < n;++i) x[i] = buf[i]; return; }}''') assert sort2_min%2 == 0 assert sort2_min//2 >= Vsort2_min print(fr''' q = {qmin}; while (q < n - q) q += q; // {qmin} <= q < n < 2q if ({sort2_min*16} <= n && n <= (7*q)>>2) {{ long long m = (3*q)>>2; // strategy: sort m, sort n-m, merge long long r = q>>3; // at least {sort2_min} since q is at least {sort2_min*8} {intI}_{sort2}_xor(x,4*r,0); {intI}_{sort2}_xor(x+4*r,r,0); {intI}_{sort2}_xor(x+5*r,r,-1); {intI}_{Vsort2}_xor(x+4*r,2*r,-1); {intI}_threestages_6_down(x,r,r); for (long long i = 0;i < 6;++i) {intI}_{Vsort2}_xor(x+i*r,r,-1); {intI}_sort(x+m,n-m); }} else if ({sort2_min*2} <= q && n == (3*q)>>1) {{ // strategy: sort q, sort q/2, merge long long r = q>>2; // at least {sort2_min//2} since q is at least {sort2_min*2} {intI}_{sort2}_xor(x,4*r,-1); {intI}_{sort2}_xor(x+4*r,2*r,0); {intI}_threestages_6_up(x,r,r); for (long long i = 0;i < 6;++i) {intI}_{Vsort2}_xor(x+i*r,r,0); return; }} else {{ {intI}_{sort2}_xor(x,q,-1); {intI}_sort(x+q,n-q); }} {intI}_V_sort(x,q,n); }}''') # ===== driver preamble() main_sort_prep() V_sort() main_sort()