#!/usr/bin/env python3 import sys from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums printstats = False if printstats: import resource def maxrsskb(): return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss xbits = int(sys.argv[1]) n = int(sys.argv[2]) # ===== parse program inputname = tuple(f'x_{i}_{xbits}' for i in range(n)) outputname = tuple(f'y_{i}_{xbits}' for i in range(n)) def lit(x): return Literal(x).suppress() def group(s): def t(x): x = list(x) if len(x) == 1: return x return [[s] + x] return t lparen = lit('(') rparen = lit(')') comma = lit(',') equal = lit('=') number = Word(nums) name = Word(alphas,alphas+nums+"_") assignment = ( name + equal + lit('constant') + lparen + number + comma + number + rparen ).set_parse_action(group('constant')) for unary in ['invert','Reverse']: assignment |= ( name + equal + lit(unary) + lparen + name + rparen ).set_parse_action(group(unary)) for binary in ['xor','or','and','add','sub','mul','lshift','signedrshift','unsignedrshift','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal']: assignment |= ( name + equal + lit(binary) + lparen + name + comma + name + rparen ).set_parse_action(group(binary)) assignment |= ( name + equal + lit('If') + lparen + name + comma + name + comma + name + rparen ).set_parse_action(group('If')) assignment |= ( name + equal + lit('Extract') + lparen + name + comma + number + comma + number + rparen ).set_parse_action(group('Extract')) assignment |= ( name + equal + lit('SignExt') + lparen + name + comma + number + rparen ).set_parse_action(group('SignExt')) assignment |= ( name + equal + lit('ZeroExt') + lparen + name + comma + number + rparen ).set_parse_action(group('ZeroExt')) for manyary in ['Concat']: assignment |= ( name + equal + lit(manyary) + lparen + name + ZeroOrMore(comma + name) + rparen ).set_parse_action(group(manyary)) assignment |= ( name + equal + name).set_parse_action(group('copy') ) assignments = ZeroOrMore(assignment) + StringEnd() commutative = 'and','or','xor','signedmin','signedmax' program = sys.stdin.read() program = assignments.parse_string(program) program = list(program) # ===== calculating number of bits produced by various operations def calcbits(op,parentbits): if op[0] == 'input': return xbits if op[0] == 'constant': return op[1] if op[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal'): # binary operation producing bit assert len(parentbits) == 2 assert parentbits[0] == parentbits[1] return 1 if op[0] in ('invert','Reverse'): # unary size-preserving operation assert len(parentbits) == 1 return parentbits[0] if op[0] in ('xor','or','and','add','sub','mul','lshift','unsignedrshift','signedrshift','signedmin','signedmax'): # binary size-preserving operation assert len(parentbits) == 2 assert parentbits[0] == parentbits[1] return parentbits[0] if op[0] == 'If': assert len(parentbits) == 3 assert parentbits[0] == 1 assert parentbits[1] == parentbits[2] return parentbits[1] if op[0] == 'Extract': top,bot = op[1:] assert len(parentbits) == 1 assert top >= bot assert parentbits[0] > top assert bot >= 0 return top+1-bot if op[0] in ('SignExt','ZeroExt'): morebits = op[1] assert len(parentbits) == 1 return parentbits[0]+morebits if op[0] == 'Concat': return sum(parentbits) raise Exception(f'unknown operation {op}') # ===== concrete calculations import random concreteinputlist = [random.randrange(2**xbits) for i in range(n)] concreteinput = {inputname[i]:concreteinputlist[i] for i in range(n)} def calcconcrete(op,parents): if op[0] == 'input': assert len(parents) == 0 return xbits,concreteinput[op[1]] if op[0] == 'constant': assert len(parents) == 0 return tuple(op[1:]) if op[0] == 'invert': assert len(parents) == 1 fbits,f = parents[0] return fbits,(1<>(8*j)) for j in range(fbits//8)] x.reverse() return fbits,sum(x[j]<<(8*j) for j in range(fbits//8)) if op[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal','xor','or','and','add','sub','mul','lshift','unsignedrshift','signedrshift','signedmin','signedmax'): assert len(parents) == 2 fbits,f = parents[0] gbits,g = parents[1] assert fbits == gbits if op[0] == 'equal': return 1,int(f == g) if op[0] == 'xor': return fbits,f^g if op[0] == 'and': return fbits,f&g if op[0] == 'or': return fbits,f|g if op[0] == 'add': return fbits,(f+g)%(1<= g) if op[0] == 'unsignedgt': return 1,int(f > g) if op[0] == 'unsignedle': return 1,int(f <= g) if op[0] == 'unsignedlt': return 1,int(f < g) if op[0] == 'signedle': return 1,int(f^flip <= g^flip) if op[0] == 'signedlt': return 1,int(f^flip < g^flip) if op[0] == 'signedmin': return fbits,min(f^flip,g^flip)^flip if op[0] == 'signedmax': return fbits,max(f^flip,g^flip)^flip if op[0] == 'lshift': assert g>g if op[0] == 'signedrshift': assert g>gsigned return fbits,(fgsigned+flip)^flip raise Exception('internal error') if op[0] == 'If': assert len(parents) == 3 cbits,c = parents[0] fbits,f = parents[1] gbits,g = parents[2] assert fbits == gbits assert cbits == 1 if c: return fbits,f return gbits,g if op[0] == 'Extract': top,bot = op[1:] assert len(parents) == 1 fbits,f = parents[0] assert fbits > top assert top >= bot assert bot >= 0 return top+1-bot,((f&((2<>bot) if op[0] in ('SignExt','ZeroExt'): ext = op[1] assert len(parents) == 1 fbits,f = parents[0] if op[0] == 'SignExt' and f&(1<<(fbits-1)): return fbits+ext,f+(1<<(fbits+ext))-(1< 0: what += lparen for pos,arg in enumerate(args): if pos > 0: what += comma what += arg what += rparen return group(name,what) lparen = lit('(') rparen = lit(')') comma = lit(',') Term = Forward() Termvar = group('var',Word('ABCDEFGHIJKLMNOPQRSTUVWXYZ')) X = lparen+Term+rparen X |= group('assign',Termvar+lit('=')+Term) X |= fun('Bits',Word(nums),Term) X |= group('Concat',lit('Concat')+lparen+Term+ZeroOrMore(comma+Term)+rparen) X |= fun('Extract',Term,Word(nums),Word(nums)) X |= fun('ZeroExt',Term,Word(nums)) for c in constpatterns: X |= fun(c) for op in 'invert','Reverse': X |= fun(op,Term) for op in 'add','sub','and','xor','or','equal','unsignedgt','unsignedge','unsignedlt','unsignedle','signedlt','signedle','signedmin','signedmax','lshift','unsignedrshift','signedrshift': X |= fun(op,Term,Term) X |= fun('If',Term,Term,Term) X |= Termvar Term <<= X result = [] for line in rules.splitlines(): line = line.split('->') if len(line) != 2: continue try: left = (Term+StringEnd()).parse_string(line[0])[0] except: raise Exception(f'failure to parse {line[0]}') try: right = (Term+StringEnd()).parse_string(line[1])[0] except: raise Exception(f'failure to parse {line[1]}') result.append((left,right)) return result peephole_list = peephole_init() def topop(pattern): if len(pattern) == 1: return 'constant' if len(pattern) == 2: assert pattern[0] != 'var' return pattern[0] assert len(pattern) >= 3 if pattern[0] in ('assign','Bits'): return topop(pattern[2]) return pattern[0] peephole_by_op = {} for left,right in peephole_list: op = topop(left) if op not in peephole_by_op: peephole_by_op[op] = [] peephole_by_op[op].append((left,right)) # returns tuple of new match dictionaries compatible with mforce def peephole_match(pattern,v,mforce={}): if len(pattern) == 1: if operation[v][0] != 'constant': return () if pattern[0] in ('bit0','bit1') and bits[v] != 1: return () if operation[v] == ['constant',bits[v],constpatterns[pattern[0]](bits[v])]: return {}, return () if len(pattern) == 2: if pattern[0] == 'var': if mforce.get(pattern[1],v) != v: return () return {pattern[1]:v}, if operation[v][0] == pattern[0]: p, = parents[v] return peephole_match(pattern[1],p,mforce=mforce) return () if pattern[0] == 'assign': assert pattern[1][0] == 'var' V = pattern[1][1] results = () for m in peephole_match(pattern[2],v,mforce=mforce): if V in m: continue m[V] = v results += m, return results if pattern[0] == 'Bits': if bits[v] != int(pattern[1]): return () return peephole_match(pattern[2],v,mforce=mforce) if pattern[0] == 'ZeroExt': # would also work for SignExt if operation[v] != [pattern[0],int(pattern[2])]: return () if len(parents[v]) != 1: return () return peephole_match(pattern[1],parents[v][0],mforce=mforce) if pattern[0] == 'Extract': if operation[v] != [pattern[0],int(pattern[2]),int(pattern[3])]: return () if len(parents[v]) != 1: return () return peephole_match(pattern[1],parents[v][0],mforce=mforce) if operation[v][0] != pattern[0]: return () if len(parents[v]) != len(pattern[1:]): return () todo = [parents[v]] if pattern[0] in commutative: todo += [[parents[v][1],parents[v][0]]] results = () for parentlist in todo: partialresults = {}, # conceptually: match each parent in turn, extending partial result # but there can be multiple matches so have list of partial results for subpattern,p in zip(pattern[1:],parentlist): newpartialresults = () for m in partialresults: submforce = mforce.copy() submforce.update(m) for subm in peephole_match(subpattern,p,mforce=submforce): assert all(m[v] == subm[v] for v in subm if v in m) subm.update(m) newpartialresults += subm, partialresults = newpartialresults results += partialresults return results # create new variable according to pattern # except substituting pattern variables according to m # or put new variable in place of v if v is not None def allocate(pattern,m,v=None,forcebits=None): global nextvalue if pattern[0] == 'var': y = m[pattern[1]] if v is not None: assert concrete[v] == concrete[y] operation[v] = operation[y] parents[v] = parents[y] return y if pattern[0] in constpatterns: assert forcebits is not None if pattern[0] in ('bit0','bit1'): assert forcebits == 1 newoperation = ['constant',forcebits,constpatterns[pattern[0]](forcebits)] newparents = [] newconcrete = calcconcrete(newoperation,[concrete[x] for x in newparents]) if v is None: nextvalue += 1 v = nextvalue bits[v] = forcebits concrete[v] = newconcrete else: assert newconcrete == concrete[v] operation[v] = newoperation parents[v] = newparents assert bits[v] == forcebits assert bits[v] == calcbits(operation[v],[bits[p] for p in parents[v]]) assert concrete[v][0] == bits[v] return v if pattern[0] == 'Bits': if forcebits is None: forcebits = int(pattern[1]) assert forcebits == int(pattern[1]) return allocate(pattern[2],m,v=v,forcebits=forcebits) if pattern[0] == 'ZeroExt': morebits = int(pattern[2]) subforcebits = None if forcebits is None else forcebits-morebits newparents = [allocate(pattern[1],m,forcebits=subforcebits)] newoperation = ['ZeroExt',morebits] newconcrete = calcconcrete(newoperation,[concrete[x] for x in newparents]) if v is None: nextvalue += 1 v = nextvalue bits[v] = calcbits(operation[v],[bits[p] for p in parents[v]]) concrete[v] = newconcrete else: assert newconcrete == concrete[v] assert bits[v] == calcbits(operation[v],[bits[p] for p in parents[v]]) operation[v] = newoperation parents[v] = newparents if forcebits is not None: assert bits[v] == forcebits assert concrete[v][0] == bits[v] return v if pattern[0] == 'Extract': top,bot = map(int,pattern[2:]) if forcebits is not None: assert forcebits == 1+top-bot forcebits = 1+top-bot newparents = [allocate(pattern[1],m)] newoperation = ['Extract',top,bot] newconcrete = calcconcrete(newoperation,[concrete[x] for x in newparents]) if v is None: nextvalue += 1 v = nextvalue bits[v] = forcebits concrete[v] = newconcrete else: assert newconcrete == concrete[v] operation[v] = newoperation parents[v] = newparents assert bits[v] == forcebits assert bits[v] == calcbits(operation[v],[bits[p] for p in parents[v]]) assert concrete[v][0] == bits[v] return v if pattern[0] == 'If': parentcond = allocate(pattern[1],m,forcebits=1) parentthen = allocate(pattern[2],m,forcebits=forcebits) parentelse = allocate(pattern[3],m,forcebits=forcebits) newparents = [parentcond,parentthen,parentelse] elif pattern[0] == 'Concat': newparents = [allocate(p,m) for p in pattern[1:]] newbits = calcbits(['Concat'],[bits[p] for p in newparents]) if forcebits is not None: assert forcebits == newbits forcebits = newbits elif pattern[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'): newparents = [allocate(p,m) for p in pattern[1:]] elif pattern[0] in ('signedmin','signedmax','xor','and','or','invert','add','sub','signedrshift'): newparents = [allocate(p,m,forcebits=forcebits) for p in pattern[1:]] else: raise Exception(f'unrecognized operation on right side of rule: {pattern[0]}') newoperation = [pattern[0]] newconcrete = calcconcrete(newoperation,[concrete[x] for x in newparents]) newbits = calcbits(newoperation,[bits[p] for p in newparents]) if forcebits is not None: assert forcebits == newbits if v is None: nextvalue += 1 v = nextvalue bits[v] = newbits concrete[v] = newconcrete else: assert concrete[v] == newconcrete assert bits[v] == newbits operation[v] = newoperation parents[v] = newparents assert concrete[v][0] == bits[v] return v if printstats: sys.stderr.write(f'starting main loop, {len(operation)} nodes, maxrsskb {maxrsskb()}\n') sys.stderr.flush() progress = True while progress: progress = False # ===== cleanups children = {} for z in operation: children[z] = set() for z in operation: for x in parents[z]: children[x].add(z) outputvalues = set(value[v] for v in outputname) deleting = set(v for v in operation if len(children[v]) == 0 and v not in outputvalues) merging = set() merge = [] for x in operation: c = list(children[x]) if len(c) < 2: continue part2children = {} for y in sorted(children[x]): if y in merging or y in deleting: continue # compute part so that two children with same part can be merged if operation[y][0] in commutative: yparents = tuple(sorted(parents[y])) else: yparents = tuple(parents[y]) part = tuple(operation[y]),yparents if part not in part2children: part2children[part] = [] part2children[part].append(y) for _,d in part2children.items(): if len(d) < 2: continue merge.append(d) for y in d: merging.add(y) # merge constants orphans2x = {} for x in operation: if len(parents[x]) > 0: continue xop = tuple(operation[x]) if xop not in orphans2x: orphans2x[xop] = [] orphans2x[xop].append(x) for xop in orphans2x: if len(orphans2x[xop]) < 2: continue if any(y in merging or y in deleting for y in orphans2x[xop]): continue merge.append(tuple(orphans2x[xop])) for y in orphans2x[xop]: merging.add(y) # actually do the merging for ylist in merge: y = ylist[0] for z in ylist[1:]: # eliminate z in favor of y if bits[y] != bits[z] or concrete[y] != concrete[z]: raise Exception(f'internal error, merging value {y} concretely {concrete[y]} with value {z} concretely {concrete[z]}') for v in outputname: if value[v] == z: value[v] = y for t in children[z]: for j in range(len(parents[t])): if parents[t][j] == z: parents[t][j] = y deleting.add(z) for v in deleting: del operation[v] del parents[v] del bits[v] del concrete[v] progress = True if printstats: sys.stderr.write(f'after cleanup: {len(operation)} nodes, maxrsskb {maxrsskb()}\n') sys.stderr.flush() if len(deleting)*10 > len(operation): # XXX: try tuning 10 continue # ===== peephole from DSL todo = [] for v in operation: if operation[v][0] not in ('constant','input'): if all(operation[p][0] == 'constant' for p in parents[v]): assert bits[v] == concrete[v][0] operation[v] = ['constant',bits[v],concrete[v][1]] parents[v] = [] progress = True for v in operation: for left,right in peephole_by_op.get(operation[v][0],[]): mtuple = peephole_match(left,v) if len(mtuple) == 0: continue m = mtuple[0] # or could heuristically pick "best" match todo.append((left,right,m,v)) progress = True break for left,right,m,v in todo: vconcrete = concrete[v] allocate(right,m,v=v,forcebits=bits[v]) if vconcrete != concrete[v]: raise Exception(f'replacing {left} with {right} changed concrete values from {vconcrete} to {concrete[v]}') # ===== peephole beyond DSL # rewrite s[high:low][high2:low2] as s[high2+low:low2+low] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Extract': continue s = parents[x][0] high2,low2 = operation[v][1:] high,low = operation[x][1:] newoperation = ['Extract',high2+low,low2+low] newparents = [s] assert concrete[v] == calcconcrete(newoperation,[concrete[p] for p in newparents]) assert bits[v] == calcbits(newoperation,[bits[p] for p in newparents]) operation[v] = newoperation parents[v] = newparents progress = True # rewrite Concat(...)[high:low] as ...[high-pos:low-pos] if possible for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Concat': continue high,low = operation[v][1:] pos = 0 for y in reversed(parents[x]): if pos <= low and high < pos+bits[y]: newoperation = ['Extract',high-pos,low-pos] newparents = [y] assert concrete[v] == calcconcrete(newoperation,[concrete[p] for p in newparents]) assert bits[v] == calcbits(newoperation,[bits[p] for p in newparents]) operation[v] = newoperation parents[v] = newparents progress = True break pos += bits[y] # rewrite invert(s)[high:low] as invert(s[high:low]) extractinvert = [] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'invert': continue s = parents[x][0] extractinvert.append((v,s)) for v,s in extractinvert: nextvalue += 1 operation[nextvalue] = operation[v] parents[nextvalue] = [s] bits[nextvalue] = bits[v] concrete[nextvalue] = calcconcrete(operation[nextvalue],[concrete[x] for x in parents[nextvalue]]) newoperation = ['invert'] newparents = [nextvalue] assert concrete[v] == calcconcrete(newoperation,[concrete[p] for p in newparents]) assert bits[v] == calcbits(newoperation,[bits[p] for p in newparents]) operation[v] = newoperation parents[v] = newparents progress = True # rewrite Reverse(s)[high:low] as Reverse(s[...]) if possible extractreverse = [] for v in operation: if operation[v][0] != 'Extract': continue x = parents[v][0] if operation[x][0] != 'Reverse': continue high,low = operation[v][1:] if low%8 != 0: continue if high%8 != 7: continue extractreverse.append((v,parents[x][0],high,low)) for v,s,high,low in extractreverse: nextvalue += 1 operation[nextvalue] = ['Extract',bits[s]-1-low,bits[s]-1-high] parents[nextvalue] = [s] bits[nextvalue] = high-low+1 assert bits[nextvalue] == bits[v] concrete[nextvalue] = calcconcrete(operation[nextvalue],[concrete[x] for x in parents[nextvalue]]) newoperation = ['Reverse'] newparents = [nextvalue] assert concrete[v] == calcconcrete(newoperation,[concrete[p] for p in newparents]) assert bits[v] == calcbits(newoperation,[bits[p] for p in newparents]) operation[v] = newoperation parents[v] = newparents progress = True if printstats: sys.stderr.write(f'after peephole: {len(operation)} nodes, maxrsskb {maxrsskb()}\n') sys.stderr.flush() # ===== output done = set() finaloperations = set() def do(v): if v in done: return done.add(v) for x in parents[v]: do(x) if operation[v][0] == 'input': print('v%d = %s' % (v,operation[v][1])) else: p = ['v%s' % x for x in parents[v]] p += ['%s' % x for x in operation[v][1:]] finaloperations.add(operation[v][0]) print('v%d = %s(%s)' % (v,operation[v][0],','.join(p))) for v in outputname: do(value[v]) print('%s = v%d' % (v,value[v])) sys.stdout.flush() if any(op not in ('signedmin','signedmax') for op in finaloperations): raise Exception(f'resulting operations {sorted(finaloperations)} are not just signedmin, signedmax') # decompose will also refuse to parse this, but error here is more informative