-rwxr-xr-x 4871 djbsort-20260127/sortverif/unroll raw
#!/usr/bin/env python3
import sys
bits = int(sys.argv[1])
n = int(sys.argv[2])
impl = int(sys.argv[3])
implname = sys.argv[4]
implcompiler = sys.argv[5]
binary = sys.argv[6]
littleendian = {'littleendian':True,'bigendian':False}[sys.argv[7]]
libs = sys.argv[8:]
import angr
import claripy
add_options = {
angr.options.LAZY_SOLVES,
angr.options.SYMBOLIC_WRITE_ADDRESSES,
angr.options.CONSERVATIVE_READ_STRATEGY,
angr.options.CONSERVATIVE_WRITE_STRATEGY,
angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY,
angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS,
}
remove_options = {
angr.options.SIMPLIFY_CONSTRAINTS,
angr.options.SIMPLIFY_EXPRS,
angr.options.SIMPLIFY_MEMORY_READS,
angr.options.SIMPLIFY_MEMORY_WRITES,
angr.options.SIMPLIFY_REGISTER_READS,
angr.options.SIMPLIFY_REGISTER_WRITES,
angr.options.SIMPLIFY_RETS,
}
claripy.simplifications.extract_distributable = {}
del claripy.simplifications.simpleton._simplifiers['__xor__']
sys.setrecursionlimit(1000000)
# ===== patch cpuid to enable avx2
# XXX: should share with outsim; in any case keep synchronized
prevcpuid = angr.engines.vex.heavy.dirty.CORRECT_amd64g_dirtyhelper_CPUID_avx_and_cx16
def cpuid(state,_):
eax = state.regs.rax[31:0]
prevcpuid(state,_)
# substitute some haswell cpuid data:
state.registers.store('rax',0x000306c3,size=8,condition=(eax==1))
state.registers.store('rbx',0x04100800,size=8,condition=(eax==1))
state.registers.store('rcx',0x7ffafbff,size=8,condition=(eax==1))
state.registers.store('rdx',0xbfebfbff,size=8,condition=(eax==1))
state.registers.store('rax',0x00000000,size=8,condition=(eax==7))
state.registers.store('rbx',0x000027ab,size=8,condition=(eax==7))
state.registers.store('rcx',0x00000000,size=8,condition=(eax==7))
state.registers.store('rdx',0x9c000600,size=8,condition=(eax==7))
return None,[]
angr.engines.vex.heavy.dirty.amd64g_dirtyhelper_CPUID_avx2 = cpuid
# ===== main
stdin = []
for i in range(n):
varname = f'x_{i}_{bits}'
variable = claripy.BVS(varname,bits,explicit_name=True)
if littleendian: variable = claripy.Reverse(variable)
stdin += [variable]
stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*stdin),has_end=True)
proj = angr.Project(binary,auto_load_libs=False,force_load_libs=libs)
state = proj.factory.full_init_state(args=[binary,str(n),str(impl),implname,implcompiler],add_options=add_options,remove_options=remove_options,stdin=stdin)
simgr = proj.factory.simgr(state)
simgr.run()
assert len(simgr.errored) == 0
exits = simgr.deadended
if len(exits) > 1:
mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits])
else:
mergedexit = exits[0]
packets = mergedexit.posix.stdout.content
def rename(op):
if op == '__add__': return 'add'
if op == '__sub__': return 'sub'
if op == '__mul__': return 'mul'
if op == '__or__': return 'or'
if op == '__xor__': return 'xor'
if op == '__and__': return 'and'
if op == '__invert__': return 'invert'
if op == '__eq__': return 'equal'
if op == '__ge__': return 'unsignedge'
if op == '__gt__': return 'unsignedgt'
if op == '__le__': return 'unsignedle'
if op == '__lt__': return 'unsignedlt'
if op == '__lshift__': return 'lshift'
if op == '__rshift__': return 'signedrshift'
if op == 'LShR': return 'unsignedrshift'
if op == 'SLE': return 'signedle'
if op == 'SLT': return 'signedlt'
return op
walked = {}
walknext = 0
def walk(t):
global walknext
if t in walked: return walked[t]
if t.op == '__xor__':
inputs = [walk(a) for a in t.args]
result = inputs[0]
for x in inputs[1:]:
walknext += 1
print('v%d = xor(v%d,v%d)' % (walknext,result,x))
result = walknext
walked[t] = result
return result
if t.op == 'BVV':
walknext += 1
print('v%d = constant(%d,%d)' % (walknext,t.size(),t.args[0]))
elif t.op == 'BVS':
walknext += 1
print('v%d = %s' % (walknext,t.args[0]))
elif t.op == 'Extract':
assert len(t.args) == 3
input = 'v%d' % walk(t.args[2])
walknext += 1
print('v%d = Extract(%s,%d,%d)' % (walknext,input,t.args[0],t.args[1]))
elif t.op in ('SignExt','ZeroExt'):
assert len(t.args) == 2
input = 'v%d' % walk(t.args[1])
walknext += 1
print('v%d = %s(%s,%d)' % (walknext,t.op,input,t.args[0]))
else:
inputs = ['v%d' % walk(a) for a in t.args]
walknext += 1
if t.op == 'SGE':
t.op = 'SLE'
inputs = reversed(inputs)
if t.op == 'SGT':
t.op = 'SLT'
inputs = reversed(inputs)
print('v%d = %s(%s)' % (walknext,rename(t.op),','.join(inputs)))
walked[t] = walknext
return walknext
assert len(packets) == n
for ppos,p in enumerate(packets):
assert p[1].op == 'BVV' and 8*p[1].args[0] == bits
result = p[0]
if littleendian: result = claripy.Reverse(result)
print(f'y_{ppos}_{bits} = v{walk(result)}')