-rwxr-xr-x 19732 djbsort-20260210/sortverif/minmax raw
#!/usr/bin/env python3
import sys
from pyparsing import StringEnd,Literal,Word,ZeroOrMore,OneOrMore,Optional,Forward,alphas,nums
xbits = int(sys.argv[1])
n = int(sys.argv[2])
inputname = tuple(f'x_{i}_{xbits}' for i in range(n))
outputname = tuple(f'y_{i}_{xbits}' for i in range(n))
# ===== operation graph
nextvalue = 0
# indexed by variable name:
value = {}
# indexed by value:
operation = {}
parents = {}
bits = {}
concrete = {} # partly for sanity check, but also conveniently enables peephole constant propagation
# indexed by (operation,parents):
oppa2value = {}
# ===== parse program
def lit(x): return Literal(x).suppress()
def group(s):
def t(x):
x = list(x)
if len(x) == 1: return x
return [[s] + x]
return t
lparen = lit('(')
rparen = lit(')')
comma = lit(',')
equal = lit('=')
number = Word(nums)
name = Word(alphas,alphas+nums+"_")
assignment = (
name + equal + lit('constant')
+ lparen + number + comma + number + rparen
).set_parse_action(group('constant'))
for unary in ['invert','Reverse']:
assignment |= (
name + equal + lit(unary)
+ lparen + name + rparen
).set_parse_action(group(unary))
for binary in ['xor','or','and','add','sub','mul','lshift','signedrshift','unsignedrshift','unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal']:
assignment |= (
name + equal + lit(binary)
+ lparen + name + comma + name + rparen
).set_parse_action(group(binary))
assignment |= (
name + equal + lit('If')
+ lparen + name + comma + name + comma + name + rparen
).set_parse_action(group('If'))
assignment |= (
name + equal + lit('Extract')
+ lparen + name + comma + number + comma + number + rparen
).set_parse_action(group('Extract'))
assignment |= (
name + equal + lit('SignExt')
+ lparen + name + comma + number + rparen
).set_parse_action(group('SignExt'))
assignment |= (
name + equal + lit('ZeroExt')
+ lparen + name + comma + number + rparen
).set_parse_action(group('ZeroExt'))
for manyary in ['Concat']:
assignment |= (
name + equal + lit(manyary)
+ lparen + name + ZeroOrMore(comma + name) + rparen
).set_parse_action(group(manyary))
assignment |= (
name + equal + name).set_parse_action(group('copy')
)
assignments = ZeroOrMore(assignment) + StringEnd()
commutative = 'and','or','xor','signedmin','signedmax'
program = sys.stdin.read()
program = assignments.parse_string(program)
program = list(program)
# ===== concrete calculations, including number of bits as first component
import random
concreteinputlist = [random.randrange(2**xbits) for i in range(n)]
concreteinput = {inputname[i]:concreteinputlist[i] for i in range(n)}
def calcconcrete(op,parents):
if op[0] == 'input':
assert len(parents) == 0
return xbits,concreteinput[op[1]]
if op[0] == 'constant':
assert len(parents) == 0
return tuple(op[1:])
if op[0] == 'invert':
assert len(parents) == 1
fbits,f = parents[0]
return fbits,(1<<fbits)-1-f
if op[0] == 'Reverse':
assert len(parents) == 1
fbits,f = parents[0]
assert fbits%8 == 0
x = [255&(f>>(8*j)) for j in range(fbits//8)]
x.reverse()
return fbits,sum(x[j]<<(8*j) for j in range(fbits//8))
if op[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal','xor','or','and','add','sub','mul','lshift','unsignedrshift','signedrshift','signedmin','signedmax'):
assert len(parents) == 2
fbits,f = parents[0]
gbits,g = parents[1]
assert fbits == gbits
if op[0] == 'equal': return 1,int(f == g)
if op[0] == 'xor': return fbits,f^g
if op[0] == 'and': return fbits,f&g
if op[0] == 'or': return fbits,f|g
if op[0] == 'add': return fbits,(f+g)%(1<<fbits)
if op[0] == 'sub': return fbits,(f-g)%(1<<fbits)
if op[0] == 'mul': return fbits,(f*g)%(1<<fbits)
flip = 1<<(fbits-1)
if op[0] == 'unsignedge': return 1,int(f >= g)
if op[0] == 'unsignedgt': return 1,int(f > g)
if op[0] == 'unsignedle': return 1,int(f <= g)
if op[0] == 'unsignedlt': return 1,int(f < g)
if op[0] == 'signedle': return 1,int(f^flip <= g^flip)
if op[0] == 'signedlt': return 1,int(f^flip < g^flip)
if op[0] == 'signedmin': return fbits,min(f^flip,g^flip)^flip
if op[0] == 'signedmax': return fbits,max(f^flip,g^flip)^flip
if op[0] == 'lshift':
assert g<fbits # XXX
return fbits,(f<<g)%(1<<fbits)
if op[0] == 'unsignedrshift':
assert g<fbits # XXX
return fbits,f>>g
if op[0] == 'signedrshift':
assert g<fbits # XXX
fsigned = (f^flip)-flip
gsigned = (g^flip)-flip
assert 0 <= gsigned and gsigned < fbits
fgsigned = fsigned>>gsigned
return fbits,(fgsigned+flip)^flip
raise Exception('internal error')
if op[0] == 'If':
assert len(parents) == 3
cbits,c = parents[0]
fbits,f = parents[1]
gbits,g = parents[2]
assert fbits == gbits
assert cbits == 1
if c: return fbits,f
return gbits,g
if op[0] == 'Extract':
top,bot = op[1:]
assert len(parents) == 1
fbits,f = parents[0]
assert fbits > top
assert top >= bot
assert bot >= 0
return top+1-bot,((f&((2<<top)-1))>>bot)
if op[0] in ('SignExt','ZeroExt'):
ext = op[1]
assert len(parents) == 1
fbits,f = parents[0]
if op[0] == 'SignExt' and f&(1<<(fbits-1)):
return fbits+ext,f+(1<<(fbits+ext))-(1<<fbits)
return fbits+ext,f
if op[0] == 'Concat':
pos,value = 0,0
for fbits,f in reversed(parents):
pos,value = pos+fbits,value+(f<<pos)
return pos,value
raise Exception(f'unknown operation {op}')
# ===== constant patterns (sometimes depending on bit size)
# XXX: make the UI more regular; e.g. bit0,bit1 can be handled via Bits
constpatterns = {
'bit0': (lambda b:0),
'bit1': (lambda b:1),
'const5': (lambda b:5),
'const26': (lambda b:26),
'const27': (lambda b:27),
'const31': (lambda b:31),
'0star1': (lambda b:1),
'0star': (lambda b:0),
'01star': (lambda b:2**(b-1)-1),
'10star': (lambda b:2**(b-1)),
'1star': (lambda b:2**b-1),
}
# ===== peephole optimizations
def peephole_init():
with open('rules') as f:
rules = f.read()
def group(name,x):
def action(r): return name,*r
return x.copy().set_parse_action(action)
def fun(name,*args):
what = lit(name)
if len(args) > 0:
what += lparen
for pos,arg in enumerate(args):
if pos > 0: what += comma
what += arg
what += rparen
return group(name,what)
lparen = lit('(')
rparen = lit(')')
comma = lit(',')
Term = Forward()
Termvar = group('var',Word('ABCDEFGHIJKLMNOPQRSTUVWXYZ'))
X = lparen+Term+rparen
X |= group('assign',Termvar+lit('=')+Term)
X |= fun('Bits',Word(nums),Term)
X |= group('Concat',lit('Concat')+lparen+Term+ZeroOrMore(comma+Term)+rparen)
X |= fun('Extract',Term,Word(nums),Word(nums))
X |= fun('ZeroExt',Term,Word(nums))
for c in constpatterns:
X |= fun(c)
for op in 'invert','Reverse':
X |= fun(op,Term)
for op in 'add','sub','and','xor','or','equal','unsignedgt','unsignedge','unsignedlt','unsignedle','signedlt','signedle','signedmin','signedmax','lshift','unsignedrshift','signedrshift':
X |= fun(op,Term,Term)
X |= fun('If',Term,Term,Term)
X |= Termvar
Term <<= X
result = []
for line in rules.splitlines():
line = line.split('->')
if len(line) != 2: continue
try:
left = (Term+StringEnd()).parse_string(line[0])[0]
except:
raise Exception(f'failure to parse {line[0]}')
try:
right = (Term+StringEnd()).parse_string(line[1])[0]
except:
raise Exception(f'failure to parse {line[1]}')
result.append((left,right))
return result
peephole_list = peephole_init()
def op2topop(operation):
if operation[0] in ('Extract','ZeroExt','SignExt'): return tuple(operation)
return operation[0]
def pattern2topop(pattern):
if len(pattern) == 1:
return 'constant'
if len(pattern) == 2:
assert pattern[0] != 'var'
return pattern[0]
if pattern[0] in ('ZeroExt','SignExt'):
return pattern[0],int(pattern[2])
if pattern[0] == 'Extract':
return pattern[0],int(pattern[2]),int(pattern[3])
if pattern[0] in ('assign','Bits'):
return pattern2topop(pattern[2])
return pattern[0]
peephole_by_op = {}
for left,right in peephole_list:
op = pattern2topop(left)
if op not in peephole_by_op: peephole_by_op[op] = []
peephole_by_op[op].append((left,right))
# returns tuple of new match dictionaries compatible with mforce
def peephole_match(pattern,newoperation,newparents,newbits,var=None,mforce={}):
if len(pattern) == 1:
if newoperation[0] != 'constant': return ()
if pattern[0] in ('bit0','bit1') and newbits != 1: return ()
if newoperation[1] != newbits: return ()
if newoperation[2] != constpatterns[pattern[0]](newbits): return ()
return {},
if len(pattern) == 2:
if pattern[0] == 'var':
if var is None: return ()
if mforce.get(pattern[1],var) != var: return ()
return {pattern[1]:var},
if newoperation[0] == pattern[0]:
p, = newparents
return peephole_match(pattern[1],operation[p],parents[p],bits[p],var=p,mforce=mforce)
return ()
if pattern[0] == 'assign':
assert pattern[1][0] == 'var'
if var is None: return ()
V = pattern[1][1]
results = ()
for m in peephole_match(pattern[2],newoperation,newparents,newbits,var=var,mforce=mforce):
if V in m: continue
m[V] = var
results += m,
return results
if pattern[0] == 'Bits':
if newbits != int(pattern[1]): return ()
return peephole_match(pattern[2],newoperation,newparents,newbits,var=var,mforce=mforce)
if newoperation[0] != pattern[0]: return ()
if pattern[0] == 'Extract':
if newoperation != (pattern[0],int(pattern[2]),int(pattern[3])): return ()
if len(newparents) != 1: return ()
p, = newparents
return peephole_match(pattern[1],operation[p],parents[p],bits[p],var=p,mforce=mforce)
if pattern[0] == 'ZeroExt': # would also work for SignExt
if newoperation != (pattern[0],int(pattern[2])): return ()
if len(newparents) != 1: return ()
p, = newparents
return peephole_match(pattern[1],operation[p],parents[p],bits[p],var=p,mforce=mforce)
if len(newparents) != len(pattern[1:]): return ()
todo = [newparents]
if pattern[0] in commutative:
todo += [[newparents[1],newparents[0]]]
results = ()
for parentlist in todo:
partialresults = {},
# conceptually: match each parent in turn, extending partial result
# but there can be multiple matches so have list of partial results
for subpattern,p in zip(pattern[1:],parentlist):
newpartialresults = ()
for m in partialresults:
submforce = mforce.copy()
submforce.update(m)
for subm in peephole_match(subpattern,operation[p],parents[p],bits[p],var=p,mforce=submforce):
assert all(m[v] == subm[v] for v in subm if v in m)
subm.update(m)
newpartialresults += subm,
partialresults = newpartialresults
results += partialresults
return results
def newvalue(newoperation,newparents):
global nextvalue
newconcrete = calcconcrete(newoperation,[concrete[x] for x in newparents])
newbits = newconcrete[0]
while True:
if newoperation[0] in commutative: newparents = tuple(sorted(newparents))
# is this value already in the graph?
oppa = tuple(newoperation),tuple(newparents)
if oppa in oppa2value:
v = oppa2value[oppa]
assert newbits == bits[v]
assert newconcrete == concrete[v]
return v
# is this value an orphan?
if len(newparents) == 0:
nextvalue += 1
operation[nextvalue] = newoperation
parents[nextvalue] = newparents
bits[nextvalue] = newbits
concrete[nextvalue] = newconcrete
oppa = tuple(newoperation),tuple(newparents)
assert oppa not in oppa2value # otherwise would have merged
oppa2value[oppa] = nextvalue
return nextvalue
assert newoperation[0] not in ('constant','input')
# constant propagation
if all(operation[p][0] == 'constant' for p in newparents):
newoperation = 'constant',newconcrete[0],newconcrete[1]
newparents = ()
continue
# peephole from DSL (note: there are also a few non-DSL rewrites below)
topop = op2topop(newoperation)
for lrpos,(left,right) in enumerate(peephole_by_op.get(topop,[])):
mtuple = peephole_match(left,newoperation,newparents,newbits)
if len(mtuple) == 0: continue
if lrpos > 0:
lrlist = peephole_by_op[topop]
lrlist = lrlist[lrpos:lrpos+1]+lrlist[:lrpos]+lrlist[lrpos+1:]
peephole_by_op[topop] = lrlist
m = mtuple[0] # or could heuristically pick "best" match
result = newvaluepattern(right,m,forcebits=newbits)
if bits[result] != newbits:
raise Exception(f'replacing {left} with {right} for {newoperation}({newparents}) substituting {m} changed sizes from {newbits} to {bits[result]}')
if concrete[result] != newconcrete:
raise Exception(f'replacing {left} with {right} for {newoperation}({newparents}) substituting {m} changed concrete values from {newconcrete} to {concrete[result]}')
oppa = tuple(newoperation),tuple(newparents)
assert oppa not in oppa2value # otherwise would have merged
oppa2value[oppa] = result
return result
if newoperation[0] == 'Extract':
w = newparents[0]
# rewrite invert(s)[high:low] as invert(s[high:low])
if operation[w][0] == 'invert':
newparents = newvalue(newoperation,parents[w]),
newoperation = 'invert',
assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents])
continue
# rewrite s[high:low][high2:low2] as s[high2+low:low2+low]
if operation[w][0] == 'Extract':
s = parents[w][0]
high2,low2 = newoperation[1:]
high,low = operation[w][1:]
newoperation = 'Extract',high2+low,low2+low
newparents = s,
assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents])
continue
# rewrite Reverse(s)[high:low] as Reverse(s[...]) if possible
if operation[w][0] == 'Reverse':
high,low = newoperation[1:]
if low%8 == 0 and high%8 == 7:
s = parents[w][0]
newparents = newvalue(('Extract',bits[s]-1-low,bits[s]-1-high),(s,)),
newoperation = 'Reverse',
assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents])
continue
# rewrite Concat(...)[high:low] as ...[high-pos:low-pos] if possible
progress = False
if operation[w][0] == 'Concat':
high,low = newoperation[1:]
pos = 0
for y in reversed(parents[w]):
if pos <= low and high < pos+bits[y]:
newoperation = 'Extract',high-pos,low-pos
newparents = y,
assert newconcrete == calcconcrete(newoperation,[concrete[p] for p in newparents])
progress = True
break
pos += bits[y]
if progress: continue
# no further rewrites
nextvalue += 1
operation[nextvalue] = newoperation
parents[nextvalue] = newparents
bits[nextvalue] = newbits
concrete[nextvalue] = newconcrete
oppa = tuple(newoperation),tuple(newparents)
assert oppa not in oppa2value # otherwise would have merged
oppa2value[oppa] = nextvalue
return nextvalue
# create new value according to pattern
# with m describing substitutions of pattern variables
def newvaluepattern(pattern,m,forcebits=None):
if pattern[0] == 'var':
return m[pattern[1]]
if pattern[0] in constpatterns:
assert forcebits is not None
if pattern[0] in ('bit0','bit1'): assert forcebits == 1
newoperation = 'constant',forcebits,constpatterns[pattern[0]](forcebits)
newparents = ()
return newvalue(newoperation,newparents)
if pattern[0] == 'Bits':
if forcebits is None: forcebits = int(pattern[1])
assert forcebits == int(pattern[1])
result = newvaluepattern(pattern[2],m,forcebits=forcebits)
assert forcebits == bits[result]
return result
if pattern[0] == 'ZeroExt':
morebits = int(pattern[2])
subforcebits = None if forcebits is None else forcebits-morebits
newoperation = 'ZeroExt',morebits
newparents = newvaluepattern(pattern[1],m,forcebits=subforcebits),
return newvalue(newoperation,newparents)
if pattern[0] == 'Extract':
top,bot = map(int,pattern[2:])
if forcebits is not None: assert forcebits == 1+top-bot
newoperation = 'Extract',top,bot
newparents = newvaluepattern(pattern[1],m),
return newvalue(newoperation,newparents)
newoperation = pattern[0],
if pattern[0] == 'If':
parentcond = newvaluepattern(pattern[1],m,forcebits=1)
parentthen = newvaluepattern(pattern[2],m,forcebits=forcebits)
parentelse = newvaluepattern(pattern[3],m,forcebits=forcebits)
newparents = parentcond,parentthen,parentelse
elif pattern[0] == 'Concat':
newparents = tuple(newvaluepattern(p,m) for p in pattern[1:])
elif pattern[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt'):
newparents = tuple(newvaluepattern(p,m) for p in pattern[1:])
elif pattern[0] in ('signedmin','signedmax','xor','and','or','invert','add','sub','signedrshift'):
newparents = tuple(newvaluepattern(p,m,forcebits=forcebits) for p in pattern[1:])
else:
raise Exception(f'unrecognized operation on right side of rule: {pattern[0]}')
return newvalue(newoperation,newparents)
# ===== run through program, optimizing each step as it appears
for v in inputname:
value[v] = newvalue(('input',v),())
for p in program:
if p[1] in value:
raise Exception(f'{p[1]} assigned twice')
if p[0] == 'copy':
value[p[1]] = value[p[2]]
continue
if p[0] == 'constant':
newoperation = p[0],int(p[2]),int(p[3])
newparents = ()
elif p[0] == 'Extract':
newoperation = p[0],int(p[3]),int(p[4])
newparents = value[p[2]],
elif p[0] in ('SignExt','ZeroExt'):
newoperation = p[0],int(p[3])
newparents = value[p[2]],
elif p[0] in ('unsignedge','unsignedgt','unsignedle','unsignedlt','signedle','signedlt','equal','invert','Reverse','xor','or','and','add','sub','mul','lshift','unsignedrshift','signedrshift','If','Concat'):
newoperation = p[0],
newparents = tuple(value[v] for v in p[2:])
else:
raise Exception(f'unknown internal operation {p[0]}')
value[p[1]] = newvalue(newoperation,newparents)
concreteoutputlist = [concrete[value[v]][1] for v in outputname]
flip = 1<<(xbits-1)
signedconcreteinputlist = [(u^flip)-flip for u in concreteinputlist]
signedconcreteoutputlist = [(u^flip)-flip for u in concreteoutputlist]
if signedconcreteoutputlist != sorted(signedconcreteinputlist):
raise Exception(f'example output {signedconcreteoutputlist} is not sorted input {sorted(signedconcreteinputlist)}')
# ===== output
done = set()
finaloperations = set()
def do(v):
if v in done: return
done.add(v)
for x in parents[v]: do(x)
if operation[v][0] == 'input':
print(f'v{v} = {operation[v][1]}')
else:
p = [f'v{x}' for x in parents[v]]
p += [f'{x}' for x in operation[v][1:]]
finaloperations.add(operation[v][0])
args = ','.join(p)
print(f'v{v} = {operation[v][0]}({args})')
for v in outputname:
do(value[v])
print(f'{v} = v{value[v]}')
sys.stdout.flush()
if any(op not in ('signedmin','signedmax') for op in finaloperations):
raise Exception(f'resulting operations {sorted(finaloperations)} are not just signedmin, signedmax')
# decompose will also refuse to parse this, but error here is more informative