#!/usr/bin/env python3

import math, sys, itertools, pickle, os, random, copy, time
import SAT_lib, my_utils, latin_utils
import argparse

from typing import List
from typing import Any

parser = argparse.ArgumentParser()
parser.add_argument("--size", type=int, default=8, help="... or order. Default=8")
# Kissat requires 31-bit seed:
parser.add_argument("--seed", type=int, default=my_utils.get_32_bits_from_urandom()&0x7fffffff, help="Seed. Default=random")
parser.add_argument("--normalize", help="Normalize/reduce (before search)", action="store_true")
parser.add_argument("--mode", type=str, nargs="?", default="restart", help="Can be enum/count/restart. Default=restart")
parser.add_argument("--randomize", help="Randomize LS before search (in 'restart' mode). Default=false", action="store_true")
parser.add_argument("--all-solutions", help="Produce all possible solutions. Default=false", action="store_true")
tmp=" ".join(SAT_lib.SAT_solvers_supported)
parser.add_argument("--solver", default="libcadical", help="Set SAT solver. Supported: "+tmp)
#parser.add_argument("--solver", default="kissat", help="Set SAT solver. Supported: "+tmp)
# verbose for SAT_lib:
parser.add_argument("-v", action="count", help="Increase verbosity. May be -vv, -vvv, -vvvv")

args=parser.parse_args()

# https://stackoverflow.com/questions/4042452/display-help-message-with-python-argparse-when-script-is-called-without-any-argu
parser.print_help()
print ("")

_seed=args.seed
print ("Setting seed", _seed)
random.seed(_seed)

SIZE=args.size
print ("Setting size", SIZE)

print ("Setting mode", args.mode)

if args.normalize:
    print ("Normalize mode")
else:
    print ("Non-normalize mode")

if args.v!=None:
    verbosity=args.v
else:
    verbosity=0

SAT_solver=args.solver
print ("Setting solver", SAT_solver)

def only_count_solutions():
    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, verbose=verbosity, threads=6)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)

    if args.normalize:
        latin_utils.fix_first_row_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, a, SIZE)

    solutions=s.count_solutions()
    print ("Solutions", solutions)

def pick_2nd_element_from_each_tuple_in_list(x):
    rt=[]
    for y in x:
        #rt.append((y[0], y[1]))
        rt.append((y[1]))
    return rt

def get_all_solutions():
    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, verbose=verbosity, threads=6)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)

    if args.normalize:
        latin_utils.fix_first_row_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, a, SIZE)

    if args.randomize:
        latin_utils.randomize_first_square(s, SIZE, a, args.normalize)

    sol=0
    if s.solve():
        while True:
            sq=latin_utils.get_square(s, a)
            sol=sol+1
            print ("*** Solution %d" % sol)
            latin_utils.print_LS_from_SAT_vars(s, a)
            if SIZE<=36:
                print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(sq))

            transversals_total=latin_utils.count_transversals(sq)
            """
            TODO clean up and merge with the similar fragment
            y=[]
            x=latin_utils.find_transversals(sq, y) # , print_debug=True)
            assert x==transversals_total
            for z in y:
                print ("transversal (cols only) ", pick_2nd_element_from_each_tuple_in_list(z))
            """
            print ("transversals_total", transversals_total, "short_form", latin_utils.list_of_lists_of_int_to_base36_str(sq))

            print ("")

            sys.stdout.flush()
            if args.all_solutions==False:
                break
            if s.fetch_next_solution()==False:
                break
    else:
        print ("get_all_solutions() UNSAT")

    print ("Solutions: %d" % sol)
    s.deinit()

def get_all_solutions_start_afresh():
    print ("get_all_solutions_start_afresh() start")
    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, verbose=verbosity, threads=6)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)

    if args.normalize:
        latin_utils.fix_first_row_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, a, SIZE)

    if args.randomize:
        latin_utils.randomize_first_square(s, SIZE, a, args.normalize)

    if s.solve():
        print ("SAT")
        sq=latin_utils.get_square(s, a)
        print ("*** Solution")
        latin_utils.print_LS_from_SAT_vars (s, a)
        if SIZE<=36:
            print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(sq))

        transversals_total=latin_utils.count_transversals(sq)
        """
        TODO clean up and merge with the similar fragment
        y=[]
        x=latin_utils.find_transversals(sq, y)
        assert x==transversals_total
        for z in y:
            print ("transversal (cols only) ", pick_2nd_element_from_each_tuple_in_list(z))
        """
        print ("transversals_total", transversals_total, "short_form", latin_utils.list_of_lists_of_int_to_base36_str(sq))

        sys.stdout.flush()
    else:
        print ("get_all_solutions_start_afresh() UNSAT")
    s.deinit()

if args.mode=="enum":
    get_all_solutions()
elif args.mode=="count":
    only_count_solutions()
elif args.mode=="restart":
    while True:
        get_all_solutions_start_afresh()
        if args.all_solutions==False:
            break
else:
    print ("Error. Unknown mode:", args.mode)

