#!/usr/bin/env python3

# the code is a bit Lispy, because it was rewritten from Racket...

# full version: maximize package versions, use MUS if conflict

import my_utils, SAT_lib, sys, json
from collections import defaultdict

path=sys.argv[1]

deps=my_utils.string_keys_to_integers(json.load(open(path+"/deps.json")))
vers=my_utils.string_keys_to_integers(json.load(open(path+"/vers.json")))
conflicts=json.load(open(path+"/conflicts.json"))

packages_total=max(vers.keys())+1

def collect_from_solution_recursively (s, vars, sol, p):
    if p in sol.keys():
        return # already in
    for v in deps[p].keys():
        if (s.get_var_from_solution(vars[p][v])):
            sol[p]=v
            [collect_from_solution_recursively (s, vars, sol, dep[0]) for dep in deps[p][v]]

def get_first_solution(initial_pkgs, get_MUS):
    print ("get_first_solution, get_MUS=", get_MUS)
    print ("initial_pkgs:", initial_pkgs)

    if get_MUS:
        s=SAT_lib.SAT_lib(SAT_solver="picomus")
    else:
        s=SAT_lib.SAT_lib(SAT_solver="libpicosat")

    vars=defaultdict(dict)
    for p in range(packages_total):
        for v in vers[p]:
            vars[p][v]=s.create_var()

    if get_MUS:
        fix_clauses_packages=defaultdict(set)
        IMPLY_clauses_packages=defaultdict(set)
        conflicts_clauses_packages=defaultdict(set)

    # each variable is one-hot
    [s.AtMost1(list(vars[p].values())) for p in range(packages_total)]

    # ver_lo/ver_high can accept values like 0000 and 9999:
    def version_range_to_list_of_vars(pkg, ver_lo, ver_high):
        rt=[]
        # key=version, val=(str) SAT var
        vers=vars[pkg]
        # find 1st element >= ver_lo
        x=my_utils.find_1st_elem_GE (list(vers.keys()), ver_lo)
        assert x!=None
        # find 1st element <= ver_high
        y=my_utils.find_1st_elem_LE (list(vers.keys()), ver_high)
        assert y!=None
        return list(range(vers[x], vers[y]+1))

    for p in range(0, packages_total):
        for v in deps[p].keys():
            tmp=[s.OR_list(version_range_to_list_of_vars(dep[0], dep[1], dep[2])) for dep in deps[p][v]]
            if len(tmp)>0:
                if get_MUS:
                    clause_start=s.CNF_next_idx+1
                s.IMPLY_always(vars[p][v], s.AND_list(tmp))
                if get_MUS:
                    clause_stop=s.CNF_next_idx-1+1
                    #print ("IMPLY", p, v, "clauses=[", clause_start, clause_stop, "]")
                    [IMPLY_clauses_packages[c].add(p) for c in range(clause_start, clause_stop+1)]

    for pkg in initial_pkgs:
        assert pkg < packages_total
        if get_MUS:
            clause_start=s.CNF_next_idx+1
        SAT_vars_to_be_ORed=[vars[pkg][v] for v in deps[pkg].keys()]
        s.fix(s.OR_list(SAT_vars_to_be_ORed), True)
        if get_MUS:
            clause_stop=s.CNF_next_idx-1+1
            #print ("s.fix", pkg, "clauses=[", clause_start, clause_stop, "]")
            [fix_clauses_packages[c].add(pkg) for c in range(clause_start, clause_stop+1)]

    # add conflicts
    for conflict in conflicts:
        if get_MUS:
            clause_start=s.CNF_next_idx+1
        c1, c2 = conflict[0], conflict[1]
        vars1=version_range_to_list_of_vars(c1[0], c1[1], c1[2])
        vars2=version_range_to_list_of_vars(c2[0], c2[1], c2[2])
        s.fix(s.NAND(s.OR_list(vars1), s.OR_list(vars2)), True) # not both
        if get_MUS:
            clause_stop=s.CNF_next_idx-1+1
            #print ("conflict between", c1, c2, "clauses=[", clause_start, clause_stop, "]")
            for c in range(clause_start, clause_stop+1):
                conflicts_clauses_packages[c].add(c1[0])
                conflicts_clauses_packages[c].add(c2[0])

    if get_MUS:
        print ("running picomus")
        MUS_clauses, MUS_vars=s.get_MUS_vars()
        fix_clauses_packages_out=set()
        IMPLY_clauses_packages_out=set()
        conflicts_clauses_packages_out=set()
        for c in MUS_clauses:
            if c in fix_clauses_packages.keys():
                fix_clauses_packages_out.update(fix_clauses_packages[c])
            if c in IMPLY_clauses_packages.keys():
                IMPLY_clauses_packages_out.update(IMPLY_clauses_packages[c])
            if c in conflicts_clauses_packages.keys():
                conflicts_clauses_packages_out.update(conflicts_clauses_packages[c])

        print ("conflicted packages (set by user):", sorted(list(fix_clauses_packages_out)))
        print ("conflicted packages (IMPLY):", sorted(list(IMPLY_clauses_packages_out)))
        print ("conflicted packages (from conflicts.json):", sorted(list(conflicts_clauses_packages_out)))
        exit(0)

    print ("going to run solver")
    if s.solve():
        print ("SAT")
        sol={}
        [collect_from_solution_recursively (s, vars, sol, p) for p in initial_pkgs]
        return s, vars, sol
    else:
        print ("UNSAT")
        return None, None, None

packages_to_install=my_utils.list_of_strings_to_list_of_ints(sys.argv[2:])

# fix packages_to_install. get initial solution
s, vars, solution=get_first_solution (packages_to_install, get_MUS=False)
if solution==None:
    get_first_solution (packages_to_install, get_MUS=True)
    exit(0)

print ("first solution:")
print ("; ".join([str(s)+":"+str(solution[s]) for s in sorted(solution.keys())])+"; ")

fixed_versions={}

def find_max_version_for_package(solution, s, vars, pkg, fixed_versions):
    found=False
    for v in vers[pkg][::-1]: # from highest to lowest
        print ("trying version", v, "for package", pkg)
        if len(fixed_versions)>0:
            print ("fixed_versions:", fixed_versions)
        SAT_var=vars[pkg][v]
        s.assume(SAT_var) # assume it's true

        res=s.solve()

        if res==False:
            print ("UNSAT")
            continue
        else:
            # new solution
            solution={}
            [collect_from_solution_recursively (s, vars, solution, p) for p in packages_to_install]
            print ("SAT")
            t="; ".join([str(s)+":"+str(solution[s]) for s in sorted(solution.keys())])+"; "
            print ("solution: "+t)
            # we stop here
            found=True
            fixed_versions[pkg]=v
            s.fix(SAT_var, True) # fix this package/version until the end
            return solution

    assert found

while True:
    # find a package that present in solution, but not in fixed_versions
    # (by computing set difference)
    in_solution_but_not_in_fixed_versions=set(solution.keys()) - set(fixed_versions.keys())
    # if we can't find one, finish
    if len(in_solution_but_not_in_fixed_versions)==0:
        s="; ".join([str(s)+":"+str(solution[s]) for s in sorted(solution.keys())])+"; "
        print ("final solution: "+s)
        exit(0)
    else:
        # found
        # pick a package with highest number
        # (we start enumerating all versions at the highest package number)
        pkg=max(in_solution_but_not_in_fixed_versions)
        solution=find_max_version_for_package(solution, s, vars, pkg, fixed_versions)

