#!/usr/bin/env python3

# dennis(a)yurichev.com, 2017-2022

import my_utils, SAT_lib
import collections, operator, functools

def div_test():
    #s=SAT_lib.SAT_lib()
    s=SAT_lib.SAT_lib(SAT_solver="libpicosat")

    BITS=32
    divident=s.alloc_BV(BITS)
    divisor=s.alloc_BV(BITS)

    s.fix_BV(divident, SAT_lib.n_to_BV(1234567890, BITS))
    s.fix_BV(divisor, SAT_lib.n_to_BV(123, BITS))

    quotient, remainder=s.divider(divident, divisor)

    assert s.solve()==True

    assert SAT_lib.BV_to_number(s.get_BV_from_solution(quotient))==10037137
    assert SAT_lib.BV_to_number(s.get_BV_from_solution(remainder))==39

def SumIsNot1_test():
    s=SAT_lib.SAT_lib()

    _vars=s.alloc_BV(4)
    s.SumIsNot1(_vars)

    assert s.count_solutions()==12

def AND_list_test():
    s=SAT_lib.SAT_lib()

    _vars=s.alloc_BV(4)
    s.fix(s.AND_list(_vars),False)

    assert s.count_solutions()==15

def make_one_hot_tst():
    s=SAT_lib.SAT_lib()
    vars=[s.create_var() for x in range(1000)]
    s.make_one_hot(vars)
    assert s.solve()==True
    for i in range(10):
        sol=[]
        for v in vars:
            sol.append(s.get_var_from_solution(v))
        assert collections.Counter(sol)[1]==1 # only single 1
        #print "OK"
        assert s.fetch_next_solution()==True
    #assert s.count_solutions()==10

def IMPLY_test():
    #s=SAT_lib.SAT_lib(maxsat=False) # should also work
    #s=SAT_lib.SAT_lib(maxsat=True) # should also work
    s=SAT_lib.SAT_lib(SAT_solver="libpicosat") # should also work

    p=s.create_var()
    q=s.create_var()
    s.fix(s.IMPLY(p, q), True)

    # FIXME: make a function: get all solutions for specific vars...
    solutions=[]
    assert s.solve()==True
    while True:
        solutions.append ((s.get_var_from_solution(p), s.get_var_from_solution(q)))
        if s.fetch_next_solution()==False:
            break
    assert sorted(solutions)==[(0, 0), (0, 1), (1, 1)]

def XOR_lst_test_helper(inputs):
    print ("XOR_lst_test_helper(): testing for %d inputs" % inputs)

    s=SAT_lib.SAT_lib(verbose=0)

    output=s.create_var()
    inputs=[s.create_var() for _ in range(inputs)]
    #print ("output", output)
    #print ("inputs", inputs)

    s.fix_EQ(output, s.XOR_list(inputs))

    solutions=s.get_all_solutions()
    for solution in solutions:
        #print (solution)
        #print ("inputs:")
        tmp=[]
        for i in inputs:
            tmp.append(solution[i])
        #print ("output", solution[output])
        #print ("must be", functools.reduce (operator.xor, tmp))
        assert solution[output] == functools.reduce (operator.xor, tmp)

def XOR_lst_test():
    #for i in range(2, 13):
    #for i in range(2, 9):
    for i in range(2, 7):
        XOR_lst_test_helper(i)

def solve_exact_cover_test():
    # example from https://en.wikipedia.org/wiki/Exact_cover#Matrix_and_hypergraph_representations
    tst=[
    [1,0,0,1,0,0,1],
    [1,0,0,1,0,0,0],
    [0,0,0,1,1,0,1],
    [0,0,1,0,1,1,0],
    [0,1,1,0,0,1,1],
    [0,1,0,0,0,0,1],
    ]

    # as in Wikipedia
    # find set of rows where all columns will be set to '1' exactly once
    # these are (starting at 0th): 1st, 3rd, 5th
    assert next(SAT_lib.solve_exact_cover(tst, "kissat", False, 1))==[1, 3, 5]

def sorting_network_test1():
    s=SAT_lib.SAT_lib()
    buf=[s.create_var() for _ in range(5)]
    s.fix_always_true(buf[0])
    s.fix_always_false(buf[1])
    s.fix_always_false(buf[2])
    s.fix_always_true(buf[3])
    s.fix_always_false(buf[4])

    buf2=s.sorting_network(buf)

    assert s.solve()==True

    assert s.get_var_from_solution(buf2[0])==1
    assert s.get_var_from_solution(buf2[1])==1
    assert s.get_var_from_solution(buf2[2])==0
    assert s.get_var_from_solution(buf2[3])==0
    assert s.get_var_from_solution(buf2[4])==0

def sorting_network_test2():
    s=SAT_lib.SAT_lib()
    buf=[s.create_var() for _ in range(5)]
    #print (buf)

    buf2=s.sorting_network(buf)
    #print (buf2)

    s.fix_always_true(buf2[0])
    s.fix_always_true(buf2[1])
    s.fix_always_true(buf2[2])
    s.fix_always_false(buf2[3])
    s.fix_always_false(buf2[4])

    """
    assert s.solve()==True
    cnt=0
    while True:
        print ("solution", cnt)
        rt=[]
        for i in range(5):
            rt.append(s.get_var_from_solution(buf[i]))
        print (rt)
        if s.fetch_next_solution()==False:
            print ("no more solutions")
            break
        cnt=cnt+1
    """
    assert s.count_solutions()==10

sorting_network_test1()
sorting_network_test2()
div_test()
SumIsNot1_test()
AND_list_test()
make_one_hot_tst()
IMPLY_test()
XOR_lst_test()
solve_exact_cover_test()

