#!/usr/bin/python

##############################################################################
# QBF and SAT Benchmarks -- Boolean formula equivalency
# Author: Will Klieber
# See accompanying PDF documentation for an overview.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
##############################################################################

import sys
import os
import pdb
import pprint
import random
import argparse
import hashlib
from collections import namedtuple 

stop = pdb.set_trace

def die(text): 
    sys.stderr.write("Error in function '%s' at line %d:\n" % 
        (sys._getframe(1).f_code.co_name, sys._getframe(1).f_lineno))
    text = str(text)
    if (text[-1] != "\n"):
        text += "\n"
    sys.stderr.write(text + "\n")
    #stop()
    sys.exit(1)

def flatten(L):
    return (item for sublist in L for item in sublist)

def irange(start, end):  
    return range(start, end + 1)  # inclusive of end value

def unique(coll):
    hit = set()
    for x in coll:
        if x in hit:
            continue
        hit.add(x)
        yield x

def swap_keys_with_values(d):
    return dict((v, k) for (k, v) in d.iteritems())

##############################################################################

class memoized(object):
    def __init__(self, func):
        self.func = func
        self.cache = {}
    def __call__(self, *args):
        try:
            return self.cache[args]
        except KeyError:
            value = self.func(*args)
            self.cache[args] = value
            return value


class Glo(object):  # for global variables
    pass

def is_lit(x):
    # Returns true if this a literal (as opposed to a formula with logical operators).
    return type(x) == int

class Fmla(tuple):
    id_cache = {}
    idx = {}
    next_idx = 1
    rev_hash = {}
    hash = {}

    def __eq__(self, other): return (self is other)
    def __ne__(self, other): return not(self is other)

    __hash__ = object.__hash__

    def __new__(cls, *args):
        ret = Fmla.id_cache.get(args, None)
        if (ret is None):
            ret = tuple.__new__(Fmla, args)
            str_rep = str(args[0]) + str(tuple(fmla_hash(x) for x in args[1:]))
            hashval = int(hashlib.sha256(str_rep.encode('utf-8')).hexdigest(), 16)
            assert(hashval > Glo.args.vars)  # This should be true with very high probability.
            Fmla.hash[ret] = hashval
            if (hashval in Fmla.rev_hash):
                die("Two formulas both hash to the same value: %s" % (hex(hashval),))
            Fmla.rev_hash[hashval] = ret
            Fmla.id_cache[args] = ret
            assert(ret not in Fmla.idx)
            Fmla.idx[ret] = Fmla.next_idx
            Fmla.next_idx += 1
        return ret

def fmla_hash(fmla):
    if type(fmla) is Fmla:
        return Fmla.hash[fmla]
    elif type(fmla) is int:
        return fmla

def sorted_by_fmla_hash(args):
    return sorted(args, key=lambda x: fmla_hash(x))

def init_Fmla():
    global Fmla_True
    global Fmla_False
    Fmla_True = Fmla(True)
    Fmla_False = Fmla(False)
        
##############################################################################

def write_qcir(self, prefix, outf):
    if type(outf) == str:
        with open(outf, 'wt') as f:
            write_qcir(self, prefix, f)
            return
    assert(self not in (Fmla_True, Fmla_False))
    assert(not is_lit(self))
    assert(self == canonicalize_arg_order(self))
    
    nesting_info = get_subformula_print_ordering(self)
    subformulas = [fmla for (order, fmla) in sorted((order, fmla) for (fmla, order) in nesting_info.items())]
    fmla_num = calc_subfmla_nums(subformulas)
    subformulas = [x for x in subformulas if not is_lit(x)]
    subformulas = list(unique((x if fmla_num[x] > 0 else negate(x)) for x in subformulas))

    outf.write("#QCIR-G14 %i\n" % (max(fmla_num.values()),))
    comment = "# leafs=%i, vars=%i, seed=%i, tickle=%r" % \
        (Glo.args.leafs, Glo.args.vars, Glo.seed, Glo.args.tickle)
    if Glo.args.top_gate != 'and':
        comment += ", top_gate='%s'" % (Glo.args.top_gate,)
    if Glo.args.qbf and Glo.args.tickle:
        comment += ", pretickle_path=%i" % (Glo.args.pretickle_path,)
    outf.write(comment + "\n")
    for (quantifier, vars) in prefix:
        assert(quantifier in ['exists','forall','free'])
        outf.write("%s(%s)\n" % (quantifier, ", ".join(str(x) for x in vars)))
    outf.write("output(%i)\n" % (fmla_num[self],))

    hit = set()
    for subfmla in subformulas:
        (op, args) = (subfmla[0], subfmla[1:])
        for arg in args:
            assert(is_lit(arg) or (arg in hit))
        hit.add(subfmla)
        hit.add(negate(subfmla))
        arg_nums = [fmla_num[x] for x in args]
        args = ", ".join([str(x) for x in arg_nums])
        outf.write("%i = %s(%s)\n" % (fmla_num[subfmla], op, args))
    outf.close()


def write_dimacs(self, prefix, outf):
    if type(outf) == str:
        with open(outf, 'wt') as f:
            write_dimacs(self, prefix, f)
            return
    assert(self not in (Fmla_True, Fmla_False))
    assert(not is_lit(self))
    assert(self == canonicalize_arg_order(self))

    nesting_info = get_subformula_print_ordering(self)
    subformulas = [fmla for (order, fmla) in sorted((order, fmla) for (fmla, order) in nesting_info.items())]
    gate_num = calc_subfmla_nums(subformulas)
    subformulas = [x for x in subformulas if not is_lit(x)]
    subformulas = list(unique((x if gate_num[x] > 0 else negate(x)) for x in subformulas))
    
    def iter_clauses(fmla):
        (op, args) = (fmla[0], fmla[1:])
        if op == 'xor':
            assert(len(args) == 2)
            op = 'ite'
            args = (args[0], negate(args[1]), args[1])
        #
        if op == 'and':
            # (f <==> (x1 & x2 & x3)) expands to
            # (~x1 | ~x2 | ~x3 | f)  &  (~f | x1)  &  (~f | x2)  &  (~f | x3)
            yield [fmla] + [negate(x) for x in args]
            for x in args:
                yield [negate(fmla), x]
        elif op == 'or':
            # (f <==> (x1 | x2 | x3)) expands to
            # (~f | x1 | x2 | x3)   &   (~x1 | f)  &  (~x2 | f)  &  (~x3 | f)
            yield [negate(fmla)] + [x for x in args]
            for x in args:
                yield [fmla, negate(x)]
        elif op == 'not':
            return
        elif op == 'ite':
            # v <==> (c ? t : f) expands to
            # ((c & t) => v) & ((c & ~t) => ~v) & ((~c & f) => v) & ((~c & ~f) => ~v)
            # which expands to
            # (~c | ~t | v)  &  (~c | t | ~v)  &  (c | ~f | v)  &  (c | f | ~v) 
            (cond, tbra, fbra) = args
            yield [negate(cond), negate(tbra), fmla]
            yield [negate(cond), tbra, negate(fmla)]
            yield [cond, negate(fbra), fmla]
            yield [cond, fbra, negate(fmla)]
        else:
            die("Unknown operator '%s'\n" % (self,))

    p_num_clauses = [0]

    def write_clause(clause):
        clause_str = " ".join(str(gate_num[x]) for x in clause)
        hit = set()
        for lit in clause:
            lit = gate_num[lit]
            if lit in hit:
                die("Repeated literal %i in clause [%s]" % (lit, clause_str))
            if -lit in hit:
                die("Contradictory literals %i and %i in clause [%s]" % (lit, -lit, clause_str))
            hit.add(lit)
        outf.write(clause_str + "  0\n")
        p_num_clauses[0] += 1

    comment = "c leafs=%i, vars=%i, seed=%i, tickle=%r" % \
        (Glo.args.leafs, Glo.args.vars, Glo.seed, Glo.args.tickle)
    if Glo.args.top_gate != 'and':
        comment += ", top_gate='%s'" % (Glo.args.top_gate,)
    if Glo.args.qbf and Glo.args.tickle:
        comment += ", pretickle_path=%i" % (Glo.args.pretickle_path,)
    comment += "\n"
    outf.write(comment)
    outf.write((" " * 78) + "\n")  # Reserve space for header
    if len(prefix) != 0:
        gate_vars = []
        for x in subformulas:
            if is_lit(x): continue
            gate_vars.append(gate_num[x])
        if prefix[-1][0] == 'exists':
            prefix[-1][1].extend(gate_vars)
        elif prefix[-1][0] == 'forall':
            prefix.append(['exists', gate_vars])
        else:
            die("Bad quantifier '%s'\n" % (prefix[-1][0],))
    for (quantifier, vars) in prefix:
        quantifier = {'exists':'e', 'forall':'a'}[quantifier]
        outf.write("%s %s  0\n" % (quantifier, " ".join(str(x) for x in vars)))
        
    write_clause([self])
    for fmla in subformulas:
        if is_lit(fmla) or len(fmla) == 1:
            continue
        for clause in iter_clauses(fmla):
            write_clause(clause)
    num_vars = max(gate_num.values())
    try:
        outf.seek(len(comment));
    except IOError:
        die("IO ERROR: Couldn't seek to beginning of output file!")
    outf.write("p cnf %d %d " % (num_vars, p_num_clauses[0]))
    outf.close()


def calc_subfmla_nums(subformulas):
    # Assigns a variable number to each gate, for use in (Q)DIMACS and QCIR.
    fmla_num = {}
    next_fmla_num = 1
    for x in subformulas:
        if is_lit(x):
            fmla_num[x] = x
            fmla_num[-x] = -x
            next_fmla_num = max(next_fmla_num, x + 1)
    for fmla in subformulas:
        if is_lit(fmla):
            continue
        fmla_neg = negate(fmla)
        if fmla in fmla_num:
            assert(fmla_neg in fmla_num)
            continue
        fmla_num[fmla] = next_fmla_num
        fmla_num[fmla_neg] = -next_fmla_num
        #print("%5i = fmla_num[%s]" % (next_fmla_num, dol_fmla(fmla)))
        #print("%5i = fmla_num[%s]" % (-next_fmla_num, dol_fmla(fmla_neg)))
        next_fmla_num += 1
    return fmla_num


def get_subformula_print_ordering(fmla, nest_lev=None, p_next_idx=None):
    # Determines the order in which the gate definitions are printed.
    # For QCIR, gates must be defined before they are used.
    # Subformulas are ordered by (nesting_level, index), where nesting_level is
    # the tree depth of the subformula and index is the order in which the
    # subformula was encountered in a depth-first search.
    if nest_lev is None:
        nest_lev = {}
        p_next_idx = [1]
    if is_lit(fmla):
        cur_idx = p_next_idx[0]
        p_next_idx[0] += 1
        nest_lev[fmla] = (1, cur_idx)
        return nest_lev
    if len(fmla) == 1:
        die("Unexpected constant: %r\n" % (fmla,))
    if fmla in nest_lev:
        return nest_lev
    max_sub_level = 0
    for arg in fmla[1:]:
        get_subformula_print_ordering(arg, nest_lev, p_next_idx)
        max_sub_level = max(max_sub_level, nest_lev[arg][0])
    cur_idx = p_next_idx[0]
    p_next_idx[0] += 1
    nest_lev[fmla] = (max_sub_level + 1, cur_idx)
    return nest_lev

##############################################################################


def subfmla_count(fmla, count):
    if is_lit(fmla):
        return 
    count.setdefault(fmla, 0)
    count[fmla] += 1
    if count[fmla] <= 1:
        (op, args) = (fmla[0], fmla[1:])
        for arg in args:
            subfmla_count(arg, count)


def dol_fmla(self, hit=None, counts=None):
    # For debugging, print a human-readable representation of the formula.
    if is_lit(self):
        return self
    if hit is None:
        hit = {}
        counts = {}
        subfmla_count(self, counts)
    if len(self) == 1:
        if self == Fmla_False:  return "False()"
        elif self == Fmla_True: return "True()"
        else: die("Unknown op '%s'\n" % (self,))
    index = hit.get(self, None)
    if index is None:
        index = Fmla.idx[self]
        hit[self] = index
    else:
        return "$" + str(index)
    #
    (op, args) = (self[0], self[1:])
    def arg_to_text(arg):
        if is_lit(arg):
            return str(arg)
        else:
            return dol_fmla(arg, hit, counts)
    if counts[self] > 1:
        prefix = "$%d:" % (index,)
    else:
        prefix = ""
    return "%s%s(%s)" % (prefix, op,  ", ".join(arg_to_text(arg) for arg in args))


@memoized
def negate(self):
    if is_lit(self): 
        ret = -self
    else:
        (op, args) = (self[0], self[1:])
        if len(args) == 0:
            if self == Fmla_False:  return Fmla_True
            elif self == Fmla_True: return Fmla_False
            else: die("Unknown op '%s'\n" % (op,))
        if op == 'not':   ret = args[0]
        elif op == 'and': ret = Fmla('or',  *sorted_by_fmla_hash([negate(x) for x in args]))
        elif op == 'or':  ret = Fmla('and', *sorted_by_fmla_hash([negate(x) for x in args]))
        elif op == 'xor': ret = Fmla('xor', *canonicalize_xor_args([negate(args[0]), args[1]]))
        elif op == 'ite':
            (test, tbra, fbra) = args
            ret = Fmla('ite', test, negate(tbra), negate(fbra))
        else:
            assert(False)
    return ret


def rand_fmla(num_leafs, lits, exclude, operator):
    # Generate a random formula.
    if num_leafs == 1:
        if not(Glo.prefilled_leafs[Glo.cur_leaf] is None):
            ret = Glo.prefilled_leafs[Glo.cur_leaf]
            ret = random.choice([ret, -ret])
        else:
            ret = random.choice(lits)
        if ret in exclude:
            try:
                ret = random.choice(list(set(lits) - exclude))
            except IndexError:
                die("Not enough variables!")
        if isinstance(exclude, set):
            exclude.add(ret)
            exclude.add(-ret)
        Glo.cur_leaf += 1
        return ret
    child_leafs = [num_leafs // 2, 
                   num_leafs - num_leafs // 2]
    if num_leafs % 2 != 0:
        random.shuffle(child_leafs)
    if exclude == () and num_leafs <= Glo.leaf_excl_count:
        exclude = set()
    if operator == 'xor':
        child_ops = ['and', 'or']
        random.shuffle(child_ops)
    elif operator in ['and', 'or']:
        child_ops = ['xor', 'xor']
    else:
        die("Unknown operator '%s'" % (operator,))
    child = [None, None]
    for i in range(0, 2):
        child[i] = rand_fmla(child_leafs[i], lits, exclude, child_ops[i])
    return Fmla(operator, *child)


@memoized
def simplify(fmla):
    # Precondition: For making XOR gates from ITEs, fmla must be canonicalized;
    # otherwise, it might be the case that tbra != negate(negate(tbra)).
    if is_lit(fmla):
        return fmla
    (op, args) = (fmla[0], fmla[1:])
    if len(args) == 0:
        if fmla == Fmla_False:  return Fmla_False
        elif fmla == Fmla_True: return Fmla_True
        else: die("Unknown constant '%s'\n" % (op,))
    args = [simplify(arg) for arg in args]
    if op in ('and', 'or'):
        if op == 'and':  (base, negbase) = (Fmla_True, Fmla_False)
        elif op == 'or': (base, negbase) = (Fmla_False, Fmla_True)
        def expand_arg(arg):
            if arg == base:
                return ()
            if arg == negbase:
                raise DeadExc
            else:
                return (arg,)
        try:
            args = tuple(flatten([expand_arg(a) for a in args]))
        except DeadExc:
            return negbase
        if len(args) == 0:  return base
        if len(args) == 1:  return args[0]
    elif op == 'xor':
        assert(len(args) == 2)
        if args[1] in (Fmla_True, Fmla_False):
            args = [args[1], args[0]]
        if args[0] == Fmla_False:
            return args[1]
        if args[0] == Fmla_True:
            return negate(args[1])
        if args[0] == args[1]:
            return Fmla_False
        if args[0] == negate(args[1]):
            return Fmla_True
    elif op == 'ite':
        (test, tbra, fbra) = args
        if test == Fmla_True:    return tbra
        if test == Fmla_False:   return fbra
        if tbra == fbra:         return tbra
        if tbra == Fmla_True:    return simplify(canonicalize_arg_order(Fmla('or',  test,         fbra)))
        if fbra == Fmla_True:    return simplify(canonicalize_arg_order(Fmla('or',  negate(test), tbra)))
        if tbra == Fmla_False:   return simplify(canonicalize_arg_order(Fmla('and', negate(test), fbra)))
        if fbra == Fmla_False:   return simplify(canonicalize_arg_order(Fmla('and', test,         tbra)))
        if fbra == negate(tbra): return simplify(canonicalize_arg_order(Fmla('xor', test, fbra)))
        assert(negate(fbra) != negate(negate(tbra)))
    else:
        die("Unknown operator: '%s'\n" % op)
    return Fmla(op, *args)


@memoized
def to_nnf(fmla):
    # Converts to negation normal form (NNF)
    if is_lit(fmla):
        return fmla
    (op, args) = (fmla[0], fmla[1:])
    if len(args) == 0:
        return fmla
    args = [to_nnf(arg) for arg in args]
    if op in ('and', 'or'):
        return Fmla(op, *args)
    elif op == 'xor':
        return to_nnf(Fmla('ite', args[0], negate(args[1]), args[1]))
    elif op == 'ite':
        (sel, y, z) = args
        return Fmla('and', Fmla('or', negate(sel), y), Fmla('or', sel, z))
    else:
        die("Unknown operator: '%s'\n" % op)


def init_eval_cache():
    return {Fmla_True: True, Fmla_False: False}


def eval_fmla(fmla, asgn, cache):
    if is_lit(fmla):
        if fmla < 0:
            return not asgn[-fmla]
        else:
            return asgn[fmla]
    try:
        return cache[fmla]
    except:
        pass
    (op, args) = (fmla[0], fmla[1:])
    args = [eval_fmla(arg, asgn, cache) for arg in args]
    if len(args) == 0:
        if self == Fmla_False:  return Fmla_False
        elif self == Fmla_True: return Fmla_True
        else: die("Unknown constant '%s'\n" % (op,))
    if op == 'not':
        return not(args[0])
    elif op == 'and':
        if any([(arg == False) for arg in args]):
            ret = False 
        else: 
            ret = True
    elif op == 'or':
        if any([(arg == True) for arg in args]):
            ret = True
        else: 
            ret = False
    elif op == 'xor':
        ret = args[0] ^ args[1]
    elif op == 'ite':
        if args[0] == True:
            ret = args[1]
        else:
            ret = args[2]
    else:
        die("Unknown operation: '%s'\n" % op)
    cache[fmla] = ret
    return ret


def pretickle(fmla, asgn, eval_cache):
    # Modifies the formula so that tickle can guarantee to change the formula's
    # truth value (under asgn) by changing a single leaf.
    def evalf(subfmla):
        return eval_fmla(subfmla, asgn, eval_cache)
    if is_lit(fmla):
        return fmla
    (op, args) = (fmla[0], fmla[1:])
    args = list(args)
    fmla_val = evalf(fmla)
    assert(fmla_val == True or fmla_val == False)
    itick = Fmla.hash[fmla] & 1  # last-minute change; get a random bit without perturbing the PRNG state.
    if op in ('and', 'or'):
        assert(len(args) == 2);
        iflip = None
        if (op=='and' and fmla_val == False) or (op=='or' and fmla_val == True):
            if evalf(args[0]) == fmla_val and evalf(args[1]) == fmla_val:
                iflip = random.choice([0,1])
                args[iflip] = negate(args[iflip])
                itick = 1 - iflip
            elif evalf(args[0]) == fmla_val and evalf(args[1]) != fmla_val:
                itick = 0
            elif evalf(args[0]) != fmla_val and evalf(args[1]) == fmla_val:
                itick = 1
            else: 
                assert(False)
        if Glo.args.pretickle_path:
            args[itick] = pretickle(args[itick], asgn, eval_cache)
            ret = Fmla(op, *args)
            assert(ret not in Glo.tickle_nodes)
            Glo.tickle_nodes[ret] = itick
        else:
            args = list([pretickle(x, asgn, eval_cache) for x in args])
            ret = Fmla(op, *args)
        return ret
    if op == 'xor':
        if Glo.args.pretickle_path:
            args[itick] = pretickle(args[itick], asgn, eval_cache)
            ret = Fmla(op, *args)
            assert(ret not in Glo.tickle_nodes)
            Glo.tickle_nodes[ret] = itick
        else:
            args = list([pretickle(x, asgn, eval_cache) for x in args])
            ret = Fmla(op, *args)
        return ret
    #elif op == 'ite':
    #    args = list([pretickle(x, asgn, eval_cache) for x in args])
    #    return Fmla(op, args[0], args[1], args[2])
    else:
        die("Unknown operation: '%s'\n" % op)
    return ret


def tickle(fmla, asgn, eval_cache, stack):
    # Flips the truth value of fmla under asgn by changing one of its leafs.
    # Precondition: fmla has been pretickled.
    def evalf(subfmla):
        return eval_fmla(subfmla, asgn, eval_cache)
    if is_lit(fmla):
        return -fmla
    stack.append(fmla)
    (op, args) = (fmla[0], fmla[1:])
    args = list(args)
    fmla_val = evalf(fmla)
    assert(fmla_val == True or fmla_val == False)
    if op in ('and', 'or'):
        assert(len(args) == 2);
        if (op=='and' and fmla_val == False) or (op=='or' and fmla_val == True):
            j = 0
            for i in range(0, len(args)):
                if evalf(args[i]) == fmla_val:
                    if Glo.args.pretickle_path:
                        assert(Glo.tickle_nodes[fmla] == i)
                    args[i] = tickle(args[i], asgn, eval_cache, stack)
                    j += 1
                    assert(j == 1)  # guaranteed by pretickle
            return Fmla(op, args[0], args[1])
        else:
            iflip = random.choice([0,1])
            if Glo.args.pretickle_path:
                iflip = Glo.tickle_nodes[fmla]
            args[iflip] = tickle(args[iflip], asgn, eval_cache, stack)
            return Fmla(op, args[0], args[1])
    elif op == 'xor':
        iflip = random.choice([0,1])
        if Glo.args.pretickle_path:
            iflip = Glo.tickle_nodes[fmla]
        args[iflip] = tickle(args[iflip], asgn, eval_cache, stack)
        return Fmla(op, args[0], args[1])
    #elif op == 'ite':
    #    if evalf(args[0]) == True:
    #        args[1] = tickle(args[1], asgn, eval_cache, stack)
    #    else:
    #        args[2] = tickle(args[2], asgn, eval_cache, stack)
    #    return Fmla(op, args[0], args[1], args[2])
    else:
        die("Unknown operation: '%s'\n" % op)
    return ret


def vars_in_fmla(fmla, var_set=None, hit=None):
    if (var_set is None):
        var_set = set()
        hit = set()
    if fmla in hit: 
        return var_set
    if is_lit(fmla):
        var_set.add(abs(fmla))
        return var_set
    args = fmla[1:]
    for arg in args:
        vars_in_fmla(arg, var_set, hit)
    return var_set


def qbf_split(fmla, asgn, stack):
    # Replaces each leaf x with (e ? x : y), where e is a fresh outermost
    # existential variable and y is a literal (of a variable from the original
    # formula).
    assert(stack[-1] == fmla)
    old_stack = list(stack)
    if is_lit(fmla):
        if not Glo.split_leafs[Glo.cur_leaf]:
            Glo.cur_leaf += 1
            return fmla
        Glo.cur_leaf += 1
        if (len(stack) > 3):
            verboten_vars = vars_in_fmla(stack[-3])
            if Glo.args.vars - len(verboten_vars) < Glo.args.leaf_excl_count:
                verboten_vars = vars_in_fmla(stack[-2])
        else:
            verboten_vars = vars_in_fmla(stack[0])
        if Glo.args.vars - len(verboten_vars) < 4:
            verboten_vars = [abs(fmla)]
        avail_vars = [x for x in irange(1, Glo.args.vars) if x not in verboten_vars]
        choice_b = random.choice(avail_vars)
        choice_b = random.choice([choice_b, -choice_b])
        sel = Glo.next_outer_var
        Glo.next_outer_var += 1
        branches = sorted([fmla, choice_b])
        return (Fmla('ite', sel, branches[0], branches[1]))
    (op, args) = (fmla[0], fmla[1:])
    new_args = []
    for arg in args:
        stack.append(arg)
        new_args.append(qbf_split(arg, asgn, stack))
        assert(stack[-1] == arg)
        stack.pop()
    return Fmla(op, *new_args)


def shuffle_arg_order(fmla):
    if is_lit(fmla):
        return fmla
    (op, old_args) = (fmla[0], fmla[1:])
    new_args = list(shuffle_arg_order(x) for x in old_args)
    if op != 'ite':
        random.shuffle(new_args)
    return Fmla(op, *new_args)
    
def canonicalize_xor_args(cur_args):
    alt_args = [negate(x) for x in cur_args]
    alt_args = sorted_by_fmla_hash(alt_args)
    cur_args = sorted_by_fmla_hash(cur_args)
    if (fmla_hash(Fmla('xor', *alt_args)) < fmla_hash(Fmla('xor', *cur_args))):
        return alt_args
    else:
        return cur_args
    
@memoized
def canonicalize_arg_order(fmla):
    if is_lit(fmla):
        return fmla
    (op, old_args) = (fmla[0], fmla[1:])
    new_args = list(canonicalize_arg_order(x) for x in old_args)
    if op != 'ite':
        new_args = sorted_by_fmla_hash(new_args)
    if op == 'xor':
        new_args = canonicalize_xor_args(new_args)
    return Fmla(op, *new_args)
    
    

def to_all_ites(fmla):
    if is_lit(fmla):
        return fmla
    (op, args) = (fmla[0], fmla[1:])
    orig_args = list(args)
    args = [to_all_ites(arg) for arg in args]
    if op=='and':
        return Fmla('ite', args[0], args[1], Fmla_False)
    elif op=='or':
        return Fmla('ite', args[0], Fmla_True, args[1])
    elif op=='xor':
        #return Fmla('ite', args[0], to_all_ites(negate(orig_args[1])), args[1])
        return Fmla('ite', args[0], negate(args[1]), args[1])
    elif op=='ite':
        return Fmla('ite', args[0], args[1], args[2])
    else:
        die("Unknown operation: '%s'\n" % op)

RefactorCtx = namedtuple('RefactorCtx', ['chain_len'])
refactor_init_ctx = RefactorCtx(0)

@memoized    
def refactor(fmla, t_outer, f_outer, ctx):
    # Refactors
    # "(sel ? t_inner : f_inner) ? t_outer : f_outer" to
    # "(sel ? (t_inner ? t_outer : f_outer) : (f_inner ? t_outer : f_outer)".
    if is_lit(fmla):
        return Fmla('ite', fmla, t_outer, f_outer)
    (op, args) = (fmla[0], fmla[1:])
    if len(args) == 0:
        if fmla == Fmla_True:    return t_outer
        elif fmla == Fmla_False: return f_outer
        else: die("Unknown constant '%s'\n" % (op,))
    if op=='ite':
        (sel, t_inner, f_inner) = args
        (sel, t_in_old, f_in_old) = args
        is_andor = ((t_inner in [Fmla_True, Fmla_False]) or 
                    (f_inner in [Fmla_True, Fmla_False]))
        chain_len = ctx.chain_len
        is_trivial = (is_lit(t_inner) and is_lit(f_inner) and 
            len(set([t_outer, f_outer]) - set([Fmla_True, Fmla_False])) == 0)
        t_inner = refactor(t_inner, t_outer, f_outer, ctx._replace(chain_len = chain_len + 1))
        f_inner = refactor(f_inner, t_outer, f_outer, ctx._replace(chain_len = chain_len + 1))
        if len(refactor.decs) == 0:
            refactor.decs = [0,0,1,1]
            random.shuffle(refactor.decs)
        if chain_len <= 2 and (chain_len == 0 or is_trivial or refactor.decs.pop()):
            ret = refactor(sel, t_inner, f_inner, ctx._replace(chain_len = chain_len + 1))
            refactor.count += 1
        else:
            ret = Fmla('ite', refactor(sel, Fmla_True, Fmla_False, ctx._replace(chain_len=0)), t_inner, f_inner)
        #if simplify(ret) != simplify(Fmla('ite', Fmla('ite', sel, t_in_old, f_in_old), t_outer, f_outer)):
        #    refactor.count += 1
        return ret
    else:
        die("Unsupported operator: '%s'\n" % op)

refactor.count = 0
refactor.decs = []

def parse_args():
    if len(sys.argv) < 2:
        print("Usage: benchgen.py qbf NumLeafs NumVars -o OutFile [options]\n" +
              "       benchgen.py sat NumLeafs NumVars -o OutFile [options]\n" +
              "       benchgen.py -h")
        sys.exit(1)
    parser = argparse.ArgumentParser(epilog=
        "Example: benchgen.py QBF 60 30 -o out.qcir\n" +
        "Example: benchgen.py SAT 400 20 -o out.cnf\n",
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("prob_type", choices=['qbf','sat'], type=str.lower, help="problem type: QBF or SAT")
    parser.add_argument("leafs", type=int, help="number of leafs")
    parser.add_argument("vars",  type=int, help="number of variables")
    parser.add_argument("-s", "--seed", type=int, help="seed for random number generator")
    parser.add_argument("-t", "--tickle", type=int, metavar='N', 
        help="Flips the polarity of a leaf; see documentation PDF for details. " +
        "For SAT: produces an instance with N equivalences AND'd together.")
    parser.add_argument("-o", type=str, dest="outfile", required=True, 
        help="output file (if dir, automatically generate filename based on args)")
    parser.add_argument("--reclim", type=int, default=2000, help="recursion limit " + 
        "(increase this if Python dies with 'RuntimeError: maximum recursion depth exceeded')")
    parser.add_argument("--qfrac", type=str, help="fraction of leafs to split on", default="1.0")
    parser.add_argument("--qneg", action="store_true", 
        help="negate QBF instance (for CNF, this results in 2 quantifier blocks instead of 3)")
    parser.add_argument("--fmt", type=str, help="output file format ('qcir', 'qdimacs', 'dimacs')")
    parser.add_argument("--qdimacs", action="store_true", help="same as '--fmt qdimacs'")
    parser.add_argument("--no-shuffle", action="store_true", dest="no_shuffle", help="don't shuffle the arg order")
    parser.add_argument("--no-refactor", action="store_true", dest="no_refactor", help="(for debugging)")
    parser.add_argument("--pretickle-path", type=int, dest="pretickle_path", help="pretickle a path (versus the whole tree)")
    parser.add_argument("--leaf-excl-count", type=int, metavar='N', dest="leaf_excl_count", 
        help="no variable may occur more than once in a subtree with N or fewer leafs")
    parser.add_argument("--top-gate", choices=['and','or','xor'], default='and', metavar='GATE_TYPE', 
        dest="top_gate", help="gate at the root of the tree")
    #
    args = parser.parse_args()
    #
    if args.pretickle_path is None:
        args.pretickle_path = 1
    #
    assert(args.leafs > 0)
    if args.qdimacs:
        if args.fmt and args.fmt != 'qdimacs':
            die("Inconsistent command-line arguments for '--fmt' and '--qdimacs'!")
        args.fmt = 'qdimacs'
    return args


def main():
    args = parse_args()
    Glo.args = args
    init_Fmla()
    seed = args.seed
    if seed is None:
        seed = random.getrandbits(30)
    Glo.seed = seed
    random.seed(seed)
    random.seed(random.getrandbits(60) ^ (args.vars + args.leafs * 100000))

    if os.path.isdir(args.outfile):
        if args.prob_type == 'qbf':
            filename = "klieber2017%s-%03d-%02d" % (args.prob_type[0], args.leafs, args.vars)
        else:
            filename = "klieber2017%s-%04d-%03d" % (args.prob_type[0], args.leafs, args.vars)
        if (args.tickle):
            filename += "-t" + str(args.tickle)
        else:
            filename += "-eq"
        if (args.qneg):
            filename += "-neg"
        if (args.qfrac != "1.0"):
            filename += "-qfrac" + args.qfrac
        if Glo.seed != 1:
            filename += "-s" + str(Glo.seed)
        if (args.fmt is None):
            args.fmt = {'qbf':'qcir', 'sat':'dimacs'}[args.prob_type]
            file_ext = {'qbf':'qcir', 'sat':'cnf'}[args.prob_type]
        filename += "." + file_ext
        args.outfile = os.path.join(args.outfile, filename)
        print("# %s" % (args.outfile,))

    Fmla.next_idx = args.vars + 10
    Glo.first_outer_var = args.vars + 1
    Glo.next_outer_var = Glo.first_outer_var
    Glo.leaf_excl_count = 8
    if args.leaf_excl_count:
        Glo.leaf_excl_count = args.leaf_excl_count

    if args.vars < Glo.leaf_excl_count:
        die("Too few vars; minimum is %d.\n" % (Glo.leaf_excl_count,))
    
    if args.fmt is None:
        args.fmt = args.outfile.split('.')[-1]
        if args.prob_type == 'sat' and args.fmt == 'cnf':
            args.fmt = 'dimacs'

    valid_fmts = {
        "qbf": ['qcir','qdimacs'], 
        "sat": ['dimacs']
        }[args.prob_type]
    if args.fmt not in valid_fmts:
        if len(valid_fmts) == 1:
            die("Invalid value for '--fmt' option: must be " + 
                "%r for %s instances." % (valid_fmts[0], args.prob_type))
        else:
            die("Invalid value for '--fmt' option: must be " + 
                "one of %r for %s instances." % (valid_fmts, args.prob_type))

    if args.seed is None:
        print("# seed: %d" % (seed,))

    args.qbf = (args.prob_type == 'qbf')
    if args.qbf:
        Fmla.next_idx += args.leafs
    lits = tuple(flatten((i, -i) for i in irange(1, args.vars)))
    Glo.cur_leaf = 0
    Glo.prefilled_leafs = [None] * args.leafs
    if args.qbf:
        if args.leafs > args.vars:
            for (leaf, var) in (zip(random.sample(range(0, args.leafs), args.vars), irange(1, args.vars))):
                Glo.prefilled_leafs[leaf] = var
    
        Glo.split_leafs = [False] * args.leafs
        for leaf in random.sample(range(0, args.leafs), int(float(args.qfrac) * args.leafs)):
            Glo.split_leafs[leaf] = True

    sys.setrecursionlimit(args.reclim)

    asgn = tuple([random.choice([False,True]) for i in range(0, args.vars + 1)])
    #sys.stdout.write("# asgn: " + ", ".join(str(i if asgn[i]==True else -i) for i in irange(1, args.vars)) + "\n")


    fmla = rand_fmla(args.leafs, lits, (), args.top_gate)

    if args.qbf:
        if args.vars <= args.leafs:
            assert(len(vars_in_fmla(fmla)) == args.vars)
        split_fmla = fmla
        if args.tickle:
            Glo.tickle_nodes = {}
            fmla = pretickle(fmla, asgn, init_eval_cache())
            if args.tickle != 1:
                die("For QBF, the only accepted value for '--tickle' is 1.")
            tickled_fmla = tickle(fmla, asgn, init_eval_cache(), [])
            split_fmla = tickled_fmla
        Glo.cur_leaf = 0
        split_fmla = qbf_split(split_fmla, asgn, [split_fmla])
        shuffled = fmla
        if not args.no_shuffle:
            shuffled = shuffle_arg_order(shuffled)
        if not args.no_refactor:
            refactored = refactor(to_all_ites(shuffled), Fmla_True, Fmla_False, refactor_init_ctx)
        else:
            refactored = shuffled
        #print("# Refactor count: %d" % (refactor.count,))
        refactored = simplify(canonicalize_arg_order(refactored))
        refactored = canonicalize_arg_order(refactored)
        final_fmla = Fmla('ite', split_fmla, refactored, negate(refactored))
        canonicalize_arg_order(fmla)
        final_fmla = canonicalize_arg_order(final_fmla)
        #
        #print(dol_fmla(fmla))
        #print("#")
        #print(dol_fmla(tickled_fmla))
        #print(dol_fmla(canonicalize_arg_order(fmla)))
        #print(dol_fmla(Fmla('xor', fmla, canonicalize_arg_order(fmla))))
        #print("#")
        #print(dol_fmla(split_fmla))
        #print("#")
        #print(dol_fmla(refactored))
        #print("#")
        #print(dol_fmla(final_fmla))
        #
        e_vars = range(Glo.first_outer_var, Glo.next_outer_var)
        u_vars = irange(1, Glo.args.vars)
        if not(args.qneg):
            prefix = [['exists', e_vars], ['forall', u_vars]]
        else:
            prefix = [['forall', e_vars], ['exists', u_vars]]
            final_fmla = negate(final_fmla)
        if args.fmt == 'qcir':
            write_qcir(final_fmla, prefix, args.outfile)
        elif args.fmt == 'qdimacs':
            write_dimacs(canonicalize_arg_order(to_nnf(final_fmla)), prefix, args.outfile)
        else:
            die("Unknown QBF format '%s'\n" % (args.fmt,))
        return
    else:
        if args.tickle:
            def gen_equiv():
                Glo.cur_leaf = 0
                fmla = rand_fmla(args.leafs, lits, (), args.top_gate)
                Glo.tickle_nodes = {}
                fmla = pretickle(fmla, asgn, init_eval_cache())
                tickled = tickle(fmla, asgn, init_eval_cache(), [])
                if not args.no_shuffle:
                    tickled = shuffle_arg_order(tickled)
                if not args.no_refactor:
                    refactored = refactor(to_all_ites(tickled), Fmla_True, Fmla_False, refactor_init_ctx)
                else:
                    refactored = tickled
                refactored = simplify(canonicalize_arg_order(refactored))
                return Fmla('xor', fmla, refactored)
            equivs = []
            for i in range(0, args.tickle):
                equivs.append(gen_equiv())
            final_fmla = Fmla('and', *equivs)
        else:
            shuffled = fmla
            if not args.no_shuffle:
                shuffled = shuffle_arg_order(shuffled)
            if not args.no_refactor:
                refactored = refactor(to_all_ites(shuffled), Fmla_True, Fmla_False, refactor_init_ctx)
            else:
                refactored = shuffled
            #print("# Refactor count: %d" % (refactor.count,))
            refactored = simplify(canonicalize_arg_order(refactored))
            final_fmla = Fmla('xor', fmla, refactored)
        #print(dol_fmla(fmla))
        #print("#")
        #print(dol_fmla(refactored))
        #print("#")
        #print(dol_fmla(final_fmla))
        write_dimacs(canonicalize_arg_order(to_nnf(final_fmla)), [], args.outfile)

if __name__ == "__main__":
    main()

