#!/usr/bin/env python3 import sys import claripy verbose = False constpatterns = { 'bit0': (lambda b:0), 'bit1': (lambda b:1), 'const5': (lambda b:5), 'const26': (lambda b:26), 'const27': (lambda b:27), 'const31': (lambda b:31), '0star1': (lambda b:1), '0star': (lambda b:0), '01star': (lambda b:2**(b-1)-1), '10star': (lambda b:2**(b-1)), '1star': (lambda b:2**b-1), } def parse(): from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums def lit(x): return Literal(x).suppress() def group(name,x): def action(r): return name,*r return x.copy().set_parse_action(action) counter = {} def countinggroup(name,x): def action(r): if name not in counter: counter[name] = 0 counter[name] += 1 return f'{name}#{counter[name]}',*r return x.copy().set_parse_action(action) def fun(name,*args): what = lit(name) if len(args) > 0: what += lparen for pos,arg in enumerate(args): if pos > 0: what += comma what += arg what += rparen return group(name,what) lparen = lit('(') rparen = lit(')') comma = lit(',') Term = Forward() Termvar = group('var',Word('ABCDEFGHIJKLMNOPQRSTUVWXYZ')) X = lparen+Term+rparen X |= group('assign',Termvar+lit('=')+Term) X |= fun('Bits',Word(nums),Term) X |= group('Concat',lit('Concat')+lparen+Term+ZeroOrMore(comma+Term)+rparen) X |= fun('Extract',Term,Word(nums),Word(nums)) X |= fun('ZeroExt',Term,Word(nums)) for c in constpatterns: X |= countinggroup(c,lit(c)) for op in 'invert','Reverse': X |= fun(op,Term) for op in 'add','sub','and','xor','or','equal','unsignedgt','unsignedge','unsignedlt','unsignedle','signedlt','signedle','signedmin','signedmax','lshift','unsignedrshift','signedrshift': X |= fun(op,Term,Term) X |= fun('If',Term,Term,Term) X |= Termvar Term <<= X result = [] for line in sys.stdin: counter = {} line = line.strip() x = line.split('->') if len(x) != 2: continue try: left = (Term+StringEnd()).parse_string(x[0])[0] except: raise Exception(f'failure to parse {x[0]}') try: right = (Term+StringEnd()).parse_string(x[1])[0] except: raise Exception(f'failure to parse {x[1]}') result.append((line,left,right)) return result # returns (bits0,v0),(bits1,v1),... # where v0,v1,... are variable names in pattern # (including const patterns such as 01star) # and bits0,bits1,... are number of bits in those variables # (or None if the number of bits is constrained) def variables(pattern,forcebits=None): if len(pattern) == 1: # constant constname,constcounter = pattern[0].split('#') if constname in ('bit0','bit1'): if forcebits is None: forcebits = 1 assert forcebits == 1 return (forcebits,pattern[0]), if pattern[0] == 'var': return (forcebits,pattern[1]), if pattern[0] in ('Reverse','invert'): return variables(pattern[1],forcebits=forcebits) if pattern[0] == 'assign': assert pattern[1][0] == 'var' V = pattern[1][1] return variables(pattern[2],forcebits=forcebits)+((forcebits,V),) if pattern[0] == 'Bits': bits = int(pattern[1]) if forcebits is None: forcebits = bits assert forcebits == bits return variables(pattern[2],forcebits=forcebits) if pattern[0] in ('SignExt','ZeroExt'): if forcebits is not None: forcebits -= int(pattern[2]) return variables(pattern[1],forcebits=forcebits) if pattern[0] == 'Extract': return variables(pattern[1]) if pattern[0] == 'If': return variables(pattern[1],forcebits=1)+variables(pattern[2],forcebits=forcebits)+variables(pattern[3],forcebits=forcebits) if pattern[0] in ('equal','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'): if forcebits is not None: assert forcebits == 1 return variables(pattern[1])+variables(pattern[2]) if pattern[0] in ('signedmin','signedmax','xor','and','or','add','sub','lshift','unsignedrshift','signedrshift'): return variables(pattern[1],forcebits=forcebits)+variables(pattern[2],forcebits=forcebits) if pattern[0] == 'Concat': result = () for p in pattern[1:]: result += variables(p) return result sys.stdout.flush() raise Exception(f'unrecognized op {pattern[0]}') # sa as {v:bits} dict # returns None if there is an incompatibility # caller's responsibility to include all variables def patternbits(pattern,sa): if len(pattern) == 1: constname,_ = pattern[0].split('#') if constname in ('bit0','bit1'): if sa[pattern[0]] != 1: return bits = sa[pattern[0]] value = constpatterns[constname](bits) if value < 0 or value >= 2**bits: return return bits if pattern[0] == 'var': return sa[pattern[1]] if pattern[0] == 'assign': assert pattern[1][0] == 'var' subbits = patternbits(pattern[2],sa) if subbits != sa[pattern[1][1]]: return return subbits if pattern[0] in ('Reverse','invert'): return patternbits(pattern[1],sa) if pattern[0] == 'Bits': forcebits = int(pattern[1]) subbits = patternbits(pattern[2],sa) if subbits != forcebits: return # wrong bits or otherwise incompatible return forcebits if pattern[0] in ('SignExt','ZeroExt'): morebits = int(pattern[2]) subbits = patternbits(pattern[1],sa) if subbits is None: return return subbits+morebits if pattern[0] == 'Extract': high,low = map(int,pattern[2:]) subbits = patternbits(pattern[1],sa) if subbits is None: return if subbits <= high: return return 1+high-low if pattern[0] == 'If': if patternbits(pattern[1],sa) != 1: return firstbits = patternbits(pattern[2],sa) if firstbits is None: return secondbits = patternbits(pattern[3],sa) if secondbits is None: return if firstbits != secondbits: return return firstbits if pattern[0] in ('equal','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'): firstbits = patternbits(pattern[1],sa) if firstbits is None: return secondbits = patternbits(pattern[2],sa) if secondbits is None: return if firstbits != secondbits: return return 1 if pattern[0] in ('signedmin','signedmax','xor','and','or','add','sub','lshift','unsignedrshift','signedrshift'): firstbits = patternbits(pattern[1],sa) if firstbits is None: return secondbits = patternbits(pattern[2],sa) if secondbits is None: return if firstbits != secondbits: return return firstbits if pattern[0] == 'Concat': result = 0 for p in pattern[1:]: bits = patternbits(p,sa) if bits is None: return result += bits return result sys.stdout.flush() raise Exception(f'unrecognized op {pattern[0]}') # claripy distinguishes between BitVec 1 and Bool def bitvec1tobool(v): return v == claripy.BVV(1,1) def booltobitvec1(v): return claripy.If(v,claripy.BVV(1,1),claripy.BVV(0,1)) # appends to constraints list in place def formula(pattern,sa,constraints): if len(pattern) == 1: constname,_ = pattern[0].split('#') bits = sa[pattern[0]] value = constpatterns[constname](bits) return claripy.BVV(value,bits) if pattern[0] == 'var': bits = sa[pattern[1]] return claripy.BVS(pattern[1],bits,explicit_name=True) if pattern[0] == 'assign': result = formula(pattern[2],sa,constraints) constraints.append(formula(pattern[1],sa,constraints) == result) return result if pattern[0] == 'Bits': return formula(pattern[2],sa,constraints) if pattern[0] == 'Reverse': return claripy.Reverse(formula(pattern[1],sa,constraints)) if pattern[0] == 'invert': return formula(pattern[1],sa,constraints).__invert__() if pattern[0] == 'SignExt': return claripy.SignExt(int(pattern[2]),formula(pattern[1],sa,constraints)) if pattern[0] == 'ZeroExt': return claripy.ZeroExt(int(pattern[2]),formula(pattern[1],sa,constraints)) if pattern[0] == 'Extract': return claripy.Extract(int(pattern[2]),int(pattern[3]),formula(pattern[1],sa,constraints)) if pattern[0] == 'If': return claripy.If(bitvec1tobool(formula(pattern[1],sa,constraints)),formula(pattern[2],sa,constraints),formula(pattern[3],sa,constraints)) if pattern[0] == 'Concat': return claripy.Concat(*(formula(p,sa,constraints) for p in pattern[1:])) if len(pattern) == 3: op1 = formula(pattern[1],sa,constraints) op2 = formula(pattern[2],sa,constraints) if pattern[0] == 'equal': return op1==op2 if pattern[0] == 'add': return op1+op2 if pattern[0] == 'sub': return op1-op2 if pattern[0] == 'xor': return op1^op2 if pattern[0] == 'and': return op1&op2 if pattern[0] == 'or': return op1|op2 if pattern[0] == 'unsignedge': return booltobitvec1(op1.UGE(op2)) if pattern[0] == 'unsignedgt': return booltobitvec1(op1.UGT(op2)) if pattern[0] == 'unsignedle': return booltobitvec1(op1.ULE(op2)) if pattern[0] == 'unsignedlt': return booltobitvec1(op1.ULT(op2)) if pattern[0] == 'signedle': return booltobitvec1(op1.SLE(op2)) if pattern[0] == 'signedlt': return booltobitvec1(op1.SLT(op2)) if pattern[0] == 'signedmin': return claripy.If(op1.SLE(op2),op1,op2) if pattern[0] == 'signedmax': return claripy.If(op1.SLE(op2),op2,op1) if pattern[0] == 'lshift': return op1.__lshift__(op2) if pattern[0] == 'unsignedrshift': return op1.LShR(op2) if pattern[0] == 'signedrshift': return op1.__rshift__(op2) sys.stdout.flush() raise Exception(f'unrecognized op {pattern[0]}') def verify_specific_size(left,right,sa): constraints = [] leftformula = formula(left,sa,constraints) rightformula = formula(right,sa,constraints) if verbose: print('left formula',leftformula) print('right formula',rightformula) print('constraints',constraints) smt = claripy.Solver(timeout=-1) # XXX warning: by default (without timeout=-1), Solver will produce wrong results after 5 minutes smt.add(leftformula != rightformula) for c in constraints: smt.add(c) mismatch = smt.satisfiable() if mismatch: print('found mismatch:') for v in sa: vexample = smt.eval(formula(('var',v),sa,constraints),1) print(f' {v} = {vexample[0]} = {hex(vexample[0])}') sys.stdout.flush() raise Exception('aborting given mismatch') print('SMT solver says ok') # yields sa as (v,bits) tuple def sizeassignments(varbits): if len(varbits) == 0: yield () return v,bits = varbits[0] # XXX: in principle should check for sizes actually used in minmax todo = (1,8,16,32,64,128,256,512) if bits is None else (bits,) for b in todo: for subsa in sizeassignments(varbits[1:]): yield ((v,b),)+subsa def verify(line,left,right): print('rule:',line) leftvars = variables(left) rightvars = variables(right) if verbose: print('left',left) print('right',right) print('left variables',leftvars) print('right variables',rightvars) constcounter = 0 printvarorder = [] var2bits = {} for bits,var in leftvars+rightvars: if var not in var2bits: var2bits[var] = bits if bits is not None: if var2bits[var] is None: var2bits[var] = bits assert var2bits[var] == bits if var not in printvarorder: printvarorder.append(var) varbits = sorted(var2bits.items()) foundsa = False for sa in sizeassignments(varbits): sa = dict(sa) leftbits = patternbits(left,sa) if leftbits is None: continue rightbits = patternbits(right,sa) if rightbits is None: continue if leftbits != rightbits: continue foundsa = True banner = 'considering sizes' for var in printvarorder: if var.startswith('bit'): continue bits = sa[var] if '#' in var: constname = var.split('#')[0] if len([var2 for var2 in printvarorder if var2.startswith(constname+'#')]) == 1: var = constname banner += f' {var}:{bits}' banner += f' output:{leftbits}' print(banner) verify_specific_size(left,right,sa) if not foundsa: sys.stdout.flush() raise Exception('does this have a size assignment?') print('=====') sys.stdout.flush() rules = parse() for line,left,right in rules: verify(line,left,right)