#!/usr/bin/env python3

from z3 import *
import struct

# knobs:

# CRC-16 on https://www.lammertbies.nl/comm/info/crc-calculation.html
#width=16
#samples=["\x01", "\x02"]
#must_be=[0xC0C1, 0xC181]
#sample_len=1

# CRC-16 (Modbus) on https://www.lammertbies.nl/comm/info/crc-calculation.html
#width=16
#samples=["\x01", "\x02"]
#must_be=[0x807E, 0x813E]
#sample_len=1

# crc-32, popular parameters
width=32
samples=["\x01", "\x02"]
must_be=[0xA505DF1B, 0x3C0C8EA1]
sample_len=1

# crc64_jones.c (redis):
#width=64
#samples=["\x01", "\x02"]
#must_be=[0x7ad870c830358979, 0xf5b0e190606b12f2]
#sample_len=1

# crc64_xz.c:
#width=64
#samples=["\x01", "\x02"]
#must_be=[0xac83edcd67c06036, 0xeb299724cc279f02]
#sample_len=1

# http://www.unit-conversion.info/texttools/crc/
#width=32
#samples=["0","1"]
#must_be=[0xf4dbdf21, 0x83dcefb7]
#sample_len=1

# recipe-259177-1.py, CRC-64-ISO
#width=64
#samples=["\x01", "\x02"]
#must_be=[0x01B0000000000000, 0x0360000000000000]
#sample_len=1

# recipe-259177-1.py, CRC-64-ISO
# many solutions!
#width=64
#samples=["\x01"]
#must_be=[0x01B0000000000000]
#sample_len=1

# crc-32, popular parameters
# slower
#width=32
#samples=["12","ab"]
#must_be=[0x4F5344CD, 0x9E83486D]
#sample_len=2

# crc64 ecma 182
#width=64
#samples=["\x01", "\x02"]
#must_be=[0x42f0e1eba9ea3693, 0x85e1c3d753d46d26]
#sample_len=1

# not z3 function (see underscore)
def LShR_(x, cnt):
    return x>>cnt

def bitrev8(x):
    x = LShR_(x, 4) | (x << 4)
    x = LShR_(x & 0xCC, 2) | ((x & 0x33) << 2)
    x = LShR_(x & 0xAA, 1) | ((x & 0x55) << 1)
    return x

# these "unoptimized" versions are constructed like a Russian doll...

def bitrev16_unoptimized(x):
    return (bitrev8(x & 0xff) << 8) | (bitrev8(LShR_(x, 8)))

def bitrev32_unoptimized(x):
    return (bitrev16_unoptimized(x & 0xffff) << 16) | (bitrev16_unoptimized(LShR_(x, 16)))

def bitrev64_unoptimized(x):
    # both versions must work:
    return (bitrev32_unoptimized(x & 0xffffffff) << 32) | bitrev32_unoptimized(LShR_(x, 32))
    #return (bitrev32(x & 0xffffffff) << 32) | bitrev32(LShR(x, 32))

def bitrev(width, val):
    if width==64:
        return bitrev64_unoptimized(val)
    if width==32:
        return bitrev32_unoptimized(val)
    if width==16:
        return bitrev16_unoptimized(val)
    raise AssertionError

mask=2**width-1
poly=BitVec('poly', width)

# states[sample][0][8] is an initial state
# ...
# states[sample][i][0] is a state where it was already XORed with input bit
# states[sample][i][1] ... where the 1th shift/XOR operation has been done
# states[sample][i][8] ... where the 8th shift/XOR operation has been done
# ...
# states[sample][sample_len][8] - final state

states=[[[BitVec('state_%d_%d_%d' % (sample, i, bit), width) for bit in range(8+1)] for i in range(sample_len+1)] for sample in range(len(samples))]
s=Solver()

reflect_in=Bool("reflect_in")

def invert(val):
    return ~val & mask

for sample in range(len(samples)):
    # initial state can be either zero or -1:
    s.add(Or(states[sample][0][8]==mask, states[sample][0][8]==0))

    # implement basic CRC algorithm
    for i in range(sample_len):
        x=BitVecVal(ord(samples[sample][i]), width) # ReflectIn True
        y=BitVecVal(bitrev8(ord(samples[sample][i])), width) # ReflectIn False
        s.add(states[sample][i+1][0] == (states[sample][i][8] ^ If(reflect_in, x, y)))

        for bit in range(8):
            # LShR() is logical shift, while >> is arithmetical shift, we use the first:
            s.add(states[sample][i+1][bit+1] == LShR(states[sample][i+1][bit],1) ^ If(states[sample][i+1][bit]&1==1, poly, 0))

    # final state must be equal to one of these:
    s.add(Or(
    states[sample][sample_len][8]==must_be[sample],
    states[sample][sample_len][8]==invert(must_be[sample]),
    states[sample][sample_len][8]==bitrev(width, must_be[sample]),
    states[sample][sample_len][8]==invert(bitrev(width, must_be[sample]))))

# get all possible results:
results=[]
while True:
    if s.check() == sat:
        m = s.model()
        # what final state was?
        if m[states[0][sample_len][8]].as_long()==must_be[0]:
            outparams="XORout=0, ReflectOut=True"
        elif invert(m[states[0][sample_len][8]].as_long())==must_be[0]:
            outparams="XORout=-1, ReflectOut=True"
        elif m[states[0][sample_len][8]].as_long()==bitrev(width, must_be[0]):
            outparams="XORout==0, ReflectOut=False"
        elif invert(m[states[0][sample_len][8]].as_long())==bitrev(width, must_be[0]):
            outparams="XORout=-1, ReflectOut=False"
        else:
            raise AssertionError

        print ("poly=0x%x, init=0x%x, ReflectIn=%s, %s" % (m[poly].as_long(), m[states[0][0][8]].as_long(), str(m[reflect_in]), outparams))

        results.append(m)
        block = []
        for d in m:
            c=d()
            block.append(c != m[d])
        s.add(Or(block))
    else:
        print ("total results", len(results))
        break
