-rwxr-xr-x 12171 djbsort-20260127/sortverif/checkrules raw
#!/usr/bin/env python3
import sys
import claripy
verbose = False
constpatterns = {
'bit0': (lambda b:0),
'bit1': (lambda b:1),
'const5': (lambda b:5),
'const26': (lambda b:26),
'const27': (lambda b:27),
'const31': (lambda b:31),
'0star1': (lambda b:1),
'0star': (lambda b:0),
'01star': (lambda b:2**(b-1)-1),
'10star': (lambda b:2**(b-1)),
'1star': (lambda b:2**b-1),
}
def parse():
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
def lit(x): return Literal(x).suppress()
def group(name,x):
def action(r): return name,*r
return x.copy().set_parse_action(action)
counter = {}
def countinggroup(name,x):
def action(r):
if name not in counter: counter[name] = 0
counter[name] += 1
return f'{name}#{counter[name]}',*r
return x.copy().set_parse_action(action)
def fun(name,*args):
what = lit(name)
if len(args) > 0:
what += lparen
for pos,arg in enumerate(args):
if pos > 0: what += comma
what += arg
what += rparen
return group(name,what)
lparen = lit('(')
rparen = lit(')')
comma = lit(',')
Term = Forward()
Termvar = group('var',Word('ABCDEFGHIJKLMNOPQRSTUVWXYZ'))
X = lparen+Term+rparen
X |= group('assign',Termvar+lit('=')+Term)
X |= fun('Bits',Word(nums),Term)
X |= group('Concat',lit('Concat')+lparen+Term+ZeroOrMore(comma+Term)+rparen)
X |= fun('Extract',Term,Word(nums),Word(nums))
X |= fun('ZeroExt',Term,Word(nums))
for c in constpatterns:
X |= countinggroup(c,lit(c))
for op in 'invert','Reverse':
X |= fun(op,Term)
for op in 'add','sub','and','xor','or','equal','unsignedgt','unsignedge','unsignedlt','unsignedle','signedlt','signedle','signedmin','signedmax','lshift','unsignedrshift','signedrshift':
X |= fun(op,Term,Term)
X |= fun('If',Term,Term,Term)
X |= Termvar
Term <<= X
result = []
for line in sys.stdin:
counter = {}
line = line.strip()
x = line.split('->')
if len(x) != 2: continue
try:
left = (Term+StringEnd()).parse_string(x[0])[0]
except:
raise Exception(f'failure to parse {x[0]}')
try:
right = (Term+StringEnd()).parse_string(x[1])[0]
except:
raise Exception(f'failure to parse {x[1]}')
result.append((line,left,right))
return result
# returns (bits0,v0),(bits1,v1),...
# where v0,v1,... are variable names in pattern
# (including const patterns such as 01star)
# and bits0,bits1,... are number of bits in those variables
# (or None if the number of bits is constrained)
def variables(pattern,forcebits=None):
if len(pattern) == 1: # constant
constname,constcounter = pattern[0].split('#')
if constname in ('bit0','bit1'):
if forcebits is None: forcebits = 1
assert forcebits == 1
return (forcebits,pattern[0]),
if pattern[0] == 'var':
return (forcebits,pattern[1]),
if pattern[0] in ('Reverse','invert'):
return variables(pattern[1],forcebits=forcebits)
if pattern[0] == 'assign':
assert pattern[1][0] == 'var'
V = pattern[1][1]
return variables(pattern[2],forcebits=forcebits)+((forcebits,V),)
if pattern[0] == 'Bits':
bits = int(pattern[1])
if forcebits is None: forcebits = bits
assert forcebits == bits
return variables(pattern[2],forcebits=forcebits)
if pattern[0] in ('SignExt','ZeroExt'):
if forcebits is not None: forcebits -= int(pattern[2])
return variables(pattern[1],forcebits=forcebits)
if pattern[0] == 'Extract':
return variables(pattern[1])
if pattern[0] == 'If':
return variables(pattern[1],forcebits=1)+variables(pattern[2],forcebits=forcebits)+variables(pattern[3],forcebits=forcebits)
if pattern[0] in ('equal','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'):
if forcebits is not None: assert forcebits == 1
return variables(pattern[1])+variables(pattern[2])
if pattern[0] in ('signedmin','signedmax','xor','and','or','add','sub','lshift','unsignedrshift','signedrshift'):
return variables(pattern[1],forcebits=forcebits)+variables(pattern[2],forcebits=forcebits)
if pattern[0] == 'Concat':
result = ()
for p in pattern[1:]: result += variables(p)
return result
sys.stdout.flush()
raise Exception(f'unrecognized op {pattern[0]}')
# sa as {v:bits} dict
# returns None if there is an incompatibility
# caller's responsibility to include all variables
def patternbits(pattern,sa):
if len(pattern) == 1:
constname,_ = pattern[0].split('#')
if constname in ('bit0','bit1'):
if sa[pattern[0]] != 1: return
bits = sa[pattern[0]]
value = constpatterns[constname](bits)
if value < 0 or value >= 2**bits: return
return bits
if pattern[0] == 'var':
return sa[pattern[1]]
if pattern[0] == 'assign':
assert pattern[1][0] == 'var'
subbits = patternbits(pattern[2],sa)
if subbits != sa[pattern[1][1]]: return
return subbits
if pattern[0] in ('Reverse','invert'):
return patternbits(pattern[1],sa)
if pattern[0] == 'Bits':
forcebits = int(pattern[1])
subbits = patternbits(pattern[2],sa)
if subbits != forcebits: return # wrong bits or otherwise incompatible
return forcebits
if pattern[0] in ('SignExt','ZeroExt'):
morebits = int(pattern[2])
subbits = patternbits(pattern[1],sa)
if subbits is None: return
return subbits+morebits
if pattern[0] == 'Extract':
high,low = map(int,pattern[2:])
subbits = patternbits(pattern[1],sa)
if subbits is None: return
if subbits <= high: return
return 1+high-low
if pattern[0] == 'If':
if patternbits(pattern[1],sa) != 1: return
firstbits = patternbits(pattern[2],sa)
if firstbits is None: return
secondbits = patternbits(pattern[3],sa)
if secondbits is None: return
if firstbits != secondbits: return
return firstbits
if pattern[0] in ('equal','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'):
firstbits = patternbits(pattern[1],sa)
if firstbits is None: return
secondbits = patternbits(pattern[2],sa)
if secondbits is None: return
if firstbits != secondbits: return
return 1
if pattern[0] in ('signedmin','signedmax','xor','and','or','add','sub','lshift','unsignedrshift','signedrshift'):
firstbits = patternbits(pattern[1],sa)
if firstbits is None: return
secondbits = patternbits(pattern[2],sa)
if secondbits is None: return
if firstbits != secondbits: return
return firstbits
if pattern[0] == 'Concat':
result = 0
for p in pattern[1:]:
bits = patternbits(p,sa)
if bits is None: return
result += bits
return result
sys.stdout.flush()
raise Exception(f'unrecognized op {pattern[0]}')
# claripy distinguishes between BitVec 1 and Bool
def bitvec1tobool(v):
return v == claripy.BVV(1,1)
def booltobitvec1(v):
return claripy.If(v,claripy.BVV(1,1),claripy.BVV(0,1))
# appends to constraints list in place
def formula(pattern,sa,constraints):
if len(pattern) == 1:
constname,_ = pattern[0].split('#')
bits = sa[pattern[0]]
value = constpatterns[constname](bits)
return claripy.BVV(value,bits)
if pattern[0] == 'var':
bits = sa[pattern[1]]
return claripy.BVS(pattern[1],bits,explicit_name=True)
if pattern[0] == 'assign':
result = formula(pattern[2],sa,constraints)
constraints.append(formula(pattern[1],sa,constraints) == result)
return result
if pattern[0] == 'Bits': return formula(pattern[2],sa,constraints)
if pattern[0] == 'Reverse': return claripy.Reverse(formula(pattern[1],sa,constraints))
if pattern[0] == 'invert': return formula(pattern[1],sa,constraints).__invert__()
if pattern[0] == 'SignExt': return claripy.SignExt(int(pattern[2]),formula(pattern[1],sa,constraints))
if pattern[0] == 'ZeroExt': return claripy.ZeroExt(int(pattern[2]),formula(pattern[1],sa,constraints))
if pattern[0] == 'Extract': return claripy.Extract(int(pattern[2]),int(pattern[3]),formula(pattern[1],sa,constraints))
if pattern[0] == 'If': return claripy.If(bitvec1tobool(formula(pattern[1],sa,constraints)),formula(pattern[2],sa,constraints),formula(pattern[3],sa,constraints))
if pattern[0] == 'Concat': return claripy.Concat(*(formula(p,sa,constraints) for p in pattern[1:]))
if len(pattern) == 3:
op1 = formula(pattern[1],sa,constraints)
op2 = formula(pattern[2],sa,constraints)
if pattern[0] == 'equal': return op1==op2
if pattern[0] == 'add': return op1+op2
if pattern[0] == 'sub': return op1-op2
if pattern[0] == 'xor': return op1^op2
if pattern[0] == 'and': return op1&op2
if pattern[0] == 'or': return op1|op2
if pattern[0] == 'unsignedge': return booltobitvec1(op1.UGE(op2))
if pattern[0] == 'unsignedgt': return booltobitvec1(op1.UGT(op2))
if pattern[0] == 'unsignedle': return booltobitvec1(op1.ULE(op2))
if pattern[0] == 'unsignedlt': return booltobitvec1(op1.ULT(op2))
if pattern[0] == 'signedle': return booltobitvec1(op1.SLE(op2))
if pattern[0] == 'signedlt': return booltobitvec1(op1.SLT(op2))
if pattern[0] == 'signedmin': return claripy.If(op1.SLE(op2),op1,op2)
if pattern[0] == 'signedmax': return claripy.If(op1.SLE(op2),op2,op1)
if pattern[0] == 'lshift': return op1.__lshift__(op2)
if pattern[0] == 'unsignedrshift': return op1.LShR(op2)
if pattern[0] == 'signedrshift': return op1.__rshift__(op2)
sys.stdout.flush()
raise Exception(f'unrecognized op {pattern[0]}')
def verify_specific_size(left,right,sa):
constraints = []
leftformula = formula(left,sa,constraints)
rightformula = formula(right,sa,constraints)
if verbose:
print('left formula',leftformula)
print('right formula',rightformula)
print('constraints',constraints)
smt = claripy.Solver(timeout=-1) # XXX warning: by default (without timeout=-1), Solver will produce wrong results after 5 minutes
smt.add(leftformula != rightformula)
for c in constraints:
smt.add(c)
mismatch = smt.satisfiable()
if mismatch:
print('found mismatch:')
for v in sa:
vexample = smt.eval(formula(('var',v),sa,constraints),1)
print(f' {v} = {vexample[0]} = {hex(vexample[0])}')
sys.stdout.flush()
raise Exception('aborting given mismatch')
print('SMT solver says ok')
# yields sa as (v,bits) tuple
def sizeassignments(varbits):
if len(varbits) == 0:
yield ()
return
v,bits = varbits[0]
# XXX: in principle should check for sizes actually used in minmax
todo = (1,8,16,32,64,128,256,512) if bits is None else (bits,)
for b in todo:
for subsa in sizeassignments(varbits[1:]):
yield ((v,b),)+subsa
def verify(line,left,right):
print('rule:',line)
leftvars = variables(left)
rightvars = variables(right)
if verbose:
print('left',left)
print('right',right)
print('left variables',leftvars)
print('right variables',rightvars)
constcounter = 0
printvarorder = []
var2bits = {}
for bits,var in leftvars+rightvars:
if var not in var2bits: var2bits[var] = bits
if bits is not None:
if var2bits[var] is None: var2bits[var] = bits
assert var2bits[var] == bits
if var not in printvarorder: printvarorder.append(var)
varbits = sorted(var2bits.items())
foundsa = False
for sa in sizeassignments(varbits):
sa = dict(sa)
leftbits = patternbits(left,sa)
if leftbits is None: continue
rightbits = patternbits(right,sa)
if rightbits is None: continue
if leftbits != rightbits: continue
foundsa = True
banner = 'considering sizes'
for var in printvarorder:
if var.startswith('bit'): continue
bits = sa[var]
if '#' in var:
constname = var.split('#')[0]
if len([var2 for var2 in printvarorder if var2.startswith(constname+'#')]) == 1:
var = constname
banner += f' {var}:{bits}'
banner += f' output:{leftbits}'
print(banner)
verify_specific_size(left,right,sa)
if not foundsa:
sys.stdout.flush()
raise Exception('does this have a size assignment?')
print('=====')
sys.stdout.flush()
rules = parse()
for line,left,right in rules:
verify(line,left,right)