-rwxr-xr-x 33604 djbsort-20260127/autogen/sort raw
#!/usr/bin/env python3
import sys
I = int(sys.argv[1])
assert I in (32,64)
# XXX: do more searches for optimal parameters here
if I == 32:
unrolledsort = (8,16),(16,32),(32,64),(64,128)
unrolledsortxor = 64,128
unrolledVsort = (8,16),(16,32),(32,64)
unrolledVsortxor = 32,64,128
threestages_specialize = 32,
pads = (161,191),(193,255) # (449,511) is also interesting
favorxor = True
else:
unrolledsort = (8,16),(16,32),(32,64)
unrolledsortxor = 32,64
unrolledVsort = (8,16),(16,32),(32,64)
unrolledVsortxor = 16,32,64
threestages_specialize = 16,
pads = (81,95),(97,127) # (225,255) is also interesting
favorxor = True
sort2_min = min(unrolledsortxor)
sort2 = f'sort_2poweratleast{sort2_min}'
Vsort2_min = min(unrolledVsortxor)
Vsort2 = f'V_sort_2poweratleast{Vsort2_min}'
vectorlen = 256//I
allowmaskload = False # fast on Intel, and on AMD starting with Zen 2, _but_ AMD documentation claims that it can fail
allowmaskstore = False # fast on Intel but not on AMD
usepartialmem = True # irrelevant unless allowmaskload or allowmaskstore
partiallimit = 64 # XXX: could allow 128 for non-mask versions without extra operations
assert vectorlen&(vectorlen-1) == 0
intI = f'int{I}'
intIvec = f'{intI}x{vectorlen}'
int8vec = f'int8x{(I*vectorlen)//8}'
def preamble():
print(fr'''/* WARNING: auto-generated (by autogen/sort); do not edit */
#include <immintrin.h>
#include "{intI}_sort.h"
#define {intI} {intI}_t
#define {intI}_largest {hex((1<<(I-1))-1)}
#include "crypto_{intI}.h"
#define {intI}_min crypto_{intI}_min
#define {intI}_MINMAX(a,b) crypto_{intI}_minmax(&(a),&(b))
#define NOINLINE __attribute__((noinline))
typedef __m256i {intIvec};
#define {intIvec}_load(z) _mm256_loadu_si256((__m256i *) (z))
#define {intIvec}_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i))
#define {intIvec}_smaller_mask(a,b) _mm256_cmpgt_epi{I}(b,a)
#define {intIvec}_add _mm256_add_epi{I}
#define {intIvec}_sub _mm256_sub_epi{I}
#define {int8vec}_iftopthenelse(c,t,e) _mm256_blendv_epi8(e,t,c)
#define {intIvec}_leftleft(a,b) _mm256_permute2x128_si256(a,b,0x20)
#define {intIvec}_rightright(a,b) _mm256_permute2x128_si256(a,b,0x31)
''')
if vectorlen == 8:
print(fr'''#define {intIvec}_MINMAX(a,b) \
do {{ \
{intIvec} c = {intIvec}_min(a,b); \
b = {intIvec}_max(a,b); \
a = c; \
}} while(0)
#define {intIvec}_min _mm256_min_epi{I}
#define {intIvec}_max _mm256_max_epi{I}
#define {intIvec}_set _mm256_setr_epi{I}
#define {intIvec}_broadcast _mm256_set1_epi{I}
#define {intIvec}_varextract _mm256_permutevar8x32_epi32
#define {intIvec}_extract(v,p0,p1,p2,p3,p4,p5,p6,p7) {intIvec}_varextract(v,_mm256_setr_epi{I}(p0,p1,p2,p3,p4,p5,p6,p7))
#define {intIvec}_constextract_eachside(v,p0,p1,p2,p3) _mm256_shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_aabb_eachside(a,b,p0,p1,p2,p3) _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b),_MM_SHUFFLE(p3,p2,p1,p0)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,c4,c5,c6,c7,t,e) _mm256_blend_epi{I}(e,t,(c0)|((c1)<<1)|((c2)<<2)|((c3)<<3)|((c4)<<4)|((c5)<<5)|((c6)<<6)|((c7)<<7))
''')
if vectorlen == 4:
print(fr'''#define {intIvec}_MINMAX(a,b) \
do {{ \
{intIvec} t = {intIvec}_smaller_mask(a,b); \
{intIvec} c = {int8vec}_iftopthenelse(t,a,b); \
b = {int8vec}_iftopthenelse(t,b,a); \
a = c; \
}} while(0)
#define int32x8_add _mm256_add_epi32
#define int32x8_sub _mm256_sub_epi32
#define int32x8_set _mm256_setr_epi32
#define int32x8_broadcast _mm256_set1_epi32
#define int32x8_varextract _mm256_permutevar8x32_epi32
#define {intIvec}_set _mm256_setr_epi64x
#define {intIvec}_broadcast _mm256_set1_epi64x
#define {intIvec}_extract(v,p0,p1,p2,p3) _mm256_permute4x64_epi64(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_eachside(v,p0,p1) _mm256_shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0)))
#define {intIvec}_constextract_a01b01a23b23(a,b,p0,p1,p2,p3) _mm256_castpd_si256(_mm256_shuffle_pd(_mm256_castsi256_pd(a),_mm256_castsi256_pd(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,t,e) _mm256_blend_epi32(e,t,(c0)|((c0)<<1)|((c1)<<2)|((c1)<<3)|((c2)<<4)|((c2)<<5)|((c3)<<6)|((c3)<<7))
#include "crypto_int32.h"
#define int32_min crypto_int32_min
''')
# XXX: can skip some of the macros above if allowmaskload and allowmaskstore
if usepartialmem and (allowmaskload or allowmaskstore):
print(fr'''#define partialmem (partialmem_storage+64)
static const {intI} partialmem_storage[] __attribute__((aligned(128))) = {{
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
}} ;
#define {intIvec}_partialload(p,z) _mm256_maskload_epi{I}((void *) (z),p)
#define {intIvec}_partialstore(p,z,i) _mm256_maskstore_epi{I}((void *) (z),p,(i))
''')
# ===== serial fixed-size networks
def smallsize():
print(fr'''NOINLINE
static void {intI}_sort_3through7({intI} *x,long long n)
{{
if (n >= 4) {{
{intI} x0 = x[0];
{intI} x1 = x[1];
{intI} x2 = x[2];
{intI} x3 = x[3];
{intI}_MINMAX(x0,x1);
{intI}_MINMAX(x2,x3);
{intI}_MINMAX(x0,x2);
{intI}_MINMAX(x1,x3);
{intI}_MINMAX(x1,x2);
if (n >= 5) {{
if (n == 5) {{
{intI} x4 = x[4];
{intI}_MINMAX(x0,x4);
{intI}_MINMAX(x2,x4);
{intI}_MINMAX(x1,x2);
{intI}_MINMAX(x3,x4);
x[4] = x4;
}} else {{
{intI} x4 = x[4];
{intI} x5 = x[5];
{intI}_MINMAX(x4,x5);
if (n == 6) {{
{intI}_MINMAX(x0,x4);
{intI}_MINMAX(x2,x4);
{intI}_MINMAX(x1,x5);
{intI}_MINMAX(x3,x5);
}} else {{
{intI} x6 = x[6];
{intI}_MINMAX(x4,x6);
{intI}_MINMAX(x5,x6);
{intI}_MINMAX(x0,x4);
{intI}_MINMAX(x2,x6);
{intI}_MINMAX(x2,x4);
{intI}_MINMAX(x1,x5);
{intI}_MINMAX(x3,x5);
{intI}_MINMAX(x5,x6);
x[6] = x6;
}}
{intI}_MINMAX(x1,x2);
{intI}_MINMAX(x3,x4);
x[4] = x4;
x[5] = x5;
}}
}}
x[0] = x0;
x[1] = x1;
x[2] = x2;
x[3] = x3;
}} else {{
{intI} x0 = x[0];
{intI} x1 = x[1];
{intI} x2 = x[2];
{intI}_MINMAX(x0,x1);
{intI}_MINMAX(x0,x2);
{intI}_MINMAX(x1,x2);
x[0] = x0;
x[1] = x1;
x[2] = x2;
}}
}}
''')
# ===== vectorized fixed-size networks
class vectorsortingnetwork:
def __init__(self,layout,xor=True):
layout = [list(x) for x in layout]
for x in layout: assert len(x) == vectorlen
assert len(layout)%2 == 0 # XXX: drop this restriction?
self.layout = layout
self.phys = [f'x{r}' for r in range(len(layout))]
# arch register r is stored in self.phys[r]
self.operations = []
self.usedregs = set(self.phys)
self.usedregssmallindex = set()
self.usexor = xor
self.useinfty = False
if xor:
self.assign('vecxor',f'{intIvec}_broadcast(xor)')
def print(self):
print('{')
if len(self.usedregssmallindex) > 0:
print(f' int32_t {",".join(sorted(self.usedregssmallindex))};')
if len(self.usedregs) > 0:
print(f' {intIvec} {",".join(sorted(self.usedregs))};')
for op in self.operations:
if op[0] == 'assign':
phys,result,newlayout = op[1:]
if newlayout is None:
print(f' {phys} = {result};')
else:
print(f' {phys} = {result}; // {" ".join(map(str,newlayout))}')
elif op[0] == 'assignsmallindex':
phys,result = op[1:]
print(f' {phys} = {result};')
elif op[0] == 'store':
result, = op[1:]
print(f' {result};')
elif op[0] == 'comment':
c, = op[1:]
print(f' // {c}')
else:
raise Exception(f'unrecognized operation {op}')
print('}')
print('')
def allocate(self,r):
'''Assign a new free physical register to arch register r. OK for caller to also use old register until calling allocate again.'''
self.phys[r] = {'x':'y','y':'x'}[self.phys[r][0]]+self.phys[r][1:]
# XXX: for generating low-level asm would instead want to cycle between fewer regs
# but, for generating C (or qhasm), trusting register allocator makes generated assignments a bit more readable
# (at least for V_sort)
def comment(self,c):
self.operations.append(('comment',c))
def assign(self,phys,result,newlayout=None):
self.usedregs.add(phys)
self.operations.append(('assign',phys,result,newlayout))
def assignsmallindex(self,phys,result):
self.usedregssmallindex.add(phys)
self.operations.append(('assignsmallindex',phys,result))
def createinfty(self):
if self.useinfty: return
self.useinfty = True
self.assign('infty',f'{intIvec}_broadcast({intI}_largest)')
def partialmask(self,offset):
if usepartialmem:
return f'{intIvec}_load(&partialmem[{offset}-n])'
# XXX: also do version with offset in constants rather than n?
rangevectorlen = ','.join(map(str,range(vectorlen)))
return f'{intIvec}_smaller_mask({intIvec}_set({rangevectorlen}),{intIvec}_broadcast(n-{offset}))'
def load(self,r,partial=False):
rphys = self.phys[r]
xor = 'vecxor^' if self.usexor else ''
# uint instead of int would slightly streamline infty usage
if partial:
self.createinfty()
if allowmaskload:
self.assign(f'partial{r}',f'{self.partialmask(r*vectorlen)}')
self.assign(rphys,f'{xor}{intIvec}_partialload(partial{r},x+{vectorlen*r})',self.layout[r])
self.assign(rphys,f'{int8vec}_iftopthenelse(partial{r},{rphys},infty)')
else:
if I == 32:
mplus = ','.join(map(str,range(r*vectorlen,(r+1)*vectorlen)))
self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)')
self.assign(f'diff{r}',f'{intIvec}_sub({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))')
self.assign(rphys,f'{intIvec}_varextract({intIvec}_load(x+pos{r}-8),diff{r})')
self.assign(rphys,f'{int8vec}_iftopthenelse(diff{r},{rphys},infty)',self.layout[r])
else:
mplus = ','.join(map(str,range(2*r*vectorlen,2*(r+1)*vectorlen)))
self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)')
self.assign(f'diff{r}',f'int32x8_sub(int32x8_set({mplus}),int32x8_broadcast(2*pos{r}))')
self.assign(rphys,f'int32x8_varextract({intIvec}_load(x+pos{r}-4),diff{r})')
self.assign(rphys,f'{int8vec}_iftopthenelse(diff{r},{rphys},infty)',self.layout[r])
else:
self.assign(rphys,f'{xor}{intIvec}_load(x+{vectorlen*r})',self.layout[r])
# warning: must do store in reverse order
def store(self,r,partial=False):
rphys = self.phys[r]
xor = 'vecxor^' if self.usexor else ''
if partial:
if allowmaskstore:
if not allowmaskload:
self.assign(f'partial{r}',f'{self.partialmask(r*vectorlen)}')
self.operations.append(('store',f'{intIvec}_partialstore(partial{r},x+{vectorlen*r},{xor}{rphys})'))
else:
if allowmaskload:
self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorlen},n)')
else:
# this is why store has to be in reverse order
if I == 32:
mplus = ','.join(map(str,range(vectorlen))) # XXX: or could offset by minimum r to reuse a vector from before
self.operations.append(('store',f'{intIvec}_store(x+pos{r}-8,{intIvec}_varextract({xor}{rphys},{intIvec}_add({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))))'))
else:
mplus = ','.join(map(str,range(2*vectorlen))) # XXX: or could offset by minimum r to reuse a vector from before
self.operations.append(('store',f'{intIvec}_store(x+pos{r}-4,int32x8_varextract({xor}{rphys},int32x8_add(int32x8_set({mplus}),int32x8_broadcast(2*pos{r}))))'))
else:
self.operations.append(('store',f'{intIvec}_store(x+{vectorlen*r},{xor}{rphys})'))
def vecswap(self,r,s):
assert r != s
self.phys[r],self.phys[s] = self.phys[s],self.phys[r]
self.layout[r],self.layout[s] = self.layout[s],self.layout[r]
def minmax(self,r,s):
assert r != s
rphys = self.phys[r]
sphys = self.phys[s]
newlayout0 = [min(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
newlayout1 = [max(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
self.layout[r],self.layout[s] = newlayout0,newlayout1
self.allocate(r)
rphysnew = self.phys[r]
if I == 32:
self.assign(rphysnew,f'{intIvec}_min({rphys},{sphys})',newlayout0)
self.assign(sphys,f'{intIvec}_max({rphys},{sphys})',newlayout1)
else: # XXX: avx-512 has min_epi64 but avx2 does not
self.assign('t',f'{intIvec}_smaller_mask({rphys},{sphys})')
self.assign(rphysnew,f'{int8vec}_iftopthenelse(t,{rphys},{sphys})',newlayout0)
self.assign(sphys,f'{int8vec}_iftopthenelse(t,{sphys},{rphys})',newlayout1)
def shuffle1(self,r,L):
r'''Rearrange layout of vector r to match L.'''
L = list(L)
oldL = self.layout[r]
perm = [oldL.index(a) for a in L]
rphys = self.phys[r]
if vectorlen == 8 and perm[4:] == [perm[0]+4,perm[1]+4,perm[2]+4,perm[3]+4]:
self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
elif vectorlen == 4 and perm[2:] == [perm[0]+2,perm[1]+2]:
self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]})',L)
elif vectorlen == 8:
self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]},{perm[4]},{perm[5]},{perm[6]},{perm[7]})',L)
elif vectorlen == 4:
self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
else:
raise Exception(f'unhandled permutation from {oldL} to {L}')
self.layout[r] = L
def shuffle2(self,r,s,L,M,exact=False):
oldL = self.layout[r]
oldM = self.layout[s]
rphys = self.phys[r]
sphys = self.phys[s]
try:
assert vectorlen == 4
p0 = oldL[:2].index(L[0])
p1 = oldM[:2].index(L[1])
p2 = oldL[2:].index(L[2])
p3 = oldM[2:].index(L[3])
q0 = oldL[:2].index(M[0])
q1 = oldM[:2].index(M[1])
q2 = oldL[2:].index(M[2])
q3 = oldM[2:].index(M[3])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
self.assign(sphys,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
return
except:
pass
try:
assert vectorlen == 8
p0 = oldL[:4].index(L[0])
p1 = oldL[:4].index(L[1])
p2 = oldM[:4].index(L[2])
p3 = oldM[:4].index(L[3])
assert p0 == oldL[4:].index(L[4])
assert p1 == oldL[4:].index(L[5])
assert p2 == oldM[4:].index(L[6])
assert p3 == oldM[4:].index(L[7])
q0 = oldL[:4].index(M[0])
q1 = oldL[:4].index(M[1])
q2 = oldM[:4].index(M[2])
q3 = oldM[:4].index(M[3])
assert q0 == oldL[4:].index(M[4])
assert q1 == oldL[4:].index(M[5])
assert q2 == oldM[4:].index(M[6])
assert q3 == oldM[4:].index(M[7])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
return
except:
pass
try:
assert not exact
assert vectorlen == 8
p0 = oldL[:4].index(L[0])
p1 = oldL[:4].index(L[2])
p2 = oldM[:4].index(L[1])
p3 = oldM[:4].index(L[3])
assert p0 == oldL[4:].index(L[4])
assert p1 == oldL[4:].index(L[6])
assert p2 == oldM[4:].index(L[5])
assert p3 == oldM[4:].index(L[7])
q0 = oldL[:4].index(M[0])
q1 = oldL[:4].index(M[2])
q2 = oldM[:4].index(M[1])
q3 = oldM[:4].index(M[3])
assert q0 == oldL[4:].index(M[4])
assert q1 == oldL[4:].index(M[6])
assert q2 == oldM[4:].index(M[5])
assert q3 == oldM[4:].index(M[7])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = [L[0],L[2],L[1],L[3],L[4],L[6],L[5],L[7]]
self.layout[s] = [M[0],M[2],M[1],M[3],M[4],M[6],M[5],M[7]]
self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',self.layout[r])
self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',self.layout[s])
return
except:
pass
try:
half = vectorlen//2
assert oldL[:half] == L[:half]
assert oldL[half:] == M[:half]
assert oldM[:half] == L[half:]
assert oldM[half:] == M[half:]
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
return
except:
pass
for bigflip in False,True:
try:
if bigflip:
half = vectorlen//2
z0 = L[:half]+M[:half]
z1 = L[half:]+M[half:]
else:
z0 = L
z1 = M
blend = []
shuf0 = [None]*vectorlen
shuf1 = [None]*vectorlen
for i in range(vectorlen):
if oldL[i] == z0[i]:
blend.append(0)
shuf0[i] = oldL.index(z1[i])
shuf1[i] = 0
if i < vectorlen//2:
assert shuf0[i] < vectorlen//2
else:
assert shuf0[i] == shuf0[i-vectorlen//2]+vectorlen//2
else:
blend.append(1)
assert oldM[i] == z1[i]
shuf0[i] = 0
shuf1[i] = oldM.index(z0[i])
if i < vectorlen//2:
assert shuf1[i] < vectorlen//2
else:
assert shuf1[i] == shuf1[i-vectorlen//2]+vectorlen//2
blend = ','.join(map(str,blend))
s0 = ','.join(map(str,shuf0[:vectorlen//2]))
s1 = ','.join(map(str,shuf1[:vectorlen//2]))
# XXX: encapsulate temporaries better
self.assign('u',f'{intIvec}_constextract_eachside({sphys},{s1})')
self.assign('t',f'{intIvec}_constextract_eachside({rphys},{s0})')
self.assign(rphys,f'{intIvec}_ifconstthenelse({blend},u,{rphys})',L)
self.assign(sphys,f'{intIvec}_ifconstthenelse({blend},{sphys},t)',M)
if bigflip:
self.allocate(r)
rphysnew = self.phys[r]
self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
self.layout[r] = L
self.layout[s] = M
return
except:
pass
raise Exception(f'unhandled permutation from {oldL},{oldM} to {L},{M}')
def rearrange_onestep(self,comparators):
numvectors = len(self.layout)
for k in range(0,numvectors,2):
if all(comparators[a] == a for a in self.layout[k]):
if all(comparators[a] == a for a in self.layout[k+1]):
continue
collected = set(self.layout[k]+self.layout[k+1])
if not all(comparators[a] in collected for a in collected):
for j in range(numvectors):
if all(comparators[a] in self.layout[j] for a in self.layout[k]):
self.vecswap(j,k+1)
return True
if any(comparators[a] in self.layout[k] for a in self.layout[k]):
newlayout0 = list(self.layout[k])
newlayout1 = list(self.layout[k+1])
for i in range(vectorlen):
a = newlayout0[i]
b = comparators[a]
if b == newlayout1[i]: continue
if b in newlayout0:
j = newlayout0.index(b)
assert j > i
newlayout1[i],newlayout0[j] = newlayout0[j],newlayout1[i]
else:
j = newlayout1.index(b)
assert j > i
newlayout1[i],newlayout1[j] = newlayout1[j],newlayout1[i]
self.shuffle2(k,k+1,newlayout0,newlayout1)
return True
if [comparators[a] for a in self.layout[k]] != self.layout[k+1]:
newlayout = [comparators[a] for a in self.layout[k]]
self.shuffle1(k+1,newlayout)
return True
return False
def rearrange(self,comparators):
comparators = dict(comparators)
while self.rearrange_onestep(comparators): pass
def fixedsize(nlow,nhigh,xor=False,V=False):
nlow = int(nlow)
nhigh = int(nhigh)
assert 0 <= nlow
assert nlow <= nhigh
assert nhigh-partiallimit <= nlow
assert nhigh%vectorlen == 0
assert nlow%vectorlen == 0
lgn = 4
while 2**lgn < nhigh: lgn += 1
if nhigh < 2*vectorlen:
raise Exception(f'unable to handle sizes below {2*vectorlen}')
if V:
funname = f'{intI}_V_sort_'
else:
funname = f'{intI}_sort_'
funargs = f'{intI} *x'
if nlow < nhigh:
funname += f'{nlow}through'
funargs += ',long long n'
funname += f'{nhigh}'
if xor:
funname += '_xor'
funargs += f',{intI} xor'
print('NOINLINE')
print(f'static void {funname}({funargs})')
# ===== decide on initial layout of nodes
numvectors = nhigh//vectorlen
layout = {}
if V:
for k in range(numvectors//2):
layout[k] = list(reversed(range(vectorlen*(numvectors//2-1-k),vectorlen*(numvectors//2-k))))
for k in range(numvectors//2,numvectors):
layout[k] = list(range(vectorlen*k,vectorlen*(k+1)))
else:
for k in range(0,numvectors,nhigh//vectorlen):
for offset in range(nhigh//vectorlen):
layout[k+offset] = list(range(vectorlen*k+offset,vectorlen*(k+nhigh//vectorlen),nhigh//vectorlen))
layout = [layout[k] for k in range(len(layout))]
# ===== build network
S = vectorsortingnetwork(layout,xor=xor)
for k in range(numvectors):
S.load(k,partial=k*vectorlen>=nlow)
for lgsubsort in range(1,lgn+1):
if V and lgsubsort < lgn: continue
for stage in reversed(range(lgsubsort)):
if nhigh >= 16 and (lgsubsort,stage) == (1,0):
comparators = {a:a^1 for a in range(nhigh)}
elif nhigh >= 32 and (lgsubsort,stage) == (2,1):
comparators = {a:a^2 for a in range(nhigh)}
elif nhigh >= 32 and (lgsubsort,stage) == (2,0):
comparators = {a:a^(3*(1&((a>>0)^(a>>1)))) for a in range(nhigh)}
elif nhigh >= 64 and (lgsubsort,stage) == (3,2):
comparators = {a:a^4 for a in range(nhigh)}
elif nhigh >= 64 and (lgsubsort,stage) == (3,1):
comparators = {a:a+[0,0,2,2,-2,-2,0,0][a%8] for a in range(nhigh)}
elif nhigh >= 64 and (lgsubsort,stage) == (3,0):
comparators = {a:a+[0,1,-1,1,-1,1,-1,0][a%8] for a in range(nhigh)}
elif nhigh >= 128 and (lgsubsort,stage) == (4,3):
comparators = {a:a^8 for a in range(nhigh)}
elif nhigh >= 128 and (lgsubsort,stage) == (4,2):
comparators = {a:a+[0,0,0,0,4,4,4,4,-4,-4,-4,-4,0,0,0,0][a%16] for a in range(nhigh)}
elif nhigh >= 128 and (lgsubsort,stage) == (4,1):
comparators = {a:a+[0,0,2,2,-2,-2,2,2,-2,-2,2,2,-2,-2,0,0][a%16] for a in range(nhigh)}
elif nhigh >= 128 and (lgsubsort,stage) == (4,0):
comparators = {a:a+[0,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,0][a%16] for a in range(nhigh)}
else:
if stage == lgsubsort-1:
stagemask = (2<<stage)-1
else:
stagemask = 1<<stage
comparators = {a:a^stagemask for a in range(nhigh)}
if vectorlen == 8 and 2<<stage == nhigh and not V:
for k in range(0,numvectors,2):
p = S.layout[k]
newlayout = [p[k^((((k>>2)^(k>>1))&1)*6)] for k in range(vectorlen)] # XXX
S.shuffle1(k,newlayout)
strcomparators = ' '.join(f'{a}:{comparators[a]}' for a in range(nhigh) if comparators[a] > a)
S.comment(f'stage ({lgsubsort},{stage}) {strcomparators}')
S.rearrange(comparators)
for k in range(0,numvectors,2):
if all(comparators[a] == a for a in S.layout[k]):
if all(comparators[a] == a for a in S.layout[k+1]):
continue
S.minmax(k,k+1)
for k in range(0,numvectors,2):
for offset in 0,1:
for i in range(numvectors):
if k*vectorlen+offset in S.layout[i]:
if i != k+offset:
S.vecswap(i,k+offset)
break
for k in range(0,numvectors,2):
y0 = list(range(k*vectorlen,(k+1)*vectorlen))
y1 = list(range((k+1)*vectorlen,(k+2)*vectorlen))
S.shuffle2(k,k+1,y0,y1,exact=True)
for k in range(numvectors) if allowmaskstore or nlow == nhigh else reversed(range(numvectors)):
S.store(k,partial=k*vectorlen>=nlow)
S.print()
# ===== V_sort
def threestages(k,down=False,p=None,atleast=None):
assert k in (4,5,6,7,8)
print('NOINLINE')
updown = 'down' if down else 'up'
if p is not None:
print(f'static void {intI}_threestages_{k}_{updown}_{p}({intI} *x)')
elif atleast is not None:
print(f'static void {intI}_threestages_{k}_{updown}_atleast{atleast}({intI} *x,long long p)')
else:
print(f'static void {intI}_threestages_{k}_{updown}({intI} *x,long long p,long long n)')
print('{')
print(' long long i;')
if p is not None:
print(f' long long p = {p};')
if p is not None or atleast is not None:
print(f' long long n = p;')
for vector in True,False: # must be this order
if p is not None and p%vectorlen == 0 and not vector: break
if atleast is not None and atleast%vectorlen == 0 and not vector: break
xtype = intIvec if vector else intI
if vector:
print(f' for (i = 0;i+{vectorlen} <= n;i += {vectorlen}) {{')
else:
print(' for (;i < n;++i) {')
for j in range(k):
addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
if vector:
print(f' {xtype} x{j} = {xtype}_load(&x[{addr}]);')
else:
print(f' {xtype} x{j} = x[{addr}];')
for i,j in (0,4),(1,5),(2,6),(3,7),(0,2),(1,3),(4,6),(5,7),(0,1),(2,3),(4,5),(6,7):
if j >= k: continue
if down: i,j = j,i
print(f' {xtype}_MINMAX(x{i},x{j});')
for j in range(k):
addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
if vector:
print(f' {xtype}_store(&x[{addr}],x{j});')
else:
print(f' x[{addr}] = x{j};')
print(' }')
print('}')
print('')
def V_sort():
for nlow,nhigh in unrolledVsort:
fixedsize(nlow,nhigh,V=True)
for n in unrolledVsortxor:
fixedsize(n,n,xor=True,V=True)
threestages(8)
threestages(7)
threestages(6)
threestages(5)
threestages(4)
threestages_min = Vsort2_min
for p in threestages_specialize:
assert p == threestages_min
threestages_min *= 2
threestages(8,p=p)
threestages(8,down=True,p=p)
threestages(8,down=True,atleast=threestages_min)
threestages(6,down=True)
print(f'''// XXX: currently xor must be 0 or -1
NOINLINE
static void {intI}_{Vsort2}_xor({intI} *x,long long n,{intI} xor)
{{''')
for n in unrolledVsortxor:
print(f' if (n == {n}) {{ {intI}_V_sort_{n}_xor(x,xor); return; }}')
assert unrolledVsortxor[:3] == (Vsort2_min,Vsort2_min*2,Vsort2_min*4)
# so n is at least Vsort2_min*8, justifying the following recursive calls
for p in threestages_specialize:
assert p in unrolledVsortxor
print(f''' if (n == {p*8}) {{
if (xor)
{intI}_threestages_8_down_{p}(x);
else
{intI}_threestages_8_up_{p}(x);
for (long long i = 0;i < 8;++i)
{intI}_V_sort_{p}_xor(x+{p}*i,xor);
return;
}}''')
print(f''' if (xor)
{intI}_threestages_8_down_atleast{threestages_min}(x,n>>3);
else
{intI}_threestages_8_up(x,n>>3,n>>3);
for (long long i = 0;i < 8;++i)
{intI}_{Vsort2}_xor(x+(n>>3)*i,n>>3,xor);
}}
/* q is power of 2; want only merge stages q,q/2,q/4,...,1 */
// XXX: assuming 8 <= q < n <= 2q; q is a power of 2
NOINLINE
static void {intI}_V_sort({intI} *x,long long q,long long n)
{{''')
assert any(nhigh == Vsort2_min for nlow,nhigh in unrolledVsort)
for nlow,nhigh in unrolledVsort:
if nhigh == Vsort2_min and favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')
print(f' if (n <= {nhigh}) {{ {intI}_V_sort_{nlow}through{nhigh}(x,n); return; }}')
if not favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')
print(f'''
// 64 <= q < n < 2q
q >>= 2;
// 64 <= 4q < n < 8q
if (7*q < n) {{
{intI}_threestages_8_up(x,q,n-7*q);
{intI}_threestages_7_up(x+n-7*q,q,8*q-n);
}} else if (6*q < n) {{
{intI}_threestages_7_up(x,q,n-6*q);
{intI}_threestages_6_up(x+n-6*q,q,7*q-n);
}} else if (5*q < n) {{
{intI}_threestages_6_up(x,q,n-5*q);
{intI}_threestages_5_up(x+n-5*q,q,6*q-n);
}} else {{
{intI}_threestages_5_up(x,q,n-4*q);
{intI}_threestages_4_up(x+n-4*q,q,5*q-n);
}}
// now want to handle each batch of q entries separately
{intI}_V_sort(x,q>>1,q);
{intI}_V_sort(x+q,q>>1,q);
{intI}_V_sort(x+2*q,q>>1,q);
{intI}_V_sort(x+3*q,q>>1,q);
x += 4*q;
n -= 4*q;
while (n >= q) {{
{intI}_V_sort(x,q>>1,q);
x += q;
n -= q;
}}
// have n entries left in last batch, with 0 <= n < q
if (n <= 1) return;
while (q >= n) q >>= 1; // empty merge stage
// now 1 <= q < n <= 2q
if (q >= 8) {{ {intI}_V_sort(x,q,n); return; }}
if (n == 8) {{
{intI}_MINMAX(x[0],x[4]);
{intI}_MINMAX(x[1],x[5]);
{intI}_MINMAX(x[2],x[6]);
{intI}_MINMAX(x[3],x[7]);
{intI}_MINMAX(x[0],x[2]);
{intI}_MINMAX(x[1],x[3]);
{intI}_MINMAX(x[0],x[1]);
{intI}_MINMAX(x[2],x[3]);
{intI}_MINMAX(x[4],x[6]);
{intI}_MINMAX(x[5],x[7]);
{intI}_MINMAX(x[4],x[5]);
{intI}_MINMAX(x[6],x[7]);
return;
}}
if (4 <= n) {{
for (long long i = 0;i < n-4;++i)
{intI}_MINMAX(x[i],x[4+i]);
{intI}_MINMAX(x[0],x[2]);
{intI}_MINMAX(x[1],x[3]);
{intI}_MINMAX(x[0],x[1]);
{intI}_MINMAX(x[2],x[3]);
n -= 4;
x += 4;
}}
if (3 <= n)
{intI}_MINMAX(x[0],x[2]);
if (2 <= n)
{intI}_MINMAX(x[0],x[1]);
}}
''')
# ===== main sort
def main_sort_prep():
smallsize()
for nlow,nhigh in unrolledsort:
fixedsize(nlow,nhigh)
for n in unrolledsortxor:
fixedsize(n,n,xor=True)
def main_sort():
print('// XXX: currently xor must be 0 or -1')
print('NOINLINE')
print(f'static void {intI}_{sort2}_xor({intI} *x,long long n,{intI} xor)')
print('{')
for n in unrolledsortxor:
print(f' if (n == {n}) {{ {intI}_sort_{n}_xor(x,xor); return; }}')
print(f' {intI}_{sort2}_xor(x,n>>1,~xor);')
print(f' {intI}_{sort2}_xor(x+(n>>1),n>>1,xor);')
print(f' {intI}_{Vsort2}_xor(x,n,xor);')
print('}')
print(fr'''
void {intI}_sort({intI} *x,long long n)
{{ long long q;
if (n <= 1) return;
if (n == 2) {{ {intI}_MINMAX(x[0],x[1]); return; }}
if (n <= 7) {{ {intI}_sort_3through7(x,n); return; }}''')
# XXX: n cutoff here should be another variable to optimize
nmin = 8
# invariant: n in program is at least nmin
assert any(nhigh == sort2_min for nlow,nhigh in unrolledsort)
for nlow,nhigh in unrolledsort:
if nhigh == sort2_min and favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')
if nlow <= nmin:
print(f' if (n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')
while nlow <= nmin and nmin <= nhigh: nmin += 1
else:
print(f' if ({nlow} <= n && n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')
if not favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')
qmin = 1
while qmin < nmin-qmin: qmin += qmin
assert sort2_min <= qmin
for padlow,padhigh in pads:
padlowdown = padlow
while padlowdown%vectorlen: padlowdown -= 1
padhighup = padhigh
while padhighup%vectorlen: padhighup += 1
print(fr''' if ({padlow} <= n && n <= {padhigh}) {{
{intI} buf[{padhighup}];
for (long long i = {padlowdown};i < {padhighup};++i) buf[i] = {intI}_largest;
for (long long i = 0;i < n;++i) buf[i] = x[i];
{intI}_sort(buf,{padhighup});
for (long long i = 0;i < n;++i) x[i] = buf[i];
return;
}}''')
assert sort2_min%2 == 0
assert sort2_min//2 >= Vsort2_min
print(fr'''
q = {qmin};
while (q < n - q) q += q;
// {qmin} <= q < n < 2q
if ({sort2_min*16} <= n && n <= (7*q)>>2) {{
long long m = (3*q)>>2; // strategy: sort m, sort n-m, merge
long long r = q>>3; // at least {sort2_min} since q is at least {sort2_min*8}
{intI}_{sort2}_xor(x,4*r,0);
{intI}_{sort2}_xor(x+4*r,r,0);
{intI}_{sort2}_xor(x+5*r,r,-1);
{intI}_{Vsort2}_xor(x+4*r,2*r,-1);
{intI}_threestages_6_down(x,r,r);
for (long long i = 0;i < 6;++i)
{intI}_{Vsort2}_xor(x+i*r,r,-1);
{intI}_sort(x+m,n-m);
}} else if ({sort2_min*2} <= q && n == (3*q)>>1) {{
// strategy: sort q, sort q/2, merge
long long r = q>>2; // at least {sort2_min//2} since q is at least {sort2_min*2}
{intI}_{sort2}_xor(x,4*r,-1);
{intI}_{sort2}_xor(x+4*r,2*r,0);
{intI}_threestages_6_up(x,r,r);
for (long long i = 0;i < 6;++i)
{intI}_{Vsort2}_xor(x+i*r,r,0);
return;
}} else {{
{intI}_{sort2}_xor(x,q,-1);
{intI}_sort(x+q,n-q);
}}
{intI}_V_sort(x,q,n);
}}''')
# ===== driver
preamble()
main_sort_prep()
V_sort()
main_sort()