from Simulation.Abstract import Model, InsertionEvent, DeletionEvent
import os
import pickle
from random import randint


class DomainCountContextModel(Model):
    """

    """

    events = ['ins', 'del']

    def __init__(self, path):
        """
        Initialize the Domian Count Context Model with alphabet, double counts,
        doubleEndCount, doubleEndType, onlySingleton, and startDomain

        path: file path to the root of the experiment's data directory
              (optionally) it can be the path to the pickle that the
              model should be loaded from
        """
        if path.endswith("pkl"):
            # Load from pickle
            self = DomainCountContextModel.unpickle(path)
        else:
            # Set pickle name in case we want to pickle it later
            self.pickleName = os.path.abspath(path + "/formatted/DomainCountContextModel.pkl")
            # Load from many pre-processed pickles
            with open(path + "/formatted/alphabet.pkl", "rb") as alphPkl:
                alphabetSet = pickle.load(alphPkl)
            with open(path + "/formatted/doubleCounts-domArchs.pkl", 'rb') as dblPkl:
                self.dblCounts = pickle.load(dblPkl)
            with open(path + "/formatted/doubleEndCount.pkl", 'rb') as dblecPkl:
                self.doubleEndCount = pickle.load(dblecPkl)
            with open(path + "/formatted/doubleEndType.pkl", 'rb') as dbletPkl:
                self.doubleEndType = pickle.load(dbletPkl)
            with open(path + "/formatted/onlySingleton.pkl", 'rb') as osPkl:
                self.onlySingleton = pickle.load(osPkl)
            with open(path + "/formatted/startDomain.pkl", 'rb') as sdPkl:
                self.startDomain = pickle.load(sdPkl)
            # Transfer the alphabet set to list
            self.alphabet = []
            for key in alphabetSet:
                self.alphabet.append(key)

    def getEventEffects(self, eventType, domArch):
        """
        Get the proposed effects of a given event on the specified architecture.

        eventType: the type of the event, either 'ins' or 'del'
        domArch: list of domain strings

        Return: (location of the event in the specified architecture (int), a domain)
        """
        n = len(domArch) - 2
        if eventType == 'ins':
            eventLoc = randint(1, n+1)
            return DomainCountContextInsertion.getEventParameters(domArch, self.alphabet, self.startDomain, eventLoc)
        elif eventType == 'del':
            eventLoc = randint(1, n)
            return eventLoc, domArch[eventLoc]
        else:
            print("Error: unexpected event type in the function getEventEffects.")
            print("Exiting...")
            os.exit(0)

    def pickle(self):
        """
        Pickle this instance of the Domain Count Context Model.
        """
        with open(self.pickleName, 'wb') as pklF:
            pickle.dump(self, pklF)

    @staticmethod
    def unpickle(path):
        """
        Unpickle the instance of Domain Count Context Model at 'path'.

        path: a string path (probably ending in
              'formatted/DomainCountContextModel.pkl')
        """
        with open(path, 'rb') as pklF:
            return pickle.load(pklF)

    def __str__(self):
        return "Domain Count Context"


class DomainCountContextInsertion(InsertionEvent):
    """
    Choose a domain to insert according to context in the corpus, or more
    specifically, the domain triples in the corpus.
    """

    @staticmethod
    def getEventParameters(domArch, alphabet, startDomain, eventLoc):
        """
        Get the domain for an insertion event.
        """
        dom = DomainCountContextInsertion._getDomain(domArch, alphabet, startDomain)
        return eventLoc, dom

    @staticmethod
    def _getDomain(domArch, alphabet, startDomain):
        """
        Randomly choose a domain from all the domains that can be inserted.

        Returns the (string) domain chosen.

        domArch: list of domain strings
        alphabet: the list of all the distinct domains that get from the raw dataset
        startDomain: a dictionary of int that store the counts of all domains that could be singleton
        """
        n = len(domArch)
        if n == 2: # Try the first insertion
            randNum = randint(1, startDomain['sum'])
            for key in startDomain:
                if key == 'sum':
                    continue
                randNum = randNum - startDomain[key]
                if randNum <= 0:
                    return key
        else:
            randNum = randint(0, len(alphabet)-1)
            return alphabet[randNum]
