import os
from Parsing import NULLDOM
from Simulation.Abstract import InvalidModelError
from Simulation.DomainCountContext import DomainCountContextModel


class Engine(object):
    """
    Domain architecture evolution simulation engine.
    """
    from random import random, expovariate, randint
    from math import log

    def __init__(self, **kwargs):
        """
        Initialize the simulation engine with the simulation parameters.

        dataset: (string) path to the dataset
        eventRate: lambda for events
        branchLength: Length of the branch to simulate over
        startDomArch: an optional starting domain architecture
        model: (string) event model
        pseudoCount: the pseudoCount used in calculating transition probability
        use_pickle: boolean, whether or not to use a pickle for the model instance
        model_pkl: path to the pickled model, dependent on 'use_pickle' being present
        """
        # Assign time
        self.lmbd = kwargs['eventRate']
        self.branchLength = kwargs['branchLength']

        # Assign pseudoCount
        self.pseudoCount = kwargs['pseudoCount']

        # Assign insertionBias
        self.insertionBias = kwargs['insertionBias']

        # Assign the model
        # if the model has already been pickled, load it
        # otherwise, initialize it from data
        if kwargs['model'] == 'domaincountcontext':
            if kwargs.get('use_pickle', False):
                self.model = DomainCountContextModel.unpickle(kwargs['model_pkl'])
            else:
                self.model = DomainCountContextModel(kwargs['dataset'])
        else:
            raise InvalidModelError("Invalid Model name: {}".format(kwargs['model']))

        # Other parameters
        # Create/Prepare the starting domain architecture
        # Create slice copies when copying from kwargs and into history to
        # prevent 'start' from persisting
        start = kwargs['startDomArch'][:]
        if len(start) == 0 or start[0] != NULLDOM:
            start.insert(0, NULLDOM)
        if len(start) == 1 or start[len(start)-1] != NULLDOM:
            start.insert(len(start), NULLDOM)

        # Local tracking variables
        self.domArchHistory = [start[:]]
        self.eventHistory = []

    def run(self):
        """
        Run the simulator (the main loop).
        """
        simTime = 0
        domArch = self.domArchHistory[0][:]
        #simTime += self.getNextTimeStep()
        simTime += 1 # Let's start with something simple kind of time step
        while simTime <= self.branchLength:  # test time constraint
            # Get the next state
            nextDA, eventType, eventLoc, eventDom = self.getNextState(domArch) # get the next architecture and the event description

            # Attempt the event
            transitionPro = self.getTransitionPro(domArch, nextDA, eventType, eventLoc) # get the transition probability
            randNum = self.random()
            if randNum <= transitionPro:
                domArch = nextDA

            # Record history
            self.domArchHistory.append(domArch[:])
            self.eventHistory.append((simTime, eventType, eventLoc, eventDom, transitionPro, randNum))

            # Increment time/branch length
            simTime += 1
            #simTime += self.getNextTimeStep()

    def getNextState(self, domArch):
        """
        The function to choose the domain architecture for the next state

        domArch: the recent domain architecture in the form of list of domains
        Return: nextDA, the next domain architecture in the form of list of domains,
                and the description of the event
        """
        # Choose the next state
        threshold = (1/len(self.model.alphabet))/self.insertionBias
        randNum = self.random()
        if randNum >= threshold or len(domArch) == 2:
            eventType = 'ins'
        else:
            eventType = 'del'
        eventLoc, eventDom = self.model.getEventEffects(eventType, domArch)
        nextDA = domArch[:]
        if eventType == 'ins':
            nextDA.insert(eventLoc, eventDom)
        elif eventType == 'del':
            del nextDA[eventLoc]
        else:
            print("Error: unexpected event type in the function getNextState.")
            print("Exiting...")
            os.exit(0)
        return nextDA, eventType, eventLoc, eventDom

    def getTransitionPro(self, domArch, nextDA, eventType, eventLoc):
        """
        The function to calculate the transition probability between the current
        domain architecture and next domain architecture.

        domArch: the recent domain architecture in the form of list of domains
        nextDA: the next domain architecture in the form of list of domains
        eventType: the type of event, 'ins' or 'del'
        eventLoc: the location of the event, whcih is also the index of the domain
                    where the event takes place
        Return: transitionPro, the transition probability
        """
        doubleDom = self.model.dblCounts
        doubleEndCount = self.model.doubleEndCount
        doubleEndType = self.model.doubleEndType
        transitionPro = 0; # The transition probability
        N = len(self.model.alphabet)

        ###################################################################
        n = len(domArch)
        m = len(nextDA)
        if n == 3 and m == 2: # domArch has only one domain and the only domain is goding to be deleted
            return transitionPro # We prevent protein from disappearing
        if n == 2 and m == 3: # domArch has none domain and the first domain is to be inserted
            return 1 # We make sure that the first domain will be inserted successffuly
        if eventType == 'del': # Deletion
            i = eventLoc
            key1 = (nextDA[i-1], nextDA[i])
            key2 = (domArch[i-1], domArch[i])
            key3 = (domArch[i], domArch[i+1])
            DX = ('1', domArch[i])
            count1 = self.pseudoCount
            count2 = self.pseudoCount
            count3 = self.pseudoCount
            if key1 in doubleDom:
                count1 += doubleDom[key1]
            if key2 in doubleDom:
                count2 += doubleDom[key2]
            if key3 in doubleDom:
                count3 += doubleDom[key3]
            count4 = doubleEndCount[DX] + self.pseudoCount * N 
            transitionPro = count1 * count4 / count2 / count3
        elif eventType == 'ins': # Insertion
            i = eventLoc
            key1 = (nextDA[i-1], nextDA[i])
            key2 = (nextDA[i], nextDA[i+1])
            DX = ('1', nextDA[i])
            key3 = (domArch[i-1], domArch[i])
            count1 = self.pseudoCount
            count2 = self.pseudoCount
            count3 = self.pseudoCount
            if key1 in doubleDom:
                count1 += doubleDom[key1]
            if key2 in doubleDom:
                count2 += doubleDom[key2]
            if key3 in doubleDom:
                count3 += doubleDom[key3]
            count4 = doubleEndCount[DX] + self.pseudoCount * N
            transitionPro = count1 * count2 / count3 / count4 
        else: # An error exists
            print("Error: unexpected event type in the function getTransitionPro.")
            print("Exiting...")
            os.exit(0)
        return transitionPro

    def _hasHadDomains(self):
        """
        In the history of the domain architecture, has it ever had domains?
        Is the entire history full of empty architectures?
        Used to prevent incorrect early termination when if the first few
        architectures are empty.
        """
        return sum([len(domArch) for domArch in self.domArchHistory]) / len(self.domArchHistory) > 2

    def getNextTimeStep(self):
        return self.expovariate(self.lmbd)

    def getDomainArchHistory(self):
        """
        Return the entire domain architecture history as a list of lists.
        """
        return self.domArchHistory

    def getEventHistory(self):
        """
        Return the domain architecture event history as a list of tuples.
        """
        return self.eventHistory

    def getParameters(self):
        parameters = {}
        parameters['eventRate'] = self.lmbd
        parameters['branchLength'] = self.branchLength
        parameters['model'] = str(self.model)
        parameters['pseudoCount'] = self.pseudoCount
        return parameters



class FloatIter(object):
    """
    Hacky implementation of 'range' that works on floats.
    (Serve batch experiment type)
    """

    def __init__(self, start, stop, step):
        self.elements = []
        if isinstance(step, float):
            digits = len(str(step)[str(step).find('.')+1:])
            for i in range(int((stop-start)/step)+1):
                item = round(start + (step*i), digits)
                if step > 0:  # positive step sizes
                    if item < stop:
                        self.elements.append(item)
                else:  # negative step sizes
                    if item > stop:
                        self.elements.append(item)
        else:
            self.elements = list(range(start, stop, step))

    def __iter__(self):
        return iter(self.elements)

    def __len__(self):
        return len(self.elements)

    def __contains__(self, item):
        return item in self.elements
