-rw-r--r-- 2503 djbsort-20260210/sortbench/radixwrapper.c raw
/*
This was written in a few minutes.
It's purely experimental, probably buggy,
not verified, not heavily tested, not reviewed, not constant-time.
*/
#include <immintrin.h>
#include <string.h>
#include "djbsort.h"
#include "radixwrapper.h"
#define CUTOFF 8192 /* use djbsort for n <= CUTOFF */
#define MAXSPLIT 1024 /* number of radix buckets allocated (could also use VLA) */
#define TARGET 50 /* preferred size for recursion */
#define int32 int32_t
#define uint32 uint32_t
typedef __m256i int32x8;
#define int32x8_load(z) _mm256_loadu_si256((__m256i *) (z))
#define int32x8_min _mm256_min_epi32
#define int32x8_max _mm256_max_epi32
#define int32x8_constextract_eachside(v,p0,p1,p2,p3) _mm256_shuffle_epi32(v,_MM_SHUFFLE(p3,p2,p1,p0))
#define int32x8_10325476(a) int32x8_constextract_eachside(a,1,0,3,2)
#define int32x8_23016745(a) int32x8_constextract_eachside(a,2,3,0,1)
#define int32x8_45670123(a) _mm256_permute4x64_epi64(a,0x4e)
#define int32x8_0(a) _mm256_extract_epi32(a,0)
static void horizontal_minmax_int32_atleast8(int32 *xmin,int32 *xmax,const int32 *x,long long n)
{
int32x8 low = int32x8_load(x+n-8);
int32x8 high = low;
for (long long i = 0;i+8 <= n;i += 8) {
int32x8 xi = int32x8_load(x+i);
low = int32x8_min(low,xi);
high = int32x8_max(high,xi);
}
low = int32x8_min(low,int32x8_10325476(low));
high = int32x8_max(high,int32x8_10325476(high));
low = int32x8_min(low,int32x8_23016745(low));
high = int32x8_max(high,int32x8_23016745(high));
low = int32x8_min(low,int32x8_45670123(low));
high = int32x8_max(high,int32x8_45670123(high));
*xmin = int32x8_0(low);
*xmax = int32x8_0(high);
}
void radixwrapper_int32(int32 *x,int32 *y,long long n)
{
if (n <= CUTOFF) { djbsort_int32(x,n); return; }
int32 xmin,xmax;
long long i,t,c[MAXSPLIT];
long long split = n/TARGET;
while (split > MAXSPLIT) split >>= 1;
horizontal_minmax_int32_atleast8(&xmin,&xmax,x,n);
uint32 range = xmax-xmin;
if (range == 0) return;
long long shift = 0;
while (range >= split) { shift += 1; range >>= 1; }
for (i = 0;i < split;++i) c[i] = 0;
for (i = 0;i < n;++i) { int32 xi = x[i]; uint32 u = xi-xmin; u >>= shift; ++c[u]; }
t = 0;
for (i = 0;i < split;++i) { long long ci = c[i]; c[i] = t; t += ci; }
for (i = 0;i < n;++i) { int32 xi = x[i]; uint32 u = xi-xmin; u >>= shift; y[c[u]++] = xi; }
t = 0;
for (i = 0;i < split;++i) {
long long ci = c[i];
memcpy(x+t,y+t,4*(ci-t));
radixwrapper_int32(x+t,y+t,ci-t);
t = ci;
}
}