-rwxr-xr-x 7003 djbsort-20180710/verif/minmax
#!/usr/bin/env python2
import sys
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
n = int(sys.argv[1])
inputname = tuple('in_%d_32' % i for i in range(n))
outputname = tuple('out%d' % i for i in range(n))
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
number = Word(nums)
name = Word(alphas,alphas+nums+"_")
assignSLE = (name + equal + Literal('SLE').suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group('SLE'))
assignIf = (name + equal + Literal('If').suppress()
+ lparen + name + comma + name + comma + name + rparen
).setParseAction(group('If'))
assignExtract = (name + equal + Literal('Extract').suppress()
+ lparen + name + comma + number + comma + number + rparen
).setParseAction(group('Extract'))
assignConcat = (name + equal + Literal('Concat').suppress()
+ lparen + name + ZeroOrMore(comma + name) + rparen
).setParseAction(group('Concat'))
assigncopy = (name + equal + name).setParseAction(group('copy'))
assignment = assignSLE | assignIf | assignExtract | assignConcat | assigncopy
assignments = ZeroOrMore(assignment) + StringEnd()
program = sys.stdin.read()
program = assignments.parseString(program)
program = list(program)
nextvalue = 0
# indexed by variable name:
value = dict()
# indexed by value:
operation = dict()
parents = dict()
bits = dict()
for v in inputname:
nextvalue += 1
value[v] = nextvalue
operation[nextvalue] = ['input',v]
parents[nextvalue] = []
bits[nextvalue] = 32
for p in program:
if p[1] in value:
raise Exception('%s assigned twice',p[1])
if p[0] == 'copy':
value[p[1]] = value[p[2]]
continue
nextvalue += 1
operation[nextvalue] = [p[0]]
if p[0] == 'SLE':
value[p[1]] = nextvalue
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = 1
elif p[0] == 'If':
assert bits[value[p[2]]] == 1
assert bits[value[p[3]]] == bits[value[p[4]]]
value[p[1]] = nextvalue
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[3]]]
elif p[0] == 'Extract':
top = int(p[3])
bot = int(p[4])
assert top >= bot
assert bits[value[p[2]]] > top
assert bot >= 0
value[p[1]] = nextvalue
operation[nextvalue] += [top,bot]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = top + 1 - bot
elif p[0] == 'Concat':
value[p[1]] = nextvalue
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = sum(bits[v] for v in parents[nextvalue])
else:
raise Exception('unknown internal operation %s' % p[0])
progress = True
while progress:
progress = False
# rewrite If(SLE(s,t),s,t) as Min(s,t)
for v in operation:
if operation[v][0] == 'If':
x = parents[v][0]
if operation[x][0] == 'SLE':
if parents[v][1] == parents[x][0]:
if parents[v][2] == parents[x][1]:
operation[v] = ['Min']
parents[v] = parents[x]
progress = True
# rewrite If(SLE(t,s),s,t) as Max(s,t)
for v in operation:
if operation[v][0] == 'If':
x = parents[v][0]
if operation[x][0] == 'SLE':
if parents[v][1] == parents[x][1]:
if parents[v][2] == parents[x][0]:
operation[v] = ['Max']
parents[v] = parents[x]
progress = True
# rewrite If(c,s[top:bot],t[top:bot])
# as If(c,s,t)[top:bot]
ifextract = []
for z in operation:
if operation[z][0] == 'If':
c,x,y = parents[z]
if operation[x][0] == 'Extract':
if operation[y] == operation[x]: # extract same bits
ifextract += [z]
for z in ifextract:
c,x,y = parents[z]
s = parents[x][0]
t = parents[y][0]
nextvalue += 1
operation[nextvalue] = ['If']
parents[nextvalue] = [c,s,t]
bits[nextvalue] = bits[s]
assert bits[s] == bits[t]
operation[z] = operation[x]
parents[z] = [nextvalue]
progress = True
children = dict()
for z in operation: children[z] = set()
for v in outputname: children[value[v]].add(-1)
for z in operation:
for x in parents[z]:
children[x].add(z)
deleting = set(v for v in operation if len(children[v]) == 0)
merging = deleting.copy()
merge = []
for x in operation:
c = list(children[x])
for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]:
if y == -1: continue
if z == -1: continue
if operation[y] != operation[z]: continue
parentsmatch = False
if parents[y] == parents[z]: parentsmatch = True
if operation[y][0] in ['Min','Max']:
if set(parents[y]) == set(parents[z]):
parentsmatch = True
if not parentsmatch: continue
assert bits[y] == bits[z]
if y in merging: continue
if z in merging: continue
merge += [(y,z)]
merging.add(y)
merging.add(z)
for y,z in merge:
# print 'eliminating %s in favor of %s' % (z,y)
for t in children[z]:
if t == -1:
for v in outputname:
if value[v] == z:
value[v] = y
else:
for j in range(len(parents[t])):
if parents[t][j] == z:
parents[t][j] = y
deleting.add(z)
progress = True
for v in deleting:
del operation[v]
del parents[v]
del bits[v]
progress = True
children = dict()
for z in operation: children[z] = set()
for v in outputname: children[value[v]].add(-1)
for z in operation:
for x in parents[z]:
children[x].add(z)
concatextract = []
for z in operation:
if operation[z][0] == 'Concat':
if all(operation[y][0] == 'Extract' for y in parents[z]):
source = parents[parents[z][0]][0]
if all(parents[y][0] == source for y in parents[z]):
remainingbits = bits[z]
ok = True
for y in parents[z]:
if operation[y][1] == remainingbits - 1:
remainingbits = operation[y][2]
else:
ok = False
if ok:
concatextract += [(z,source)]
for y,source in concatextract:
for z in children[y]:
if z == -1:
for v in outputname:
if value[v] == y:
value[v] = source
else:
for i in range(len(parents[z])):
if parents[z][i] == y:
parents[z][i] = source
progress = True
done = set()
def do(v):
if v in done: return
done.add(v)
for x in parents[v]: do(x)
if operation[v][0] == 'input':
print 'v%d = %s' % (v,operation[v][1])
else:
p = ['v%s' % x for x in parents[v]]
p += ['%s' % x for x in operation[v][1:]]
print 'v%d = %s(%s)' % (v,operation[v][0],','.join(p))
for v in outputname:
do(value[v])
print '%s = v%d' % (v,value[v])