// my version of D.Knuth's sat0w.w SAT solver (algorithm B)

// Initial version: Sep-2019.
// Updated/revised: Jan-2021.

#include <assert.h>

#include <cstring>
#include <algorithm>
#include <string>
#include <map>
#include <stack>
#include <vector>
#include <string>
#include <sstream>
#include <iterator>
#include <iostream>
#include <fstream>

bool knuth_style=true;
// in case of Knuth-style SAT file:
std::map<unsigned, std::string> var_n_name;
std::map<std::string, unsigned> var_name_n;

unsigned verbose=0;

unsigned vars_t;
int* WL; // size=clauses_t. // idx=clause_no. each value - next clause_no or -1 (if this is a last clause_no in the list)
unsigned literals_t=0;

int* literals; // size=literals_t. // each value is index in clause_DB[] or -1 (if WL is empty for this literal)

unsigned clauses_t;
unsigned *clause_len; // array. idx=clause number. len=clauses_t
int *clause_DB; // size=literals_total_in_all_clauses. each value is a variable in a clause
unsigned *clause_start; // array. len=clauses_t

unsigned nodes=0;

/*
from TAOCP 7.2.2.2:

m_j = 0 means we're trying x_j = 1 and haven't yet tried x_j = 0. 
m_j = 1 means we're trying x_j = 0 and haven't yet tried x_j = 1. 
m_j = 2 means we're trying x_j = 1 after x_j = 0 has failed.      
m_j = 3 means we're trying x_j = 0 after x_j = 1 has failed.      
*/

#define TRUE_AND_FALSE_HAVENT_TRIED 0
#define FALSE_AND_TRUE_HAVENT_TRIED 1
#define TRUE_AFTER_FALSE_FAILED     2
#define FALSE_AFTER_TRUE_FAILED     3

unsigned char *move; // len=vars_t

void reattach_clause_to_lit (unsigned lit, unsigned c_no)
{
        if (verbose>1)
                printf ("%s (lit=%d, c_no=%d\n", __func__, lit, c_no);

        unsigned old_c_no=literals[lit];
        literals[lit]=c_no;
        WL[c_no]=old_c_no;
};

unsigned var2lit (int var)
{
        assert(var!=0);

        if (var<0)
                return (abs(var)-1)*2;
        else
                return (abs(var)-1)*2+1;
};

bool var_already_assigned_to_false (int var)
{
        return (move[abs(var)-1])==FALSE_AND_TRUE_HAVENT_TRIED ||
                (move[abs(var)-1])==FALSE_AFTER_TRUE_FAILED;
}

bool var_already_assigned_to_true (int var)
{
        return (move[abs(var)-1])==TRUE_AND_FALSE_HAVENT_TRIED ||
                (move[abs(var)-1])==TRUE_AFTER_FALSE_FAILED;
}

void swap_ints (int *a, int *b)
{
        int tmp=*a;
        *a=*b;
        *b=tmp;
};

bool find_better_place_for_clause(unsigned level, unsigned c_no)
{
        if (verbose>1)
                printf ("%s c_no=%d\n", __func__, c_no);

        if (clause_len[c_no]==1)
                return false; // unit clause

        // cycle watched literal through all of them in the clause:
        for (unsigned i=1; i<clause_len[c_no]; i++)
        {
                int var=clause_DB[clause_start[c_no]+i];

                // try to connect this clause to the literal $lit$, if possible
                if ((abs(var)-1) > level)
                {
                        // variable is unassigned yet, it's OK to postpone it for future
                        reattach_clause_to_lit (var2lit(var), c_no);
                        // swap literals
                        swap_ints(&clause_DB[clause_start[c_no]], &clause_DB[clause_start[c_no]+i]);
                        if (verbose>1)
                                printf ("%s() var %d connected to the 'future' literal\n", __func__, var2lit(var));
                        return true;
                }
                else
                {
                        // variable from clause is assigned already

                        // is it false and current variable is false? OK, connect it there
                        // or
                        // is it true and current variable is true? OK, connect it there
                        bool cond1=var_already_assigned_to_false(var) && var<0;
                        bool cond2=var_already_assigned_to_true(var) && var>0;
                        if (cond1 || cond2)
                        {
                                reattach_clause_to_lit (var2lit(var), c_no);
                                // swap literals
                                swap_ints(&clause_DB[clause_start[c_no]], &clause_DB[clause_start[c_no]+i]);
                                if (verbose>1)
                                        printf ("%s() var %d connected to the 'past' literal\n", __func__, var2lit(var));
                                return true;
                        }
                        else
                        {
                                // we can't reconnect this variable
                                // proceed to increment $i$
                        };
                };
        };

        // we made a full round
        // so we can't reconnect this clause
        if (verbose>1)
                printf("%s() we made a full round. return false\n", __func__);
        return false;
};

void dump_literals ()
{
        for (unsigned l=0; l<literals_t; l++)
        {
                printf ("literal=%d var=%d parity=%d. clauses: ", l, l>>1, l&1);
                for (unsigned c_no=literals[l]; c_no!=-1; c_no=WL[c_no])
                        printf ("%d ", c_no);
                printf("\n");
        };
#if 0
        printf ("just list of clauses:\n");
        for (unsigned l=0; l<literals_t; l++)
        {
                for (unsigned c_no=literals[l]; c_no!=-1; c_no=WL[c_no])
                        printf ("%d\n", c_no);
        };
        printf ("end list of clauses:\n");
#endif
};

// redistribute all clauses
bool reconnect_all_clauses(unsigned level)
{
        if (verbose>=2)
        {
                printf ("%s() begin. level=%d\n", __func__, level);
                printf ("literals now:\n");
                dump_literals();
        };

        // disconnect all clauses and "disembowel" them
        unsigned lit;
        if ((move[level])==FALSE_AND_TRUE_HAVENT_TRIED ||
            (move[level])==FALSE_AFTER_TRUE_FAILED)
                lit=level*2+1;
        else
                lit=level*2;

        if (verbose>1)
                printf ("%s() reconnecting all clauses for lit=%d\n", __func__, lit);

        for (unsigned c_no=literals[lit]; c_no!=-1; )
        {
                if (verbose>1)
                        printf ("%s() c_no=%d\n", __func__, c_no);
                unsigned next_c_no=WL[c_no]; // because WL[c_no] gets zapped in find_better_place_for_clause()
                if (verbose>1)
                        printf ("%s() next_c_no=%d\n", __func__, next_c_no);
                if (find_better_place_for_clause (level, c_no))
                {
                        literals[lit]=next_c_no;
                }
                else
                {
                        if (verbose>=2)
                        {
                                printf ("%s() Clause %d contradicted\n", __func__, c_no);
                                printf ("%s() -> false\n", __func__);
                                //printf ("literals before exit:\n");
                                //dump_literals();
                        };
                        return false;
                };
                c_no=next_c_no;
        };
        if (verbose>=2)
        {
                printf ("%s() -> true\n", __func__);
                //printf ("literals before exit:\n");
                //dump_literals();
        };
        return true;
};

void try_bits (unsigned level)
{
// D.Knuth's code is rich for GOTOs, so is mine...
// Rationale: imitate recursion without recursion...
begin:
        nodes++;
        if (verbose>1)
                if (level>0)
                {
                        printf ("move: ");
                        for (unsigned v=0; v<level; v++)
                                printf ("%d", move[v]);
                        printf ("\n");
                };

        if (verbose>=2)
                if (level>0)
                {
                        printf ("literals:\n");
                        dump_literals ();
                };

        if (level==vars_t)
        {
#if 0
                printf ("move: ");
                for (unsigned v=0; v<level; v++)
                        printf ("%d", move[v]);
                printf ("\n");
#endif
                printf ("SAT\n");
                if (knuth_style)
                        printf (" ");
                for (unsigned v=0; v<vars_t; v++)
                {
                        if (knuth_style)
                        {
                                if ((move[v])==FALSE_AND_TRUE_HAVENT_TRIED ||
                                    (move[v])==FALSE_AFTER_TRUE_FAILED)
                                        printf ("~%s ", var_n_name[v+1].c_str());
                                else
                                        printf ("%s ", var_n_name[v+1].c_str());
                        }
                        else
                        {
                                if ((move[v])==FALSE_AND_TRUE_HAVENT_TRIED ||
                                    (move[v])==FALSE_AFTER_TRUE_FAILED)
                                        printf ("-%d ", v+1);
                                else
                                        printf ("%d ", v+1);
                        };
                }
                printf ("\n");
                if (verbose>=1)
                        printf ("nodes=%d\n", nodes);
                exit(0);
        }
        // pick the first alternative

        int first=literals[level*2+1]; // for false lit
        int second=literals[level*2];  // for true lit
        //printf ("line %d first=%d second=%d\n", __LINE__, first, second);

        // this part differs from D.Knuth's sat0w.w
        if ((first == -1) && (second == -1)) // both WLs are empty
                move[level]=FALSE_AND_TRUE_HAVENT_TRIED;
        else if ((first != -1) && (second != -1)) // both WLs are present
                move[level]=FALSE_AND_TRUE_HAVENT_TRIED;
        else if ((first != -1) && (second == -1)) // first (for false lit) present, but not the second (for true lit)
                move[level]=FALSE_AND_TRUE_HAVENT_TRIED;
        else if ((first == -1) && (second != -1)) // second (for true lit) present, but not the first (for false lit)
                move[level]=TRUE_AND_FALSE_HAVENT_TRIED;
        else
        {
                assert (!"can't be here");
        };
        
        //printf ("line %d, move[%d]=%d\n", __LINE__, level, move[level]);

        // or, this can be as simple as:
        //move[level]=FALSE_AND_TRUE_HAVENT_TRIED;

call_reconnect_all_clauses:
        // try first alternative...
        if (verbose>1)
                printf ("trying %d for variable %d\n", move[level], level+1);
        if (reconnect_all_clauses(level))
        {
                level++;
                goto begin;
        };
        // can't reconnect some clause

switch_and_backtrack:
        // the first alternative has been failed, so switch to the second
        if (move[level]==FALSE_AND_TRUE_HAVENT_TRIED)
        {
                move[level]=TRUE_AFTER_FALSE_FAILED;
                goto call_reconnect_all_clauses;
        }
        else if (move[level]==TRUE_AND_FALSE_HAVENT_TRIED)
        {
                move[level]=FALSE_AFTER_TRUE_FAILED;
                goto call_reconnect_all_clauses;
        };
        // we are here if move[level]==FALSE_AFTER_TRUE_FAILED or TRUE_AFTER_FALSE_FAILED

        // backtrack or report UNSAT
        if (level==0)
        {
                printf ("UNSAT\n");
                if (verbose>0)
                        printf ("nodes=%d\n", nodes);
                exit(0);
        };
        level--;
        // like if we back from recursive call...
        goto switch_and_backtrack;
};

void solve ()
{
        // allocate everything and set to default:
        literals=(int*)malloc(sizeof(int)*literals_t);
        for (unsigned i=0; i<literals_t; i++)
                literals[i]=-1;
        WL=(int*)malloc(sizeof(int)*clauses_t);
        for (unsigned i=0; i<clauses_t; i++)
                WL[i]=-1;

        // init: connect all clauses, while current watched literal is 0th in each clause
        for (unsigned c_no=0; c_no<clauses_t; c_no++)
        {
                int first_var=clause_DB[clause_start[c_no]];
                reattach_clause_to_lit(var2lit(first_var), c_no);
        };

        if (verbose>=2)
        {
                printf ("initial:\n");
                dump_literals ();
        };

        //printf ("vars_t=%d\n", vars_t);
        move=(unsigned char*)malloc(sizeof(unsigned)*vars_t);
        for (int i=0; i<vars_t; i++)
                move[i]=FALSE_AND_TRUE_HAVENT_TRIED;

        try_bits (0);
        assert(!"we can't be here");
};

// https://stackoverflow.com/questions/216823/whats-the-best-way-to-trim-stdstring
// trim from end (in place)
static inline void rtrim(std::string &s)
{
        s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch)
                                {
                                        return !std::isspace(ch);
                                }).base(), s.end());
}

// read standard DIMACS CNF file
void read_CNF_file (const char *fname)
{
        // for each clause
        class clause
        {
                public:
                        std::vector<int> clause;
        };

        std::vector<class clause> clauses;

        std::ifstream file(fname);
        assert (file);

        // parse header:
        std::string header;
        // skip comments:
        while (true)
        {
                assert (std::getline(file, header));
                if (header[0]!='c')
                        break;
        };
        std::istringstream header_stream(header);
        std::string s;
        assert (getline(header_stream, s, ' ' )); assert (s=="p");
        assert (getline(header_stream, s, ' ' )); assert (s=="cnf");
        assert (getline(header_stream, s, ' ' ));
        rtrim (s);
        vars_t=std::stol(s);
        literals_t=vars_t*2;
        assert (getline(header_stream, s, ' ' ));
        rtrim (s);
        clauses_t=std::stol(s);

        // parse list of clauses:
        unsigned literals_total_in_all_clauses=0;
        std::string str;
        while (std::getline(file, str))
        {
                rtrim(str);
                if (str.size()==0 || str[0]=='c')
                        continue;

                // parse space-separated list of numbers:
                std::stringstream in(str);
                class clause cl;
                std::string s;
                while (std::getline(in, s, ' '))
                {
                        if (s.size()==0)
                                continue;
                        //std::cerr << "[" << s << "]" << std::endl;
                        int v=std::stol(s);
                        if (v==0) // terminating zero
                                break;
                        cl.clause.push_back(v);
                        literals_total_in_all_clauses++;
                }
                clauses.push_back(cl);
        }
        if (clauses.size() != clauses_t)
                printf ("warning, clauses total in header: %d but %ld are read\n", clauses_t, clauses.size());
        clauses_t=clauses.size();
        clause_len=(unsigned*)malloc(sizeof(unsigned)*clauses_t);
        clause_start=(unsigned*)malloc(sizeof(unsigned)*clauses_t);
        for (unsigned i=0; i<clauses_t; i++)
        {
                clause_len[i]=clauses[i].clause.size();
        };
        clause_DB=(int*)malloc(sizeof(int)*literals_total_in_all_clauses);
        unsigned start=0;
        for (unsigned i=0; i<clauses_t; i++)
        {
                clause_start[i]=start;
                for (unsigned j=0; j<clause_len[i]; j++)
                {
                        clause_DB[start]=clauses[i].clause[j];
                        start++;
                        assert (start<=literals_total_in_all_clauses);
                };
        };
};

// read D.Knuth's style SAT file:
void read_SAT_file (const char *fname)
{
        // for each clause
        class clause
        {
                public:
                        std::vector<int> clause;
        };

        std::vector<class clause> clauses;

        std::ifstream file(fname);
        assert (file);

        // parse list of clauses:
        unsigned literals_total_in_all_clauses=0;
        std::string str;
        unsigned var_n=1;
        while (std::getline(file, str))
        {
                if (str.size()==0)
                        continue;
                if (str[0]=='~' && str[1]==' ')
                        continue;

                // parse space-separated list of numbers:
                std::stringstream in(str);
                class clause cl;
                std::string s;
                while (std::getline(in, s, ' '))
                {
                        if (s.size()==0)
                                continue;
                        int v;
                        std::string var_name;
                        bool negated=false;

                        if (s[0]=='~')
                        {
                                var_name=s.substr(1); // https://stackoverflow.com/questions/34698270/substr-from-the-1-to-the-last-character-c
                                negated=true;
                        }
                        else
                        {
                                var_name=s;
                                negated=false;
                        };

                        if (var_name_n.find(var_name) == var_name_n.end())
                        {
                                // var not found, so add it
                                var_n_name[var_n]=var_name;
                                var_name_n[var_name]=var_n;
                                if (negated)
                                        cl.clause.push_back(-var_n);
                                else
                                        cl.clause.push_back(var_n);
                                var_n++;
                        }
                        else
                        {
                                if (negated)
                                        cl.clause.push_back(-var_name_n[var_name]);
                                else
                                        cl.clause.push_back(var_name_n[var_name]);
                        };
                        literals_total_in_all_clauses++;
                }
                clauses.push_back(cl);
        }
        vars_t=var_n-1;
        literals_t=vars_t*2;
        clauses_t=clauses.size();
        if (verbose>0)
        {
                printf ("vars_t=%d\n", vars_t);
                printf ("clauses_t=%d\n", clauses_t);
                printf ("literals_total_in_all_clauses=%d\n", literals_total_in_all_clauses);
        };
        clause_len=(unsigned*)malloc(sizeof(unsigned)*clauses_t);
        clause_start=(unsigned*)malloc(sizeof(unsigned)*clauses_t);
        for (unsigned i=0; i<clauses_t; i++)
        {
                clause_len[i]=clauses[i].clause.size();
        };
        clause_DB=(int*)malloc(sizeof(int)*literals_total_in_all_clauses);
        unsigned start=0;
        for (unsigned i=0; i<clauses_t; i++)
        {
                clause_start[i]=start;
                for (unsigned j=0; j<clause_len[i]; j++)
                {
                        clause_DB[start]=clauses[i].clause[j];
                        start++;
                        assert (start<=literals_total_in_all_clauses);
                };
        };
};

// https://stackoverflow.com/questions/874134/find-out-if-string-ends-with-another-string-in-c
inline bool ends_with(std::string const & value, std::string const & ending)
{
        if (ending.size() > value.size()) return false;
        return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
}

int main(int argc, char* argv[])
{
        if (argc<2)
        {
                std::cerr << "Usage: [-v|-vv] <filename.cnf|sat>" << std::endl;
                std::cerr << "File format determined by extension: *.cnf for DIMACS, *.sat for Knuth's style files" << std::endl;
                return 0;
        }
        for (int i=1; i<argc; i++)
        {
                if (argv[i][0]=='-')
                {
                        if (strcmp (argv[i], "-v")==0)
                                verbose=1;
                        if (strcmp (argv[i], "-vv")==0)
                                verbose=2;
                }
                else
                {
                        if (ends_with(std::string(argv[i]), ".sat"))
                                knuth_style=true;
                        else if (ends_with(std::string(argv[i]), ".cnf"))
                                knuth_style=false;
                        else
                        {
                                std::cerr << "Can't determine file type" << std::endl;
                                return 0;
                        };
                        if (knuth_style==false)
                                read_CNF_file (argv[i]);
                        else
                                read_SAT_file (argv[i]);
                        solve ();
                        assert(!"we can't be here");
                };
        };
};

