#!/usr/bin/env python3 # my own SAT/CNF library # dennis(a)yurichev.com, 2017-2023 # "BV" stands for bitvector # TODO: check signed/unsigned issues in adder/multiplier/divider # TODO: instance class? term class? # TODO: get rid of *fix* functions? from ctypes import * import subprocess, os, itertools, sys, my_utils import signal from typing import List from typing import Tuple from typing import Dict import operator, functools, math, time, copy # this slows down everything. but needed for ACL2 proof checker, AFAIR CNF_REMOVE_TAUTOLOGICAL=False CNF_REMOVE_DUPLICATES=False #USE_IPAMIR=True USE_IPAMIR=False SAT_solvers_supported=["libcadical", "kissat", "gimsatul", "minisat", "plingeling", "picosat", "libpicosat"] def one_hot_to_number(one_hot): return int(math.log(one_hot, 2)) """ Recursive version, couldn't stand my tests :( But easier to understand, as usual def is_tautological(cls:List[int]): if len(cls)==1: return False head=cls[0] rest=cls[1:] if -head in rest: return True return is_tautological(rest) """ def is_tautological(cls:List[int]): if len(cls)==1: return False for i in range(len(cls)): if -cls[i] in cls[i+1:]: return True return False # BV=[MSB...LSB] def BV_to_number(BV:List[int]) -> int: # coeff=1, 2^1, 2^2 ... 2^(len(BV)-1) coeff=1 rt=0 for v in my_utils.rvr(BV): rt=rt+coeff*v coeff=coeff*2 return rt # bit order: [MSB..LSB] # 'size' is desired width of bitvector, in bits: # return: list of false/true values! def n_to_BV (n:int, size:int) -> List[bool]: out=[False]*size i=0 for var in my_utils.rvr(list(bin(n)[2:])): if var=='1': out[i]=True else: out[i]=False i=i+1 return my_utils.rvr(out) class picosat: def __init__(self): self.lib=None #print (sys.path) for d in sys.path: fname=d+"/libpicosat.so" if os.path.exists(fname): #print (fname) self.lib=CDLL(fname) break if self.lib==None: print ("Error: Can't find libpicosat.so, I searched in: ", sys.path) exit(0) self.lib.picosat_init.restype=c_void_p self.lib.picosat_reset.argtypes=[c_void_p] self.lib.picosat_add.argtypes=[c_void_p, c_int] self.lib.picosat_add.restype=c_int # int picosat_add_lits (PicoSAT *, int * lits); self.lib.picosat_add_lits.argtypes=[c_void_p, POINTER(c_int)] self.lib.picosat_add_lits.restype=c_int self.lib.picosat_assume.argtypes=[c_void_p, c_int] self.lib.picosat_sat.argtypes=[c_void_p, c_int] self.lib.picosat_sat.restype=c_int self.lib.picosat_deref.argtypes=[c_void_p, c_int] self.lib.picosat_deref.restype=c_int self.ctx=self.lib.picosat_init() def reset (self): self.lib.picosat_reset(self.ctx) def add (self, i): return self.lib.picosat_add(self.ctx, i) def add_lits (self, lits): arg2=(c_int * len(lits))(*lits) return self.lib.picosat_add_lits(self.ctx, arg2) def assume (self, i): self.lib.picosat_assume(self.ctx, i) def sat (self): return self.lib.picosat_sat(self.ctx, -1) def deref (self, i): return self.lib.picosat_deref(self.ctx, i) class cadical: def __init__(self): self.lib=None #print (sys.path) for d in sys.path: fname=d+"/libcadical_wrapper.so" if os.path.exists(fname): #print (fname) self.lib=CDLL(fname) break if self.lib==None: print ("Error: Can't find libcadical_wrapper.so, I searched in: ", sys.path) exit(0) self.lib.ccadical_wrapper_init.restype=c_void_p self.lib.ccadical_wrapper_release.argtypes=[c_void_p] self.lib.ccadical_wrapper_add.argtypes=[c_void_p, c_int] self.lib.ccadical_wrapper_add.restype=c_int self.lib.ccadical_wrapper_add_lits.argtypes=[c_void_p, POINTER(c_int)] self.lib.ccadical_wrapper_add_lits.restype=c_int self.lib.ccadical_wrapper_assume.argtypes=[c_void_p, c_int] self.lib.ccadical_wrapper_solve.argtypes=[c_void_p, c_int] self.lib.ccadical_wrapper_solve.restype=c_int self.lib.ccadical_wrapper_val.argtypes=[c_void_p, c_int] self.lib.ccadical_wrapper_val.restype=c_int self.ctx=self.lib.ccadical_wrapper_init() def release (self): self.lib.ccadical_wrapper_release(self.ctx) def add (self, i): return self.lib.ccadical_wrapper_add(self.ctx, i) def add_lits (self, lits): arg2=(c_int * len(lits))(*lits) return self.lib.ccadical_wrapper_add_lits(self.ctx, arg2) def assume (self, i): self.lib.ccadical_wrapper_assume(self.ctx, i) def solve (self): return self.lib.ccadical_wrapper_solve(self.ctx, -1) def value (self, i): return self.lib.ccadical_wrapper_val(self.ctx, i) """ test: o=picosat() o.add(1) o.add(0) #o.add(-2) #o.add(0) print (o.sat()) o.assume(-2) print (o.sat()) print ("calling deref") print (o.deref(1)) print (o.deref(2)) """ class IPAMIR: def __init__(self): self.lib=None #print (sys.path) for d in sys.path: fname=d+"/libuwrmaxsat.so" if os.path.exists(fname): #print (fname) self.lib=CDLL(fname) break if self.lib==None: print ("Error: Can't find libuwrmaxsat.so, I searched in: ", sys.path) exit(0) """ ipamir_add_hard ipamir_add_soft_lit ipamir_assume ipamir_init ipamir_release ipamir_set_terminate ipamir_signature ipamir_solve ipamir_val_lit ipamir_val_obj IPAMIR_API void * ipamir_init (); IPAMIR_API void ipamir_release (void * solver); IPAMIR_API void ipamir_add_hard (void * solver, int32_t lit_or_zero); IPAMIR_API void ipamir_add_soft_lit (void * solver, int32_t lit, uint64_t weight); IPAMIR_API int ipamir_solve (void * solver); IPAMIR_API int32_t ipamir_val_lit (void * solver, int32_t lit); """ self.lib.ipamir_init.restype=c_void_p self.lib.ipamir_release.argtypes=[c_void_p] self.lib.ipamir_add_hard.argtypes=[c_void_p, c_int32] self.lib.ipamir_add_soft_lit.argtypes=[c_void_p, c_int32, c_uint64] #self.lib.ipamir_add_soft_lit.argtypes=[c_void_p, c_int64, c_uint64] self.lib.ipamir_solve.argtypes=[c_void_p] self.lib.ipamir_solve.restype=c_int self.lib.ipamir_val_lit.argtypes=[c_void_p, c_int32] self.lib.ipamir_val_lit.restype=c_int32 self.ctx=self.lib.ipamir_init() def release (self): self.lib.ipamir_release(self.ctx) def add_soft (self, lit, weight): #print ("IPAMIR.add_soft", lit, weight) #self.lib.ipamir_add_soft_lit(self.ctx, lit, weight) self.lib.ipamir_add_soft_lit(self.ctx, -lit, weight) def add_hard (self, lit): #print ("IPAMIR.add_hard", lit) self.lib.ipamir_add_hard(self.ctx, lit) def solve (self): return self.lib.ipamir_solve(self.ctx) def value (self, i): return self.lib.ipamir_val_lit(self.ctx, i) class SAT_lib: def __init__(self, maxsat:bool=False, maxsat_inc:bool=False, maxsat_inc_timeout:int=60, SAT_solver:str="kissat", verbose:int=0, seed:int=-1, proof:bool=False, threads:int=1): self.SAT_solver=SAT_solver self.verbose=verbose self.proof=proof self.threads=threads # for gimsatul self.last_var=1 # unused var... # just list of lines to be written to CNF-file: if SAT_solver == "libpicosat": self.libpicosat=picosat() self.CNF_initialized=False self.CNF_write_comments=False elif SAT_solver == "libcadical": self.cadical=cadical() self.CNF_initialized=False self.CNF_write_comments=False else: self.CNF_next_idx=0 self.CNF_clauses={} # type: Dict[int, List[int]] self.CNF_comments={} # type: Dict[int, str] self.CNF_initialized=True self.CNF_write_comments=True self.maxsat=maxsat if self.maxsat==False: self.CNF_fname="SAT_lib."+str(os.getpid())+".cnf" else: self.maxsat_inc=maxsat_inc self.maxsat_inc_timeout=maxsat_inc_timeout if USE_IPAMIR: self.CNF_initialized=False self.IPAMIR=IPAMIR() if self.maxsat_inc: print ("Error, incomplete maxsat isn't supported with IPAMIR") assert False else: self.CNF_fname="SAT_lib."+str(os.getpid())+".wcnf" self.clauses_total=0 #self.HARD_CLAUSE=10000 self.HARD_CLAUSE=50000000 self.remove_CNF_file=True #self.remove_CNF_file=False if proof: self.remove_CNF_file=False self.seed=seed # allocate a single variable fixed to False: self.const_false=self.create_var() self.fix(self.const_false, False) # allocate a single variable fixed to True: self.const_true=self.create_var() self.fix(self.const_true, True) self.OR_list_already_generated={} # type: Dict[Tuple[int, ...], int] self.XOR_list_already_generated={} # type: Dict[Tuple[int, ...], int] self.AND_already_generated={} # type: Dict[Tuple[int, int], int] self.AND_list_already_generated={} # type: Dict[Tuple[int, ...], int] self.NAND_already_generated={} # type: Dict[Tuple[int, int], int] self.NOT_already_generated={} # type: Dict[int, int] self.XOR_already_generated={} # type: Dict[Tuple[int, int], int] self.EQ_already_generated={} # type: Dict[Tuple[int, int], int] self.IMPLY_already_generated={} # type: Dict[Tuple[int, int], int] self.ITE_already_generated={} # type: Dict[Tuple[int, int, int], int] # only for libpicosat and libcadical: def assume (self, v:int): if self.SAT_solver=="libpicosat": self.libpicosat.assume(v) elif self.SAT_solver=="libcadical": self.cadical.assume(v) else: print ("Error: that solver doesn't support assume:", self.SAT_solver) assert False def run_minisat (self): results_fname="results.txt."+str(os.getpid()) #print (results_fname) try: child = subprocess.Popen(["minisat", self.CNF_fname, results_fname], stdout=subprocess.PIPE) except FileNotFoundError: print ("minisat not found. install it: 'sudo apt install minisat' or 'sudo pkg_add -i minisat'") exit(0) child.wait() # 10 is SAT, 20 is UNSAT if child.returncode==20: os.remove (results_fname) return None if child.returncode!=10: print ("(minisat) unknown retcode: ", child.returncode) exit(0) #print "minisat done" t=my_utils.read_lines_from_file(results_fname)[1].split(" ") # remove last "variable", which is 0 assert t[-1]=='0' t=t[:-1] # there was a time whan $t$ list was returned as solution! # now it's dict solution={} for i in t: if i.startswith("-"): solution[-int(i)]=False else: solution[int(i)]=True os.remove (results_fname) return solution # cmd_line is an array def run_plingeling_or_open_wbo (self, cmd_line:str): #print (cmd_line) tmp_fname="SAT_lib.tmp.out."+str(os.getpid()) err_fname="SAT_lib.tmp.err."+str(os.getpid()) logfile=open(tmp_fname, "w") errfile=open(err_fname, "w") try: child = subprocess.Popen(cmd_line, stdout=logfile, stderr=errfile) except FileNotFoundError: print ("can't run: "+" ".join(cmd_line)) exit(0) try: #print ("\ngoing to call child.wait()\n") child.wait() #print ("\nchild.wait() done\n") except KeyboardInterrupt: print ("Interrupt by user") logfile.close() errfile.close() os.remove(tmp_fname) os.remove(err_fname) if self.remove_CNF_file: os.remove(self.CNF_fname) exit(0) logfile.flush() logfile.close() errfile.flush() errfile.close() errfile=open(err_fname, "r") line = errfile.readline() if len(line)>0: print ("error log:", line) exit(0) errfile.close() tmp=[] logfile=open(tmp_fname, "r") while True: line = logfile.readline() #print ("line=", line) if line.startswith("s UNSAT"): logfile.close() os.remove(tmp_fname) os.remove(err_fname) return None elif line.startswith("v "): tmp.append (line[2:].rstrip()) elif line=='': break else: pass logfile.close() os.remove(tmp_fname) os.remove(err_fname) if len(tmp)==0: # timeout. UNKNOWN return None assert len(tmp)!=0 t=" ".join(tmp).split(" ") # there was a time whan $t$ list was returned as solution! # now it's dict solution={} for i in t: if i.startswith("-"): solution[-int(i)]=False else: solution[int(i)]=True return solution # cmd_line is an array def run_open_wbo_inc (self, cmd_line:List[str]): #print (cmd_line) tmp_fname="SAT_lib.tmp.out."+str(os.getpid()) err_fname="SAT_lib.tmp.err."+str(os.getpid()) logfile=open(tmp_fname, "w") errfile=open(err_fname, "w") try: child = subprocess.Popen(cmd_line, stdout=logfile, stderr=errfile) except FileNotFoundError: print ("can't run: "+" ".join(cmd_line)) exit(0) try: time.sleep(self.maxsat_inc_timeout) child.send_signal(signal.SIGTERM) if self.verbose>0: print ("") print (f"SIGTERM signal sent to incomplete solver after {self.maxsat_inc_timeout} seconds.") #print ("going to run child.wait()") child.wait() #print ("\nchild.wait() finished") except KeyboardInterrupt: print ("Interrupt by user") logfile.close() errfile.close() os.remove(tmp_fname) os.remove(err_fname) if self.remove_CNF_file: os.remove(self.CNF_fname) exit(0) logfile.flush() logfile.close() errfile.flush() errfile.close() errfile=open(err_fname, "r") line = errfile.readline() if len(line)>0: print ("error log:", line) exit(0) errfile.close() tmp=[] logfile=open(tmp_fname, "r") while True: line = logfile.readline() #print ("line=", line) if line.startswith("s UNSAT"): logfile.close() os.remove(tmp_fname) os.remove(err_fname) return None elif line.startswith("v "): tmp.append (line[2:].rstrip()) elif line=='': break else: pass logfile.close() os.remove(tmp_fname) os.remove(err_fname) #print ("tmp=", tmp) if len(tmp)!=1: print (f"Error, unexpected {tmp=}") assert False tmp2=tmp[0] # there was a time whan $t$ list was returned as solution! # now it's dict solution={} #print ("tmp", tmp) for i in range(len(tmp2)): c=tmp2[i] if c=='0': solution[i+1]=False elif c=='1': solution[i+1]=True else: print ("c", c) assert False return solution def run_libpicosat(self): t=self.libpicosat.sat() if t==10: # SAT solution={} for v in range(1, self.last_var): x=self.libpicosat.deref(v) if x==-1: solution[v]=False elif x==1: solution[v]=True else: raise ValueError ("unknown code returned by libpicosat.deref(): "+str(v)) return solution elif t==20: # UNSAT return None else: raise ValueError ("unknown code returned by libpicosat.sat(): "+str(t)) def run_libcadical(self): t=self.cadical.solve() if t==10: # SAT solution={} for v in range(1, self.last_var): x=self.cadical.value(v) if x<0: solution[v]=False elif x>0: solution[v]=True else: raise ValueError ("unknown code returned by cadical.value(): "+str(v)) return solution elif t==20: # UNSAT return None else: raise ValueError ("unknown code returned by cadical.value(): "+str(t)) def run_sat_solver(self): if self.SAT_solver=="minisat": if self.seed!=-1: print ("FIXME: run minisat with seed") if self.proof: print ("FIXME: run minisat with proof") return self.run_minisat() elif self.SAT_solver=="kissat" or self.SAT_solver=="gimsatul": tmp=[self.SAT_solver] if self.seed!=-1: if self.SAT_solver=="gimsatul": print ("Warning: gimsatul doesn't support seed") else: tmp.append("--seed="+str(self.seed)) # but kissat does #print ("Setting seed to", self.seed) tmp.append(self.CNF_fname) if self.proof: proof_fname="SAT_lib."+str(os.getpid())+"."+self.SAT_solver+".proof" print ("Setting proof filename to", proof_fname) tmp.append(proof_fname) if self.SAT_solver=="gimsatul" and self.threads>1: tmp.append("--threads="+str(self.threads)) return self.run_plingeling_or_open_wbo(tmp) elif self.SAT_solver=="plingeling": if self.seed!=-1: print ("Warning: impossible to run plingeling with seed") if self.proof: print ("FIXME: run plingeling with proof") return self.run_plingeling_or_open_wbo(["plingeling", self.CNF_fname]) elif self.SAT_solver=="picosat": #if self.seed!=-1: # return self.run_plingeling_or_open_wbo(["picosat", "-s "+str(self.seed), self.CNF_fname]) # BUG. NOT WORKING #else: # return self.run_plingeling_or_open_wbo(["picosat", self.CNF_fname]) if self.seed!=-1: print ("FIXME: run picosat with seed") if self.proof: print ("FIXME: run picosat with proof") return self.run_plingeling_or_open_wbo(["picosat", self.CNF_fname]) elif self.SAT_solver=="libpicosat": if self.seed!=-1: # FIXME: print this only once! #print ("FIXME: run libpicosat with seed") pass if self.proof: # FIXME: print this only once! #print ("FIXME: run libpicosat with proof") pass return self.run_libpicosat() elif self.SAT_solver=="libcadical": if self.seed!=-1: # FIXME: print this only once! #print ("FIXME: run libpicosat with seed") pass if self.proof: # FIXME: print this only once! print ("FIXME: run libcadical with proof") pass return self.run_libcadical() else: raise ValueError ("unknown/unsupported SAT solver: "+self.SAT_solver) def run_open_wbo_solver (self): #return self.run_plingeling_or_open_wbo(["open-wbo", "-algorithm=5", "-cpu-lim=2", self.CNF_fname]) #return self.run_plingeling_or_open_wbo(["open-wbo", "-algorithm=5", "-cpu-lim=10", self.CNF_fname]) return self.run_plingeling_or_open_wbo(["open-wbo", "-algorithm=5", self.CNF_fname]) #return self.run_plingeling_or_open_wbo(["./tt-open-wbo-inc-Glucose4_1_static", self.CNF_fname]) def run_open_wbo_inc_solver (self): return self.run_open_wbo_inc(["./tt-open-wbo-inc-Glucose4_1_static", self.CNF_fname]) def run_IPAMIR(self): t=self.IPAMIR.solve() if t==30: # SAT solution={} for v in range(1, self.last_var): x=self.IPAMIR.value(v) if x<0: solution[v]=False elif x>0: solution[v]=True else: raise ValueError (f"unknown code returned by IPAMIR.val_lit(): {v}") return solution else: raise ValueError (f"unknown code returned by IPAMIR.solve(): {t}") def run_maxsat_solver (self): if self.maxsat_inc==False: if USE_IPAMIR: return self.run_IPAMIR() else: return self.run_open_wbo_solver() else: return self.run_open_wbo_inc_solver() def run_picomus (self): tmp_fname="tmp.out."+str(os.getpid()) err_fname="tmp.err."+str(os.getpid()) logfile=open(tmp_fname, "w") errfile=open(err_fname, "w") try: child = subprocess.Popen(["picomus", self.CNF_fname], stdout=logfile, stderr=errfile) except FileNotFoundError: print ("can't run picomus") exit(0) child.wait() logfile.flush() logfile.close() errfile.flush() errfile.close() errfile=open(err_fname, "r") line = errfile.readline() if len(line)>0: print ("error log:", line) exit(0) errfile.close() tmp=[] logfile=open(tmp_fname, "r") while True: line = logfile.readline() #print ("line=", line) if line.startswith("v "): tmp.append (int(line[2:].rstrip())) elif line=='': break else: pass logfile.close() os.remove(tmp_fname) os.remove(err_fname) return tmp def get_MUS_vars (self): self.write_CNF() clauses=self.run_picomus () #print ("get_MUS_vars(), clauses:", clauses) vars=set() for clause in clauses: if clause==0: break s=self.CNF_clauses[clause-1] #print ("get_MUS_vars(), s:", s) for tmp in s: vars.add (abs(int(tmp))) #print ("get_MUS_vars(), vars:", vars) return clauses, sorted(list(vars)) def write_CNF(self): VARS_TOTAL=self.last_var-1 f=open(self.CNF_fname, "w") if self.maxsat==False: f.write ("p cnf "+str(VARS_TOTAL)+" "+str(self.clauses_total)+"\n") else: f.write ("p wcnf "+str(VARS_TOTAL)+" "+str(self.clauses_total)+" "+str(self.HARD_CLAUSE)+"\n") for i in range(self.CNF_next_idx): if i in self.CNF_clauses: line=" ".join(map(str, self.CNF_clauses[i]+[0]))+"\n" elif i in self.CNF_comments: line="c "+self.CNF_comments[i]+"\n" f.write(line) f.close() if self.verbose>0: print ("write_CNF() clauses=%d" % self.clauses_total) self.CNF_written=True def create_var(self): self.last_var=self.last_var+1 return self.last_var-1 # FIXME: get rid of it! def neg(self, v:int) -> int: #print ("neg:", v) #if v==None: # raise ValueError if v==0: raise ValueError return -v def neg_if(self, cond:bool, var:int) -> int: if cond: return self.neg(var) else: return var def BV_neg(self, lst:List[int]) -> List[int]: return [self.neg(l) for l in lst] # to be tested... def add_comment(self, comment:str): if self.CNF_write_comments==False: return self.CNF_comments[self.CNF_next_idx]=comment self.CNF_next_idx=self.CNF_next_idx+1 def add_clause(self, cls:List[int]): if self.verbose>=2: print ("add_clause", cls) if self.SAT_solver=="libcadical": #for v in cls: # self.cadical.add(v) #self.cadical.add(0) self.cadical.add_lits(cls+[0]) return if self.SAT_solver=="libpicosat": #for v in cls: # self.libpicosat.add(v) #self.libpicosat.add(0) self.libpicosat.add_lits(cls+[0]) return # remove tautological clauses if CNF_REMOVE_TAUTOLOGICAL: if is_tautological(cls): return # filter our duplicate clauses # FIXME: SLOW! if CNF_REMOVE_DUPLICATES: cls=my_utils.uniq_list(cls) if cls in self.CNF_clauses.values(): return if self.maxsat==False: self.clauses_total=self.clauses_total+1 self.CNF_clauses[self.CNF_next_idx]=cls self.CNF_next_idx=self.CNF_next_idx+1 else: if USE_IPAMIR: for c in cls: self.IPAMIR.add_hard(c) self.IPAMIR.add_hard(0) else: self.clauses_total=self.clauses_total+1 self.CNF_clauses[self.CNF_next_idx]=[self.HARD_CLAUSE]+cls self.CNF_next_idx=self.CNF_next_idx+1 #if (self.clauses_total % 1000000)==0: # print "(hearbeat) add_clause(). clauses_total=", self.clauses_total def deinit(self): if self.SAT_solver=="libcadical": self.cadical.release() return if self.SAT_solver=="libpicosat": self.libpicosat.reset() return if self.maxsat and USE_IPAMIR: #print ("calling IPAMIR.release()") self.IPAMIR.release() def add_clauses(self, clauses:List[List[int]]): for cls in clauses: self.add_clause(cls) # BUG. must be only one lit, not clause def add_soft_clause(self, cls:List[int], weight:int): assert self.maxsat==True assert weight>0 assert type(cls)==list #print ("cls=", cls) #print ("weight=", weight) assert len(cls)==1 if USE_IPAMIR: self.IPAMIR.add_soft(cls[0], weight) #self.IPAMIR.add_hard(0) #self.IPAMIR.add_soft(0, weight) else: self.clauses_total=self.clauses_total+1 self.CNF_clauses[self.CNF_next_idx]=[weight]+cls self.CNF_next_idx=self.CNF_next_idx+1 def AND_Tseitin(self, v1:int, v2:int, out:int): self.add_clause([self.neg(v1), self.neg(v2), out]) self.add_clause([v1, self.neg(out)]) self.add_clause([v2, self.neg(out)]) def AND(self, v1:int, v2:int) -> int: if (v1, v2) in self.AND_already_generated: return self.AND_already_generated[(v1,v2)] out=self.create_var() self.AND_Tseitin(v1, v2, out) self.AND_already_generated[(v1,v2)]=out return out def AND_list(self, l:List[int]) -> int: assert(len(l)>=1) # this is correct! if len(l)==1: return l[0] if len(l)==2: return self.AND(l[0], l[1]) if tuple(l) in self.AND_list_already_generated: return self.AND_list_already_generated[tuple(l)] out=self.AND(l[0], self.AND_list(l[1:])) self.AND_list_already_generated[tuple(l)]=out return out # AKA 'not both' def NAND(self, v1:int, v2:int) -> int: if self.verbose>=1: print ("NAND", v1, v2) if (v1, v2) in self.NAND_already_generated: return self.NAND_already_generated[(v1,v2)] out=self.NOT(self.AND(v1, v2)) self.NAND_already_generated[(v1,v2)]=out return out def BV_AND(self, x:List[int], y:List[int]) -> List[int]: assert type(x)==list assert type(y)==list rt=[] for pair in zip(x, y): rt.append(self.AND(pair[0],pair[1])) return rt # as in Tseitin transformations. # N.B.: previously called "OR" def OR_list(self, vals:List[int]) -> int: if self.verbose>=2: print ("OR_list", vals) if tuple(vals) in self.OR_list_already_generated: rt=self.OR_list_already_generated[tuple(vals)] if self.verbose>=2: print ("OR_list -> cached", rt) return rt #print (vals) if len(vals)==0: raise ValueError ("OR_list() requires list of >=1 vars") if len(vals)==1: return vals[0] # this is correct! out=self.create_var() self.add_clause(vals+[self.neg(out)]) for v in vals: self.add_clause([self.neg(v), out]) self.OR_list_already_generated[tuple(vals)]=out return out def OR_always(self, vals:List[int]): self.add_clause(vals) # to be used only for small number of inputs (less that ~12) def XOR_list(self, inputs:List[int]) -> int: if self.verbose>=2: print ("XOR_list", inputs) if tuple(inputs) in self.XOR_list_already_generated: rt=self.XOR_list_already_generated[tuple(inputs)] if self.verbose>=2: print ("XOR_list -> cached", rt) return rt inputs_t=len(inputs) if len(inputs)<=1: raise ValueError ("XOR_list() requires list of >=1 vars") out=self.create_var() # see the book for explanation of this slightly esoteric function for row in range(2**inputs_t): #print ("XOR_list, row", row) cls=[] tmp=[] for col in range(inputs_t): if (row>>col)&1==0: cls.append(self.neg(inputs[col])) tmp.append(False) else: cls.append(inputs[col]) tmp.append(True) tmp2=functools.reduce (operator.xor, tmp) if (inputs_t&1)==0: tmp2=not(tmp2) if tmp2==False: cls.append(out) else: cls.append(self.neg(out)) self.add_clause(cls) self.XOR_list_already_generated[tuple(inputs)]=out return out def alloc_BV(self, n:int) -> List[int]: return [self.create_var() for i in range(n)] def fix_soft(self, var:int, b:bool, weight:int): if b: self.add_soft_clause([var], weight) else: self.add_soft_clause([self.neg(var)], weight) def fix_soft_always_true(self, var:int, weight:int): self.fix_soft(var, True, weight) def fix_soft_always_true_all_bits_in_BV(self, BV:List[int], weight:int): for b in BV: self.fix_soft_always_true(b, weight) def fix(self, var:int, b:bool): if self.verbose>=2: print ("fix", var, b) if b: self.add_clause([var]) else: self.add_clause([self.neg(var)]) def fix_always_false(self, var:int): self.fix(var, False) def fix_always_true(self, var:int): self.fix(var, True) # BV is a list of True/False def fix_BV(self, _vars:List[int], BV:List[bool]): #print _vars, BV assert len(_vars)==len(BV) for var, _bool in zip(_vars, BV): self.fix (var, _bool) def fix_BV_all_bits_1(self, _vars:List[int]): for var in _vars: self.fix_always_true (var) # BV is a list of True/False def fix_BV_soft(self, _vars:List[int], BV:List[bool], weight:int): assert len(_vars)==len(BV) for var, _bool in zip(_vars, BV): self.fix_soft (var, _bool, weight) def get_var_from_solution(self, var:int) -> int: #print ("self.solution", self.solution) # 1 if var is present in solution, 0 if present in negated form: if self.solution[int(var)]: return 1 else: return 0 raise ValueError ("incorrect var number: "+str(var)) def get_BV_from_solution(self, BV:List[int]) -> List[int]: return [self.get_var_from_solution(var) for var in BV] def solve(self) -> bool: if self.CNF_initialized: self.write_CNF() if self.maxsat: self.solution=self.run_maxsat_solver() else: self.solution=self.run_sat_solver() if self.CNF_initialized and self.remove_CNF_file and os.path.exists(self.CNF_fname): os.remove(self.CNF_fname) if self.solution==None: return False else: return True def NOT(self, x:int) -> int: if self.verbose>=2: print ("NOT", x) if x in self.NOT_already_generated: return self.NOT_already_generated[x] rt=self.create_var() self.add_clause([self.neg(rt), self.neg(x)]) self.add_clause([rt, x]) self.NOT_already_generated[x]=rt return rt def BV_NOT(self, x:List[int]) -> List[int]: rt=[] for b in x: rt.append(self.NOT(b)) return rt def XOR(self, x:int, y:int) -> int: if self.verbose>=2: print ("XOR", x, y) if (x, y) in self.XOR_already_generated: return self.XOR_already_generated[(x,y)] rt=self.create_var() self.add_clause([self.neg(x), self.neg(y), self.neg(rt)]) self.add_clause([x, y, self.neg(rt)]) self.add_clause([x, self.neg(y), rt]) self.add_clause([self.neg(x), y, rt]) self.XOR_already_generated[(x,y)]=rt return rt def BV_OR(self, x:List[int], y:List[int]) -> List[int]: rt=[] for pair in zip(x,y): rt.append(self.OR_list([pair[0], pair[1]])) return rt def BV_OR_list(self, l:List[List[int]]) -> List[int]: assert(len(l)>=2) if len(l)==2: return self.BV_OR(l[0], l[1]) return self.BV_OR(l[0], self.BV_OR_list(l[1:])) def BV_XOR(self, x:List[int], y:List[int]) -> List[int]: #print ("BV_XOR: start") rt=[] for pair in zip(x,y): rt.append(self.XOR(pair[0], pair[1])) #print ("BV_XOR: finish") return rt def BV_XOR_list(self, l:List[List[int]]) -> List[int]: assert(len(l)>=2) if len(l)==2: return self.BV_XOR(l[0], l[1]) return self.BV_XOR(l[0], self.BV_XOR_list(l[1:])) def EQ(self, x:int, y:int) -> int: if self.verbose>=2: print ("EQ", x, y) if (x, y) in self.EQ_already_generated: return self.EQ_already_generated[(x,y)] out=self.NOT(self.XOR(x,y)) self.EQ_already_generated[(x,y)]=out return out def NEQ(self, x:int, y:int) -> int: return self.XOR(x,y) # p => q def IMPLY(self, p:int, q:int) -> int: if self.verbose>=1: print ("IMPLY", p, q) if (p, q) in self.IMPLY_already_generated: return self.IMPLY_already_generated[(p, q)] out=self.OR_list([self.NOT(p), q]) self.IMPLY_already_generated[(p, q)]=out return out def IMPLY_always(self, p:int, q:int) -> int: return self.fix(self.IMPLY(p, q), True) # naive/pairwise/quadratic encoding def AtMost1_pairwise(self, lst:List[int]): for pair in itertools.combinations(lst, r=2): self.add_clause([self.neg(pair[0]), self.neg(pair[1])]) # "commander" (?) encoding # TODO: maybe should be tuned? tested? def AtMost1_commander(self, lst:List[int]) -> int: parts=my_utils.partition(lst, 3) c=[] # type: List[int] for part in parts: if len(part)<12: self.AtMost1_pairwise(part) c.append(self.OR_list(part)) else: c.append(self.AtMost1_commander(part)) self.AtMost1_pairwise(c) return self.OR_list(c) def AtMost1(self, lst:List[int]): #self.AtMost1_pairwise(lst) #return if len(lst)<12: self.AtMost1_pairwise(lst) else: self.AtMost1_commander(lst) # previously named POPCNT1 # make one-hot (AKA unitary) variable def make_one_hot(self, lst:List[int]): self.AtMost1(lst) self.OR_always(lst) # TODO tests def mult_one_hots(self, x:List[int], y:List[int]) -> List[int]: assert len(x)==len(y) in_size=len(x) z=[] for i in range(in_size): for j in range(in_size): z.append(self.AND(x[i], y[j])) assert len(z)==in_size**2 return z def neg_nth_elem_in_lst(self, lst:List[int], n:int) -> List[int]: rt=[] # type: List[int] assert n int: #print len(l1), len(l2) assert len(l1)==len(l2) self.add_comment("BV_EQ") t=[] for p in zip(l1, l2): t.append(self.NOT(self.EQ(p[0], p[1]))) return self.NOT(self.OR_list(t)) # bitvectors must be different. def fix_BV_NEQ(self, l1:List[int], l2:List[int]): #print len(l1), len(l2) assert len(l1)==len(l2) self.add_comment("fix_BV_NEQ") t=[self.XOR(l1[i], l2[i]) for i in range(len(l1))] self.add_clause(t) # full-adder, as found by Mathematica using truth table: def FA (self, a:int, b:int, cin:int) -> Tuple[int, int]: s=self.create_var() cout=self.create_var() self.add_clause([self.neg(a), self.neg(b), self.neg(cin), s]) self.add_clause([self.neg(a), self.neg(b), cout]) self.add_clause([self.neg(a), self.neg(cin), cout]) self.add_clause([self.neg(a), cout, s]) self.add_clause([a, b, cin, self.neg(s)]) self.add_clause([a, b, self.neg(cout)]) self.add_clause([a, cin, self.neg(cout)]) self.add_clause([a, self.neg(cout), self.neg(s)]) self.add_clause([self.neg(b), self.neg(cin), cout]) self.add_clause([self.neg(b), cout, s]) self.add_clause([b, cin, self.neg(cout)]) self.add_clause([b, self.neg(cout), self.neg(s)]) self.add_clause([self.neg(cin), cout, s]) self.add_clause([cin, self.neg(cout), self.neg(s)]) return s, cout # bit order: [MSB..LSB] # n-bit adder: def adder(self, X:List[int], Y:List[int]) -> Tuple[List[int], int]: assert len(X)==len(Y) # first full-adder could be half-adder # start with lowest bits: inputs=my_utils.rvr(list(zip(X,Y))) carry=self.const_false sums=[] for pair in inputs: # "carry" variable is replaced at each iteration. # so it is used in the each FA() call from the previous FA() call. s, carry = self.FA(pair[0], pair[1], carry) sums.append(s) return my_utils.rvr(sums), carry # bit is 0 or 1. # i.e., if it's 0, output is 0 (all bits) # if it's 1, output=input def mult_by_bit(self, X:List[int], bit:int) -> List[int]: return [self.AND(i, bit) for i in X] # bit order: [MSB..LSB] # build multiplier using adders and mult_by_bit blocks: def multiplier(self, X:List[int], Y:List[int]) -> List[int]: assert len(X)==len(Y) out=[] #initial: prev=[self.const_false]*len(X) # first adder can be skipped, but I left thing "as is" to make it simpler for Y_bit in my_utils.rvr(Y): s, carry = self.adder(self.mult_by_bit(X, Y_bit), prev) out.append(s[-1]) prev=[carry] + s[:-1] return prev + my_utils.rvr(out) def NEG(self, x:List[int]) -> List[int]: # invert all bits tmp=self.BV_NOT(x) # add 1 one=self.alloc_BV(len(tmp)) self.fix_BV(one, n_to_BV(1, len(tmp))) return self.adder(tmp, one)[0] # untested (?) def shift_left (self, x:List[int], cnt:int) -> List[int]: return x[cnt:]+[self.const_false]*cnt def shift_left_1 (self, x:List[int]) -> List[int]: return x[1:]+[self.const_false] def shift_right (self, x:List[int], cnt:int) -> List[int]: return [self.const_false]*cnt+x[cnt:] def shift_right_1 (self, x:List[int]) -> List[int]: return [self.const_false]+x[:-1] def create_MUX(self, ins:List[int], sels:List[int]) -> int: assert 2**len(sels)==len(ins) x=self.create_var() for sel in range(len(ins)): # for example, 32 for 5-bit selector tmp=[self.neg_if((sel>>i)&1==1, sels[i]) for i in range(len(sels))] # 5 for 5-bit selector self.add_clause([self.neg(ins[sel])] + tmp + [x]) self.add_clause([ins[sel]] + tmp + [self.neg(x)]) return x # for 1-bit sel # ins=[[outputs for sel==0], [outputs for sel==1]] def create_wide_MUX (self, ins:List[List[int]], sels:List[int]) -> List[int]: out=[] for i in range(len(ins[0])): inputs=[x[i] for x in ins] out.append(self.create_MUX(inputs, sels)) return out # untested: def ITE(self, s:int, f:int, t:int) -> int: if (s, f, t) in self.ITE_already_generated: return self.ITE_already_generated[(s,f,t)] if s==0: raise ValueError if f==0: raise ValueError if t==0: raise ValueError x=self.create_var() if x==0: raise ValueError # as found by my util self.add_clause([self.neg(s),self.neg(t),x]) self.add_clause([self.neg(s),t,self.neg(x)]) self.add_clause([s,self.neg(f),x]) self.add_clause([s,f,self.neg(x)]) self.ITE_already_generated[(s,f,t)]=x return x def subtractor(self, minuend:List[int], subtrahend:List[int]) -> Tuple[List[int], int]: # same as adder(), buf: 1) subtrahend is inverted; 2) input carry-in bit is 1 X=minuend Y=self.BV_NOT(subtrahend) inputs=my_utils.rvr(list(zip(X,Y))) carry=self.const_true sums=[] for pair in inputs: # "carry" variable is replaced at each iteration. # so it is used in the each FA() call from the previous FA() call. st, carry = self.FA(pair[0], pair[1], carry) sums.append(st) return my_utils.rvr(sums), carry # 0 if a=b def comparator_GE(self, a:List[int], b:List[int]) -> int: tmp, carry = self.subtractor(a, b) return carry def div_blk(self, enable, divident:List[int], divisor:List[int]) -> Tuple[List[int], int]: assert len(divident)==len(divisor) diff, _ = self.subtractor(minuend=divident, subtrahend=divisor) cmp_res = self.AND(enable, self.comparator_GE(divident, divisor)) out=self.alloc_BV(len(divident)) # is it used?! return self.create_wide_MUX([divident, diff], [cmp_res]), cmp_res def divider(self, divident:List[int], divisor:List[int]) -> Tuple[List[int], List[int]]: assert len(divident)==len(divisor) BITS=len(divisor) wide_divisor=self.shift_left([self.const_false]*BITS+divisor, BITS-1) quotient=[] for b in range(BITS): enable=self.NOT(self.OR_list(wide_divisor[:BITS])) divident, q_bit=self.div_blk(enable, divident, wide_divisor[BITS:]) quotient.append(q_bit) wide_divisor=self.shift_right_1(wide_divisor) # remainder is left in divident: return quotient, divident def add_negated_solution_as_clause (self): negated_solution=[] for v in range(1, self.last_var): negated_solution.append(self.neg_if(self.get_var_from_solution(v), v)) self.add_clause(negated_solution) def fetch_next_solution(self): self.add_negated_solution_as_clause() return self.solve() def count_solutions(self): if self.solve()==False: return 0 cnt=1 while True: if self.fetch_next_solution()==False: break cnt=cnt+1 return cnt def get_all_solutions(self): if self.solve()==False: return None rt=[self.solution] while True: if self.fetch_next_solution()==False: break rt.append(self.solution) return rt def BV_not_zero(self, bv:List[int]) -> int: return self.OR_list(bv) def BV_zero(self, bv:List[int]) -> int: return self.NOT(self.OR_list(bv)) def get_val_from_solution(self, var:List[int]) -> int: return BV_to_number(self.get_BV_from_solution(var)) def make_distinct_BVs (self, lst:List[List[int]]): assert type(lst)==list assert type(lst[0])==list for pair in itertools.combinations(lst, r=2): self.fix_BV_NEQ(pair[0], pair[1]) # ... to each other def make_all_BVs_EQ (self, lst:List[List[int]]): assert type(lst)==list assert type(lst[0])==list for pair in itertools.combinations(lst, r=2): self.fix_BV_EQ(pair[0], pair[1]) def sort_unit(self, a:int, b:int): return self.OR_list([a,b]), self.AND(a,b) # BUGGY """ def sorting_network_make_ladder(self, lst:List[int]) -> List[int]: if len(lst)==2: return list(self.sort_unit(lst[0], lst[1])) tmp=self.sorting_network_make_ladder(lst[1:]) # lst without head first, second=self.sort_unit(lst[0], tmp[0]) return [first, second] + tmp[1:] # recursive! def sorting_network(self, lst:List[int]) -> List[int]: # simplest possible, bubble sort if len(lst)==2: return self.sorting_network_make_ladder(lst) tmp=self.sorting_network_make_ladder(lst) return self.sorting_network(tmp[:-1]) + [tmp[-1]] """ # like: https://en.wikipedia.org/wiki/File:Six-wire-bubble-sorting-network.svg # iterative def sorting_network(self, lst:List[int]) -> List[int]: lst=copy.deepcopy(lst) lst_len=len(lst) for i in range(lst_len): for j in range((lst_len-i)-1): x, y = self.sort_unit(lst[j], lst[j+1]) lst[j], lst[j+1] = x, y return lst def POPCNT(self, n:int, vars:List[int]): sorted=self.sorting_network(vars) self.fix_always_false(sorted[n]) if n!=0: self.fix_always_true(sorted[n-1]) def BV_zero_extend(self, lst:List[int], final_size:int): #print (len(lst)) #print (final_size) assert len(lst)<=final_size if len(lst)==final_size: # do nothing return lst bits_to_add = final_size-len(lst) return ([self.const_false] * bits_to_add) + lst # bit order: [MSB..LSB] # 'size' is desired width of bitvector, in bits: # return: list of SAT variables def n_to_BV (self, n:int, size:int) -> List[int]: rt=[] lst=n_to_BV(n, size) for l in lst: if l: rt.append(self.const_true) else: rt.append(self.const_false) return rt def all_diff_for_one_hot_vars (self, vars_:List[List[int]]): one_hot_width=len(vars_[0]) mask=(1< bool return self.BV_EQ(self.BV_OR_list(vars_), self.n_to_BV(mask, one_hot_width)) # this is bool var returned # G=list of tuples. each tuple is edge between two vertices # each number is vertex # total=number of vertices # return: list, color for each vertex def find_2_coloring_of_graph (G:List[Tuple[int,int]], total:int) -> List[int]: #print "find_2_coloring_of_graph begin" s=SAT_lib(False) colors=[s.alloc_BV(2) for p in range(total)] for i in G: #s.add(colors[i[0]]!=colors[i[1]]) s.fix_BV_NEQ(colors[i[0]], colors[i[1]]) assert s.solve() # get solution and return it: #print "find_2_coloring_of_graph end" return [BV_to_number(s.get_BV_from_solution(colors[p])) for p in range(total)] # https://en.wikipedia.org/wiki/Exact_cover#Matrix_and_hypergraph_representations def solve_exact_cover(m:List[List[int]], SAT_solver, proof, threads): s=SAT_lib(SAT_solver=SAT_solver, proof=proof, threads=threads) rows_total=len(m) cols_total=len(m[0]) variables=[s.create_var() for _ in range(rows_total)] for c in range(cols_total): x=[] for r in range(rows_total): if m[r][c]: x.append(variables[r]) # to satisfy col $c$, one of rows in $x$ must be set (but only one!) s.make_one_hot(x) if s.solve()==False: #print ("no solutions") # no solutions return while True: rt=[] for v in range(rows_total): if s.get_var_from_solution(variables[v]): #print ("row", v) rt.append(v) yield rt if s.fetch_next_solution()==False: break #print ("solve_exact_cover: finish") s.deinit()