#!/usr/bin/env python3

# problem: libcadical consumes too much memory enumerating exact covers for DEK square (e)

import math, sys, itertools, pickle, os, random, copy, time
import SAT_lib, my_utils, latin_utils
import argparse

from typing import List
from typing import Any

parser = argparse.ArgumentParser()
parser.add_argument("--size", type=int, default=8, help="... or order. Default=8")
# Kissat requires 31-bit seed:
parser.add_argument("--seed", type=int, default=my_utils.get_32_bits_from_urandom()&0x7fffffff, help="Seed. Default=random")
parser.add_argument("--normalize", help="Normalize/reduce (before search). Default=false", action="store_true")
parser.add_argument("--mode", type=str, nargs="?", default="enum", help="Can be enum/restart. Default=enum")
parser.add_argument("--randomize", help="Randomize LS before search (in 'restart' mode). Default=false", action="store_true")
parser.add_argument("--all-solutions", default=False, help="Produce all possible solutions. Default=false", action="store_true")
tmp=" ".join(SAT_lib.SAT_solvers_supported)
parser.add_argument("--solver", default="libcadical", help="Set SAT solver. Supported: "+tmp)
#parser.add_argument("--solver", default="kissat", help="Set SAT solver. Supported: "+tmp)
# verbose for SAT_lib:
parser.add_argument("-v", action="count", help="Increase verbosity. May be -vv, -vvv, -vvvv")
parser.add_argument("--MOLS3", help="Find also 3-MOLS. Default=false", action="store_true")
parser.add_argument("--first", type=str, nargs="?", default=None, help="Set first square (in short form)")
# it produces proof that there is no 2-MOLS or if no more 2-MOLS exist
parser.add_argument("--proof", help="Generate proof (during exact cover stage). Default=false", action="store_true")

args=parser.parse_args()

# https://stackoverflow.com/questions/4042452/display-help-message-with-python-argparse-when-script-is-called-without-any-argu
parser.print_help()
print ("")

_seed=args.seed
print ("Setting seed", _seed)
random.seed(_seed)

SIZE=args.size
print ("Setting size", SIZE)

print ("Setting mode", args.mode)

if args.normalize:
    print ("Normalize mode")
else:
    print ("Non-normalize mode")

if args.proof:
    print ("Proof will be produced (during exact cover stage)")

if args.v!=None:
    verbosity=args.v
else:
    verbosity=0

SAT_solver=args.solver
print ("Setting solver", SAT_solver)

start_time=int(time.time())

def find_transversals_sets(square):
    transversals_total=latin_utils.find_transversals(square, None)
    print ("transversals_total", transversals_total)
    if transversals_total==0:
        print ("No transversals, nothing to do.")
        return
    order=len(square)
    """
    if order==9 and transversals_total<310:
        print ("Small number of transversals, skipping.")
        return
    if order==10 and transversals_total<920:
        print ("Small number of transversals, skipping.")
        return
    """
    x=[]
    transversals_total=latin_utils.find_transversals(square, x)
    assert transversals_total==len(x)
    order=len(square)
    squares_total=order**2
    m=[[False for _ in range(squares_total)] for _ in range(transversals_total)]
    # prepare data for exact cover problem
    for t_i, t in enumerate(x):
        # t = list of row/col/symbol
        for sq in t:
            r, c, symb=sq
            sq_n=r*order+c
            m[t_i][sq_n]=True

    # row - transversal
    # col - square
    #print (m)
    print ("Going to solve exact cover problem")
    for transversals_set_i, transversals_set in enumerate(SAT_lib.solve_exact_cover(m, SAT_solver=SAT_solver, proof=args.proof, threads=10)):
        assert len(transversals_set)==order
        print ("transversals_set_i", transversals_set_i)
        transversal_i=[[None for _ in range(order)] for _ in range(order)]
        transversal_set_coords=[]
        for i, t in enumerate(transversals_set):
            #print ("i, t", i, t)
            transversal_coords=[]
            for coord in x[t]:
                r, c, symb=coord
                transversal_i[r][c]=i
                transversal_coords.append((r,c))
            transversal_set_coords.append(transversal_coords)

        yield transversal_i, transversal_set_coords

def print_LS_with_colored_transversals(sq, transversal_i, transversal_set_coords):
    order=len(sq)
    for r in range(order):
        for c in range(order):
            print (latin_utils.ANSI_set_foreground_color_2(transversal_i[r][c]+1), end="")
            print (latin_utils.cell_to_str(sq[r][c], order), end="")
            print (my_utils.ANSI_reset(), end="")
            print (" ", end="")
        print ("")

def find_mate_for_square_with_transversals(first, transversal_set_coords):
    order=len(first)
    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, threads=10)

    a=[[s.alloc_BV(order) for c in range(order)] for r in range(order)]
    b=[[s.alloc_BV(order) for c in range(order)] for r in range(order)]

    latin_utils.latin_add_constraints(order, s, a)
    latin_utils.latin_add_constraints(order, s, b)

    latin_utils.make_mutually_orthogonal(order, s, a, b)

    latin_utils.fix_square_to_hardcoded(s, a, first, order)

    # setting first column of mate (b) to [0..9]
    latin_utils.fix_first_col_increasing(s, b, order)

    # symbols withing each transversal must be the same
    for transversal_i, transversal in enumerate(transversal_set_coords):
        variables=[]
        for coord in transversal:
            r, c=coord
            variables.append(b[r][c])
        s.make_all_BVs_EQ (variables)

    # TODO: check for other solutions that may be possible (?)
    if s.solve():
        print ("*** find_mate_for_square_with_transversals()")
        print ("find_mate_for_square_with_transversals() (SAT)")
        print ("First:")
        latin_utils.print_LS_from_SAT_vars(s, a)
        if order<=36:
            print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(latin_utils.get_square(s, a)))
        print ("")
        print ("Mate:")
        latin_utils.print_LS_from_SAT_vars(s, b)
        if order<=36:
            print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(latin_utils.get_square(s, b)))
        print ("")
        print ("Concatenated:")
        latin_utils.print_MOLS2_from_SAT_vars(s, a, b)
        print ("")
        sys.stdout.flush()
        rt=latin_utils.get_square(s, b)
        s.deinit()

        return rt
    else:
        print ("find_mate_for_square_with_transversals() UNSAT")

    print ("find_mate_for_square_with_transversals() stop")
    s.deinit()
    return None

def find_3MOLS(SIZE, first, second):
    print ("find_3MOLS() start")

    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, threads=10)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]
    b=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]
    c=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)
    latin_utils.latin_add_constraints(SIZE, s, b)
    latin_utils.latin_add_constraints(SIZE, s, c)

    latin_utils.make_mutually_orthogonal(SIZE, s, a, b)
    latin_utils.make_mutually_orthogonal(SIZE, s, a, c)
    latin_utils.make_mutually_orthogonal(SIZE, s, b, c)

    if args.normalize:
        latin_utils.fix_first_row_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, b, SIZE)
        latin_utils.fix_first_col_increasing(s, c, SIZE)

    if first==None:
        print ("find_3MOLS() warning first square is not set. will be found during search.")
    else:
        print ("find_3MOLS() setting first square as hardcoded")
        latin_utils.fix_square_to_hardcoded(s, a, first, SIZE)

    if second==None:
        print ("find_3MOLS() warning second square/mate is not set. will be found during search.")
    else:
        print ("find_3MOLS() setting second square/mate as hardcoded")
        latin_utils.fix_square_to_hardcoded(s, b, second, SIZE)

    if s.solve():
        print ("find_3MOLS() (SAT)")
        print ("wall time:", my_utils.human(int(time.time())-start_time))

        print ("First:")
        latin_utils.print_LS_from_SAT_vars(s, a)
        if SIZE<=36:
            print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(latin_utils.get_square(s, a)))
        print ("")
        print ("Mate 1:")
        latin_utils.print_LS_from_SAT_vars(s, b)
        if SIZE<=36:
            print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(latin_utils.get_square(s, b)))
        print ("")
        print ("Mate 2:")
        latin_utils.print_LS_from_SAT_vars(s, c)
        if SIZE<=36:
            print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(latin_utils.get_square(s, c)))
        print ("")

        print ("Concatenated:")
        latin_utils.print_MOLS3_from_SAT_vars(s, a, b, c)

        sys.stdout.flush()
        exit(0)
    else:
        print ("find_3MOLS() (UNSAT)")

    s.deinit()

# FIXME: difference?!
def find_mate(first):
    order=len(first)
    mates_cnt=0
    for transversal_i, transversal_set_coords in find_transversals_sets(first):
        print_LS_with_colored_transversals(first, transversal_i, transversal_set_coords)
        mate=find_mate_for_square_with_transversals(first, transversal_set_coords)
        #print ("mate", mate)
        #print ("args.MOLS3", args.MOLS3)
        if mate!=None and args.MOLS3==False:
            exit(0)
        mates_cnt=mates_cnt+1
        if mate!=None and args.MOLS3:
            find_3MOLS(SIZE, first, mate)
        return # use only first exact cover solution!
    #print ("mates_cnt", mates_cnt)
    #return mates_cnt

def find_mates(first):
    order=len(first)
    mates=0
    for transversal_i, transversal_set_coords in find_transversals_sets(first):
        mates=mates+1
        #print ("Mate number", mates)
        print_LS_with_colored_transversals(first, transversal_i, transversal_set_coords)
        mate=find_mate_for_square_with_transversals(first, transversal_set_coords)
        if mate!=None and args.MOLS3:
            find_3MOLS(order, first, mate)
    print ("Finish. Mates total", mates)
    return mates

def get_all_solutions():
    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, verbose=verbosity, threads=10)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)

    if args.normalize:
        latin_utils.fix_first_row_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, a, SIZE)

    if args.randomize:
        latin_utils.randomize_first_square(s, SIZE, a, args.normalize)

    sol=0
    if s.solve():
        while True:
            sq=latin_utils.get_square(s, a)
            sol=sol+1
            print ("*** Solution %d" % sol)
            latin_utils.print_LS_from_SAT_vars(s, a)
            if SIZE<=36:
                print ("Short form:", latin_utils.list_of_lists_of_int_to_base36_str(sq))

            #x=[]
            #transversals_cnt=latin_utils.find_transversals(sq, x)
            #print ("transversals=%7d" % transversals_cnt)
            find_mate(sq)

            print ("")

            sys.stdout.flush()
            if args.all_solutions==False:
                break
            if s.fetch_next_solution()==False:
                break
    else:
        print ("get_all_solutions() UNSAT")

    print ("Solutions: %d" % sol)
    s.deinit()

def get_all_solutions_start_afresh():
    print ("get_all_solutions_start_afresh() start")
    s=SAT_lib.SAT_lib(seed=_seed, SAT_solver=SAT_solver, verbose=verbosity, threads=10)

    a=[[s.alloc_BV(SIZE) for c in range(SIZE)] for r in range(SIZE)]

    latin_utils.latin_add_constraints(SIZE, s, a)

    if args.normalize:
        latin_utils.fix_first_row_increasing(s, a, SIZE)
        latin_utils.fix_first_col_increasing(s, a, SIZE)

    if args.randomize:
        latin_utils.randomize_first_square(s, SIZE, a, args.normalize)

    if s.solve():
        print ("SAT")
        sq=latin_utils.get_square(s, a)
        print ("*** Solution")
        latin_utils.print_LS_from_SAT_vars (s, a)
        short_form=latin_utils.list_of_lists_of_int_to_base36_str(sq)
        if SIZE<=36:
            print ("Short form:", short_form)

        #print ("transversals=%7d" % transversals_cnt)
        find_mate(sq)
        #mates=find_mates(sq) # or use this
        #transversals_cnt=latin_utils.find_transversals(sq, None)
        #print ("transversals=%07d mates=%05d short_form=%s" % (transversals_cnt, mates, short_form))
        #print ("transversals=%07d mates=%05d" % (transversals_cnt, mates))

        sys.stdout.flush()
    else:
        print ("get_all_solutions_start_afresh() UNSAT")
    s.deinit()

if args.first!=None:
    print ("Setting first square")
    sq=args.first
    sq=latin_utils.base36_str_to_list_of_lists(sq)
    if args.normalize:
        sq=latin_utils.normalize(sq)
        print ("Normalizing first square, now it's: "+latin_utils.list_of_lists_of_int_to_base36_str(sq))
    find_mates(sq)
    exit(0)

if args.mode=="enum":
    get_all_solutions()
elif args.mode=="restart":
    while True:
        get_all_solutions_start_afresh()
        if args.all_solutions==False:
            break
else:
    print ("Error. Unknown mode:", args.mode)

