-rwxr-xr-x 7420 djbsort-20180717/verif/minmax
#!/usr/bin/env python2
import sys
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
n = int(sys.argv[1])
inputname = tuple('in_%d_32' % i for i in range(n))
outputname = tuple('out%d' % i for i in range(n))
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
number = Word(nums)
name = Word(alphas,alphas+nums+"_")
assignment = (
name + equal + Literal('constant').suppress()
+ lparen + number + comma + number + rparen
).setParseAction(group('constant'))
assignment |= (
name + equal + Literal('invert').suppress()
+ lparen + name + rparen
).setParseAction(group('invert'))
for binary in ['xor','or','and','add','sub','mul','signedrshift','signedle','signedlt','equal']:
assignment |= (
name + equal + Literal(binary).suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group(binary))
assignment |= (
name + equal + Literal('If').suppress()
+ lparen + name + comma + name + comma + name + rparen
).setParseAction(group('If'))
assignment |= (
name + equal + Literal('Extract').suppress()
+ lparen + name + comma + number + comma + number + rparen
).setParseAction(group('Extract'))
assignment |= (
name + equal + Literal('Concat').suppress()
+ lparen + name + ZeroOrMore(comma + name) + rparen
).setParseAction(group('Concat'))
assignment |= (
name + equal + name).setParseAction(group('copy')
)
assignments = ZeroOrMore(assignment) + StringEnd()
program = sys.stdin.read()
program = assignments.parseString(program)
program = list(program)
nextvalue = 0
# indexed by variable name:
value = dict()
# indexed by value:
operation = dict()
parents = dict()
bits = dict()
for v in inputname:
nextvalue += 1
value[v] = nextvalue
operation[nextvalue] = ['input',v]
parents[nextvalue] = []
bits[nextvalue] = 32
for p in program:
if p[1] in value:
raise Exception('%s assigned twice',p[1])
if p[0] == 'copy':
value[p[1]] = value[p[2]]
continue
nextvalue += 1
operation[nextvalue] = [p[0]]
if p[0] == 'constant':
parents[nextvalue] = []
operation[nextvalue] += [int(p[2]),int(p[3])]
bits[nextvalue] = int(p[2])
elif p[0] in ['signedle','signedlt','equal']:
# binary operation producing bit
assert bits[value[p[2]]] == bits[value[p[3]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = 1
elif p[0] in ['invert']:
# unary size-preserving operation
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[2]]]
elif p[0] in ['xor','or','and','add','sub','mul','signedrshift']:
# binary size-preserving operation
assert bits[value[p[2]]] == bits[value[p[3]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[2]]]
elif p[0] == 'If':
assert bits[value[p[2]]] == 1
assert bits[value[p[3]]] == bits[value[p[4]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[3]]]
elif p[0] == 'Extract':
top = int(p[3])
bot = int(p[4])
assert top >= bot
assert bits[value[p[2]]] > top
assert bot >= 0
operation[nextvalue] += [top,bot]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = top + 1 - bot
elif p[0] == 'Concat':
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = sum(bits[v] for v in parents[nextvalue])
else:
raise Exception('unknown internal operation %s' % p[0])
value[p[1]] = nextvalue
progress = True
while progress:
progress = False
# rewrite If(signedle(s,t),s,t) as signedmin(s,t)
# rewrite If(signedlt(s,t),s,t) as signedmin(s,t)
for v in operation:
if operation[v][0] == 'If':
x = parents[v][0]
if operation[x][0] in ['signedle','signedlt']:
if parents[v][1] == parents[x][0]:
if parents[v][2] == parents[x][1]:
operation[v] = ['signedmin']
parents[v] = parents[x]
progress = True
# rewrite If(signedle(t,s),s,t) as signedmax(s,t)
# rewrite If(signedlt(t,s),s,t) as signedmax(s,t)
for v in operation:
if operation[v][0] == 'If':
x = parents[v][0]
if operation[x][0] in ['signedle','signedlt']:
if parents[v][1] == parents[x][1]:
if parents[v][2] == parents[x][0]:
operation[v] = ['signedmax']
parents[v] = parents[x]
progress = True
# rewrite If(c,constant(1,1),constant(1,0)) as c
for v in operation:
if operation[v][0] == 'If':
c,x,y = parents[v]
if operation[x] == ['constant',1,1]:
if operation[y] == ['constant',1,0]:
operation[v] = operation[c]
parents[v] = parents[c]
progress = True
# rewrite invert(invert(c)) as c
for v in operation:
if operation[v][0] == 'invert':
x = parents[v][0]
if operation[x][0] == 'invert':
y = parents[x][0]
operation[v] = operation[y]
parents[v] = parents[y]
progress = True
# rewrite xor(c,constant(1,1)) as invert(c)
for v in operation:
if operation[v][0] == 'xor':
c,x = parents[v]
if operation[x] == ['constant',1,1]:
operation[v] = ['invert']
parents[v] = [c]
progress = True
# rewrite equal(c,constant(1,0)) as invert(c)
for v in operation:
if operation[v][0] == 'equal':
c,x = parents[v]
if operation[x] == ['constant',1,0]:
operation[v] = ['invert']
parents[v] = [c]
progress = True
children = dict()
for z in operation: children[z] = set()
for v in outputname: children[value[v]].add(-1)
for z in operation:
for x in parents[z]:
children[x].add(z)
deleting = set(v for v in operation if len(children[v]) == 0)
merging = deleting.copy()
merge = []
for x in operation:
c = list(children[x])
for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]:
if y == -1: continue
if z == -1: continue
if operation[y] != operation[z]: continue
parentsmatch = False
if parents[y] == parents[z]: parentsmatch = True
if operation[y][0] in ['signedmin','signedmax']:
if set(parents[y]) == set(parents[z]):
parentsmatch = True
if not parentsmatch: continue
assert bits[y] == bits[z]
if y in merging: continue
if z in merging: continue
merge += [(y,z)]
merging.add(y)
merging.add(z)
for y,z in merge:
# print 'eliminating %s in favor of %s' % (z,y)
for t in children[z]:
if t == -1:
for v in outputname:
if value[v] == z:
value[v] = y
else:
for j in range(len(parents[t])):
if parents[t][j] == z:
parents[t][j] = y
deleting.add(z)
progress = True
for v in deleting:
del operation[v]
del parents[v]
del bits[v]
progress = True
done = set()
def do(v):
if v in done: return
done.add(v)
for x in parents[v]: do(x)
if operation[v][0] == 'input':
print 'v%d = %s' % (v,operation[v][1])
else:
p = ['v%s' % x for x in parents[v]]
p += ['%s' % x for x in operation[v][1:]]
print 'v%d = %s(%s)' % (v,operation[v][0],','.join(p))
for v in outputname:
do(value[v])
print '%s = v%d' % (v,value[v])