'''
combinatoric-puzzles.py
This code was written in 15-113 lecture.
It is only for demonstrational purposes, and may contain
dubious style and even an occasional bug.

Using combinatorics (and backtracking in one case) to solve:
    * Three-Card Poker
    * Krypto (24)
    * Cryptarithm (SEND+MORE=MONEY)
'''

from itertools import *
from collections import defaultdict
import time

###########################################################
### Helper Fns
###########################################################

def count(iterable):
    return sum(1 for _ in iterable)

def countIf(iterable, test):
    return sum(1 for v in iterable if test(v))

###########################################################
### Three-Card Poker
###########################################################

# See:
# http://en.wikipedia.org/wiki/Three_card_poker#Hand_Ranks_of_Three_Card_Poker

def getDeck():
    suits = 'CDHS'
    faces = 'A23456789TJQK'
    return list(product(faces, suits))

def totalHands(cardsPerHand):
    deck = getDeck()
    return count(combinations(deck, cardsPerHand))
    
def hasPair(hand):
    # ignore flushes:
    suits = (suit for (face, suit) in hand)
    if len(set(suits)) == 1: return False
    # count pairs:
    faceCounts = defaultdict(int)
    for face,suit in hand:
        faceCounts[face] += 1
    return max(faceCounts.values()) == 2

def totalPairs(cardsPerHand):
    deck = getDeck()
    return countIf(combinations(deck, cardsPerHand), hasPair)

def testThreeCardPoker():
    print('Testing three-card poker...', end='')
    assert(totalHands(3) == 22100)
    assert(totalPairs(3) == 3744)
    print('Passed!')

###########################################################
### Krypto
###########################################################

def evalKrypto(numbers, operators):
    operators = list(reversed(operators))
    numbers = list(reversed(numbers))
    result = numbers.pop()
    while numbers:
        operator = operators.pop()
        number = numbers.pop()
        if operator == '/': operator = '//'
        result = eval(f'{result} {operator} {number}')
    return result

def makeKryptoString(numbers, operators):
    numberStrings = [str(n) for n in numbers]
    operatorStrings = ('',) + operators
    stringPairs = zip(operatorStrings, numberStrings)
    strings = [''.join(v) for v in stringPairs]
    return ''.join(strings)

def krypto(numbers, target, ops='+-*/'):
    n = len(numbers)
    for numbers in permutations(numbers):
        for operators in product(ops, repeat=n-1):
            if evalKrypto(numbers, operators) == target:
                return makeKryptoString(numbers, operators)
    return None

def testKrypto():
    print('Testing krypto()...', end='')
    assert(krypto([23,7,3], 140) == '23-3*7')
    assert(krypto([104, 11, 17, 26], 144) == '17*26/11+104')
    assert(krypto([3,22,44], 195112, ['+','-','*','/','**','^','%']) == '22^44**3')
    assert(krypto([1,2,3], 42) == None)
    print('Passed!')

###########################################################
### Cryptarithm
###########################################################

'''
   SEND    9567
  +MORE    1085
  -----    ----
  MONEY   10652

Try all substitutions:
  DEMNORSY
  12345678
  12345687
  01234569
  ...
'''

def cryptarithm(puzzle):
    # 'SEND+MORE=MONEY'
    # extract words and letters
    words = puzzle.replace('=','+').split('+')
    letters = ''.join(sorted(set(''.join(words))))
    firstLetters = ''.join(sorted(set([word[0] for word in words])))
    # iterate over all permutations of digits to assign to letters
    for digits in itertools.permutations(range(10), len(letters)):
        map = makeMap(letters, digits)
        if solvesCryptarithm(puzzle, firstLetters, map):
            return makeSolutionString(puzzle, map)
    return None

def makeSolutionString(puzzle, map):
    result = [ ]
    for char in puzzle:
        if char.isalpha():
            result.append(str(map[char]))
        else:
            result.append(char)
    return ''.join(result)

def solvesCryptarithm(puzzle, firstLetters, map):
    # this is the gross way (and if I am saying that....)
    # a more graceful way is below in isLegalPartialSolution()
    # first verify first letters are non-zero
    for letter in firstLetters:
        if map[letter] == 0:
            return False
    # first replace each letter with its corresponding digit
    for letter in map:
        puzzle = puzzle.replace(letter, str(map[letter]))
    # now replace = with ==, to make 'SEND+MORE=MONEY' into 'SEND+MORE==MONEY'
    puzzle = puzzle.replace('=', '==')
    return eval(puzzle)

def makeMap(letters, digits):
    map = dict()
    for letter, digit in zip(letters, digits):
        map[letter] = digit
    return map

def testCryptarithm():
    print('Testing cryptarithm()...', end='')
    fn = cryptarithm
    fn = fasterCryptarithm
    time0 = time.time()
    # from: http://en.wikipedia.org/wiki/Cryptarithm
    assert(fn('SEND+MORE=MONEY') == '9567+1085=10652')
    # from: http://cryptarithms.awardspace.us/
    assert(fn('BARREL+BROOMS=SHOVELS') == '893360+832241=1725601')
    assert(fn('COUPLE+COUPLE=QUARTET') == '653924+653924=1307848')
    time1 = time.time()
    print(f'Passed in {round(time1 - time0, 1)}s')

###########################################################
### Faster Cryptarithm
###########################################################

def fasterCryptarithm(puzzle):
    # we sure hope it's faster.  It will use backtracking.  Ooooh.
    words = puzzle.replace('=','+').split('+')
    letters = ''.join(sorted(set(''.join(words))))
    colishLetters = list(reversed(getColishLetters(words)))
    firstLetters = ''.join(sorted(set([word[0] for word in words])))
    assignedDigits = set()
    map = dict()
    def solvePuzzleWithBacktracking():
        if len(colishLetters) == 0:
            return map
        else:
            nextLetter = colishLetters.pop()
            for nextDigit in range(10):
                if nextDigit not in assignedDigits:
                    map[nextLetter] = nextDigit
                    assignedDigits.add(nextDigit)
                    if isLegalPartialSolution(words, firstLetters, map):
                        solution = solvePuzzleWithBacktracking()
                        if solution != None:
                            return solution
                    assignedDigits.remove(nextDigit)
                    del map[nextLetter]
            colishLetters.append(nextLetter)
        return None
    if solvePuzzleWithBacktracking() == None:
        return None
    return makeSolutionString(puzzle, map)

def isLegalPartialSolution(words, firstLetters, map):
    # first verify first letters are non-zero
    for letter in firstLetters:
        if (letter in map) and map[letter] == 0:
            return False
    # add colishly from ones digit upwards
    carry = 0
    for col in range(len(words[2])):
        digits = [ ]
        for i in range(3):
            word = words[i]
            if col < len(word):
                letter = word[-1-col]
                if letter not in map:
                    return True
                else:
                    digit = map[letter]
            else:
                digit = 0
            digits.append(digit)
        if ((carry + digits[0] + digits[1])%10) != digits[2]:
            return False
        carry = (carry + digits[0] + digits[1]) // 10
    return True

def getColishLetters(words):
    colishLetters = [ ]
    for col in range(len(words[2])):
        for word in words:
            if col < len(word):
                letter = word[-1-col]
                if letter not in colishLetters:
                    colishLetters.append(letter)
    return colishLetters

###########################################################
### testAll and main
###########################################################

def testAll():
    testThreeCardPoker()
    testKrypto()
    testCryptarithm()

def main():
    testAll()

main()