'''
huffman-coding.py
This code was written in 15-113 lecture on Tue 24-Jan and Thu 26-Jan.
It is only for demonstrational purposes, and may contain
dubious style and even an occasional bug.
'''

import requests, string, heapq

class Node:
    def __init__(self, symbol, frequency, children):
        self.symbol = symbol
        self.frequency = frequency
        self.children = children

    def __lt__(self, other):
        if not isinstance(other, Node):
            raise Exception('darn it')
        return (self.frequency < other.frequency)

    def __repr__(self):
        return f'{self.symbol}({self.frequency:0.2})'

def readFromUrl(url):
    response = requests.get(url)
    return response.text

def getFrequencyData():
    url = 'http://fitaly.com/board/domper3/posts/136.html'
    contents = readFromUrl(url)
    lines = [ ]
    inBlock = False
    for line in contents.splitlines():
        if line == '<br ab>ASCII  Char  Count':
            inBlock = True
        elif inBlock:
            lines.append(line)
        if line == '<br ab>126    ~       8 ( 0.0003%)':
            break
    return '\n'.join(lines).replace('<br ab>','')

def makeFrequencyMap(data):
    d = dict()
    for line in data.splitlines():
        line = line.replace('(', ' ')
        entries = line.split()
        asciiValue = int(entries[0])
        char = chr(asciiValue)
        if char in string.printable:
            frequency = float(entries[-1][:-2])
            d[char] = frequency
    return d

def makeScrabbleFrequencyMap():
    letterScores = [
    #  a, b, c, d, e, f, g, h, i, j, k, l, m,
       1, 3, 3, 2, 1, 4, 2, 4, 1, 8, 5, 1, 3,
    #  n, o, p, q, r, s, t, u, v, w, x, y, z
       1, 1, 3,10, 1, 1, 1, 1, 4, 4, 8, 4,10
    ]
    result = dict()
    total = 0
    for i in range(26):
        result[chr(ord('a') + i)] = 1/letterScores[i]
        total += 1/letterScores[i]**0.5
    result[' '] = 1
    total += 1
    for key in result:
        result[key] = result[key] / total * 100
    return result

def makePrefixTree():
    global frequencyMap # do not do this
    frequencyMap = makeFrequencyMap(getFrequencyData())
    for key in list(frequencyMap.keys()):
        if not key.islower() and key != ' ':
            del frequencyMap[key]
    # frequencyMap = makeScrabbleFrequencyMap()
    heap = [ ]
    # 1. load each letter as a node in the heap
    for symbol in frequencyMap:
        node = Node(symbol, frequencyMap[symbol], [])
        heapq.heappush(heap, node)
    # 2. keep removing the two smallest, combining them, and
    #    pushing them back
    while len(heap) > 1:
        node1 = heapq.heappop(heap)
        node2 = heapq.heappop(heap)
        symbol3 = node1.symbol + node2.symbol
        frequency3 = node1.frequency + node2.frequency
        children3 = [node1, node2]
        node3 = Node(symbol3, frequency3, children3)
        heapq.heappush(heap, node3)
    # 3. The whole darn tree lives in heap[0]
    root = heap[0]
    return root

def makeVariableCode(prefixTree):
    codeMap = dict()
    def walkTree(node, prefix):
        if node.children == []:
            codeMap[node.symbol] = prefix
        else:
            assert(len(node.children) < 3)
            for i in range(len(node.children)):
                walkTree(node.children[i], prefix + str(i))
    walkTree(prefixTree, '')
    return codeMap

def printTree(node, depth=0):
    print('  '*depth + f'{node.symbol} {node.frequency}')
    for child in node.children:
        printTree(child, depth+1)

def encodeToBitsAsString(codeMap, s):
    return ''.join([codeMap[c] for c in s])

def decodeFromBitsAsString(prefixTree, encodedString):
    symbols = [ ]
    node = prefixTree
    for bit in encodedString:
        node = node.children[int(bit)]
        if node.children == [ ]:
            symbols.append(node.symbol)
            node = prefixTree
    return ''.join(symbols)

def getExpectedBitsPerSymbol():
    # this uses 2 globals .  Sorry.  Sue us.
    weightedSum = 0
    for symbol in frequencyMap:
        frequency = frequencyMap[symbol]/100
        code = codeMap[symbol]
        weightedSum += frequency * len(code)
    return weightedSum

prefixTree = makePrefixTree()
codeMap = makeVariableCode(prefixTree)
s = 'this is a test it is so amazing'
encodedString = encodeToBitsAsString(codeMap, s)
print(encodedString)
print(decodeFromBitsAsString(prefixTree, encodedString))
print(getExpectedBitsPerSymbol())