-rwxr-xr-x 46059 djbsort-20260210/autogen/sort raw
#!/usr/bin/env python3
import sys
I = 32
vectorunit = 'avx2'
if len(sys.argv) > 1:
I = int(sys.argv[1])
assert I in (32,64)
if len(sys.argv) > 2:
vectorunit = sys.argv[2]
assert vectorunit in ('neon','sse42','avx2')
# XXX: further porting requires various tweaks below
# e.g. arm32 does not have vclt_s64 even with neon
# e.g. sse2 does not have blendv_epi8
vectorbits = {'neon':128,'sse2':128,'sse42':128,'avx2':256}[vectorunit]
vectorwords = vectorbits//I
# XXX: do more searches for optimal parameters here
if vectorwords == 8:
unrolledsort = (8,16),(16,32),(32,64),(64,128)
unrolledsortxor = 64,128
unrolledVsort = (8,16),(16,32),(32,64)
unrolledVsortxor = 32,64,128
threestages_specialize = 32,
pads = (161,191),(193,255) # (449,511) is also interesting
favorxor = True
elif vectorwords == 4:
if I == 32:
unrolledsort = (8,16),(16,32)
else:
unrolledsort = (8,16),(16,32),(32,64)
unrolledsortxor = 32,64
unrolledVsort = (8,16),(16,32),(32,64)
unrolledVsortxor = 16,32,64
threestages_specialize = 16,
pads = (81,95),(97,127) # (225,255) is also interesting
favorxor = True
elif vectorwords == 2:
unrolledsort = (8,16),(16,32),
unrolledsortxor = 32,
unrolledVsort = (8,16),(16,32),(32,64)
unrolledVsortxor = 16,32,64
threestages_specialize = 16,
pads = ()
favorxor = True
if vectorunit != 'avx2':
pads = () # for simplicity in the absence of speed study
sort2_min = min(unrolledsortxor)
sort2 = f'sort_2poweratleast{sort2_min}'
Vsort2_min = min(unrolledVsortxor)
Vsort2 = f'V_sort_2poweratleast{Vsort2_min}'
allowmaskload = False # fast on Intel, and on AMD starting with Zen 2, _but_ AMD documentation claims that it can fail
allowmaskstore = False # fast on Intel but not on AMD
usepartialmem = True # irrelevant unless allowmaskload or allowmaskstore
partiallimit = 64 # XXX: could allow 128 for non-mask versions without extra operations
assert vectorwords&(vectorwords-1) == 0
intI = f'int{I}'
intIvec = f'{intI}x{vectorwords}'
int8vec = f'int8x{(I*vectorwords)//8}'
int32vec = f'int32x{(I*vectorwords)//32}'
def preamble():
print('/* WARNING: auto-generated (by autogen/sort); do not edit */')
print('')
if vectorunit == 'neon':
print('#include <arm_neon.h>')
else:
print('#include <immintrin.h>')
print('')
print(fr'''#include "{intI}_sort.h"
#define {intI} {intI}_t
#define {intI}_largest {hex((1<<(I-1))-1)}
#include "crypto_{intI}.h"
#define {intI}_min crypto_{intI}_min
#define {intI}_MINMAX(a,b) crypto_{intI}_minmax(&(a),&(b))
#define NOINLINE __attribute__((noinline))
''')
if vectorunit == 'neon':
print(f'''#include "crypto_int8.h"
#define int8 crypto_int8
#define int8_min crypto_int8_min
#define int8x16 int8x16_t
#include "crypto_uint8.h"
#define uint8 crypto_uint8
#define uint8x16 uint8x16_t
''')
if I != 32:
print(f'''#include "crypto_int32.h"
#define int32 crypto_int32
#define int32_min crypto_int32_min
#define int32x4 int32x4_t
''')
print(f'#define {intIvec} {intIvec}_t')
print(f'#define u{intIvec} u{intIvec}_t')
# XXX: should also look for ways to use ld2 etc
print(fr'''#define {intIvec}_load vld1q_s{I}
#define {intIvec}_store vst1q_s{I}
#define {intIvec}_ifthenelse vbslq_s{I}
''')
if I == 32:
print(fr'''#define {intIvec}_smaller_umask vcltq_s{I}
#define {intIvec}_min vminq_s{I}
#define {intIvec}_max vmaxq_s{I}
#define {intIvec}_MINMAX(a,b) \
do {{ \
{intIvec} c = {intIvec}_min(a,b); \
b = {intIvec}_max(a,b); \
a = c; \
}} while(0)
''')
else:
print(fr'''#define {int32vec}_smaller_umask vcltq_s32
#define {intIvec}_smaller_umask vcltq_s{I}
#define {intIvec}_MINMAX(a,b) \
do {{ \
u{intIvec} t = {intIvec}_smaller_umask(a,b); \
{intIvec} c = {intIvec}_ifthenelse(t,a,b); \
b = {intIvec}_ifthenelse(t,b,a); \
a = c; \
}} while(0)
''')
# XXX: tweak varextract name to reflect differences in out-of-range handling
# XXX: also use tbx for infty etc
print(f'''#define {int8vec}_load vld1q_s8
#define {int8vec}_varextract vqtbl1q_s8
#define {int8vec}_add vaddq_s8
#define {int8vec}_sub vsubq_s8
#define {int8vec}_broadcast vdupq_n_s8
#define u{int8vec}_load vld1q_u8
#define u{int8vec}_add vaddq_u8
#define u{int8vec}_sub vsubq_u8
#define u{int8vec}_broadcast vdupq_n_u8
#define {int8vec}_from_{intIvec} vreinterpretq_s8_s{I}
#define u{intIvec}_from_{intIvec} vreinterpretq_u{I}_s{I}
#define {intIvec}_from_u{intIvec} vreinterpretq_s{I}_u{I}
#define {intIvec}_from_{int8vec} vreinterpretq_s{I}_s8
#define {int32vec}_load vld1q_s32
#define {int32vec}_add vaddq_s32
#define {int32vec}_sub vsubq_s32
#define {intIvec}_broadcast vdupq_n_s{I}
static inline u{int8vec} u{int8vec}_set(uint8 x0,uint8 x1,uint8 x2,uint8 x3,uint8 x4,uint8 x5,uint8 x6,uint8 x7,uint8 x8,uint8 x9,uint8 x10,uint8 x11,uint8 x12,uint8 x13,uint8 x14,uint8 x15)
{{
uint8 x[16] = {{x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15}};
return u{int8vec}_load(x);
}}
''')
if I == 32:
print(f'''static inline {int32vec} {int32vec}_set(int32 x0,int32 x1,int32 x2,int32 x3)
{{
int32 x[4] = {{x0,x1,x2,x3}};
return {int32vec}_load(x);
}}
''')
if I == 64:
print(f'''static inline {intIvec} {intIvec}_set(int{I} x0,int{I} x1)
{{
int{I} x[2] = {{x0,x1}};
return {intIvec}_load(x);
}}
''')
if I == 32:
# XXX: maybe better to use vswp for 2301
print(f'''#define {intIvec}_1032 vrev64q_s32
#define {intIvec}_2301(v) vextq_s32(v,v,2)
#define {intIvec}_3210(v) {intIvec}_1032({intIvec}_2301(v))
#define {intIvec}_a0b0a2b2 vtrn1q_s32
#define {intIvec}_a1b1a3b3 vtrn2q_s32
#define {intIvec}_a0b0a1b1 vzip1q_s32
#define {intIvec}_a2b2a3b3 vzip2q_s32
#define {intIvec}_leftleft(a,b) vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(a),vreinterpretq_s64_s32(b)))
#define {intIvec}_rightright(a,b) vreinterpretq_s32_s64(vzip2q_s64(vreinterpretq_s64_s32(a),vreinterpretq_s64_s32(b)))
''')
else:
print(f'''#define {intIvec}_10(v) vextq_s64(v,v,1)
#define {intIvec}_leftleft vzip1q_s64
#define {intIvec}_rightright vzip2q_s64
''')
else:
mmprefix = {'sse2':'_mm_','sse42':'_mm_','avx2':'_mm256_'}[vectorunit]
print(fr'''typedef __m{vectorbits}i {intIvec};
#define {intIvec}_load(z) {mmprefix}loadu_si{vectorbits}((__m{vectorbits}i *) (z))
#define {intIvec}_store(z,i) {mmprefix}storeu_si{vectorbits}((__m{vectorbits}i *) (z),(i))
#define {intIvec}_smaller_mask(a,b) {mmprefix}cmpgt_epi{I}(b,a)
#define {intIvec}_add {mmprefix}add_epi{I}
#define {intIvec}_sub {mmprefix}sub_epi{I}''')
print(f'#define {int8vec}_iftopthenelse(c,t,e) {mmprefix}blendv_epi8(e,t,c)')
if vectorunit == 'sse42':
print(fr'''#define {intIvec}_leftleft(a,b) {mmprefix}unpacklo_epi64(a,b)
#define {intIvec}_rightright(a,b) {mmprefix}unpackhi_epi64(a,b)''')
if I == 32:
print(fr'''#define {intIvec}_a0b0a1b1(a,b) {mmprefix}unpacklo_epi32(a,b)
#define {intIvec}_a2b2a3b3(a,b) {mmprefix}unpackhi_epi32(a,b)''')
if vectorunit == 'avx2':
print(fr'''#define {intIvec}_leftleft(a,b) {mmprefix}permute2x128_si256(a,b,0x20)
#define {intIvec}_rightright(a,b) {mmprefix}permute2x128_si256(a,b,0x31)''')
if I == 32:
print(fr'''
#define {intIvec}_MINMAX(a,b) \
do {{ \
{intIvec} c = {intIvec}_min(a,b); \
b = {intIvec}_max(a,b); \
a = c; \
}} while(0)
''')
if I == 64:
print(fr'''
#define {intIvec}_MINMAX(a,b) \
do {{ \
{intIvec} t = {intIvec}_smaller_mask(a,b); \
{intIvec} c = {int8vec}_iftopthenelse(t,a,b); \
b = {int8vec}_iftopthenelse(t,b,a); \
a = c; \
}} while(0)
''')
if I == 32:
print(fr'''#define {intIvec}_min {mmprefix}min_epi{I}
#define {intIvec}_max {mmprefix}max_epi{I}
#define {intIvec}_set {mmprefix}setr_epi{I}
#define {intIvec}_broadcast {mmprefix}set1_epi{I}''')
if (I,vectorwords) == (32,4):
print(fr'''#define {int8vec}_add {mmprefix}add_epi8
#define {int8vec}_sub {mmprefix}sub_epi8
#define {int8vec}_set {mmprefix}setr_epi8
#define {int8vec}_broadcast {mmprefix}set1_epi8
#define {int8vec}_varextract {mmprefix}shuffle_epi8
#define {int32vec}_add {mmprefix}add_epi32
#define {int32vec}_sub {mmprefix}sub_epi32
#define {int32vec}_set {mmprefix}setr_epi32
#define {int32vec}_broadcast {mmprefix}set1_epi32
#define {intIvec}_extract(v,p0,p1,p2,p3) {mmprefix}shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_ab0ab1ab2ab3(a,b,p0,p1,p2,p3) {mmprefix}castps_si128({mmprefix}blend_ps({mmprefix}castsi128_ps(a),{mmprefix}castsi128_ps(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3)))
#define {intIvec}_1032(v) {intIvec}_extract(v,1,0,3,2)
#define {intIvec}_2301(v) {intIvec}_extract(v,2,3,0,1)
#define {intIvec}_3210(v) {intIvec}_extract(v,3,2,1,0)
#include "crypto_int8.h"
#define int8_min crypto_int8_min
''')
if (I,vectorwords) == (64,2):
print(fr'''#define {int8vec}_add {mmprefix}add_epi8
#define {int8vec}_sub {mmprefix}sub_epi8
#define {int8vec}_set {mmprefix}setr_epi8
#define {int8vec}_broadcast {mmprefix}set1_epi8
#define {int8vec}_varextract {mmprefix}shuffle_epi8
#define {int32vec}_add {mmprefix}add_epi32
#define {int32vec}_sub {mmprefix}sub_epi32
#define {int32vec}_set {mmprefix}setr_epi32
#define {int32vec}_broadcast {mmprefix}set1_epi32
#define {intIvec}_extract(v,p0,p1) {mmprefix}shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0)))
#define {intIvec}_set(a,b) {mmprefix}set_epi64x(b,a)
#define {intIvec}_broadcast {mmprefix}set1_epi64x
#define {intIvec}_10(v) {intIvec}_extract(v,1,0)
#include "crypto_int8.h"
#define int8_min crypto_int8_min
#include "crypto_int32.h"
#define int32_min crypto_int32_min
''')
if (I,vectorwords) == (32,8):
print(fr'''#define {intIvec}_varextract {mmprefix}permutevar8x32_epi32
#define {intIvec}_extract(v,p0,p1,p2,p3,p4,p5,p6,p7) {intIvec}_varextract(v,{mmprefix}setr_epi{I}(p0,p1,p2,p3,p4,p5,p6,p7))
#define {intIvec}_constextract_eachside(v,p0,p1,p2,p3) {mmprefix}shuffle_epi{I}(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_aabb_eachside(a,b,p0,p1,p2,p3) {mmprefix}castps_si256({mmprefix}shuffle_ps({mmprefix}castsi256_ps(a),{mmprefix}castsi256_ps(b),_MM_SHUFFLE(p3,p2,p1,p0)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,c4,c5,c6,c7,t,e) {mmprefix}blend_epi{I}(e,t,(c0)|((c1)<<1)|((c2)<<2)|((c3)<<3)|((c4)<<4)|((c5)<<5)|((c6)<<6)|((c7)<<7))
''')
if (I,vectorwords) == (64,4):
print(fr'''#define {int32vec}_add {mmprefix}add_epi32
#define {int32vec}_sub {mmprefix}sub_epi32
#define {int32vec}_set {mmprefix}setr_epi32
#define {int32vec}_broadcast {mmprefix}set1_epi32
#define {int32vec}_varextract {mmprefix}permutevar8x32_epi32
#define {intIvec}_set {mmprefix}setr_epi64x
#define {intIvec}_broadcast {mmprefix}set1_epi64x
#define {intIvec}_extract(v,p0,p1,p2,p3) {mmprefix}permute4x64_epi64(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define {intIvec}_constextract_eachside(v,p0,p1) {mmprefix}shuffle_epi32(v,_MM_SHUFFLE(2*(p1)+1,2*(p1),2*(p0)+1,2*(p0)))
#define {intIvec}_constextract_a01b01a23b23(a,b,p0,p1,p2,p3) {mmprefix}castpd_si256({mmprefix}shuffle_pd({mmprefix}castsi256_pd(a),{mmprefix}castsi256_pd(b),(p0)|((p1)<<1)|((p2)<<2)|((p3)<<3)))
#define {intIvec}_ifconstthenelse(c0,c1,c2,c3,t,e) {mmprefix}blend_epi32(e,t,(c0)|((c0)<<1)|((c1)<<2)|((c1)<<3)|((c2)<<4)|((c2)<<5)|((c3)<<6)|((c3)<<7))
#define {intIvec}_1032(v) {intIvec}_extract(v,1,0,3,2)
#define {intIvec}_2301(v) {intIvec}_extract(v,2,3,0,1)
#define {intIvec}_3210(v) {intIvec}_extract(v,3,2,1,0)
#include "crypto_int32.h"
#define int32_min crypto_int32_min
''')
# XXX: can skip some of the macros above if allowmaskload and allowmaskstore
if usepartialmem and (allowmaskload or allowmaskstore):
print(fr'''#define partialmem (partialmem_storage+64)
static const {intI} partialmem_storage[] __attribute__((aligned(128))) = {{
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
}} ;
#define {intIvec}_partialload(p,z) {mmprefix}maskload_epi{I}((void *) (z),p)
#define {intIvec}_partialstore(p,z,i) {mmprefix}maskstore_epi{I}((void *) (z),p,(i))
''')
# ===== serial fixed-size networks
def smallsize():
print(fr'''NOINLINE
static void {intI}_sort_3through7({intI} *x,long long n)
{{
if (n >= 4) {{
{intI} x0 = x[0];
{intI} x1 = x[1];
{intI} x2 = x[2];
{intI} x3 = x[3];
{intI}_MINMAX(x0,x1);
{intI}_MINMAX(x2,x3);
{intI}_MINMAX(x0,x2);
{intI}_MINMAX(x1,x3);
{intI}_MINMAX(x1,x2);
if (n >= 5) {{
if (n == 5) {{
{intI} x4 = x[4];
{intI}_MINMAX(x0,x4);
{intI}_MINMAX(x2,x4);
{intI}_MINMAX(x1,x2);
{intI}_MINMAX(x3,x4);
x[4] = x4;
}} else {{
{intI} x4 = x[4];
{intI} x5 = x[5];
{intI}_MINMAX(x4,x5);
if (n == 6) {{
{intI}_MINMAX(x0,x4);
{intI}_MINMAX(x2,x4);
{intI}_MINMAX(x1,x5);
{intI}_MINMAX(x3,x5);
}} else {{
{intI} x6 = x[6];
{intI}_MINMAX(x4,x6);
{intI}_MINMAX(x5,x6);
{intI}_MINMAX(x0,x4);
{intI}_MINMAX(x2,x6);
{intI}_MINMAX(x2,x4);
{intI}_MINMAX(x1,x5);
{intI}_MINMAX(x3,x5);
{intI}_MINMAX(x5,x6);
x[6] = x6;
}}
{intI}_MINMAX(x1,x2);
{intI}_MINMAX(x3,x4);
x[4] = x4;
x[5] = x5;
}}
}}
x[0] = x0;
x[1] = x1;
x[2] = x2;
x[3] = x3;
}} else {{
{intI} x0 = x[0];
{intI} x1 = x[1];
{intI} x2 = x[2];
{intI}_MINMAX(x0,x1);
{intI}_MINMAX(x0,x2);
{intI}_MINMAX(x1,x2);
x[0] = x0;
x[1] = x1;
x[2] = x2;
}}
}}
''')
# ===== vectorized fixed-size networks
class vectorsortingnetwork:
def __init__(self,layout,xor=True):
layout = [list(x) for x in layout]
for x in layout: assert len(x) == vectorwords
assert len(layout)%2 == 0 # XXX: drop this restriction?
self.layout = layout
self.phys = [f'x{r}' for r in range(len(layout))]
# arch register r is stored in self.phys[r]
self.operations = []
self.usedregs = set(self.phys)
self.usedregssmallindex = set()
self.usexor = xor
self.useinfty = False
if xor:
self.assign('vecxor',f'{intIvec}_broadcast(xor)')
def print(self):
print('{')
if len(self.usedregssmallindex) > 0:
print(f' int32_t {",".join(sorted(self.usedregssmallindex))};')
if len(self.usedregs) > 0:
print(f' {intIvec} {",".join(sorted(self.usedregs))};')
for op in self.operations:
if op[0] == 'assign':
phys,result,newlayout = op[1:]
if newlayout is None:
print(f' {phys} = {result};')
else:
print(f' {phys} = {result}; // {" ".join(map(str,newlayout))}')
elif op[0] == 'assignsmallindex':
phys,result = op[1:]
print(f' {phys} = {result};')
elif op[0] == 'store':
result, = op[1:]
print(f' {result};')
elif op[0] == 'comment':
c, = op[1:]
print(f' // {c}')
else:
raise Exception(f'unrecognized operation {op}')
print('}')
print('')
def allocate(self,r):
'''Assign a new free physical register to arch register r. OK for caller to also use old register until calling allocate again.'''
self.phys[r] = {'x':'y','y':'x'}[self.phys[r][0]]+self.phys[r][1:]
# XXX: for generating low-level asm would instead want to cycle between fewer regs
# but, for generating C (or qhasm), trusting register allocator makes generated assignments a bit more readable
# (at least for V_sort)
def comment(self,c):
self.operations.append(('comment',c))
def assign(self,phys,result,newlayout=None):
self.usedregs.add(phys)
self.operations.append(('assign',phys,result,newlayout))
def assignsmallindex(self,phys,result):
self.usedregssmallindex.add(phys)
self.operations.append(('assignsmallindex',phys,result))
def createinfty(self):
if self.useinfty: return
self.useinfty = True
self.assign('infty',f'{intIvec}_broadcast({intI}_largest)')
def partialmask(self,offset):
if usepartialmem:
return f'{intIvec}_load(&partialmem[{offset}-n])'
# XXX: also do version with offset in constants rather than n?
rangevectorwords = ','.join(map(str,range(vectorwords)))
return f'{intIvec}_smaller_mask({intIvec}_set({rangevectorwords}),{intIvec}_broadcast(n-{offset}))'
def load(self,r,partial=False):
rphys = self.phys[r]
xor = 'vecxor^' if self.usexor else ''
if not partial:
self.assign(rphys,f'{xor}{intIvec}_load(x+{vectorwords*r})',self.layout[r])
return
# uint instead of int would slightly streamline infty usage
self.createinfty()
if allowmaskload:
self.assign(f'partial{r}',f'{self.partialmask(r*vectorwords)}')
self.assign(rphys,f'{xor}{intIvec}_partialload(partial{r},x+{vectorwords*r})',self.layout[r])
self.assign(rphys,f'{int8vec}_iftopthenelse(partial{r},{rphys},infty)')
return
self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorwords},n)')
xdata = f'{intIvec}_load(x+pos{r}-{vectorwords})'
if vectorbits == 128:
if vectorunit == 'neon':
mplus = ','.join(map(str,range(0,vectorbits//8)))
rotated = f'{intIvec}_from_{int8vec}({int8vec}_varextract({int8vec}_from_{intIvec}({xdata}),u{int8vec}_add(u{int8vec}_set({mplus}),u{int8vec}_broadcast({I//8}*((-pos{r})&{(vectorbits//I-1)})))))'
else:
mplus = ','.join(map(str,range(vectorbits//8,2*vectorbits//8)))
rotated = f'{int8vec}_varextract({xdata},{int8vec}_sub({int8vec}_set({mplus}),{int8vec}_broadcast({I//8}*(pos{r}&{(vectorbits//I-1)}))))'
mplus = ','.join(map(str,range(r*vectorwords,(r+1)*vectorwords)))
if vectorunit == 'neon':
control = f'{intIvec}_smaller_umask({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))'
else:
control = f'{intIvec}_smaller_mask({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))'
elif (I,vectorwords) == (32,8):
mplus = ','.join(map(str,range(r*vectorwords,(r+1)*vectorwords)))
self.assign(f'diff{r}',f'{intIvec}_sub({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r}))')
rotated = f'{intIvec}_varextract({xdata},diff{r})'
control = f'diff{r}'
elif (I,vectorwords) == (64,4):
mplus = ','.join(map(str,range(2*r*vectorwords,2*(r+1)*vectorwords)))
self.assign(f'diff{r}',f'{int32vec}_sub({int32vec}_set({mplus}),{int32vec}_broadcast(2*pos{r}))')
rotated = f'{int32vec}_varextract({xdata},diff{r})'
control = f'diff{r}'
else:
raise Exception('unhandled partial load')
if vectorunit == 'neon':
self.assign(rphys,f'{intIvec}_ifthenelse({control},{rotated},infty)',self.layout[r])
else:
self.assign(rphys,f'{int8vec}_iftopthenelse({control},{rotated},infty)',self.layout[r])
# warning: must do store in reverse order
def store(self,r,partial=False):
rphys = self.phys[r]
xor = 'vecxor^' if self.usexor else ''
if not partial:
self.operations.append(('store',f'{intIvec}_store(x+{vectorwords*r},{xor}{rphys})'))
return
if allowmaskstore:
if not allowmaskload:
self.assign(f'partial{r}',f'{self.partialmask(r*vectorwords)}')
self.operations.append(('store',f'{intIvec}_partialstore(partial{r},x+{vectorwords*r},{xor}{rphys})'))
return
if allowmaskload:
self.assignsmallindex(f'pos{r}',f'int32_min({(r+1)*vectorwords},n)')
# this is why store has to be in reverse order
if vectorbits == 128:
if vectorunit == 'neon':
mplus = ','.join(map(str,range(0,vectorbits//8)))
storeshift = f'u{int8vec}_sub(u{int8vec}_set({mplus}),u{int8vec}_broadcast({I//8}*((-pos{r})&{(vectorbits//I-1)})))'
xdata = f'{intIvec}_from_{int8vec}({int8vec}_varextract({int8vec}_from_{intIvec}({xor}{rphys}),{storeshift}))'
else:
mplus = ','.join(map(str,range(vectorbits//8,2*vectorbits//8)))
storeshift = f'{int8vec}_add({int8vec}_set({mplus}),{int8vec}_broadcast({I//8}*(pos{r}&{(vectorbits//I-1)})))'
xdata = f'{int8vec}_varextract({xor}{rphys},{storeshift})'
elif (I,vectorwords) == (32,8):
mplus = ','.join(map(str,range(vectorwords)))
xdata = f'{intIvec}_varextract({xor}{rphys},{intIvec}_add({intIvec}_set({mplus}),{intIvec}_broadcast(pos{r})))'
elif (I,vectorwords) == (64,4):
mplus = ','.join(map(str,range(2*vectorwords)))
xdata = f'{int32vec}_varextract({xor}{rphys},{int32vec}_add({int32vec}_set({mplus}),{int32vec}_broadcast(2*pos{r})))'
else:
raise Exception('unhandled partial store')
self.operations.append(('store',f'{intIvec}_store(x+pos{r}-{vectorwords},{xdata})'))
def vecswap(self,r,s):
assert r != s
self.phys[r],self.phys[s] = self.phys[s],self.phys[r]
self.layout[r],self.layout[s] = self.layout[s],self.layout[r]
def minmax(self,r,s):
assert r != s
rphys = self.phys[r]
sphys = self.phys[s]
newlayout0 = [min(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
newlayout1 = [max(L0,L1) for L0,L1 in zip(self.layout[r],self.layout[s])]
self.layout[r],self.layout[s] = newlayout0,newlayout1
self.allocate(r)
rphysnew = self.phys[r]
if I == 32:
self.assign(rphysnew,f'{intIvec}_min({rphys},{sphys})',newlayout0)
self.assign(sphys,f'{intIvec}_max({rphys},{sphys})',newlayout1)
else:
if vectorunit == 'neon':
self.assign('t',f'{intIvec}_from_u{intIvec}({intIvec}_smaller_umask({rphys},{sphys}))')
self.assign(rphysnew,f'{intIvec}_ifthenelse(u{intIvec}_from_{intIvec}(t),{rphys},{sphys})',newlayout0)
self.assign(sphys,f'{intIvec}_ifthenelse(u{intIvec}_from_{intIvec}(t),{sphys},{rphys})',newlayout1)
else:
# XXX: avx-512 has min_epi64 but avx2 does not
self.assign('t',f'{intIvec}_smaller_mask({rphys},{sphys})')
self.assign(rphysnew,f'{int8vec}_iftopthenelse(t,{rphys},{sphys})',newlayout0)
self.assign(sphys,f'{int8vec}_iftopthenelse(t,{sphys},{rphys})',newlayout1)
def shuffle1(self,r,L):
r'''Rearrange layout of vector r to match L.'''
L = list(L)
oldL = self.layout[r]
perm = [oldL.index(a) for a in L]
rphys = self.phys[r]
if vectorwords == 2 and perm == [1,0]:
self.assign(rphys,f'{intIvec}_10({rphys})',L)
elif vectorwords == 4 and perm == [1,0,3,2]:
self.assign(rphys,f'{intIvec}_1032({rphys})',L)
elif vectorwords == 4 and perm == [2,3,0,1]:
self.assign(rphys,f'{intIvec}_2301({rphys})',L)
elif vectorwords == 4 and perm == [3,2,1,0]:
self.assign(rphys,f'{intIvec}_3210({rphys})',L)
elif (I,vectorwords) == (32,8) and perm[4:] == [perm[0]+4,perm[1]+4,perm[2]+4,perm[3]+4]:
self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
elif (I,vectorwords) == (64,4) and perm[2:] == [perm[0]+2,perm[1]+2]:
self.assign(rphys,f'{intIvec}_constextract_eachside({rphys},{perm[0]},{perm[1]})',L)
elif vectorwords == 8 and vectorunit != 'neon':
self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]},{perm[4]},{perm[5]},{perm[6]},{perm[7]})',L)
elif vectorwords == 4 and vectorunit != 'neon':
self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]},{perm[2]},{perm[3]})',L)
elif vectorwords == 2 and vectorunit != 'neon':
self.assign(rphys,f'{intIvec}_extract({rphys},{perm[0]},{perm[1]})',L)
else:
raise Exception(f'unhandled permutation from {oldL} to {L}')
self.layout[r] = L
def shuffle2(self,r,s,L,M,exact=False):
oldL = self.layout[r]
oldM = self.layout[s]
rphys = self.phys[r]
sphys = self.phys[s]
try:
assert (I,vectorwords) == (64,4)
p0 = oldL[:2].index(L[0])
p1 = oldM[:2].index(L[1])
p2 = oldL[2:].index(L[2])
p3 = oldM[2:].index(L[3])
q0 = oldL[:2].index(M[0])
q1 = oldM[:2].index(M[1])
q2 = oldL[2:].index(M[2])
q3 = oldM[2:].index(M[3])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
self.assign(sphys,f'{intIvec}_constextract_a01b01a23b23({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
return
except:
pass
try:
assert (I,vectorwords) == (32,8)
p0 = oldL[:4].index(L[0])
p1 = oldL[:4].index(L[1])
p2 = oldM[:4].index(L[2])
p3 = oldM[:4].index(L[3])
assert p0 == oldL[4:].index(L[4])
assert p1 == oldL[4:].index(L[5])
assert p2 == oldM[4:].index(L[6])
assert p3 == oldM[4:].index(L[7])
q0 = oldL[:4].index(M[0])
q1 = oldL[:4].index(M[1])
q2 = oldM[:4].index(M[2])
q3 = oldM[:4].index(M[3])
assert q0 == oldL[4:].index(M[4])
assert q1 == oldL[4:].index(M[5])
assert q2 == oldM[4:].index(M[6])
assert q3 == oldM[4:].index(M[7])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',L)
self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',M)
return
except:
pass
try:
assert not exact
assert (I,vectorwords) == (32,8)
p0 = oldL[:4].index(L[0])
p1 = oldL[:4].index(L[2])
p2 = oldM[:4].index(L[1])
p3 = oldM[:4].index(L[3])
assert p0 == oldL[4:].index(L[4])
assert p1 == oldL[4:].index(L[6])
assert p2 == oldM[4:].index(L[5])
assert p3 == oldM[4:].index(L[7])
q0 = oldL[:4].index(M[0])
q1 = oldL[:4].index(M[2])
q2 = oldM[:4].index(M[1])
q3 = oldM[:4].index(M[3])
assert q0 == oldL[4:].index(M[4])
assert q1 == oldL[4:].index(M[6])
assert q2 == oldM[4:].index(M[5])
assert q3 == oldM[4:].index(M[7])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = [L[0],L[2],L[1],L[3],L[4],L[6],L[5],L[7]]
self.layout[s] = [M[0],M[2],M[1],M[3],M[4],M[6],M[5],M[7]]
self.assign(rphysnew,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{p0},{p1},{p2},{p3})',self.layout[r])
self.assign(sphys,f'{intIvec}_constextract_aabb_eachside({rphys},{sphys},{q0},{q1},{q2},{q3})',self.layout[s])
return
except:
pass
try:
half = vectorwords//2
assert oldL[:half] == L[:half]
assert oldL[half:] == M[:half]
assert oldM[:half] == L[half:]
assert oldM[half:] == M[half:]
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
return
except:
pass
try:
assert oldL == [L[2],L[0],M[2],M[0]]
assert oldM == [L[3],L[1],M[3],M[1]]
self.assign('t',f'{intIvec}_a0b0a1b1({rphys},{sphys})',[L[2],L[3],L[0],L[1]])
self.assign('u',f'{intIvec}_a2b2a3b3({rphys},{sphys})',[M[2],M[3],M[0],M[1]])
self.layout[r] = L
self.layout[s] = M
self.assign(rphys,f'{intIvec}_2301(t)',L)
self.assign(sphys,f'{intIvec}_2301(u)',M)
return
except:
pass
try:
assert vectorunit == 'neon'
assert oldL == [L[0],M[0],L[2],M[2]]
assert oldM == [L[1],M[1],L[3],M[3]]
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_a0b0a2b2({rphys},{sphys})',L)
self.assign(sphys,f'{intIvec}_a1b1a3b3({rphys},{sphys})',M)
return
except:
pass
try:
assert vectorunit == 'neon'
assert oldL == [L[0],L[2],M[0],M[2]]
assert oldM == [L[1],L[3],M[1],M[3]]
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_a0b0a1b1({rphys},{sphys})',L)
self.assign(sphys,f'{intIvec}_a2b2a3b3({rphys},{sphys})',M)
return
except:
pass
try:
assert vectorunit != 'neon'
assert not exact
# XXX: generalize this
assert oldL == [L[0],M[0],L[2],M[2]]
assert oldM == [L[1],M[1],L[3],M[3]]
self.assign('t',f'{intIvec}_1032({rphys})',[M[0],L[0],M[2],L[2]])
self.layout[r] = [L[1],L[0],L[3],L[2]]
self.layout[s] = M
self.assign(rphys,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},1,0,1,0)',self.layout[r])
self.assign(sphys,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},0,1,0,1)',self.layout[s])
return
except:
pass
try:
assert vectorunit != 'neon'
# XXX: generalize this
assert oldL == [M[0],L[0],M[2],L[2]]
assert oldM == [M[1],L[1],M[3],L[3]]
self.assign('t',f'{intIvec}_1032({rphys})',[L[0],M[0],L[2],M[2]])
self.assign('u',f'{intIvec}_1032({sphys})',[L[1],M[1],L[2],M[3]])
self.allocate(r)
rphysnew = self.phys[r]
self.layout[r] = L
self.layout[s] = M
self.assign(rphysnew,f'{intIvec}_constextract_ab0ab1ab2ab3(t,{sphys},0,1,0,1)',L)
self.assign(sphys,f'{intIvec}_constextract_ab0ab1ab2ab3(u,{rphys},1,0,1,0)',M)
return
except:
pass
for bigflip in False,True:
try:
assert vectorbits == 256
if bigflip:
half = vectorwords//2
z0 = L[:half]+M[:half]
z1 = L[half:]+M[half:]
else:
z0 = L
z1 = M
blend = []
shuf0 = [None]*vectorwords
shuf1 = [None]*vectorwords
for i in range(vectorwords):
if oldL[i] == z0[i]:
blend.append(0)
shuf0[i] = oldL.index(z1[i])
shuf1[i] = 0
if i < vectorwords//2:
assert shuf0[i] < vectorwords//2
else:
assert shuf0[i] == shuf0[i-vectorwords//2]+vectorwords//2
else:
blend.append(1)
assert oldM[i] == z1[i]
shuf0[i] = 0
shuf1[i] = oldM.index(z0[i])
if i < vectorwords//2:
assert shuf1[i] < vectorwords//2
else:
assert shuf1[i] == shuf1[i-vectorwords//2]+vectorwords//2
blend = ','.join(map(str,blend))
s0 = ','.join(map(str,shuf0[:vectorwords//2]))
s1 = ','.join(map(str,shuf1[:vectorwords//2]))
# XXX: encapsulate temporaries better
self.assign('u',f'{intIvec}_constextract_eachside({sphys},{s1})')
self.assign('t',f'{intIvec}_constextract_eachside({rphys},{s0})')
self.assign(rphys,f'{intIvec}_ifconstthenelse({blend},u,{rphys})',L)
self.assign(sphys,f'{intIvec}_ifconstthenelse({blend},{sphys},t)',M)
if bigflip:
self.allocate(r)
rphysnew = self.phys[r]
self.assign(rphysnew,f'{intIvec}_leftleft({rphys},{sphys})',L)
self.assign(sphys,f'{intIvec}_rightright({rphys},{sphys})',M)
self.layout[r] = L
self.layout[s] = M
return
except:
pass
raise Exception(f'unhandled permutation from {oldL},{oldM} to {L},{M}')
def rearrange_onestep(self,comparators):
numvectors = len(self.layout)
for k in range(0,numvectors,2):
if all(comparators[a] == a for a in self.layout[k]):
if all(comparators[a] == a for a in self.layout[k+1]):
continue
collected = set(self.layout[k]+self.layout[k+1])
if not all(comparators[a] in collected for a in collected):
for j in range(numvectors):
if all(comparators[a] in self.layout[j] for a in self.layout[k]):
self.vecswap(j,k+1)
return True
if any(comparators[a] in self.layout[k] for a in self.layout[k]):
newlayout0 = list(self.layout[k])
newlayout1 = list(self.layout[k+1])
for i in range(vectorwords):
a = newlayout0[i]
b = comparators[a]
if b == newlayout1[i]: continue
if b in newlayout0:
j = newlayout0.index(b)
assert j > i
newlayout1[i],newlayout0[j] = newlayout0[j],newlayout1[i]
else:
j = newlayout1.index(b)
assert j > i
newlayout1[i],newlayout1[j] = newlayout1[j],newlayout1[i]
self.shuffle2(k,k+1,newlayout0,newlayout1)
return True
if [comparators[a] for a in self.layout[k]] != self.layout[k+1]:
newlayout = [comparators[a] for a in self.layout[k]]
self.shuffle1(k+1,newlayout)
return True
return False
def rearrange(self,comparators):
comparators = dict(comparators)
while self.rearrange_onestep(comparators): pass
def fixedsize(nlow,nhigh,xor=False,V=False):
nlow = int(nlow)
nhigh = int(nhigh)
assert 0 <= nlow
assert nlow <= nhigh
assert nhigh-partiallimit <= nlow
assert nhigh%vectorwords == 0
assert nlow%vectorwords == 0
lgn = 4
while 2**lgn < nhigh: lgn += 1
if nhigh < 2*vectorwords:
raise Exception(f'unable to handle sizes below {2*vectorwords}')
if V:
funname = f'{intI}_V_sort_'
else:
funname = f'{intI}_sort_'
funargs = f'{intI} *x'
if nlow < nhigh:
funname += f'{nlow}through'
funargs += ',long long n'
funname += f'{nhigh}'
if xor:
funname += '_xor'
funargs += f',{intI} xor'
print('NOINLINE')
print(f'static void {funname}({funargs})')
# ===== decide on initial layout of nodes
numvectors = nhigh//vectorwords
layout = {}
if V:
for k in range(numvectors//2):
layout[k] = list(reversed(range(vectorwords*(numvectors//2-1-k),vectorwords*(numvectors//2-k))))
for k in range(numvectors//2,numvectors):
layout[k] = list(range(vectorwords*k,vectorwords*(k+1)))
else:
for k in range(0,numvectors,nhigh//vectorwords):
for offset in range(nhigh//vectorwords):
layout[k+offset] = list(range(vectorwords*k+offset,vectorwords*(k+nhigh//vectorwords),nhigh//vectorwords))
layout = [layout[k] for k in range(len(layout))]
# ===== build network
S = vectorsortingnetwork(layout,xor=xor)
for k in range(numvectors):
S.load(k,partial=k*vectorwords>=nlow)
for lgsubsort in range(1,lgn+1):
if V and lgsubsort < lgn: continue
for stage in reversed(range(lgsubsort)):
if nhigh >= 2*vectorwords and (lgsubsort,stage) == (1,0):
comparators = {a:a^1 for a in range(nhigh)}
elif nhigh >= 4*vectorwords and (lgsubsort,stage) == (2,1):
comparators = {a:a^2 for a in range(nhigh)}
elif nhigh >= 4*vectorwords and (lgsubsort,stage) == (2,0):
comparators = {a:a^(3*(1&((a>>0)^(a>>1)))) for a in range(nhigh)}
elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,2):
comparators = {a:a^4 for a in range(nhigh)}
elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,1):
comparators = {a:a+[0,0,2,2,-2,-2,0,0][a%8] for a in range(nhigh)}
elif nhigh >= 8*vectorwords and (lgsubsort,stage) == (3,0):
comparators = {a:a+[0,1,-1,1,-1,1,-1,0][a%8] for a in range(nhigh)}
elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,3):
comparators = {a:a^8 for a in range(nhigh)}
elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,2):
comparators = {a:a+[0,0,0,0,4,4,4,4,-4,-4,-4,-4,0,0,0,0][a%16] for a in range(nhigh)}
elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,1):
comparators = {a:a+[0,0,2,2,-2,-2,2,2,-2,-2,2,2,-2,-2,0,0][a%16] for a in range(nhigh)}
elif nhigh >= 16*vectorwords and (lgsubsort,stage) == (4,0):
comparators = {a:a+[0,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,1,-1,0][a%16] for a in range(nhigh)}
else:
if stage == lgsubsort-1:
stagemask = (2<<stage)-1
else:
stagemask = 1<<stage
comparators = {a:a^stagemask for a in range(nhigh)}
if vectorwords == 8 and 2<<stage == nhigh and not V:
for k in range(0,numvectors,2):
p = S.layout[k]
newlayout = [p[k^((((k>>2)^(k>>1))&1)*6)] for k in range(vectorwords)] # XXX
S.shuffle1(k,newlayout)
strcomparators = ' '.join(f'{a}:{comparators[a]}' for a in range(nhigh) if comparators[a] > a)
S.comment(f'stage ({lgsubsort},{stage}) {strcomparators}')
S.rearrange(comparators)
for k in range(0,numvectors,2):
if all(comparators[a] == a for a in S.layout[k]):
if all(comparators[a] == a for a in S.layout[k+1]):
continue
S.minmax(k,k+1)
for k in range(0,numvectors,2):
for offset in 0,1:
for i in range(numvectors):
if k*vectorwords+offset in S.layout[i]:
if i != k+offset:
S.vecswap(i,k+offset)
break
for k in range(0,numvectors,2):
y0 = list(range(k*vectorwords,(k+1)*vectorwords))
y1 = list(range((k+1)*vectorwords,(k+2)*vectorwords))
S.shuffle2(k,k+1,y0,y1,exact=True)
for k in range(numvectors) if allowmaskstore or nlow == nhigh else reversed(range(numvectors)):
S.store(k,partial=k*vectorwords>=nlow)
S.print()
# ===== V_sort
def threestages(k,down=False,p=None,atleast=None):
assert k in (4,5,6,7,8)
print('NOINLINE')
updown = 'down' if down else 'up'
if p is not None:
print(f'static void {intI}_threestages_{k}_{updown}_{p}({intI} *x)')
elif atleast is not None:
print(f'static void {intI}_threestages_{k}_{updown}_atleast{atleast}({intI} *x,long long p)')
else:
print(f'static void {intI}_threestages_{k}_{updown}({intI} *x,long long p,long long n)')
print('{')
print(' long long i;')
if p is not None:
print(f' long long p = {p};')
if p is not None or atleast is not None:
print(f' long long n = p;')
for vector in True,False: # must be this order
if p is not None and p%vectorwords == 0 and not vector: break
if atleast is not None and atleast%vectorwords == 0 and not vector: break
xtype = intIvec if vector else intI
if vector:
print(f' for (i = 0;i+{vectorwords} <= n;i += {vectorwords}) {{')
else:
print(' for (;i < n;++i) {')
for j in range(k):
addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
if vector:
print(f' {xtype} x{j} = {xtype}_load(&x[{addr}]);')
else:
print(f' {xtype} x{j} = x[{addr}];')
for i,j in (0,4),(1,5),(2,6),(3,7),(0,2),(1,3),(4,6),(5,7),(0,1),(2,3),(4,5),(6,7):
if j >= k: continue
if down: i,j = j,i
print(f' {xtype}_MINMAX(x{i},x{j});')
for j in range(k):
addr = 'i' if j == 0 else 'p+i' if j == 1 else f'{j}*p+i'
if vector:
print(f' {xtype}_store(&x[{addr}],x{j});')
else:
print(f' x[{addr}] = x{j};')
print(' }')
print('}')
print('')
def V_sort():
for nlow,nhigh in unrolledVsort:
fixedsize(nlow,nhigh,V=True)
for n in unrolledVsortxor:
fixedsize(n,n,xor=True,V=True)
threestages(8)
threestages(7)
threestages(6)
threestages(5)
threestages(4)
threestages_min = Vsort2_min
for p in threestages_specialize:
assert p == threestages_min
threestages_min *= 2
threestages(8,p=p)
threestages(8,down=True,p=p)
threestages(8,down=True,atleast=threestages_min)
threestages(6,down=True)
print(f'''// XXX: currently xor must be 0 or -1
NOINLINE
static void {intI}_{Vsort2}_xor({intI} *x,long long n,{intI} xor)
{{''')
for n in unrolledVsortxor:
print(f' if (n == {n}) {{ {intI}_V_sort_{n}_xor(x,xor); return; }}')
assert unrolledVsortxor[:3] == (Vsort2_min,Vsort2_min*2,Vsort2_min*4)
# so n is at least Vsort2_min*8, justifying the following recursive calls
for p in threestages_specialize:
assert p in unrolledVsortxor
print(f''' if (n == {p*8}) {{
if (xor)
{intI}_threestages_8_down_{p}(x);
else
{intI}_threestages_8_up_{p}(x);
for (long long i = 0;i < 8;++i)
{intI}_V_sort_{p}_xor(x+{p}*i,xor);
return;
}}''')
print(f''' if (xor)
{intI}_threestages_8_down_atleast{threestages_min}(x,n>>3);
else
{intI}_threestages_8_up(x,n>>3,n>>3);
for (long long i = 0;i < 8;++i)
{intI}_{Vsort2}_xor(x+(n>>3)*i,n>>3,xor);
}}
/* q is power of 2; want only merge stages q,q/2,q/4,...,1 */
// XXX: assuming 8 <= q < n <= 2q; q is a power of 2
NOINLINE
static void {intI}_V_sort({intI} *x,long long q,long long n)
{{''')
assert any(nhigh == Vsort2_min for nlow,nhigh in unrolledVsort)
for nlow,nhigh in unrolledVsort:
if nhigh == Vsort2_min and favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')
print(f' if (n <= {nhigh}) {{ {intI}_V_sort_{nlow}through{nhigh}(x,n); return; }}')
if not favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{Vsort2}_xor(x,n,0); return; }}''')
assert any(nhigh == 64 for nlow,nhigh in unrolledVsort)
print(f'''
// 64 <= q < n < 2q
q >>= 2;
// 64 <= 4q < n < 8q
if (7*q < n) {{
{intI}_threestages_8_up(x,q,n-7*q);
{intI}_threestages_7_up(x+n-7*q,q,8*q-n);
}} else if (6*q < n) {{
{intI}_threestages_7_up(x,q,n-6*q);
{intI}_threestages_6_up(x+n-6*q,q,7*q-n);
}} else if (5*q < n) {{
{intI}_threestages_6_up(x,q,n-5*q);
{intI}_threestages_5_up(x+n-5*q,q,6*q-n);
}} else {{
{intI}_threestages_5_up(x,q,n-4*q);
{intI}_threestages_4_up(x+n-4*q,q,5*q-n);
}}
// now want to handle each batch of q entries separately
{intI}_V_sort(x,q>>1,q);
{intI}_V_sort(x+q,q>>1,q);
{intI}_V_sort(x+2*q,q>>1,q);
{intI}_V_sort(x+3*q,q>>1,q);
x += 4*q;
n -= 4*q;
while (n >= q) {{
{intI}_V_sort(x,q>>1,q);
x += q;
n -= q;
}}
// have n entries left in last batch, with 0 <= n < q
if (n <= 1) return;
while (q >= n) q >>= 1; // empty merge stage
// now 1 <= q < n <= 2q
if (q >= 8) {{ {intI}_V_sort(x,q,n); return; }}
if (n == 8) {{
{intI}_MINMAX(x[0],x[4]);
{intI}_MINMAX(x[1],x[5]);
{intI}_MINMAX(x[2],x[6]);
{intI}_MINMAX(x[3],x[7]);
{intI}_MINMAX(x[0],x[2]);
{intI}_MINMAX(x[1],x[3]);
{intI}_MINMAX(x[0],x[1]);
{intI}_MINMAX(x[2],x[3]);
{intI}_MINMAX(x[4],x[6]);
{intI}_MINMAX(x[5],x[7]);
{intI}_MINMAX(x[4],x[5]);
{intI}_MINMAX(x[6],x[7]);
return;
}}
if (4 <= n) {{
for (long long i = 0;i < n-4;++i)
{intI}_MINMAX(x[i],x[4+i]);
{intI}_MINMAX(x[0],x[2]);
{intI}_MINMAX(x[1],x[3]);
{intI}_MINMAX(x[0],x[1]);
{intI}_MINMAX(x[2],x[3]);
n -= 4;
x += 4;
}}
if (3 <= n)
{intI}_MINMAX(x[0],x[2]);
if (2 <= n)
{intI}_MINMAX(x[0],x[1]);
}}
''')
# ===== main sort
def main_sort_prep():
smallsize()
for nlow,nhigh in unrolledsort:
fixedsize(nlow,nhigh)
for n in unrolledsortxor:
fixedsize(n,n,xor=True)
def main_sort():
print('// XXX: currently xor must be 0 or -1')
print('NOINLINE')
print(f'static void {intI}_{sort2}_xor({intI} *x,long long n,{intI} xor)')
print('{')
for n in unrolledsortxor:
print(f' if (n == {n}) {{ {intI}_sort_{n}_xor(x,xor); return; }}')
print(f' {intI}_{sort2}_xor(x,n>>1,~xor);')
print(f' {intI}_{sort2}_xor(x+(n>>1),n>>1,xor);')
print(f' {intI}_{Vsort2}_xor(x,n,xor);')
print('}')
print(fr'''
void {intI}_sort({intI} *x,long long n)
{{ long long q;
if (n <= 1) return;
if (n == 2) {{ {intI}_MINMAX(x[0],x[1]); return; }}
if (n <= 7) {{ {intI}_sort_3through7(x,n); return; }}''')
# XXX: n cutoff here should be another variable to optimize
nmin = 8
# invariant: n in program is at least nmin
assert any(nhigh == sort2_min for nlow,nhigh in unrolledsort)
for nlow,nhigh in unrolledsort:
if nhigh == sort2_min and favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')
if nlow <= nmin:
print(f' if (n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')
while nlow <= nmin and nmin <= nhigh: nmin += 1
else:
print(f' if ({nlow} <= n && n <= {nhigh}) {{ {intI}_sort_{nlow}through{nhigh}(x,n); return; }}')
if not favorxor:
print(f' if (!(n & (n - 1))) {{ {intI}_{sort2}_xor(x,n,0); return; }}''')
qmin = 1
while qmin < nmin-qmin: qmin += qmin
assert sort2_min <= qmin
for padlow,padhigh in pads:
padlowdown = padlow
while padlowdown%vectorwords: padlowdown -= 1
padhighup = padhigh
while padhighup%vectorwords: padhighup += 1
print(fr''' if ({padlow} <= n && n <= {padhigh}) {{
{intI} buf[{padhighup}];
for (long long i = {padlowdown};i < {padhighup};++i) buf[i] = {intI}_largest;
for (long long i = 0;i < n;++i) buf[i] = x[i];
{intI}_sort(buf,{padhighup});
for (long long i = 0;i < n;++i) x[i] = buf[i];
return;
}}''')
assert sort2_min%2 == 0
assert sort2_min//2 >= Vsort2_min
print(fr'''
q = {qmin};
while (q < n - q) q += q;
// {qmin} <= q < n < 2q
if ({sort2_min*16} <= n && n <= (7*q)>>2) {{
long long m = (3*q)>>2; // strategy: sort m, sort n-m, merge
long long r = q>>3; // at least {sort2_min} since q is at least {sort2_min*8}
{intI}_{sort2}_xor(x,4*r,0);
{intI}_{sort2}_xor(x+4*r,r,0);
{intI}_{sort2}_xor(x+5*r,r,-1);
{intI}_{Vsort2}_xor(x+4*r,2*r,-1);
{intI}_threestages_6_down(x,r,r);
for (long long i = 0;i < 6;++i)
{intI}_{Vsort2}_xor(x+i*r,r,-1);
{intI}_sort(x+m,n-m);
}} else if ({sort2_min*2} <= q && n == (3*q)>>1) {{
// strategy: sort q, sort q/2, merge
long long r = q>>2; // at least {sort2_min//2} since q is at least {sort2_min*2}
{intI}_{sort2}_xor(x,4*r,-1);
{intI}_{sort2}_xor(x+4*r,2*r,0);
{intI}_threestages_6_up(x,r,r);
for (long long i = 0;i < 6;++i)
{intI}_{Vsort2}_xor(x+i*r,r,0);
return;
}} else {{
{intI}_{sort2}_xor(x,q,-1);
{intI}_sort(x+q,n-q);
}}
{intI}_V_sort(x,q,n);
}}''')
# ===== driver
preamble()
main_sort_prep()
V_sort()
main_sort()