# DO NOT SUBMIT THIS FILE. 
# SUBMIT ONLY THE FILE problems.py FOR THE PROGRAMMING PART.

# Feel free to modify any code in this file since you're not submitting it.

import random
import numpy as np
import cvxpy as cp
import time
import sys
import importlib

if len(sys.argv) > 1: fname = sys.argv[1]
else: fname = "problems"
p = importlib.import_module(fname)


"""
generates a random auction with m items and n bids.

input:
- m is the number of items
- n is the number of bids
- p is the probability that a random bid includes item i
"""
def random_auction(m, n, p):
    
    bidset = set()
    itemset = set()
    # generate the bids
    while len(bidset) < n:
        b = list()
        for i in range(m):
            if random.random() <= p:
                b.append(i)
                itemset.add(i)
        b = tuple(b)
        
        if b and b not in bidset:
            # print(b)
            bidset.add(b)
    bids = [(b, random.randint(1, m)) for b in sorted(bidset)]
    return (m, bids)




def run_tests():
    random.seed(0)

    def fmt(correct, candidate):
        if np.round(correct) == np.round(candidate):
            return "\033[0;32m%s\033[0m" % candidate
        else:
            return "\033[0;31m%s\033[0m" % candidate

    correct_values = open("values.txt")
    total_nodes_searched = 0

    m = 1
    
    start_time = time.time()

    while True:
        _, correct_value = next(correct_values).split()
        correct_value = int(correct_value)

        _, bids = random_auction(m, m, m ** -random.random())
        
        c, A, b = p.to_integer_program(m, bids)        

        nodes_searched = None
        value = None

        # x = cp.Variable(A.shape[1], boolean=True)
        # value = cp.Problem(cp.Maximize(c @ x), [A @ x <= b]).solve(solver="GLPK_MI")

        # To test `to_integer_program`, uncomment the line above
        # and comment the line below. To test `IPSearchProblem`, do the oppsite.

        value, nodes_searched = p.a_star(p.IPSearchProblem(c, A, b))

        # print(m, "%d" % value)
        

        print(m, "expected %d, got %s, totals: time %.3fs" % (
            correct_value, fmt(correct_value, value), time.time() - start_time), end="")
        if nodes_searched is not None: 
            total_nodes_searched += nodes_searched
            print(", nodes %d, simplex iters %d" % (
                total_nodes_searched, p.get_total_simplex_iters()))
        else: print()

        m += 1


if __name__ == "__main__":
    run_tests()
