-rw-r--r-- 2503 djbsort-20260210/sortbench/radixwrapper.c raw
/* This was written in a few minutes. It's purely experimental, probably buggy, not verified, not heavily tested, not reviewed, not constant-time. */ #include <immintrin.h> #include <string.h> #include "djbsort.h" #include "radixwrapper.h" #define CUTOFF 8192 /* use djbsort for n <= CUTOFF */ #define MAXSPLIT 1024 /* number of radix buckets allocated (could also use VLA) */ #define TARGET 50 /* preferred size for recursion */ #define int32 int32_t #define uint32 uint32_t typedef __m256i int32x8; #define int32x8_load(z) _mm256_loadu_si256((__m256i *) (z)) #define int32x8_min _mm256_min_epi32 #define int32x8_max _mm256_max_epi32 #define int32x8_constextract_eachside(v,p0,p1,p2,p3) _mm256_shuffle_epi32(v,_MM_SHUFFLE(p3,p2,p1,p0)) #define int32x8_10325476(a) int32x8_constextract_eachside(a,1,0,3,2) #define int32x8_23016745(a) int32x8_constextract_eachside(a,2,3,0,1) #define int32x8_45670123(a) _mm256_permute4x64_epi64(a,0x4e) #define int32x8_0(a) _mm256_extract_epi32(a,0) static void horizontal_minmax_int32_atleast8(int32 *xmin,int32 *xmax,const int32 *x,long long n) { int32x8 low = int32x8_load(x+n-8); int32x8 high = low; for (long long i = 0;i+8 <= n;i += 8) { int32x8 xi = int32x8_load(x+i); low = int32x8_min(low,xi); high = int32x8_max(high,xi); } low = int32x8_min(low,int32x8_10325476(low)); high = int32x8_max(high,int32x8_10325476(high)); low = int32x8_min(low,int32x8_23016745(low)); high = int32x8_max(high,int32x8_23016745(high)); low = int32x8_min(low,int32x8_45670123(low)); high = int32x8_max(high,int32x8_45670123(high)); *xmin = int32x8_0(low); *xmax = int32x8_0(high); } void radixwrapper_int32(int32 *x,int32 *y,long long n) { if (n <= CUTOFF) { djbsort_int32(x,n); return; } int32 xmin,xmax; long long i,t,c[MAXSPLIT]; long long split = n/TARGET; while (split > MAXSPLIT) split >>= 1; horizontal_minmax_int32_atleast8(&xmin,&xmax,x,n); uint32 range = xmax-xmin; if (range == 0) return; long long shift = 0; while (range >= split) { shift += 1; range >>= 1; } for (i = 0;i < split;++i) c[i] = 0; for (i = 0;i < n;++i) { int32 xi = x[i]; uint32 u = xi-xmin; u >>= shift; ++c[u]; } t = 0; for (i = 0;i < split;++i) { long long ci = c[i]; c[i] = t; t += ci; } for (i = 0;i < n;++i) { int32 xi = x[i]; uint32 u = xi-xmin; u >>= shift; y[c[u]++] = xi; } t = 0; for (i = 0;i < split;++i) { long long ci = c[i]; memcpy(x+t,y+t,4*(ci-t)); radixwrapper_int32(x+t,y+t,ci-t); t = ci; } }