#!/usr/bin/env python3 import sys from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums xbits = int(sys.argv[1]) n = int(sys.argv[2]) inputname = tuple(f'x_{i}_{xbits}' for i in range(n)) outputname = tuple(f'y_{i}_{xbits}' for i in range(n)) # ===== operation graph nextvalue = 0 # indexed by variable name: value = {} # indexed by value: operation = {} parents = {} bits = {} concrete = {} # partly for sanity check, but also conveniently enables peephole constant propagation # indexed by (operation,parents): oppa2value = {} # ===== parse program def lit(x): return Literal(x).suppress() def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = lit('(') rparen = lit(')') comma = lit(',') equal = lit('=') number = Word(nums) name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + lit('constant') + lparen + number + comma + number + rparen ).set_parse_action(group('constant')) for unary in ['invert','Reverse']: assignment |= ( name + equal + lit(unary) + lparen + name + rparen ).set_parse_action(group(unary)) for binary in ['xor','or','and','add','sub','mul','lshift','signedrshift','unsignedrshift','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal']: assignment |= ( name + equal + lit(binary) + lparen + name + comma + name + rparen ).set_parse_action(group(binary)) assignment |= ( name + equal + lit('If') + lparen + name + comma + name + comma + name + rparen ).set_parse_action(group('If')) assignment |= ( name + equal + lit('Extract') + lparen + name + comma + number + comma + number + rparen ).set_parse_action(group('Extract')) assignment |= ( name + equal + lit('SignExt') + lparen + name + comma + number + rparen ).set_parse_action(group('SignExt')) assignment |= ( name + equal + lit('ZeroExt') + lparen + name + comma + number + rparen ).set_parse_action(group('ZeroExt')) for manyary in ['Concat']: assignment |= ( name + equal + lit(manyary) + lparen + name + ZeroOrMore(comma + name) + rparen ).set_parse_action(group(manyary)) assignment |= ( name + equal + name).set_parse_action(group('copy') ) assignments = ZeroOrMore(assignment) + StringEnd() commutative = 'and','or','xor','signedmin','signedmax' program = sys.stdin.read() program = assignments.parse_string(program) program = list(program) # ===== concrete calculations, including number of bits as first component import random concreteinputlist = [random.randrange(2**xbits) for i in range(n)] concreteinput = {inputname[i]:concreteinputlist[i] for i in range(n)} def calcconcrete(op,parents): if op[0] == 'input': assert len(parents) == 0 return xbits,concreteinput[op[1]] if op[0] == 'constant': assert len(parents) == 0 return tuple(op[1:]) if op[0] == 'invert': assert len(parents) == 1 fbits,f = parents[0] return fbits,(1<>(8*j)) for j in range(fbits//8)] x.reverse() return fbits,sum(x[j]<<(8*j) for j in range(fbits//8)) if op[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal','xor','or','and','add','sub','mul','lshift','unsignedrshift','signedrshift','signedmin','signedmax'): assert len(parents) == 2 fbits,f = parents[0] gbits,g = parents[1] assert fbits == gbits if op[0] == 'equal': return 1,int(f == g) if op[0] == 'xor': return fbits,f^g if op[0] == 'and': return fbits,f&g if op[0] == 'or': return fbits,f|g if op[0] == 'add': return fbits,(f+g)%(1<= g) if op[0] == 'unsignedgt': return 1,int(f > g) if op[0] == 'unsignedle': return 1,int(f <= g) if op[0] == 'unsignedlt': return 1,int(f < g) if op[0] == 'signedle': return 1,int(f^flip <= g^flip) if op[0] == 'signedlt': return 1,int(f^flip < g^flip) if op[0] == 'signedmin': return fbits,min(f^flip,g^flip)^flip if op[0] == 'signedmax': return fbits,max(f^flip,g^flip)^flip if op[0] == 'lshift': assert g>g if op[0] == 'signedrshift': assert g>gsigned return fbits,(fgsigned+flip)^flip raise Exception('internal error') if op[0] == 'If': assert len(parents) == 3 cbits,c = parents[0] fbits,f = parents[1] gbits,g = parents[2] assert fbits == gbits assert cbits == 1 if c: return fbits,f return gbits,g if op[0] == 'Extract': top,bot = op[1:] assert len(parents) == 1 fbits,f = parents[0] assert fbits > top assert top >= bot assert bot >= 0 return top+1-bot,((f&((2<>bot) if op[0] in ('SignExt','ZeroExt'): ext = op[1] assert len(parents) == 1 fbits,f = parents[0] if op[0] == 'SignExt' and f&(1<<(fbits-1)): return fbits+ext,f+(1<<(fbits+ext))-(1< 0: what += lparen for pos,arg in enumerate(args): if pos > 0: what += comma what += arg what += rparen return group(name,what) lparen = lit('(') rparen = lit(')') comma = lit(',') Term = Forward() Termvar = group('var',Word('ABCDEFGHIJKLMNOPQRSTUVWXYZ')) X = lparen+Term+rparen X |= group('assign',Termvar+lit('=')+Term) X |= fun('Bits',Word(nums),Term) X |= group('Concat',lit('Concat')+lparen+Term+ZeroOrMore(comma+Term)+rparen) X |= fun('Extract',Term,Word(nums),Word(nums)) X |= fun('ZeroExt',Term,Word(nums)) for c in constpatterns: X |= fun(c) for op in 'invert','Reverse': X |= fun(op,Term) for op in 'add','sub','and','xor','or','equal','unsignedgt','unsignedge','unsignedlt','unsignedle','signedlt','signedle','signedmin','signedmax','lshift','unsignedrshift','signedrshift': X |= fun(op,Term,Term) X |= fun('If',Term,Term,Term) X |= Termvar Term <<= X result = [] for line in rules.splitlines(): line = line.split('->') if len(line) != 2: continue try: left = (Term+StringEnd()).parse_string(line[0])[0] except: raise Exception(f'failure to parse {line[0]}') try: right = (Term+StringEnd()).parse_string(line[1])[0] except: raise Exception(f'failure to parse {line[1]}') result.append((left,right)) return result peephole_list = peephole_init() def op2topop(operation): if operation[0] in ('Extract','ZeroExt','SignExt'): return tuple(operation) return operation[0] def pattern2topop(pattern): if len(pattern) == 1: return 'constant' if len(pattern) == 2: assert pattern[0] != 'var' return pattern[0] if pattern[0] in ('ZeroExt','SignExt'): return pattern[0],int(pattern[2]) if pattern[0] == 'Extract': return pattern[0],int(pattern[2]),int(pattern[3]) if pattern[0] in ('assign','Bits'): return pattern2topop(pattern[2]) return pattern[0] peephole_by_op = {} for left,right in peephole_list: op = pattern2topop(left) if op not in peephole_by_op: peephole_by_op[op] = [] peephole_by_op[op].append((left,right)) # returns tuple of new match dictionaries compatible with mforce def peephole_match(pattern,newoperation,newparents,newbits,var=None,mforce={}): if len(pattern) == 1: if newoperation[0] != 'constant': return () if pattern[0] in ('bit0','bit1') and newbits != 1: return () if newoperation[1] != newbits: return () if newoperation[2] != constpatterns[pattern[0]](newbits): return () return {}, if len(pattern) == 2: if pattern[0] == 'var': if var is None: return () if mforce.get(pattern[1],var) != var: return () return {pattern[1]:var}, if newoperation[0] == pattern[0]: p, = newparents return peephole_match(pattern[1],operation[p],parents[p],bits[p],var=p,mforce=mforce) return () if pattern[0] == 'assign': assert pattern[1][0] == 'var' if var is None: return () V = pattern[1][1] results = () for m in peephole_match(pattern[2],newoperation,newparents,newbits,var=var,mforce=mforce): if V in m: continue m[V] = var results += m, return results if pattern[0] == 'Bits': if newbits != int(pattern[1]): return () return peephole_match(pattern[2],newoperation,newparents,newbits,var=var,mforce=mforce) if newoperation[0] != pattern[0]: return () if pattern[0] == 'Extract': if newoperation != (pattern[0],int(pattern[2]),int(pattern[3])): return () if len(newparents) != 1: return () p, = newparents return peephole_match(pattern[1],operation[p],parents[p],bits[p],var=p,mforce=mforce) if pattern[0] == 'ZeroExt': # would also work for SignExt if newoperation != (pattern[0],int(pattern[2])): return () if len(newparents) != 1: return () p, = newparents return peephole_match(pattern[1],operation[p],parents[p],bits[p],var=p,mforce=mforce) if len(newparents) != len(pattern[1:]): return () todo = [newparents] if pattern[0] in commutative: todo += [[newparents[1],newparents[0]]] results = () for parentlist in todo: partialresults = {}, # conceptually: match each parent in turn, extending partial result # but there can be multiple matches so have list of partial results for subpattern,p in zip(pattern[1:],parentlist): newpartialresults = () for m in partialresults: submforce = mforce.copy() submforce.update(m) for subm in peephole_match(subpattern,operation[p],parents[p],bits[p],var=p,mforce=submforce): assert all(m[v] == subm[v] for v in subm if v in m) subm.update(m) newpartialresults += subm, partialresults = newpartialresults results += partialresults return results def newvalue(newoperation,newparents): global nextvalue newconcrete = calcconcrete(newoperation,[concrete[x] for x in newparents]) newbits = newconcrete[0] while True: if newoperation[0] in commutative: newparents = tuple(sorted(newparents)) # is this value already in the graph? oppa = tuple(newoperation),tuple(newparents) if oppa in oppa2value: v = oppa2value[oppa] assert newbits == bits[v] assert newconcrete == concrete[v] return v # is this value an orphan? if len(newparents) == 0: nextvalue += 1 operation[nextvalue] = newoperation parents[nextvalue] = newparents bits[nextvalue] = newbits concrete[nextvalue] = newconcrete oppa = tuple(newoperation),tuple(newparents) assert oppa not in oppa2value # otherwise would have merged oppa2value[oppa] = nextvalue return nextvalue assert newoperation[0] not in ('constant','input') # constant propagation if all(operation[p][0] == 'constant' for p in newparents): newoperation = 'constant',newconcrete[0],newconcrete[1] newparents = () continue # peephole from DSL (note: there are also a few non-DSL rewrites below) topop = op2topop(newoperation) for lrpos,(left,right) in enumerate(peephole_by_op.get(topop,[])): mtuple = peephole_match(left,newoperation,newparents,newbits) if len(mtuple) == 0: continue if lrpos > 0: lrlist = peephole_by_op[topop] lrlist = lrlist[lrpos:lrpos+1]+lrlist[:lrpos]+lrlist[lrpos+1:] peephole_by_op[topop] = lrlist m = mtuple[0] # or could heuristically pick "best" match result = newvaluepattern(right,m,forcebits=newbits) if bits[result] != newbits: raise Exception(f'replacing {left} with {right} for {newoperation}({newparents}) substituting {m} changed sizes from {newbits} to {bits[result]}') if concrete[result] != newconcrete: raise Exception(f'replacing {left} with {right} for {newoperation}({newparents}) substituting {m} changed concrete values from {newconcrete} to {concrete[result]}') oppa = tuple(newoperation),tuple(newparents) assert oppa not in oppa2value # otherwise would have merged oppa2value[oppa] = result return result if newoperation[0] == 'Extract': w = newparents[0] # rewrite invert(s)[high:low] as invert(s[high:low]) if operation[w][0] == 'invert': newparents = newvalue(newoperation,parents[w]), newoperation = 'invert', assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents]) continue # rewrite s[high:low][high2:low2] as s[high2+low:low2+low] if operation[w][0] == 'Extract': s = parents[w][0] high2,low2 = newoperation[1:] high,low = operation[w][1:] newoperation = 'Extract',high2+low,low2+low newparents = s, assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents]) continue # rewrite Reverse(s)[high:low] as Reverse(s[...]) if possible if operation[w][0] == 'Reverse': high,low = newoperation[1:] if low%8 == 0 and high%8 == 7: s = parents[w][0] newparents = newvalue(('Extract',bits[s]-1-low,bits[s]-1-high),(s,)), newoperation = 'Reverse', assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents]) continue # rewrite Concat(...)[high:low] as ...[high-pos:low-pos] if possible progress = False if operation[w][0] == 'Concat': high,low = newoperation[1:] pos = 0 for y in reversed(parents[w]): if pos <= low and high < pos+bits[y]: newoperation = 'Extract',high-pos,low-pos newparents = y, assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents]) progress = True break pos += bits[y] if progress: continue # no further rewrites nextvalue += 1 operation[nextvalue] = newoperation parents[nextvalue] = newparents bits[nextvalue] = newbits concrete[nextvalue] = newconcrete oppa = tuple(newoperation),tuple(newparents) assert oppa not in oppa2value # otherwise would have merged oppa2value[oppa] = nextvalue return nextvalue # create new value according to pattern # with m describing substitutions of pattern variables def newvaluepattern(pattern,m,forcebits=None): if pattern[0] == 'var': return m[pattern[1]] if pattern[0] in constpatterns: assert forcebits is not None if pattern[0] in ('bit0','bit1'): assert forcebits == 1 newoperation = 'constant',forcebits,constpatterns[pattern[0]](forcebits) newparents = () return newvalue(newoperation,newparents) if pattern[0] == 'Bits': if forcebits is None: forcebits = int(pattern[1]) assert forcebits == int(pattern[1]) result = newvaluepattern(pattern[2],m,forcebits=forcebits) assert forcebits == bits[result] return result if pattern[0] == 'ZeroExt': morebits = int(pattern[2]) subforcebits = None if forcebits is None else forcebits-morebits newoperation = 'ZeroExt',morebits newparents = newvaluepattern(pattern[1],m,forcebits=subforcebits), return newvalue(newoperation,newparents) if pattern[0] == 'Extract': top,bot = map(int,pattern[2:]) if forcebits is not None: assert forcebits == 1+top-bot newoperation = 'Extract',top,bot newparents = newvaluepattern(pattern[1],m), return newvalue(newoperation,newparents) newoperation = pattern[0], if pattern[0] == 'If': parentcond = newvaluepattern(pattern[1],m,forcebits=1) parentthen = newvaluepattern(pattern[2],m,forcebits=forcebits) parentelse = newvaluepattern(pattern[3],m,forcebits=forcebits) newparents = parentcond,parentthen,parentelse elif pattern[0] == 'Concat': newparents = tuple(newvaluepattern(p,m) for p in pattern[1:]) elif pattern[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'): newparents = tuple(newvaluepattern(p,m) for p in pattern[1:]) elif pattern[0] in ('signedmin','signedmax','xor','and','or','invert','add','sub','signedrshift'): newparents = tuple(newvaluepattern(p,m,forcebits=forcebits) for p in pattern[1:]) else: raise Exception(f'unrecognized operation on right side of rule: {pattern[0]}') return newvalue(newoperation,newparents) # ===== run through program, optimizing each step as it appears for v in inputname: value[v] = newvalue(('input',v),()) for p in program: if p[1] in value: raise Exception(f'{p[1]} assigned twice') if p[0] == 'copy': value[p[1]] = value[p[2]] continue if p[0] == 'constant': newoperation = p[0],int(p[2]),int(p[3]) newparents = () elif p[0] == 'Extract': newoperation = p[0],int(p[3]),int(p[4]) newparents = value[p[2]], elif p[0] in ('SignExt','ZeroExt'): newoperation = p[0],int(p[3]) newparents = value[p[2]], elif p[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal','invert','Reverse','xor','or','and','add','sub','mul','lshift','unsignedrshift','signedrshift','If','Concat'): newoperation = p[0], newparents = tuple(value[v] for v in p[2:]) else: raise Exception(f'unknown internal operation {p[0]}') value[p[1]] = newvalue(newoperation,newparents) concreteoutputlist = [concrete[value[v]][1] for v in outputname] flip = 1<<(xbits-1) signedconcreteinputlist = [(u^flip)-flip for u in concreteinputlist] signedconcreteoutputlist = [(u^flip)-flip for u in concreteoutputlist] if signedconcreteoutputlist != sorted(signedconcreteinputlist): raise Exception(f'example output {signedconcreteoutputlist} is not sorted input {sorted(signedconcreteinputlist)}') # ===== output done = set() finaloperations = set() def do(v): if v in done: return done.add(v) for x in parents[v]: do(x) if operation[v][0] == 'input': print(f'v{v} = {operation[v][1]}') else: p = [f'v{x}' for x in parents[v]] p += [f'{x}' for x in operation[v][1:]] finaloperations.add(operation[v][0]) args = ','.join(p) print(f'v{v} = {operation[v][0]}({args})') for v in outputname: do(value[v]) print(f'{v} = v{value[v]}') sys.stdout.flush() if any(op not in ('signedmin','signedmax') for op in finaloperations): raise Exception(f'resulting operations {sorted(finaloperations)} are not just signedmin, signedmax') # decompose will also refuse to parse this, but error here is more informative