-rwxr-xr-x 4805 djbsort-20260127/autogen/useint raw
#!/usr/bin/env python3
import os
# XXX: vectorize for more platforms
# XXX: can integrate xor loops with int{bits}_sort
# XXX: for int{bits}_sortdown, can reverse the output array as an alternative
def preamble():
f.write('/* WARNING: auto-generated (by autogen/useint); do not edit */\n\n')
if vec == 'avx2':
f.write(f'''#include <immintrin.h>
typedef __m256i {intIvec};
#define {intIvec}_load(z) _mm256_loadu_si256((__m256i *) (z))
#define {intIvec}_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i))
#define {intIvec}_broadcast _mm256_set1_epi{bits}{'x' if bits == 64 else ''}
''')
if what == 'float':
if bits == 64:
# avx2 does not have srai_epi64
topbit = hex(1<<(bits-1))+'ULL'
f.write(f'#define {intIvec}_floatmask(y) _mm256_sub_epi{bits}((y)&{intIvec}_broadcast({topbit}),_mm256_srli_epi{bits}(y,{bits-1}))\n')
else:
f.write(f'#define {intIvec}_floatmask(y) _mm256_srli_epi{bits}(_mm256_srai_epi{bits}(y,{bits-1}),1)\n')
f.write(f'''#include "djbsort.h"
#include "{fun}_sort.h"
''')
if what == 'float':
f.write(f'#include "crypto_int{bits}.h"\n')
def intxor():
if vec == 'avx2':
f.write(f''' long long j;
{intIvec} vecxor = {intIvec}_broadcast({xor});
for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
{intIvec} x0 = {intIvec}_load(x+j);
{intIvec} x1 = {intIvec}_load(x+j+{vectorlen});
x0 ^= vecxor;
x1 ^= vecxor;
{intIvec}_store(x+j,x0);
{intIvec}_store(x+j+{vectorlen},x1);
}}
for (;j < n;++j) x[j] ^= {xor};
djbsort_int{bits}({cast}x,n);
for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
{intIvec} x0 = {intIvec}_load(x+j);
{intIvec} x1 = {intIvec}_load(x+j+{vectorlen});
x0 ^= vecxor;
x1 ^= vecxor;
{intIvec}_store(x+j,x0);
{intIvec}_store(x+j+{vectorlen},x1);
}}
for (;j < n;++j) x[j] ^= {xor};
''')
else:
f.write(f''' long long j;
for (j = 0;j < n;++j) x[j] ^= {xor};
djbsort_int{bits}({cast}x,n);
for (j = 0;j < n;++j) x[j] ^= {xor};
''')
def floatxor():
if vec == 'avx2':
vecxordown = f' ^ {intIvec}_broadcast(-1)' if down == 'down' else ''
f.write(f''' int{bits}_t *y = (int{bits}_t *) x;
long long j;
for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
{intIvec} y0 = {intIvec}_load(y+j);
{intIvec} y1 = {intIvec}_load(y+j+{vectorlen});
y0 ^= {intIvec}_floatmask(y0);
y1 ^= {intIvec}_floatmask(y1);
{intIvec}_store(y+j,y0{vecxordown});
{intIvec}_store(y+j+{vectorlen},y1{vecxordown});
}}
for (;j < n;++j) {{
int{bits}_t yj = y[j];
yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
y[j] = yj{xordown};
}}
djbsort_int{bits}(y,n);
for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{
{intIvec} y0 = {intIvec}_load(y+j){vecxordown};
{intIvec} y1 = {intIvec}_load(y+j+{vectorlen}){vecxordown};
y0 ^= {intIvec}_floatmask(y0);
y1 ^= {intIvec}_floatmask(y1);
{intIvec}_store(y+j,y0);
{intIvec}_store(y+j+{vectorlen},y1);
}}
for (;j < n;++j) {{
int{bits}_t yj = y[j]{xordown};
yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
y[j] = yj;
}}
''')
else:
f.write(f''' int{bits}_t *y = (int{bits}_t *) x;
long long j;
for (j = 0;j < n;++j) {{
int{bits}_t yj = y[j];
yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
y[j] = yj{xordown};
}}
djbsort_int{bits}(y,n);
for (j = 0;j < n;++j) {{
int{bits}_t yj = y[j]{xordown};
yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1;
y[j] = yj;
}}
''')
for bits in 32,64:
for down in '','down':
for what in 'int','uint','float':
for vec in '','avx2':
if what == 'int' and down == '': continue
fun = f'{what}{bits}{down}'
impldir = f'{fun}/{vec}useint{bits}'
T = f'{what}{bits}_t'
if T == 'float32_t': T = 'float'
if T == 'float64_t': T = 'double'
vectorlen = 256//bits
intIvec = f'int{bits}x{vectorlen}'
os.makedirs(f'{impldir}',exist_ok=True)
if vec == 'avx2':
with open(f'{impldir}/architectures','w') as f:
f.write('amd64 avx2\nx86 avx2\n')
with open(f'{impldir}/sort.c','w') as f:
preamble()
f.write(f'''
void {fun}_sort({T} *x,long long n)
{{
''')
if what == 'int':
xor = -1
cast = ''
intxor()
if what == 'uint':
if down == 'down':
xor = hex((1<<(bits-1))-1)
else:
xor = hex(1<<(bits-1))
if bits == 64: xor += 'ULL'
cast = f'(int{bits}_t *) '
intxor()
if what == 'float':
xordown = ' ^ -1' if down == 'down' else ''
floatxor()
f.write('}\n')