#!/usr/bin/python

import sys
import getopt
import random
import time

# Model of heat conduction in two dimensions

def usage(name):
    print("%s [-h] [-v] [-V] [-n NROW] [-c NCOL] [-b BMODE] [-e EPS] [-s STEPS] [-C FILE]" % name)
    print("  -h       Print this message")
    print("  -v       Verbose mode")
    print("  -V       Visualize as heat map")
    print("  -n NROW  Set number of rows in grid")
    print("  -c NCOL  Set number of columns in grid")
    print("  -b BMODE Boundary conditions (h:horizontal, c:corner, d:diagonal, r:random)")
    print("  -e EPS   Set convergence limit")
    print("  -s STEPS Set limit on maximum number of steps")
    print("  -C FILE  Capture final image in FILE (.png or .jpg)")
    sys.exit(0)

specialImported = False

timeDelta = 0.0

def importSpecial():
    global specialImported, hshow
    if not specialImported:
        import hshow
        specialImported = True

class BoundaryCondition:
    northRow = []
    eastColumn = []
    southRow = []
    westColumn = []
    modeNames = {'h':"Horizontal", 'c':"Corner", 'd':"Diagonal", 'r':"Random" }

    def __init__(self, nrow, ncol, mode):
        if ncol is None:
            ncol = nrow
        if mode == 'h':
            dc = 1.0/(ncol+1)
            self.northRow = [dc * (c+1) for c in range(ncol)]
            self.eastColumn = [1.0 for r in range(nrow)]
            self.southRow = [dc * (c+1) for c in range(ncol)]
            self.westColumn = [0.0 for r in range(nrow)]
        elif mode == 'd':
            dc = 1.0/(ncol+1)
            dr = 1.0/(nrow+1)
            self.northRow = [dc * (c+1) for c in range(ncol)]
            self.eastColumn = [1.0 - dr * (r+1) for r in range(nrow)]
            self.southRow = [1.0 - dc * (c+1) for c in range(ncol)]
            self.westColumn = [dr * (r+1) for r in range(nrow)]
        elif mode == 'c':
            dc = 1.0/(ncol+1)
            dr = 1.0/(nrow+1)
            self.northRow = [1.0 - dc * (c+1) for c in range(ncol)]
            self.eastColumn = [0.0 for r in range(nrow)]
            self.westColumn = [1.0 - dr * (r+1) for r in range(nrow)]
            self.southRow = [0.0 for c in range(ncol)]
        elif mode == 'r':
            self.northRow = [random.random() for c in range(ncol)]
            self.eastColumn = [random.random() for r in range(nrow)]
            self.southRow = [random.random() for c in range(ncol)]
            self.westColumn = [random.random() for r in range(nrow)]
        else:
            print("Invalid boundary condition mode %s" % mode)
            sys.exit(1)

class Grid:
    nrow = 30
    ncol = 30
    temperature = []
    conductivity = 0.80
    epsilon = 0.0001
    verbose = True
    visualizer = None

    def __init__(self, nrow = None, ncol = None, epsilon = None, verbose = True, visualize = True):
        if nrow is not None:
            self.nrow = nrow
            # By default, make square grid
            self.ncol = nrow
        if ncol is not None:
            self.ncol = ncol
        if epsilon is not None:
            self.epsilon = epsilon
        if verbose is not None:
            self.verbose = verbose
        if visualize:
            importSpecial()
            self.visualizer = hshow.Display(self.nrow+2, self.ncol+2)
        self.initialize()
        if self.visualizer is not None:
            self.visualizer.setColors(self.temperature)
        
    # How large is the temperature array
    def arraySize(self):
        return (self.nrow + 2) * (self.ncol + 2)

    # Grid rows numbered from 0 to nrow-1
    def rowIndex(self, r):
        # Make room for boundary row
        return (r+1) * (self.ncol+2)

    def locate(self, r, c):
        return self.rowIndex(r) + (c+1)
        
    def setTemperature(self, r, c, val):
        self.temperature[self.locate(r,c)] = val

    def getTemperature(self, r, c):
        return self.temperature[self.locate(r,c)]
        
    def initialize(self, boundaryConditions = None):
        self.temperature = [0.0 for idx in range(self.arraySize())]
                  
        if boundaryConditions is not None:
            nr = boundaryConditions.northRow
            ec = boundaryConditions.eastColumn
            sr = boundaryConditions.southRow
            wc = boundaryConditions.westColumn
            
            for r in range(self.nrow):
                self.setTemperature(r, -1,        wc[r])
                self.setTemperature(r, self.ncol, ec[r])
            for c in range(self.ncol):
                self.setTemperature(-1,        c, nr[c])
                self.setTemperature(self.nrow, c, sr[c])

        if self.visualizer is not None:
            self.visualizer.setColors(self.temperature)

            
    def step(self):
        newTemperature = [self.temperature[idx] for idx in range(self.arraySize())]
        maxDiff = 0
        for r in range(self.nrow):
            for c in range(self.ncol):
                ov = self.getTemperature(r,   c)
                nv = self.getTemperature(r-1, c)
                ev = self.getTemperature(r, c+1)
                sv = self.getTemperature(r+1, c)
                wv = self.getTemperature(r, c-1)
                newv = 0.25 * self.conductivity * (nv+ev+sv+wv) + (1-self.conductivity) * ov
                diff = abs(newv-ov)
                maxDiff = max(maxDiff, diff)
                newTemperature[self.locate(r,c)] = newv
        self.temperature = newTemperature
        if self.visualizer is not None:
            self.visualizer.setColors(self.temperature)
        return maxDiff
            
    def run(self, maxSteps):
        diff = 0
        for step in range(maxSteps):
            diff = self.step()
            if self.verbose:
                print("Step %d. Max. difference = %.6f" % (step, diff))
            if diff < self.epsilon:
                print("Terminated after %d steps. Max. difference = %.6f" % (step, diff))
                return
            if self.visualizer is not None:
                time.sleep(timeDelta)
        print("Failed to Terminate after %d steps. Max. difference = %.6f" % (maxSteps, diff))
        
    def showRow(self, r):
        if r < 0 or r == self.nrow:
            svals = ["      "] + ["%.3f " % self.getTemperature(r, c) for c in range(self.ncol)] + ["     "]
        else:
            svals = ["%.3f " % self.getTemperature(r, c-1) for c in range(self.ncol+2)]
        print("".join(svals))

    def show(self):
        for r in range(self.nrow+2):
            self.showRow(r-1)
                

def solve(nrow, ncol, mode, epsilon, maxSteps, verbose, visualize, cfile):
    g = Grid(nrow = nrow, ncol = ncol, epsilon = epsilon, verbose = verbose, visualize = visualize)
    bc = BoundaryCondition(nrow, ncol, mode)
    g.initialize(bc)
    g.run(maxSteps)
    if verbose:
        g.show()
    if cfile is not None:
        g.visualizer.capture(cfile)


def run(name, args):
    nrow = 30
    ncol = None
    maxSteps = 1000
    epsilon = None
    verbose = False
    visualize = False
    mode = 'h'
    cfile = None
    
    optlist, args = getopt.getopt(args, "hvVn:c:b:e:s:C:")
    for (opt, val) in optlist:
        if opt == '-h':
            usage(name)
        elif opt == '-v':
            verbose = True
        elif opt == '-V':
            visualize = True
        elif opt == '-n':
            nrow = int(val)
        elif opt == '-c':
            ncol = int(val)
        elif opt == '-b':
            mode = val
        elif opt == '-e':
            epsilon = float(val)
        elif opt == '-s':
            maxSteps = int(val)
        elif opt == '-C':
            cfile = val

    if ncol is None:
        ncol = nrow

    if not visualize and cfile is not None:
        print("Can't capture image unless also run visualization")
        cfile = None

    solve(nrow, ncol, mode, epsilon, maxSteps, verbose, visualize, cfile)

if __name__ == "__main__":
    run(sys.argv[0], sys.argv[1:])
