-rwxr-xr-x 4050 djbsort-20180710/verif/decompose
#!/usr/bin/env python2
import sys
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
n = int(sys.argv[1])
inputname = tuple('in_%d_32' % i for i in range(n))
outputname = tuple('out%d' % i for i in range(n))
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = Literal('(').suppress()
rparen = Literal(')').suppress()
comma = Literal(',').suppress()
equal = Literal('=').suppress()
name = Word(alphas,alphas+nums+"_")
assignMin = (name + equal + Literal('Min').suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group('Min'))
assignMax = (name + equal + Literal('Max').suppress()
+ lparen + name + comma + name + rparen
).setParseAction(group('Max'))
assigncopy = (name + equal + name).setParseAction(group('copy'))
assignment = assignMin | assignMax | assigncopy
assignments = ZeroOrMore(assignment) + StringEnd()
program = sys.stdin.read()
program = assignments.parseString(program)
program = list(program)
if n < 2:
assigned = set()
for i in range(n): assigned.add(inputname[i])
for p in program:
if p[0] == 'copy':
if not p[2] in assigned: raise Exception('%s used before assigned' % p[2])
if p[1] in assigned: raise Exception('%s assigned twice' % p[1])
assigned.add(p[1])
else:
raise Exception('only copies allowed for n<2')
for i in range(n):
if not outputname[i] in assigned:
raise Exception('output %s not assigned' % outputname[i])
sys.exit(0)
batch = set((inputname[i],) for i in range(n))
# invariants:
# batch is a partition of range(n)
# each part is nonempty
# each part is a list of names of variables whose values are sorted
def findmerge():
batchpos = dict()
for b in batch:
for v in b:
batchpos[v] = b
for p in program:
if p[0] == 'copy':
batchpos[p[1]] = batchpos[p[2]]
elif p[0] in ['Min','Max']:
b,c = batchpos[p[2]],batchpos[p[3]]
if b != c: return b,c
batchpos[p[1]] = b
raise Exception('no progress')
def merge(b,c):
global program
batch.discard(b)
batch.discard(c)
if len(b) > len(c): b,c = c,b # simplifies tests later
known = set(b + c) # variables depending only on b,c
outputs = set(b + c) # sinks
program1 = [] # instructions depending only on b,c
program2 = [] # other instructions
for p in program:
if p[0] == 'copy':
if p[2] in known:
outputs.discard(p[2])
known.add(p[1])
outputs.add(p[1])
program1 += [p]
else:
program2 += [p]
elif p[0] in ['Min','Max']:
if p[2] in known and p[3] in known:
outputs.discard(p[2])
outputs.discard(p[3])
known.add(p[1])
outputs.add(p[1])
program1 += [p]
else:
program2 += [p]
if len(outputs) != len(b) + len(c):
raise Exception('unexpected %d combinations of %s, %s' % (len(outputs),b,c))
# using fact that len(b) <= len(c)
for test in range(len(b) + 1):
values = dict()
for i in range(test): values[b[i]] = i
for i in range(len(c)): values[c[i]] = test + i
for i in range(test,len(b)): values[b[i]] = len(c) + i
for p in program1:
if p[0] == 'copy':
values[p[1]] = values[p[2]]
elif p[0] == 'Min':
values[p[1]] = min(values[p[2]],values[p[3]])
elif p[0] == 'Max':
values[p[1]] = max(values[p[2]],values[p[3]])
if test == 0:
bc = [(values[o],o) for o in outputs]
bc.sort()
if [o[0] for o in bc] != range(len(b) + len(c)):
raise Exception('bad combinations of %s, %s' % (b,c))
bc = tuple(o[1] for o in bc)
else:
if [values[o] for o in bc] != range(len(b) + len(c)):
raise Exception('test %d fails for %s, %s' % (test,b,c))
batch.add(bc)
program = program2
while len(batch) > 1:
b,c = findmerge()
merge(b,c)
b = batch.pop()
if b != outputname:
raise Exception('sorted but in wrong order %s',b)