#!/usr/bin/env python3 import sys I = int(sys.argv[1]) assert I in (32,64) # XXX: do more searches for optimal parameters here if I == 32: unrolledsort = (8,16),(16,32),(32,64),(64,128) unrolledsortxor = 64,128 unrolledVsort = (8,16),(16,32),(32,64) unrolledVsortxor = 32,64,128 threestages_specialize = 32, pads = (161,191),(193,255) # (449,511) is also interesting favorxor = True else: unrolledsort = (8,16),(16,32),(32,64) unrolledsortxor = 32,64 unrolledVsort = (8,16),(16,32),(32,64) unrolledVsortxor = 16,32,64 threestages_specialize = 16, pads = (81,95),(97,127) # (225,255) is also interesting favorxor = True sort2_min = min(unrolledsortxor) sort2 = f'sort_2poweratleast{sort2_min}' Vsort2_min = min(unrolledVsortxor) Vsort2 = f'V_sort_2poweratleast{Vsort2_min}' vectorlen = 256//I allowmaskload = False # fast on Intel, and on AMD starting with Zen 2, _but_ AMD documentation claims that it can fail allowmaskstore = False # fast on Intel but not on AMD usepartialmem = True # irrelevant unless allowmaskload or allowmaskstore partiallimit = 64 # XXX: could allow 128 for non-mask versions without extra operations assert vectorlen&(vectorlen-1) == 0 intI = f'int{I}' intIvec = f'{intI}x{vectorlen}' int8vec = f'int8x{(I*vectorlen)//8}' def preamble(): print(fr'''/* WARNING: auto-generated (by autogen/sort); do not edit */ #include #include "{intI}_sort.h" #define {intI} {intI}_t #define {intI}_largest {hex((1<<(I-1))-1)} #include "crypto_{intI}.h" #define {intI}_min crypto_{intI}_min #define {intI}_MINMAX(a,b) crypto_{intI}_minmax(&(a),&(b)) #define NOINLINE __attribute__((noinline)) typedef __m256i {intIvec}; #define {intIvec}_load(z) _mm256_loadu_si256((__m256i *) (z)) #define {intIvec}_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i)) #define {intIvec}_smaller_mask(a,b) _mm256_cmpgt_epi{I}(b,a) #define {intIvec}_add _mm256_add_epi{I} #define {intIvec}_sub _mm256_sub_epi{I} #define {int8vec}_iftopthenelse(c,t,e) _mm256_blendv_epi8(e,t,c) #define {intIvec}_leftleft(a,b) _mm256_permute2x128_si256(a,b,0x20) #define {intIvec}_rightright(a,b) _mm256_permute2x128_si256(a,b,0x31) ''') if vectorlen == 8: print(fr'''#define {intIvec}_MINMAX(a,b) \ do {{ \ {intIvec} c = {intIvec}_min(a,b); \ b = {intIvec}_max(a,b); \ a = c; \ }} while(0) #define {intIvec}_min _mm256_min_epi{I} #define {intIvec}_max _mm256_max_epi{I} #define {intIvec}_set _mm256_setr_epi{I} #define {intIvec}_broadcast _mm256_set1_epi{I} #define {intIvec}_varextract _mm256_permutevar8x32_epi32 #define {intIvec}_extract(v,p0,p1,p2,p3,p4,p5,p6,p7) {intIvec}_varextract(v,_mm256_setr_epi{I}(p0,p1,p2,p3,p4,p5,p6,p7)) #define {intIvec}_constextract_eachside(v,p0,p1,p2,p3) _mm256_shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0)) #define {intIvec}_constextract_aabb_eachside(a,b,p0,p1,p2,p3) _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b),_MM_SHUFFLE(p3,p2,p1,p0))) #define {intIvec}_ifconstthenelse(c0,c1,c2,c3,c4,c5,c6,c7,t,e) _mm256_blend_epi{I}(e,t,(c0)|((c1)<<1)|((c2)<<2)|((c3)<<3)|((c4)<<4)|((c5)<<5)|((c6)<<6)|((c7)<<7)) ''') if vectorlen == 4: print(fr'''#define {intIvec}_MINMAX(a,b) \ do {{ \ {intIvec} t = {intIvec}_smaller_mask(a,b); \ {intIvec} c = {int8vec}_iftopthenelse(t,a,b); \ b = {int8vec}_iftopthenelse(t,b,a); \ a = c; \ }} while(0) #define int32x8_add _mm256_add_epi32 #define int32x8_sub _mm256_sub_epi32 #define int32x8_set _mm256_setr_epi32 #define int32x8_broadcast _mm256_set1_epi32 #define int32x8_varextract _mm256_permutevar8x32_epi32 #define {intIvec}_set _mm256_setr_epi64x #define {intIvec}_broadcast _mm256_set1_epi64x #define {intIvec}_extract(v,p0,p1,p2,p3) _mm256_permute4x64_epi64(v,_MM_SHUFFLE(p3,p2,p1,p0)) #define {intIvec}_constextract_eachside(v,p0,p1) _mm256_shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0))) #define {intIvec}_constextract_a01b01a23b23(a,b,p0,p1,p2,p3) _mm256_castpd_si256(_mm256_shuffle_pd(_mm256_castsi256_pd(a),_mm256_castsi256_pd(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3))) #define {intIvec}_ifconstthenelse(c0,c1,c2,c3,t,e) _mm256_blend_epi32(e,t,(c0)|((c0)<<1)|((c1)<<2)|((c1)<<3)|((c2)<<4)|((c2)<<5)|((c3)<<6)|((c3)<<7)) #include "crypto_int32.h" #define int32_min crypto_int32_min ''') # XXX: can skip some of the macros above if allowmaskload and allowmaskstore if usepartialmem and (allowmaskload or allowmaskstore): print(fr'''#define partialmem (partialmem_storage+64) static const {intI} partialmem_storage[] __attribute__((aligned(128))) = {{ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, }} ; #define {intIvec}_partialload(p,z) _mm256_maskload_epi{I}((void *) (z),p) #define {intIvec}_partialstore(p,z,i) _mm256_maskstore_epi{I}((void *) (z),p,(i)) ''') # ===== serial fixed-size networks def smallsize(): print(fr'''NOINLINE static void {intI}_sort_3through7({intI} *x,long long n) {{ if (n >= 4) {{ {intI} x0 = x[0]; {intI} x1 = x[1]; {intI} x2 = x[2]; {intI} x3 = x[3]; {intI}_MINMAX(x0,x1); {intI}_MINMAX(x2,x3); {intI}_MINMAX(x0,x2); {intI}_MINMAX(x1,x3); {intI}_MINMAX(x1,x2); if (n >= 5) {{ if (n == 5) {{ {intI} x4 = x[4]; {intI}_MINMAX(x0,x4); {intI}_MINMAX(x2,x4); {intI}_MINMAX(x1,x2); {intI}_MINMAX(x3,x4); x[4] = x4; }} else {{ {intI} x4 = x[4]; {intI} x5 = x[5]; {intI}_MINMAX(x4,x5); if (n == 6) {{ {intI}_MINMAX(x0,x4); {intI}_MINMAX(x2,x4); {intI}_MINMAX(x1,x5); {intI}_MINMAX(x3,x5); }} else {{ {intI} x6 = x[6]; {intI}_MINMAX(x4,x6); {intI}_MINMAX(x5,x6); {intI}_MINMAX(x0,x4); {intI}_MINMAX(x2,x6); {intI}_MINMAX(x2,x4); {intI}_MINMAX(x1,x5); {intI}_MINMAX(x3,x5); {intI}_MINMAX(x5,x6); x[6] = x6; }} {intI}_MINMAX(x1,x2); {intI}_MINMAX(x3,x4); x[4] = x4; x[5] = x5; }} }} x[0] = x0; x[1] = x1; x[2] = x2; x[3] = x3; }} else {{ {intI} x0 = x[0]; {intI} x1 = x[1]; {intI} x2 = x[2]; {intI}_MINMAX(x0,x1); {intI}_MINMAX(x0,x2); {intI}_MINMAX(x1,x2); x[0] = x0; x[1] = x1; x[2] = x2; }} }} ''') # ===== vectorized fixed-size networks class vectorsortingnetwork: def __init__(self,layout,xor=True): layout = [list(x) for x in layout] for x in layout: assert len(x) == vectorlen assert len(layout)%2 == 0 # XXX: drop this restriction? self.layout = layout self.phys = [f'x{r}' for r in range(len(layout))] # arch register r is stored in self.phys[r] self.operations = [] self.usedregs = set(self.phys) self.usedregssmallindex = set() self.usexor = xor self.useinfty = False if xor: self.assign('vecxor',f'{intIvec}_broadcast(xor)') def print(self): print('{') if len(self.usedregssmallindex) > 0: print(f' int32_t {",".join(sorted(self.usedregssmallindex))};') if len(self.usedregs) > 0: print(f' {intIvec} {",".join(sorted(self.usedregs))};') for op in self.operations: if op[0] == 'assign': phys,result,newlayout = op[1:] if newlayout is None: print(f' {phys} = {result};') else: print(f' {phys} = {result}; // {" ".join(map(str,newlayout))}') elif op[0] == 'assignsmallindex': phys,result = op[1:] print(f' {phys} = {result};') elif op[0] == 'store': result, = op[1:] print(f' {result};') elif op[0] == 'comment': c, = op[1:] print(f' // {c}') else: raise Exception(f'unrecognized operation {op}') print('}') print('') def allocate(self,r): '''Assign a new free physical register to arch register r. OK for caller to also use old register until calling allocate again.''' self.phys[r] = {'x':'y','y':'x'}[self.phys[r][0]]+self.phys[r][1:] # XXX: for generating low-level asm would instead want to cycle between fewer regs # but, for generating C (or qhasm), trusting register allocator makes generated assignments a bit more readable # (at least for V_sort) def comment(self,c): self.operations.append(('comment',c)) def assign(self,phys,result,newlayout=None): self.usedregs.add(phys) self.operations.append(('assign',phys,result,newlayout)) def assignsmallindex(self,phys,result): self.usedregssmallindex.add(phys) self.operations.append(('assignsmallindex',phys,result)) def createinfty(self): if self.useinfty: return self.useinfty = True self.assign('infty',f'{intIvec}_broadcast({intI}_largest)') def partialmask(self,offset): if usepartialmem: return f'{intIvec}_load(&partialmem[{offset}-n])' # XXX: also do version with offset in constants rather than n? rangevectorlen = ','.join(map(str,range(vectorlen))) return f'{intIvec}_smaller_mask({intIvec}_set({rangevectorlen}),{intIvec}_broadcast(n-{offset}))' def load(self,r,partial=False): rphys = self.phys[r] xor = 'vecxor^' if self.usexor else '' # uint instead of int would slightly streamline infty usage if partial: self.createinfty() if allowmaskload: self.assign(f'partial{r}',f'{self.partialmask(r*vectorlen)}') self.assign(rphys,f'{xor}{intIvec}_partialload(partial{r},x+{vectorlen*r})',self.layout[r]) self.assign(rphys,f'{int8vec}_iftopthenelse(partial{r},{rphys},infty)') else: if I == 32: mplus = ','.join(map(str,range(r*vectorlen,(r+1)*vectorlen))) self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)') self.assign(f'diff{r}',f'{intIvec}_sub({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))') self.assign(rphys,f'{intIvec}_varextract({intIvec}_load(x+pos{r}-8),diff{r})') self.assign(rphys,f'{int8vec}_iftopthenelse(diff{r},{rphys},infty)',self.layout[r]) else: mplus = ','.join(map(str,range(2*r*vectorlen,2*(r+1)*vectorlen))) self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)') self.assign(f'diff{r}',f'int32x8_sub(int32x8_set({mplus}),int32x8_broadcast(2*pos{r}))') self.assign(rphys,f'int32x8_varextract({intIvec}_load(x+pos{r}-4),diff{r})') self.assign(rphys,f'{int8vec}_iftopthenelse(diff{r},{rphys},infty)',self.layout[r]) else: self.assign(rphys,f'{xor}{intIvec}_load(x+{vectorlen*r})',self.layout[r]) # warning: must do store in reverse order def store(self,r,partial=False): rphys = self.phys[r] xor = 'vecxor^' if self.usexor else '' if partial: if allowmaskstore: if not allowmaskload: self.assign(f'partial{r}',f'{self.partialmask(r*vectorlen)}') self.operations.append(('store',f'{intIvec}_partialstore(partial{r},x+{vectorlen*r},{xor}{rphys})')) else: if allowmaskload: self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)') else: # this is why store has to be in reverse order if I == 32: mplus = ','.join(map(str,range(vectorlen))) # XXX: or could offset by minimum r to reuse a vector from before self.operations.append(('store',f'{intIvec}_store(x+pos{r}-8,{intIvec}_varextract({xor}{rphys},{intIvec}_add({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))))')) else: mplus = ','.join(map(str,range(2*vectorlen))) # XXX: or could offset by minimum r to reuse a vector from before self.operations.append(('store',f'{intIvec}_store(x+pos{r}-4,int32x8_varextract({xor}{rphys},int32x8_add(int32x8_set({mplus}),int32x8_broadcast(2*pos{r}))))')) else: self.operations.append(('store',f'{intIvec}_store(x+{vectorlen*r},{xor}{rphys})')) def vecswap(self,r,s): assert r != s self.phys[r],self.phys[s] = self.phys[s],self.phys[r] self.layout[r],self.layout[s] = self.layout[s],self.layout[r] def minmax(self,r,s): assert r != s rphys = self.phys[r] sphys = self.phys[s] newlayout0 = [min(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])] newlayout1 = [max(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])] self.layout[r],self.layout[s] = newlayout0,newlayout1 self.allocate(r) rphysnew = self.phys[r] if I == 32: self.assign(rphysnew,f'{intIvec}_min({rphys},{sphys})',newlayout0) self.assign(sphys,f'{intIvec}_max({rphys},{sphys})',newlayout1) else: # XXX: avx-512 has min_epi64 but avx2 does not self.assign('t',f'{intIvec}_smaller_mask({rphys},{sphys})') self.assign(rphysnew,f'{int8vec}_iftopthenelse(t,{rphys},{sphys})',newlayout0) self.assign(sphys,f'{int8vec}_iftopthenelse(t,{sphys},{rphys})',newlayout1) def shuffle1(self,r,L): r'''Rearrange layout of vector r to match L.''' L = list(L) oldL = self.layout[r] perm = [oldL.index(a) for a in L] rphys = self.phys[r] if vectorlen == 8 and perm[4:] == [perm[0]+4,perm[1]+4,perm[2]+4,perm[3]+4]: self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L) elif vectorlen == 4 and perm[2:] == [perm[0]+2,perm[1]+2]: self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]})',L) elif vectorlen == 8: self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]},{perm[4]},{perm[5]},{perm[6]},{perm[7]})',L) elif vectorlen == 4: self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L) else: raise Exception(f'unhandled permutation from {oldL} to {L}') self.layout[r] = L def shuffle2(self,r,s,L,M,exact=False): oldL = self.layout[r] oldM = self.layout[s] rphys = self.phys[r] sphys = self.phys[s] try: assert vectorlen == 4 p0 = oldL[:2].index(L[0]) p1 = oldM[:2].index(L[1]) p2 = oldL[2:].index(L[2]) p3 = oldM[2:].index(L[3]) q0 = oldL[:2].index(M[0]) q1 = oldM[:2].index(M[1]) q2 = oldL[2:].index(M[2]) q3 = oldM[2:].index(M[3]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{p0},{p1},{p2},{p3})',L) self.assign(sphys,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{q0},{q1},{q2},{q3})',M) return except: pass try: assert vectorlen == 8 p0 = oldL[:4].index(L[0]) p1 = oldL[:4].index(L[1]) p2 = oldM[:4].index(L[2]) p3 = oldM[:4].index(L[3]) assert p0 == oldL[4:].index(L[4]) assert p1 == oldL[4:].index(L[5]) assert p2 == oldM[4:].index(L[6]) assert p3 == oldM[4:].index(L[7]) q0 = oldL[:4].index(M[0]) q1 = oldL[:4].index(M[1]) q2 = oldM[:4].index(M[2]) q3 = oldM[:4].index(M[3]) assert q0 == oldL[4:].index(M[4]) assert q1 == oldL[4:].index(M[5]) assert q2 == oldM[4:].index(M[6]) assert q3 == oldM[4:].index(M[7]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',L) self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',M) return except: pass try: assert not exact assert vectorlen == 8 p0 = oldL[:4].index(L[0]) p1 = oldL[:4].index(L[2]) p2 = oldM[:4].index(L[1]) p3 = oldM[:4].index(L[3]) assert p0 == oldL[4:].index(L[4]) assert p1 == oldL[4:].index(L[6]) assert p2 == oldM[4:].index(L[5]) assert p3 == oldM[4:].index(L[7]) q0 = oldL[:4].index(M[0]) q1 = oldL[:4].index(M[2]) q2 = oldM[:4].index(M[1]) q3 = oldM[:4].index(M[3]) assert q0 == oldL[4:].index(M[4]) assert q1 == oldL[4:].index(M[6]) assert q2 == oldM[4:].index(M[5]) assert q3 == oldM[4:].index(M[7]) self.allocate(r) rphysnew = self.phys[r] self.layout[r] = [L[0],L[2],L[1],L[3],L[4],L[6],L[5],L[7]] self.layout[s] = [M[0],M[2],M[1],M[3],M[4],M[6],M[5],M[7]] self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',self.layout[r]) self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',self.layout[s]) return except: pass try: half = vectorlen//2 assert oldL[:half] == L[:half] assert oldL[half:] == M[:half] assert oldM[:half] == L[half:] assert oldM[half:] == M[half:] self.allocate(r) rphysnew = self.phys[r] self.layout[r] = L self.layout[s] = M self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L) self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M) return except: pass for bigflip in False,True: try: if bigflip: half = vectorlen//2 z0 = L[:half]+M[:half] z1 = L[half:]+M[half:] else: z0 = L z1 = M blend = [] shuf0 = [None]*vectorlen shuf1 = [None]*vectorlen for i in range(vectorlen): if oldL[i] == z0[i]: blend.append(0) shuf0[i] = oldL.index(z1[i]) shuf1[i] = 0 if i < vectorlen//2: assert shuf0[i] < vectorlen//2 else: assert shuf0[i] == shuf0[i-vectorlen//2]+vectorlen//2 else: blend.append(1) assert oldM[i] == z1[i] shuf0[i] = 0 shuf1[i] = oldM.index(z0[i]) if i < vectorlen//2: assert shuf1[i] < vectorlen//2 else: assert shuf1[i] == shuf1[i-vectorlen//2]+vectorlen//2 blend = ','.join(map(str,blend)) s0 = ','.join(map(str,shuf0[:vectorlen//2])) s1 = ','.join(map(str,shuf1[:vectorlen//2])) # XXX: encapsulate temporaries better self.assign('u',f'{intIvec}_constextract_eachside({sphys},{s1})') self.assign('t',f'{intIvec}_constextract_eachside({rphys},{s0})') self.assign(rphys,f'{intIvec}_ifconstthenelse({blend},u,{rphys})',L) self.assign(sphys,f'{intIvec}_ifconstthenelse({blend},{sphys},t)',M) if bigflip: self.allocate(r) rphysnew = self.phys[r] self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L) self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M) self.layout[r] = L self.layout[s] = M return except: pass raise Exception(f'unhandled permutation from {oldL},{oldM} to {L},{M}') def rearrange_onestep(self,comparators): numvectors = len(self.layout) for k in range(0,numvectors,2): if all(comparators[a] == a for a in self.layout[k]): if all(comparators[a] == a for a in self.layout[k+1]): continue collected = set(self.layout[k]+self.layout[k+1]) if not all(comparators[a] in collected for a in collected): for j in range(numvectors): if all(comparators[a] in self.layout[j] for a in self.layout[k]): self.vecswap(j,k+1) return True if any(comparators[a] in self.layout[k] for a in self.layout[k]): newlayout0 = list(self.layout[k]) newlayout1 = list(self.layout[k+1]) for i in range(vectorlen): a = newlayout0[i] b = comparators[a] if b == newlayout1[i]: continue if b in newlayout0: j = newlayout0.index(b) assert j > i newlayout1[i],newlayout0[j] = newlayout0[j],newlayout1[i] else: j = newlayout1.index(b) assert j > i newlayout1[i],newlayout1[j] = newlayout1[j],newlayout1[i] self.shuffle2(k,k+1,newlayout0,newlayout1) return True if [comparators[a] for a in self.layout[k]] != self.layout[k+1]: newlayout = [comparators[a] for a in self.layout[k]] self.shuffle1(k+1,newlayout) return True return False def rearrange(self,comparators): comparators = dict(comparators) while self.rearrange_onestep(comparators): pass def fixedsize(nlow,nhigh,xor=False,V=False): nlow = int(nlow) nhigh = int(nhigh) assert 0 <= nlow assert nlow <= nhigh assert nhigh-partiallimit <= nlow assert nhigh%vectorlen == 0 assert nlow%vectorlen == 0 lgn = 4 while 2**lgn < nhigh: lgn += 1 if nhigh < 2*vectorlen: raise Exception(f'unable to handle sizes below {2*vectorlen}') if V: funname = f'{intI}_V_sort_' else: funname = f'{intI}_sort_' funargs = f'{intI} *x' if nlow < nhigh: funname += f'{nlow}through' funargs += ',long long n' funname += f'{nhigh}' if xor: funname += '_xor' funargs += f',{intI} xor' print('NOINLINE') print(f'static void {funname}({funargs})') # ===== decide on initial layout of nodes numvectors = nhigh//vectorlen layout = {} if V: for k in range(numvectors//2): layout[k] = list(reversed(range(vectorlen*(numvectors//2-1-k),vectorlen*(numvectors//2-k)))) for k in range(numvectors//2,numvectors): layout[k] = list(range(vectorlen*k,vectorlen*(k+1))) else: for k in range(0,numvectors,nhigh//vectorlen): for offset in range(nhigh//vectorlen): layout[k+offset] = list(range(vectorlen*k+offset,vectorlen*(k+nhigh//vectorlen),nhigh//vectorlen)) layout = [layout[k] for k in range(len(layout))] # ===== build network S = vectorsortingnetwork(layout,xor=xor) for k in range(numvectors): S.load(k,partial=k*vectorlen>=nlow) for lgsubsort in range(1,lgn+1): if V and lgsubsort < lgn: continue for stage in reversed(range(lgsubsort)): if nhigh >= 16 and (lgsubsort,stage) == (1,0): comparators = {a:a^1 for a in range(nhigh)} elif nhigh >= 32 and (lgsubsort,stage) == (2,1): comparators = {a:a^2 for a in range(nhigh)} elif nhigh >= 32 and (lgsubsort,stage) == (2,0): comparators = {a:a^(3*(1&((a>>0)^(a>>1)))) for a in range(nhigh)} elif nhigh >= 64 and (lgsubsort,stage) == (3,2): comparators = {a:a^4 for a in range(nhigh)} elif nhigh >= 64 and (lgsubsort,stage) == (3,1): comparators = {a:a+[0,0,2,2,-2,-2,0,0][a%8] for a in range(nhigh)} elif nhigh >= 64 and (lgsubsort,stage) == (3,0): comparators = {a:a+[0,1,-1,1,-1,1,-1,0][a%8] for a in range(nhigh)} elif nhigh >= 128 and (lgsubsort,stage) == (4,3): comparators = {a:a^8 for a in range(nhigh)} elif nhigh >= 128 and (lgsubsort,stage) == (4,2): comparators = {a:a+[0,0,0,0,4,4,4,4,-4,-4,-4,-4,0,0,0,0][a%16] for a in range(nhigh)} elif nhigh >= 128 and (lgsubsort,stage) == (4,1): comparators = {a:a+[0,0,2,2,-2,-2,2,2,-2,-2,2,2,-2,-2,0,0][a%16] for a in range(nhigh)} elif nhigh >= 128 and (lgsubsort,stage) == (4,0): comparators = {a:a+[0,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,0][a%16] for a in range(nhigh)} else: if stage == lgsubsort-1: stagemask = (2<>2)^(k>>1))&1)*6)] for k in range(vectorlen)] # XXX S.shuffle1(k,newlayout) strcomparators = ' '.join(f'{a}:{comparators[a]}' for a in range(nhigh) if comparators[a] > a) S.comment(f'stage ({lgsubsort},{stage}) {strcomparators}') S.rearrange(comparators) for k in range(0,numvectors,2): if all(comparators[a] == a for a in S.layout[k]): if all(comparators[a] == a for a in S.layout[k+1]): continue S.minmax(k,k+1) for k in range(0,numvectors,2): for offset in 0,1: for i in range(numvectors): if k*vectorlen+offset in S.layout[i]: if i != k+offset: S.vecswap(i,k+offset) break for k in range(0,numvectors,2): y0 = list(range(k*vectorlen,(k+1)*vectorlen)) y1 = list(range((k+1)*vectorlen,(k+2)*vectorlen)) S.shuffle2(k,k+1,y0,y1,exact=True) for k in range(numvectors) if allowmaskstore or nlow == nhigh else reversed(range(numvectors)): S.store(k,partial=k*vectorlen>=nlow) S.print() # ===== V_sort def threestages(k,down=False,p=None,atleast=None): assert k in (4,5,6,7,8) print('NOINLINE') updown = 'down' if down else 'up' if p is not None: print(f'static void {intI}_threestages_{k}_{updown}_{p}({intI} *x)') elif atleast is not None: print(f'static void {intI}_threestages_{k}_{updown}_atleast{atleast}({intI} *x,long long p)') else: print(f'static void {intI}_threestages_{k}_{updown}({intI} *x,long long p,long long n)') print('{') print(' long long i;') if p is not None: print(f' long long p = {p};') if p is not None or atleast is not None: print(f' long long n = p;') for vector in True,False: # must be this order if p is not None and p%vectorlen == 0 and not vector: break if atleast is not None and atleast%vectorlen == 0 and not vector: break xtype = intIvec if vector else intI if vector: print(f' for (i = 0;i+{vectorlen} <= n;i += {vectorlen}) {{') else: print(' for (;i < n;++i) {') for j in range(k): addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i' if vector: print(f' {xtype} x{j} = {xtype}_load(&x[{addr}]);') else: print(f' {xtype} x{j} = x[{addr}];') for i,j in (0,4),(1,5),(2,6),(3,7),(0,2),(1,3),(4,6),(5,7),(0,1),(2,3),(4,5),(6,7): if j >= k: continue if down: i,j = j,i print(f' {xtype}_MINMAX(x{i},x{j});') for j in range(k): addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i' if vector: print(f' {xtype}_store(&x[{addr}],x{j});') else: print(f' x[{addr}] = x{j};') print(' }') print('}') print('') def V_sort(): for nlow,nhigh in unrolledVsort: fixedsize(nlow,nhigh,V=True) for n in unrolledVsortxor: fixedsize(n,n,xor=True,V=True) threestages(8) threestages(7) threestages(6) threestages(5) threestages(4) threestages_min = Vsort2_min for p in threestages_specialize: assert p == threestages_min threestages_min *= 2 threestages(8,p=p) threestages(8,down=True,p=p) threestages(8,down=True,atleast=threestages_min) threestages(6,down=True) print(f'''// XXX: currently xor must be 0 or -1 NOINLINE static void {intI}_{Vsort2}_xor({intI} *x,long long n,{intI} xor) {{''') for n in unrolledVsortxor: print(f' if (n == {n}) {{ {intI}_V_sort_{n}_xor(x,xor); return; }}') assert unrolledVsortxor[:3] == (Vsort2_min,Vsort2_min*2,Vsort2_min*4) # so n is at least Vsort2_min*8, justifying the following recursive calls for p in threestages_specialize: assert p in unrolledVsortxor print(f''' if (n == {p*8}) {{ if (xor) {intI}_threestages_8_down_{p}(x); else {intI}_threestages_8_up_{p}(x); for (long long i = 0;i < 8;++i) {intI}_V_sort_{p}_xor(x+{p}*i,xor); return; }}''') print(f''' if (xor) {intI}_threestages_8_down_atleast{threestages_min}(x,n>>3); else {intI}_threestages_8_up(x,n>>3,n>>3); for (long long i = 0;i < 8;++i) {intI}_{Vsort2}_xor(x+(n>>3)*i,n>>3,xor); }} /* q is power of 2; want only merge stages q,q/2,q/4,...,1 */ // XXX: assuming 8 <= q < n <= 2q; q is a power of 2 NOINLINE static void {intI}_V_sort({intI} *x,long long q,long long n) {{''') assert any(nhigh == Vsort2_min for nlow,nhigh in unrolledVsort) for nlow,nhigh in unrolledVsort: if nhigh == Vsort2_min and favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''') print(f' if (n <= {nhigh}) {{ {intI}_V_sort_{nlow}through{nhigh}(x,n); return; }}') if not favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''') print(f''' // 64 <= q < n < 2q q >>= 2; // 64 <= 4q < n < 8q if (7*q < n) {{ {intI}_threestages_8_up(x,q,n-7*q); {intI}_threestages_7_up(x+n-7*q,q,8*q-n); }} else if (6*q < n) {{ {intI}_threestages_7_up(x,q,n-6*q); {intI}_threestages_6_up(x+n-6*q,q,7*q-n); }} else if (5*q < n) {{ {intI}_threestages_6_up(x,q,n-5*q); {intI}_threestages_5_up(x+n-5*q,q,6*q-n); }} else {{ {intI}_threestages_5_up(x,q,n-4*q); {intI}_threestages_4_up(x+n-4*q,q,5*q-n); }} // now want to handle each batch of q entries separately {intI}_V_sort(x,q>>1,q); {intI}_V_sort(x+q,q>>1,q); {intI}_V_sort(x+2*q,q>>1,q); {intI}_V_sort(x+3*q,q>>1,q); x += 4*q; n -= 4*q; while (n >= q) {{ {intI}_V_sort(x,q>>1,q); x += q; n -= q; }} // have n entries left in last batch, with 0 <= n < q if (n <= 1) return; while (q >= n) q >>= 1; // empty merge stage // now 1 <= q < n <= 2q if (q >= 8) {{ {intI}_V_sort(x,q,n); return; }} if (n == 8) {{ {intI}_MINMAX(x[0],x[4]); {intI}_MINMAX(x[1],x[5]); {intI}_MINMAX(x[2],x[6]); {intI}_MINMAX(x[3],x[7]); {intI}_MINMAX(x[0],x[2]); {intI}_MINMAX(x[1],x[3]); {intI}_MINMAX(x[0],x[1]); {intI}_MINMAX(x[2],x[3]); {intI}_MINMAX(x[4],x[6]); {intI}_MINMAX(x[5],x[7]); {intI}_MINMAX(x[4],x[5]); {intI}_MINMAX(x[6],x[7]); return; }} if (4 <= n) {{ for (long long i = 0;i < n-4;++i) {intI}_MINMAX(x[i],x[4+i]); {intI}_MINMAX(x[0],x[2]); {intI}_MINMAX(x[1],x[3]); {intI}_MINMAX(x[0],x[1]); {intI}_MINMAX(x[2],x[3]); n -= 4; x += 4; }} if (3 <= n) {intI}_MINMAX(x[0],x[2]); if (2 <= n) {intI}_MINMAX(x[0],x[1]); }} ''') # ===== main sort def main_sort_prep(): smallsize() for nlow,nhigh in unrolledsort: fixedsize(nlow,nhigh) for n in unrolledsortxor: fixedsize(n,n,xor=True) def main_sort(): print('// XXX: currently xor must be 0 or -1') print('NOINLINE') print(f'static void {intI}_{sort2}_xor({intI} *x,long long n,{intI} xor)') print('{') for n in unrolledsortxor: print(f' if (n == {n}) {{ {intI}_sort_{n}_xor(x,xor); return; }}') print(f' {intI}_{sort2}_xor(x,n>>1,~xor);') print(f' {intI}_{sort2}_xor(x+(n>>1),n>>1,xor);') print(f' {intI}_{Vsort2}_xor(x,n,xor);') print('}') print(fr''' void {intI}_sort({intI} *x,long long n) {{ long long q; if (n <= 1) return; if (n == 2) {{ {intI}_MINMAX(x[0],x[1]); return; }} if (n <= 7) {{ {intI}_sort_3through7(x,n); return; }}''') # XXX: n cutoff here should be another variable to optimize nmin = 8 # invariant: n in program is at least nmin assert any(nhigh == sort2_min for nlow,nhigh in unrolledsort) for nlow,nhigh in unrolledsort: if nhigh == sort2_min and favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''') if nlow <= nmin: print(f' if (n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}') while nlow <= nmin and nmin <= nhigh: nmin += 1 else: print(f' if ({nlow} <= n && n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}') if not favorxor: print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''') qmin = 1 while qmin < nmin-qmin: qmin += qmin assert sort2_min <= qmin for padlow,padhigh in pads: padlowdown = padlow while padlowdown%vectorlen: padlowdown -= 1 padhighup = padhigh while padhighup%vectorlen: padhighup += 1 print(fr''' if ({padlow} <= n && n <= {padhigh}) {{ {intI} buf[{padhighup}]; for (long long i = {padlowdown};i < {padhighup};++i) buf[i] = {intI}_largest; for (long long i = 0;i < n;++i) buf[i] = x[i]; {intI}_sort(buf,{padhighup}); for (long long i = 0;i < n;++i) x[i] = buf[i]; return; }}''') assert sort2_min%2 == 0 assert sort2_min//2 >= Vsort2_min print(fr''' q = {qmin}; while (q < n - q) q += q; // {qmin} <= q < n < 2q if ({sort2_min*16} <= n && n <= (7*q)>>2) {{ long long m = (3*q)>>2; // strategy: sort m, sort n-m, merge long long r = q>>3; // at least {sort2_min} since q is at least {sort2_min*8} {intI}_{sort2}_xor(x,4*r,0); {intI}_{sort2}_xor(x+4*r,r,0); {intI}_{sort2}_xor(x+5*r,r,-1); {intI}_{Vsort2}_xor(x+4*r,2*r,-1); {intI}_threestages_6_down(x,r,r); for (long long i = 0;i < 6;++i) {intI}_{Vsort2}_xor(x+i*r,r,-1); {intI}_sort(x+m,n-m); }} else if ({sort2_min*2} <= q && n == (3*q)>>1) {{ // strategy: sort q, sort q/2, merge long long r = q>>2; // at least {sort2_min//2} since q is at least {sort2_min*2} {intI}_{sort2}_xor(x,4*r,-1); {intI}_{sort2}_xor(x+4*r,2*r,0); {intI}_threestages_6_up(x,r,r); for (long long i = 0;i < 6;++i) {intI}_{Vsort2}_xor(x+i*r,r,0); return; }} else {{ {intI}_{sort2}_xor(x,q,-1); {intI}_sort(x+q,n-q); }} {intI}_V_sort(x,q,n); }}''') # ===== driver preamble() main_sort_prep() V_sort() main_sort()