# SUBMIT ONLY THIS FILE FOR THE PROGRAMMING PART. DO NOT SUBMIT run.py.

import numpy as np

from queue import PriorityQueue
from dataclasses import dataclass, field
from typing import Any


"""
convert an auction into a linear program

input: 
m: the number of items
bids: see assignment handout
output: a tuple of three np.arrays (c, A, b)
    such that solving the winner-determination problem 
    for the auction is equivalent to solving the integer program
    max <c, x> st. Ax <= b, x in {0, 1}^n.
"""
def to_integer_program(m, bids):
    # you should implement this!
    pass

class IPSearchProblem:
    def __init__(self, c, A, b):
        # you should implement this!
        pass

    """returns a representation of the source node"""
    def source(self):
        # you should implement this!
        pass

    """returns True if and only if the node is a goal state"""
    def is_goal(self, node): 
        # you should implement this!
        pass

    """returns the f-value of the node.
    
    Warning: A* as written in the starter code assumes a *MAXIMIZATION* problem.
    Therefore, f should be *decreasing* as more things get assigned, 
    not *increasing*."""
    def f(self, node): 
        # you should implement this!
        pass

    """returns a list of children of the node."""
    def children(self, node):
        # you should implement this!
        pass


####################### BEGIN STARTER CODE #######################

# You are free to modify the code below. It is possible to do well
# on this assignment without modifying this code, though, and we would
# recommend that. Our solution does not modify this code.

"""
runs A* search on a *maximizing* problem.

assumption: f values are always non-increasing

input: a search problem object
"""
def a_star(problem):
    @dataclass(order=True)
    class PQItem:
        priority: float
        node: Any=field(compare=False)

    nodes_searched = 0
    q = PriorityQueue()
    source = problem.source()
    q.put(PQItem(-problem.f(source), source))
    while not q.empty():
        top = q.get()
        nodes_searched += 1
        if problem.is_goal(top.node): return -top.priority, nodes_searched
        for child in problem.children(top.node):
            q.put(PQItem(-problem.f(child), child))
    # there's always a feasible solution, so you should never get here!
    assert False

TOL = 1e-8

total_simplex_iters = 0

def get_total_simplex_iters():
    return total_simplex_iters

"""
uses the Sherman-Morrison formula
    https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula
to perform fast matrix inverse updates
"""
def simplex_feasible(c, A, b, basis):
    global total_simplex_iters
    I, A_I_inv = basis
    I = np.array(I)
    A_I_inv = np.array(A_I_inv)
    while 1:
        total_simplex_iters += 1
        # A_I_inv = np.linalg.inv(A[:,I])
        xI = A_I_inv @ b
        Ac = A_I_inv.T @ -c[I]

        # print(c.shape, A.T.shape, Ac.shape)
        cbar = -c - A.T @ Ac
        cost_idx = np.where(cbar < -TOL)[0]
        if cost_idx.size == 0: 
            obj = c[I] @ xI
            x = np.zeros_like(c)
            x[I] = xI
            return obj, x, (I, A_I_inv)
        j = np.argmin(cbar)
        d = -A_I_inv @ A[:,j]
        Ineg = np.where(d < -TOL)[0]
        if Ineg.size == 0:
            # unbounded
            return np.inf, None, (I, A_I_inv)
        red = -xI[Ineg] / d[Ineg]
        Istar = Ineg[np.argmin(red)]
        i = I[Istar]
        I[Istar] = j

        u = A[:, j] - A[:, i]
        A_I_inv -= (np.outer(A_I_inv @ u, A_I_inv[Istar, :]) / (1 + np.dot(A_I_inv[Istar, :], u)) )

"""A complete simplex algorithm, for solving LPs in standard form:
    max <c, x> s.t. Ax == b, x >= 0

Arguments:
    c {numpy array} -- a cost vector describing the objective function.
    A {numpy ndarray} -- a matrix describing the lhs of the constraints.
    b {numpy array} -- a vector describing the rhs of the constraints. Your
    solution should satisfy A @ x == b.
    basis {numpy array} -- warm-start basis from a previous run 
        (must be feasible, or else undefined behavior may occur!!)
Returns: a tuple (v, x, basis) where
    v -- the objective value (-np.inf if infeasible; +np.inf if unbounded)
    x -- an optimal solution {numpy array} (None if infeasible or unbounded)
    basis = (I, A_I_inv) -- optimal basis (i.e., A_I_inv = inverse of A[:, I]).
"""
def simplex(c, A, b, basis=None):
    assert len(b.shape) == 1
    assert len(c.shape) == 1
    assert (b.shape[0], c.shape[0]) == A.shape
    if basis is None:
        m, n = A.shape
        sgn = (b > 0).astype(float) * 2 - 1
        A = A * np.reshape(sgn, [-1, 1])
        b = b * sgn
        c1 = np.concatenate((-np.ones(m), np.zeros(n)))
        A1 = np.hstack([np.eye(m), A])
        I1 = np.arange(m)
        A_I_inv = np.linalg.inv(A1[:, I1])
        obj, x, _ = simplex_feasible(c1, A1, b, (I1, A_I_inv))
        if obj < -TOL:
            # infeasible
            return -np.inf, None, None
        I = np.where(x[m:] > TOL)[0]
        Iold = I
        if len(I) < m:
            I = list(I)
            for i in range(n):
                I.append(i)
                if np.linalg.matrix_rank(A[:, I]) != len(I): I.pop()
                if len(I) == m: break
        if len(I) < m:
            print("error: malformed LP (perhaps you have a column of all zeros?)")
            return None, None, None
        A_I_inv = np.linalg.inv(A[:, I])
    else:
        I, A_I_inv = basis
        # feasibility check
        assert (A_I_inv @ b > -TOL).all()
    return simplex_feasible(c, A, b, (I, A_I_inv))
