#!/usr/bin/env python3 # the code is a bit Lispy, because it was rewritten from Racket... # full version: maximize package versions, use MUS if conflict import my_utils, SAT_lib, sys, json from collections import defaultdict path=sys.argv[1] deps=my_utils.string_keys_to_integers(json.load(open(path+"/deps.json"))) vers=my_utils.string_keys_to_integers(json.load(open(path+"/vers.json"))) conflicts=json.load(open(path+"/conflicts.json")) packages_total=max(vers.keys())+1 def collect_from_solution_recursively (s, vars, sol, p): if p in sol.keys(): return # already in for v in deps[p].keys(): if (s.get_var_from_solution(vars[p][v])): sol[p]=v [collect_from_solution_recursively (s, vars, sol, dep[0]) for dep in deps[p][v]] def get_first_solution(initial_pkgs, get_MUS): print ("get_first_solution, get_MUS=", get_MUS) print ("initial_pkgs:", initial_pkgs) if get_MUS: s=SAT_lib.SAT_lib(SAT_solver="picomus") else: s=SAT_lib.SAT_lib(SAT_solver="libpicosat") vars=defaultdict(dict) for p in range(packages_total): for v in vers[p]: vars[p][v]=s.create_var() if get_MUS: fix_clauses_packages=defaultdict(set) IMPLY_clauses_packages=defaultdict(set) conflicts_clauses_packages=defaultdict(set) # each variable is one-hot [s.AtMost1(list(vars[p].values())) for p in range(packages_total)] # ver_lo/ver_high can accept values like 0000 and 9999: def version_range_to_list_of_vars(pkg, ver_lo, ver_high): rt=[] # key=version, val=(str) SAT var vers=vars[pkg] # find 1st element >= ver_lo x=my_utils.find_1st_elem_GE (list(vers.keys()), ver_lo) assert x!=None # find 1st element <= ver_high y=my_utils.find_1st_elem_LE (list(vers.keys()), ver_high) assert y!=None return list(range(vers[x], vers[y]+1)) for p in range(0, packages_total): for v in deps[p].keys(): tmp=[s.OR_list(version_range_to_list_of_vars(dep[0], dep[1], dep[2])) for dep in deps[p][v]] if len(tmp)>0: if get_MUS: clause_start=s.CNF_next_idx+1 s.IMPLY_always(vars[p][v], s.AND_list(tmp)) if get_MUS: clause_stop=s.CNF_next_idx-1+1 #print ("IMPLY", p, v, "clauses=[", clause_start, clause_stop, "]") [IMPLY_clauses_packages[c].add(p) for c in range(clause_start, clause_stop+1)] for pkg in initial_pkgs: assert pkg < packages_total if get_MUS: clause_start=s.CNF_next_idx+1 SAT_vars_to_be_ORed=[vars[pkg][v] for v in deps[pkg].keys()] s.fix(s.OR_list(SAT_vars_to_be_ORed), True) if get_MUS: clause_stop=s.CNF_next_idx-1+1 #print ("s.fix", pkg, "clauses=[", clause_start, clause_stop, "]") [fix_clauses_packages[c].add(pkg) for c in range(clause_start, clause_stop+1)] # add conflicts for conflict in conflicts: if get_MUS: clause_start=s.CNF_next_idx+1 c1, c2 = conflict[0], conflict[1] vars1=version_range_to_list_of_vars(c1[0], c1[1], c1[2]) vars2=version_range_to_list_of_vars(c2[0], c2[1], c2[2]) s.fix(s.NAND(s.OR_list(vars1), s.OR_list(vars2)), True) # not both if get_MUS: clause_stop=s.CNF_next_idx-1+1 #print ("conflict between", c1, c2, "clauses=[", clause_start, clause_stop, "]") for c in range(clause_start, clause_stop+1): conflicts_clauses_packages[c].add(c1[0]) conflicts_clauses_packages[c].add(c2[0]) if get_MUS: print ("running picomus") MUS_clauses, MUS_vars=s.get_MUS_vars() fix_clauses_packages_out=set() IMPLY_clauses_packages_out=set() conflicts_clauses_packages_out=set() for c in MUS_clauses: if c in fix_clauses_packages.keys(): fix_clauses_packages_out.update(fix_clauses_packages[c]) if c in IMPLY_clauses_packages.keys(): IMPLY_clauses_packages_out.update(IMPLY_clauses_packages[c]) if c in conflicts_clauses_packages.keys(): conflicts_clauses_packages_out.update(conflicts_clauses_packages[c]) print ("conflicted packages (set by user):", sorted(list(fix_clauses_packages_out))) print ("conflicted packages (IMPLY):", sorted(list(IMPLY_clauses_packages_out))) print ("conflicted packages (from conflicts.json):", sorted(list(conflicts_clauses_packages_out))) exit(0) print ("going to run solver") if s.solve(): print ("SAT") sol={} [collect_from_solution_recursively (s, vars, sol, p) for p in initial_pkgs] return s, vars, sol else: print ("UNSAT") return None, None, None packages_to_install=my_utils.list_of_strings_to_list_of_ints(sys.argv[2:]) # fix packages_to_install. get initial solution s, vars, solution=get_first_solution (packages_to_install, get_MUS=False) if solution==None: get_first_solution (packages_to_install, get_MUS=True) exit(0) print ("first solution:") print ("; ".join([str(s)+":"+str(solution[s]) for s in sorted(solution.keys())])+"; ") fixed_versions={} def find_max_version_for_package(solution, s, vars, pkg, fixed_versions): found=False for v in vers[pkg][::-1]: # from highest to lowest print ("trying version", v, "for package", pkg) if len(fixed_versions)>0: print ("fixed_versions:", fixed_versions) SAT_var=vars[pkg][v] s.assume(SAT_var) # assume it's true res=s.solve() if res==False: print ("UNSAT") continue else: # new solution solution={} [collect_from_solution_recursively (s, vars, solution, p) for p in packages_to_install] print ("SAT") t="; ".join([str(s)+":"+str(solution[s]) for s in sorted(solution.keys())])+"; " print ("solution: "+t) # we stop here found=True fixed_versions[pkg]=v s.fix(SAT_var, True) # fix this package/version until the end return solution assert found while True: # find a package that present in solution, but not in fixed_versions # (by computing set difference) in_solution_but_not_in_fixed_versions=set(solution.keys()) - set(fixed_versions.keys()) # if we can't find one, finish if len(in_solution_but_not_in_fixed_versions)==0: s="; ".join([str(s)+":"+str(solution[s]) for s in sorted(solution.keys())])+"; " print ("final solution: "+s) exit(0) else: # found # pick a package with highest number # (we start enumerating all versions at the highest package number) pkg=max(in_solution_but_not_in_fixed_versions) solution=find_max_version_for_package(solution, s, vars, pkg, fixed_versions)