# Visualize grid with heat-map colors

import math
import Tkinter
from PIL import Image, ImageDraw

def cstring(c):
    fields = [("%.2x" % int(255*x)) for x in c]
    return "#" + "".join(fields)

class Colors:
    black = (0.0, 0.0, 0.0)
    blue =  (0.0, 0.0, 1.0)
    green = (0.0, 1.0, 0.0)
    cyan =  (0.0, 1.0, 1.0)
    red =   (1.0, 0.0, 0.0)
    magenta = (1.0, 0.0, 1.0)
    yellow = (1.0, 1.0, 0.0)
    white = (1.0, 1.0, 1.0)
    lightred = (1.0, 0.5, 0.5)
    gray = (0.5, 0.5, 0.5)

class HeatMap:
    colorList = (Colors.magenta, Colors.blue, Colors.cyan, Colors.green, Colors.yellow, Colors.red)
    weightList = (0.2, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)
    # Below splitVal, interpolate over color range.
    # Above, interpolate between final two colors
    splitVal = 0.5
    scaleList = []

    def __init__(self):
        self.scaleList = []
        for i in range(len(self.colorList)):
            self.scaleList.append(map(lambda (c): c * self.weightList[i], self.colorList[i]))

    # Perform interpolation of number in range [0, 1.0) to get color
    def interpolate(self, x):
        upper = x >= self.splitVal
        if upper:
            lcolor = self.scaleList[-2]
            rcolor = self.scaleList[-1]
            point = (x - self.splitVal) / (1.0 - self.splitVal)
        else:
            x = x / self.splitVal
            segs = len(self.scaleList) - 2
            sx = x * segs
            interval = int(sx)
            point = sx - interval
            lcolor = self.scaleList[interval]
            rcolor = self.scaleList[interval+1]
        color = [rcolor[idx] * point + lcolor[idx] * (1-point) for idx in range(3)]
        return cstring(color)

    def genColor(self, val):
        if val == 0:
            return cstring(Colors.black)
        return self.interpolate(val)

    def genColors(self, vlist):
        clist = [self.genColor(v) for v in vlist]
        return clist

class Display:
    nrow = 100
    ncol = 100
    squareSize = 8
    boundarySeparation = 2
    display = None  # TK Window
    frame = None    # Frame within window
    canvas = None   # Canvas within frame
    gridSquares = [] # Set of rectangles, nrow * ncol total
    allSquares = []  # Set of all rectangles
    colorList = []  # Most recent set of colors
    hmap = None
    neutralColor = cstring(Colors.gray)

    def __init__(self, nrow, ncol, maxdim = 800):
        self.nrow = nrow
        self.ncol = ncol
        self.squareSize = maxdim // max(self.ncol, self.nrow)
        self.display = Tkinter.Tk()
        self.display.title('Heat Flow Simulation of %d X %d grid' % (nrow-2, ncol-2))
        self.frame = Tkinter.Frame(self.display)
        self.frame.pack(fill=Tkinter.BOTH)
        iwidth = self.imageWidth()
        iheight = self.imageHeight()
        self.canvas = Tkinter.Canvas(self.frame,
                                     width = iwidth,
                                     height = iheight)
        self.canvas.pack(fill=Tkinter.BOTH)
        bgSquare = self.canvas.create_rectangle(0, 0, iwidth, iheight, width = 0,
                                                fill = self.neutralColor)
        self.allSquares = [bgSquare]
        self.gridSquares = []
        for r in range(0, self.nrow):
            for c in range(0, self.ncol):
                (x, y) = self.xyPos(r, c)
                sq = self.canvas.create_rectangle(x, y, x+self.squareSize, y+self.squareSize, width = 0,
                                                  fill = cstring(Colors.black))
                self.gridSquares.append(sq)
                self.allSquares.append(sq)
        self.fixCorners()
        self.hmap = HeatMap()
        self.update()

    def imageWidth(self):
        return self.ncol * self.squareSize + 2 * self.boundarySeparation

    def imageHeight(self):
        return self.nrow * self.squareSize + 2 * self.boundarySeparation

    # Keep colors of corner rectangles neutral
    def fixCorners(self):
        self.colorSquare(self.index(0,0), self.neutralColor)
        self.colorSquare(self.index(self.ncol-1,0), self.neutralColor)
        self.colorSquare(self.index(0,self.nrow-1), self.neutralColor)
        self.colorSquare(self.index(self.ncol-1,self.nrow-1), self.neutralColor)

    def update(self):
        self.canvas.update()

    def xyPos(self, r, c):
        x = self.squareSize * c
        if c > 0:
            x += self.boundarySeparation
        if c == self.ncol - 1:
            x += self.boundarySeparation
        y = self.squareSize * r
        if r > 0:
            y += self.boundarySeparation
        if r == self.nrow - 1:
            y += self.boundarySeparation
        return (x, y)

    def index(self, r, c):
        return r * self.ncol + c

    def rowCol(self, idx):
        r = idx // self.ncol
        c = idx % self.ncol
        return (r, c)

    def colorSquare(self, idx, color):
        if idx >= 0 and idx < len(self.gridSquares):
            square =  self.gridSquares[idx]
            self.canvas.itemconfig(square, fill = color)

    # Set colors based on counts for each square
    def setColors(self, vlist = []):
        clist = self.hmap.genColors(vlist)
        self.colorList = clist
        for idx in range(len(vlist)):
            self.colorSquare(idx, clist[idx])
        self.fixCorners()
        self.update()

    def finish(self):
        self.display.destroy()

    def capture(self, fname):
        img = Image.new('RGB', (self.imageWidth(), self.imageHeight()), "gray")
        dimg = ImageDraw.Draw(img)
        for idx in range(len(self.colorList)):
            r, c = self.rowCol(idx)
            x, y = self.xyPos(r, c)
            dimg.rectangle((x, y, x + self.squareSize, y + self.squareSize), fill = self.colorList[idx])
        try:
            img.save(fname)
            print("Saved image in file %s." % (fname))
        except Exception as ex:
            print("Could not save image to file %s (%s)" % (fname, str(ex)))
