# Python Starter Code for Arithmetic Coding
import random
random.seed(0) # for reproducibility

def arithmeticEncode(p, seq):
  # TODO: Implement this function
  return None

def arithmeticDecode(p, code):
  # TODO: Implement this function
  return None

# Define the source distribution
numSymbols = 5
p0 = [0.01]
p0.extend([random.uniform(0.1, 1.1) for _ in range(numSymbols - 1)])
p = [x/sum(p0) for x in p0]

# Short (length-6) Example
seqLen = 6
seq = [random.randint(2, numSymbols) for _ in range(seqLen)]
seq.append(1) # termination symbol
code = arithmeticEncode(p, seq)
decodeStr = arithmeticDecode(p, code)
print code
print seq
print decodeStr

# Longer (length-30) Example
seqLen = 30
seq = [random.randint(2, numSymbols) for _ in range(seqLen)]
code = arithmeticEncode(p, seq)
decodeStr = arithmeticDecode(p, code)
# display a small subsequence, for verification
print code
print seq
print decodeStr

# check that arithmeticDecode(p, arithmeticDecode(p, input)) == input
print 'Input and decoded output are equal: ' + str(decodeStr == seq)
