'''
This code was written in 15-113 lecture.
It is only for demonstrational purposes, and may contain
dubious style and even an occasional bug.
'''

from collections import deque
import copy, random, heapq, time

class StateCollection:
    def __init__(self): raise NotImplementedError()
    def isEmpty(self): raise NotImplementedError()
    def add(self, value): raise NotImplementedError()
    def getNext(self): raise NotImplementedError()

class Heap(StateCollection):
    def __init__(self): self.heap = [ ]
    def isEmpty(self): return self.heap == [ ]
    def add(self, value): heapq.heappush(self.heap, value)
    def getNext(self): return heapq.heappop(self.heap)

class StackOrQueue(StateCollection):
    def __init__(self): self.deque = deque()
    def isEmpty(self): return len(self.deque) == 0
    def add(self, value): self.deque.append(value)
    def getNext(self): raise NotImplementedError()

class Stack(StackOrQueue):
    def getNext(self): return self.deque.pop()

class Queue(StackOrQueue):
    def getNext(self): return self.deque.popleft()

class PuzzleState:
    def __init__(self): raise NotImplementedError()
    def isSolution(self): raise NotImplementedError()
    def getChildren(self): raise NotImplementedError()
    def __hash__(self): raise NotImplementedError()
    def __eq__(self, other): raise NotImplementedError()

class NPuzzleState(PuzzleState):
    stateCount = None

    def __init__(self, puzzle, parent=None, move=None):
        self.puzzle = puzzle  # [ ['A', '-'],
                              #   ['C', 'B'] ]
        self.N = len(puzzle[0])
        self.emptyRow, self.emptyCol = self.findEmptyCell()
        self.parent = parent
        if self.parent:
            self.moves = self.parent.moves + [move]
            NPuzzleState.stateCount += 1
        else:
            self.moves = [ ]
            NPuzzleState.stateCount = 1
        admissibleHeuristicToGoal = self.getAdmissibleHeuristicToGoal()
        self.totalToGoal = len(self.moves) + admissibleHeuristicToGoal

    def __lt__(self, other):
        if not isinstance(other, NPuzzleState):
            raise Exception("Are you f'ing kidding?!?")
        return (self.totalToGoal < other.totalToGoal)

    def getAdmissibleHeuristicToGoal(self):
        # return sum of Manhattan Distances for each letter
        # to its final/goal location
        distances = [ ]
        for row in range(self.N):
            for col in range(self.N):
                piece = self.puzzle[row][col]
                if (piece != '-'):
                    offset = ord(piece) - ord('A')
                    targetRow, targetCol = divmod(offset, self.N)
                    distances.append(abs(row - targetRow) + abs(col - targetCol))
        return sum(distances)

    def __repr__(self):
        return '\n'.join([' '.join(row) for row in self.puzzle])

    def __hash__(self):
        return hash(tuple([tuple(v) for v in self.puzzle]))

    def __eq__(self, other):
        return (isinstance(other, NPuzzleState) and
                (self.puzzle == other.puzzle))

    def findEmptyCell(self):
        for row in range(self.N):
            for col in range(self.N):
                if self.puzzle[row][col] == '-':
                    return (row, col)
        raise Exception('WTF?!?')

    def isSolution(self):
        for offset in range(self.N**2 - 1):
            row, col = divmod(offset, self.N)
            targetLetter = chr(ord('A') + offset)
            if self.puzzle[row][col] != targetLetter:
                return False
        return True

    def getChildren(self):
        children = [ ]
        for drow, dcol, move in [(-1, 0, 'U'), (+1, 0, 'D'),
                                 (0, -1, 'L'), (0, +1, 'R')]:
            newRow, newCol = self.emptyRow + drow, self.emptyCol + dcol
            if ((newRow >= 0) and (newRow < self.N) and
                (newCol >= 0) and (newCol < self.N)):
                newPuzzle = copy.deepcopy(self.puzzle)
                newPuzzle[self.emptyRow][self.emptyCol] = newPuzzle[newRow][newCol]
                newPuzzle[newRow][newCol] = '-'
                child = NPuzzleState(newPuzzle, self, move)
                children.append(child)
        return children

    @staticmethod
    def makeRandomPuzzleState(N, randomMovesCount):
        solutionPuzzle = NPuzzleState.makeSolutionPuzzle(N)
        state = NPuzzleState(solutionPuzzle)
        # @TODO: add a set and avoid duplicates
        for _ in range(randomMovesCount):
            state = random.choice(state.getChildren())
        return NPuzzleState(state.puzzle)

    @staticmethod
    def makeSolutionPuzzle(N):
        puzzle = [(['-'] * N) for _ in range(N)]
        for offset in range(N**2 - 1):
            row, col = divmod(offset, N)
            targetLetter = chr(ord('A') + offset)
            puzzle[row][col] = targetLetter
        return puzzle

def solve(puzzleState, stateCollection):
    if puzzleState.isSolution():
        return puzzleState
    else:
        seenStates = {puzzleState}
        stateCollection.add(puzzleState)
        while not stateCollection.isEmpty():
            state = stateCollection.getNext()
            for childState in state.getChildren():
                if childState not in seenStates:
                    if childState.isSolution():
                        return childState
                    else:
                        stateCollection.add(childState)
                        seenStates.add(childState)
        return None

def dfs(startState):
    return solve(startState, Stack())

def bfs(startState):
    return solve(startState, Queue())

def astar(startState):
    return solve(startState, Heap())

def testSolver(startState, solverFn):
    print(f'\nTesting {solverFn.__name__}:')
    NPuzzleState.stateCount = 1 # just the start state
    time0 = time.time()
    solution = solverFn(startState)
    time1 = time.time()
    print(solution.moves, NPuzzleState.stateCount)
    print(f'Time = {round(time1-time0,2)} seconds')

random.seed(7)
testDFS = False
if testDFS:
    startState = NPuzzleState.makeRandomPuzzleState(N=3, randomMovesCount=15)
else:
    startState = NPuzzleState.makeRandomPuzzleState(N=5, randomMovesCount=30)    
print(startState)
if testDFS: testSolver(startState, dfs)
testSolver(startState, bfs)
testSolver(startState, astar)