-rwxr-xr-x 4805 djbsort-20260127/autogen/useint raw
#!/usr/bin/env python3 import os # XXX: vectorize for more platforms # XXX: can integrate xor loops with int{bits}_sort # XXX: for int{bits}_sortdown, can reverse the output array as an alternative def preamble(): f.write('/* WARNING: auto-generated (by autogen/useint); do not edit */\n\n') if vec == 'avx2': f.write(f'''#include <immintrin.h> typedef __m256i {intIvec}; #define {intIvec}_load(z) _mm256_loadu_si256((__m256i *) (z)) #define {intIvec}_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i)) #define {intIvec}_broadcast _mm256_set1_epi{bits}{'x' if bits == 64 else ''} ''') if what == 'float': if bits == 64: # avx2 does not have srai_epi64 topbit = hex(1<<(bits-1))+'ULL' f.write(f'#define {intIvec}_floatmask(y) _mm256_sub_epi{bits}((y)&{intIvec}_broadcast({topbit}),_mm256_srli_epi{bits}(y,{bits-1}))\n') else: f.write(f'#define {intIvec}_floatmask(y) _mm256_srli_epi{bits}(_mm256_srai_epi{bits}(y,{bits-1}),1)\n') f.write(f'''#include "djbsort.h" #include "{fun}_sort.h" ''') if what == 'float': f.write(f'#include "crypto_int{bits}.h"\n') def intxor(): if vec == 'avx2': f.write(f''' long long j; {intIvec} vecxor = {intIvec}_broadcast({xor}); for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{ {intIvec} x0 = {intIvec}_load(x+j); {intIvec} x1 = {intIvec}_load(x+j+{vectorlen}); x0 ^= vecxor; x1 ^= vecxor; {intIvec}_store(x+j,x0); {intIvec}_store(x+j+{vectorlen},x1); }} for (;j < n;++j) x[j] ^= {xor}; djbsort_int{bits}({cast}x,n); for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{ {intIvec} x0 = {intIvec}_load(x+j); {intIvec} x1 = {intIvec}_load(x+j+{vectorlen}); x0 ^= vecxor; x1 ^= vecxor; {intIvec}_store(x+j,x0); {intIvec}_store(x+j+{vectorlen},x1); }} for (;j < n;++j) x[j] ^= {xor}; ''') else: f.write(f''' long long j; for (j = 0;j < n;++j) x[j] ^= {xor}; djbsort_int{bits}({cast}x,n); for (j = 0;j < n;++j) x[j] ^= {xor}; ''') def floatxor(): if vec == 'avx2': vecxordown = f' ^ {intIvec}_broadcast(-1)' if down == 'down' else '' f.write(f''' int{bits}_t *y = (int{bits}_t *) x; long long j; for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{ {intIvec} y0 = {intIvec}_load(y+j); {intIvec} y1 = {intIvec}_load(y+j+{vectorlen}); y0 ^= {intIvec}_floatmask(y0); y1 ^= {intIvec}_floatmask(y1); {intIvec}_store(y+j,y0{vecxordown}); {intIvec}_store(y+j+{vectorlen},y1{vecxordown}); }} for (;j < n;++j) {{ int{bits}_t yj = y[j]; yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1; y[j] = yj{xordown}; }} djbsort_int{bits}(y,n); for (j = 0;j+{2*vectorlen} <= n;j += {2*vectorlen}) {{ {intIvec} y0 = {intIvec}_load(y+j){vecxordown}; {intIvec} y1 = {intIvec}_load(y+j+{vectorlen}){vecxordown}; y0 ^= {intIvec}_floatmask(y0); y1 ^= {intIvec}_floatmask(y1); {intIvec}_store(y+j,y0); {intIvec}_store(y+j+{vectorlen},y1); }} for (;j < n;++j) {{ int{bits}_t yj = y[j]{xordown}; yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1; y[j] = yj; }} ''') else: f.write(f''' int{bits}_t *y = (int{bits}_t *) x; long long j; for (j = 0;j < n;++j) {{ int{bits}_t yj = y[j]; yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1; y[j] = yj{xordown}; }} djbsort_int{bits}(y,n); for (j = 0;j < n;++j) {{ int{bits}_t yj = y[j]{xordown}; yj ^= ((uint{bits}_t) crypto_int{bits}_negative_mask(yj)) >> 1; y[j] = yj; }} ''') for bits in 32,64: for down in '','down': for what in 'int','uint','float': for vec in '','avx2': if what == 'int' and down == '': continue fun = f'{what}{bits}{down}' impldir = f'{fun}/{vec}useint{bits}' T = f'{what}{bits}_t' if T == 'float32_t': T = 'float' if T == 'float64_t': T = 'double' vectorlen = 256//bits intIvec = f'int{bits}x{vectorlen}' os.makedirs(f'{impldir}',exist_ok=True) if vec == 'avx2': with open(f'{impldir}/architectures','w') as f: f.write('amd64 avx2\nx86 avx2\n') with open(f'{impldir}/sort.c','w') as f: preamble() f.write(f''' void {fun}_sort({T} *x,long long n) {{ ''') if what == 'int': xor = -1 cast = '' intxor() if what == 'uint': if down == 'down': xor = hex((1<<(bits-1))-1) else: xor = hex(1<<(bits-1)) if bits == 64: xor += 'ULL' cast = f'(int{bits}_t *) ' intxor() if what == 'float': xordown = ' ^ -1' if down == 'down' else '' floatxor() f.write('}\n')