-rwxr-xr-x 12617 djbsort-20180729/verif/minmax
#!/usr/bin/env python2
import sys
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
n = int(sys.argv[1])
inputname = tuple('in_%d_32' % i for i in range(n))
outputname = tuple('out%d' % i for i in range(n))
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
number = Word(nums)
name = Word(alphas,alphas+nums+"_")
assignment = (
name + equal + Literal('constant').suppress()
+ lparen + number + comma + number + rparen
).setParseAction(group('constant'))
assignment |= (
name + equal + Literal('invert').suppress()
+ lparen + name + rparen
).setParseAction(group('invert'))
assignment |= (
name + equal + Literal('Reverse').suppress()
+ lparen + name + rparen
).setParseAction(group('Reverse'))
for binary in ['xor','or','and','add','sub','mul','signedrshift','signedle','signedlt','equal']:
assignment |= (
name + equal + Literal(binary).suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group(binary))
assignment |= (
name + equal + Literal('If').suppress()
+ lparen + name + comma + name + comma + name + rparen
).setParseAction(group('If'))
assignment |= (
name + equal + Literal('Extract').suppress()
+ lparen + name + comma + number + comma + number + rparen
).setParseAction(group('Extract'))
assignment |= (
name + equal + Literal('Concat').suppress()
+ lparen + name + ZeroOrMore(comma + name) + rparen
).setParseAction(group('Concat'))
assignment |= (
name + equal + name).setParseAction(group('copy')
)
assignments = ZeroOrMore(assignment) + StringEnd()
program = sys.stdin.read()
program = assignments.parseString(program)
program = list(program)
nextvalue = 0
# indexed by variable name:
value = dict()
# indexed by value:
operation = dict()
parents = dict()
bits = dict()
for v in inputname:
nextvalue += 1
value[v] = nextvalue
operation[nextvalue] = ['input',v]
parents[nextvalue] = []
bits[nextvalue] = 32
for p in program:
if p[1] in value:
raise Exception('%s assigned twice',p[1])
if p[0] == 'copy':
value[p[1]] = value[p[2]]
continue
nextvalue += 1
operation[nextvalue] = [p[0]]
if p[0] == 'constant':
parents[nextvalue] = []
operation[nextvalue] += [int(p[2]),int(p[3])]
bits[nextvalue] = int(p[2])
elif p[0] in ['signedle','signedlt','equal']:
# binary operation producing bit
assert bits[value[p[2]]] == bits[value[p[3]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = 1
elif p[0] in ['invert','Reverse']:
# unary size-preserving operation
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[2]]]
elif p[0] in ['xor','or','and','add','sub','mul','signedrshift']:
# binary size-preserving operation
assert bits[value[p[2]]] == bits[value[p[3]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[2]]]
elif p[0] == 'If':
assert bits[value[p[2]]] == 1
assert bits[value[p[3]]] == bits[value[p[4]]]
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = bits[value[p[3]]]
elif p[0] == 'Extract':
top = int(p[3])
bot = int(p[4])
assert top >= bot
assert bits[value[p[2]]] > top
assert bot >= 0
operation[nextvalue] += [top,bot]
parents[nextvalue] = [value[p[2]]]
bits[nextvalue] = top + 1 - bot
elif p[0] == 'Concat':
parents[nextvalue] = [value[v] for v in p[2:]]
bits[nextvalue] = sum(bits[v] for v in parents[nextvalue])
else:
raise Exception('unknown internal operation %s' % p[0])
value[p[1]] = nextvalue
progress = True
while progress:
progress = False
# rewrite invert(invert(c)) as c
# rewrite Reverse(Reverse(c)) as c
for v in operation:
if operation[v][0] in ['invert','Reverse']:
x = parents[v][0]
if operation[x][0] == operation[v][0]:
y = parents[x][0]
operation[v] = operation[y]
parents[v] = parents[y]
progress = True
# rewrite xor(s,xor(s,t)) as t
for v in operation:
if operation[v][0] != 'xor': continue
s,x = parents[v]
if operation[x][0] != 'xor':
x,s = parents[v]
if operation[x][0] != 'xor': continue
y,t = parents[x]
if y != s:
t,y = parents[x]
if y != s: continue
operation[v] = operation[t]
parents[v] = parents[t]
progress = True
# if c is constant(b,2**(b-1)-1):
# rewrite signedmin(s,c) as s
# rewrite signedmax(s,c) as c
for v in operation:
b = bits[v]
if operation[v][0] not in ['signedmin','signedmax']: continue
s,c = parents[v]
if operation[c] != ['constant',b,2**(b-1)-1]:
c,s = parents[v]
if operation[c] != ['constant',b,2**(b-1)-1]: continue
if operation[v][0] == 'signedmin':
operation[v] = operation[s]
parents[v] = parents[s]
progress = True
continue
if operation[v][0] == 'signedmax':
operation[v] = operation[c]
parents[v] = parents[c]
progress = True
continue
# if c is constant(b,2**(b-1)):
# rewrite signedmin(s,c) as c
# rewrite signedmax(s,c) as s
for v in operation:
b = bits[v]
if operation[v][0] not in ['signedmin','signedmax']: continue
s,c = parents[v]
if operation[c] != ['constant',b,2**(b-1)]:
c,s = parents[v]
if operation[c] != ['constant',b,2**(b-1)]: continue
if operation[v][0] == 'signedmax':
operation[v] = operation[s]
parents[v] = parents[s]
progress = True
continue
if operation[v][0] == 'signedmin':
operation[v] = operation[c]
parents[v] = parents[c]
progress = True
continue
# if c is constant(b,2**b-1):
# rewrite xor(c,s) as invert(s)
for v in operation:
b = bits[v]
if operation[v][0] != 'xor': continue
c,s = parents[v]
if operation[c] != ['constant',b,2**b-1]:
s,c = parents[v]
if operation[c] != ['constant',b,2**b-1]: continue
operation[v] = ['invert']
parents[v] = [s]
progress = True
# rewrite signedmin(invert(s),invert(t)) as invert(signedmax(s,t))
# rewrite signedmax(invert(s),invert(t)) as invert(signedmin(s,t))
invertminmax = []
for v in operation:
b = bits[v]
if operation[v][0] in ['signedmin','signedmax']:
x,y = parents[v]
if operation[x][0] == 'invert' and operation[y][0] == 'invert':
invertminmax += [(v,parents[x][0],parents[y][0])]
for v,s,t in invertminmax:
nextvalue += 1
flip = {'signedmax':'signedmin','signedmin':'signedmax'}
operation[nextvalue] = [flip[operation[v][0]]]
parents[nextvalue] = [s,t]
bits[nextvalue] = bits[v]
operation[v] = ['invert']
parents[v] = [nextvalue]
progress = True
# rewrite invert(s)[high:low] as invert(s[high:low])
extractinvert = []
for v in operation:
if operation[v][0] != 'Extract': continue
x = parents[v][0]
if operation[x][0] != 'invert': continue
s = parents[x][0]
extractinvert += [(v,s)]
for v,s in extractinvert:
nextvalue += 1
operation[nextvalue] = operation[v]
parents[nextvalue] = [s]
bits[nextvalue] = bits[v]
operation[v] = ['invert']
parents[v] = [nextvalue]
progress = True
# rewrite s[bits-1:0] as s
for v in operation:
if operation[v][0] != 'Extract': continue
s = parents[v][0]
if operation[v][1:] != [bits[s]-1,0]: continue
operation[v] = operation[s]
parents[v] = parents[s]
progress = True
# rewrite s[high:low][high2:low2] as s[high2+low:low2+low]
for v in operation:
if operation[v][0] != 'Extract': continue
x = parents[v][0]
if operation[x][0] != 'Extract': continue
s = parents[x][0]
high2,low2 = operation[v][1:]
high,low = operation[x][1:]
operation[v] = ['Extract',high2 + low,low2 + low]
parents[v] = [s]
progress = True
# rewrite Reverse(s)[high:low] as Reverse(s[...]) if possible
extractreverse = []
for v in operation:
if operation[v][0] != 'Extract': continue
x = parents[v][0]
if operation[x][0] != 'Reverse': continue
high,low = operation[v][1:]
if low % 8 != 0: continue
if high % 8 != 7: continue
extractreverse += [(v,parents[x][0],high,low)]
for v,s,high,low in extractreverse:
nextvalue += 1
operation[nextvalue] = ['Extract',bits[s] - 1 - low,bits[s] - 1 - high]
parents[nextvalue] = [s]
bits[nextvalue] = high - low + 1
assert bits[nextvalue] == bits[v]
operation[v] = ['Reverse']
parents[v] = [nextvalue]
progress = True
# rewrite Concat(...)[high:low] as ...[high-pos:low-pos] if possible
for v in operation:
if operation[v][0] != 'Extract': continue
x = parents[v][0]
if operation[x][0] != 'Concat': continue
high,low = operation[v][1:]
pos = 0
for y in reversed(parents[x]):
if pos <= low and high < pos + bits[y]:
operation[v] = ['Extract',high - pos,low - pos]
parents[v] = [y]
progress = True
break
pos += bits[y]
# rewrite If(signedle(s,t),s,t) as signedmin(s,t)
# rewrite If(signedlt(s,t),s,t) as signedmin(s,t)
for v in operation:
if operation[v][0] == 'If':
x = parents[v][0]
if operation[x][0] in ['signedle','signedlt']:
if parents[v][1] == parents[x][0]:
if parents[v][2] == parents[x][1]:
operation[v] = ['signedmin']
parents[v] = parents[x]
progress = True
# rewrite If(signedle(t,s),s,t) as signedmax(s,t)
# rewrite If(signedlt(t,s),s,t) as signedmax(s,t)
for v in operation:
if operation[v][0] == 'If':
x = parents[v][0]
if operation[x][0] in ['signedle','signedlt']:
if parents[v][1] == parents[x][1]:
if parents[v][2] == parents[x][0]:
operation[v] = ['signedmax']
parents[v] = parents[x]
progress = True
# rewrite If(c,constant(1,1),constant(1,0)) as c
for v in operation:
if operation[v][0] == 'If':
c,x,y = parents[v]
if operation[x] == ['constant',1,1]:
if operation[y] == ['constant',1,0]:
operation[v] = operation[c]
parents[v] = parents[c]
progress = True
# rewrite xor(c,constant(1,1)) as invert(c)
for v in operation:
if operation[v][0] == 'xor':
c,x = parents[v]
if operation[x] == ['constant',1,1]:
operation[v] = ['invert']
parents[v] = [c]
progress = True
# rewrite equal(c,constant(1,0)) as invert(c)
for v in operation:
if operation[v][0] == 'equal':
c,x = parents[v]
if operation[x] == ['constant',1,0]:
operation[v] = ['invert']
parents[v] = [c]
progress = True
children = dict()
for z in operation: children[z] = set()
for v in outputname: children[value[v]].add(-1)
for z in operation:
for x in parents[z]:
children[x].add(z)
deleting = set(v for v in operation if len(children[v]) == 0)
merging = deleting.copy()
merge = []
for x in operation:
c = list(children[x])
for y,z in [(c[i],c[j]) for j in range(len(c)) for i in range(j)]:
if y == -1: continue
if z == -1: continue
if operation[y] != operation[z]: continue
parentsmatch = False
if parents[y] == parents[z]: parentsmatch = True
if operation[y][0] in ['signedmin','signedmax']:
if set(parents[y]) == set(parents[z]):
parentsmatch = True
if not parentsmatch: continue
assert bits[y] == bits[z]
if y in merging: continue
if z in merging: continue
merge += [(y,z)]
merging.add(y)
merging.add(z)
for y,z in merge:
# print 'eliminating %s in favor of %s' % (z,y)
for t in children[z]:
if t == -1:
for v in outputname:
if value[v] == z:
value[v] = y
else:
for j in range(len(parents[t])):
if parents[t][j] == z:
parents[t][j] = y
deleting.add(z)
progress = True
for v in deleting:
del operation[v]
del parents[v]
del bits[v]
progress = True
done = set()
def do(v):
if v in done: return
done.add(v)
for x in parents[v]: do(x)
if operation[v][0] == 'input':
print 'v%d = %s' % (v,operation[v][1])
else:
p = ['v%s' % x for x in parents[v]]
p += ['%s' % x for x in operation[v][1:]]
print 'v%d = %s(%s)' % (v,operation[v][0],','.join(p))
for v in outputname:
do(value[v])
print '%s = v%d' % (v,value[v])