import random

class lda:
    """Latent Dirichlet allocation using compressed Gibbs sampling"""
    
    def __init__(self):
        self.numTopics = 10
        self.alpha = 0.1
        self.beta = 0.001
        #for trace output, how often to output updates
        self.dstep = 250
    
    class topicCounter:
        """Counts occurrences of each topic"""
        def __init__(self):
            self._ctrMap = {}
        def get(self,k): 
            return self._ctrMap.get(k,0)
        def __getitem__(self,k): 
            return self._ctrMap.get(k,0)
        def add(self,k, delta=+1):
            self._ctrMap[k] = self.get(k) + delta
            #keep array sparse
            if not self._ctrMap[k]: del self._ctrMap[k]

    #data structures and conventions
    #
    # topic k, docId d, and wordId w are integer indices
    #
    # x[d][j] = w, index of j-th word in doc d
    # z[d][j] = k, index of latent topic of j-th word in doc d
    # vocab[w] = string for the word with index w
    #
    # totalTopicCount[k] = number of words in topic k
    # docTopicCount[d][k] = number of words in topic k for document d
    # wordTopicCount[w][k] = number of occurrences of word w in topic k
    # totalWords = number of words in the corpus

    def initGibbs(self):
        print '.initializing latent vars'
        self.totalTopicCount = self.topicCounter()
        self.docTopicCount = [self.topicCounter() for d in xrange(len(self.x))]
        self.wordTopicCount = [self.topicCounter() for w in xrange(len(self.vocab))]
        self.z = [[-1 for j in xrange(len(self.x[d]))] for d in xrange(len(self.x))]
        for d in xrange(len(self.x)):
            if (d+1)%self.dstep==0: print '..doc',d+1,'of',len(self.x)
            for j in xrange(len(self.x[d])):
                w = self.x[d][j]
                k = random.randint(0,self.numTopics-1)
                self.z[d][j] = k
                self.docTopicCount[d].add(k, 1)
                self.wordTopicCount[w].add(k, 1)
                self.totalTopicCount.add(k, 1)
        #reasonable parameters
        self.alpha = 1.0/self.numTopics
        self.beta = 1.0/len(self.vocab)
        print "alpha:",self.alpha,"beta:",self.beta
                
    def runGibbs(self,maxT):
        for t in xrange(maxT):
            print '.iteration',t+1,'of',maxT
            for d in xrange(len(self.x)):
                if (d+1)%self.dstep==0: print '..doc',d+1,'of',len(self.x)
                for j in xrange(len(self.x[d])):
                    k = self.resample(d,j)
                    self.flip(d, j, self.z[d][j], k)

    def flip(self, d, j, k_old, k_new):
        """update counts to reflect a changed value of z[d][j]"""
        if k_old != k_new:
            w = self.x[d][j]
            self.docTopicCount[d].add(k_old, -1)
            self.docTopicCount[d].add(k_new, +1)
            self.wordTopicCount[w].add(k_old, -1)
            self.wordTopicCount[w].add(k_new, +1)
            self.totalTopicCount.add(k_old, -1)
            self.totalTopicCount.add(k_new, +1)
            self.z[d][j] = k_new

    def resample(self,d,j):
        """sample a new value of z[d][j]"""
        p = []
        norm = 0.0
        #compute pk = Pr(z_dj=k | everything else)
        for k in xrange(self.numTopics):
            w = self.x[d][j]
            z_dj_equals_k = 1 if self.z[d][j]==k else 0
            #unnormalized chance of picking topic k in doc d
            ak = (self.docTopicCount[d][k] - z_dj_equals_k + self.alpha)
            #unnormalized chance of picking topic k for word w
            bk = ((self.wordTopicCount[w][k] - z_dj_equals_k + self.beta)
                  /(self.totalTopicCount[k] - z_dj_equals_k + self.totalWords * self.beta))
            pk = ak*bk
            p.append(pk)
            norm += pk
        #pick randomly from the normalized pk
        sum_p_up_to_k = 0.0
        r = random.random()
        for k in xrange(self.numTopics):
            sum_p_up_to_k += p[k]/norm
            if r < sum_p_up_to_k:
                return k
        
    def phi(self,w,k):
        """weight of word w under topic k"""
        num = (self.wordTopicCount[w][k] + self.beta)
        denom = (self.totalTopicCount[k] + self.totalWords * self.beta)
        return num/denom
        
    def theta(self,d,k):
        """weight of doc under topic k"""
        num = (self.docTopicCount[d][k] + self.alpha)
        denom = (sum(self.docTopicCount[d]) + self.numTopics*self.alpha)
        return num/denom
    
    ####################
    # i/o
    ####################

    def loadDat(self,filename):
        """Load from format used by Blei: line number d contains
        number-of-words <space> wordId:freq <space> wordId:freq ..."""
        print '.loading data from',filename
        fdat = open(filename,'r')
        self.totalWords = 0
        self.x = []
        d = 0
        line = fdat.readline()
        while line:
            if (d+1)%self.dstep==0: print '..doc',d+1,'of',len(self.x)
            parts = line.strip().split(" ")
            numWords = int(parts[0])
            self.x.append([])
            for i in xrange(1,len(parts)):
                wordId,freq = parts[i].split(':')
                for j in xrange(int(freq)):
                    self.x[d].append(int(wordId))
                self.totalWords += int(freq)
            #print d,'has',len(self.x[d]),'words'
            line = fdat.readline()
            d += 1
        fdat.close()

    
    def loadVocab(self,filename):
        """load vocabulary: line number w is string associated with word with id w"""
        self.vocab = []
        fvocab = open(filename,'r')
        line = fvocab.readline()
        while line:
            self.vocab.append(line.strip())
            line = fvocab.readline()
        fvocab.close()
    
    def saveModel(self,filename):
        """save latent vars to a file"""
        fmod = open(filename,'w')
        for d in xrange(len(self.x)):
            fmod.write(" ".join([str(zdj) for zdj in self.z[d]]))
            fmod.write("\n")
        fmod.close()

    def loadModel(self,filename):
        """restore a model - latent vars and counts - saved with
        saveModel.  assumes that data has been loaded with loadData"""
        #just to get the counters set up
        self.initGibbs()
        #load new z values and adjust
        print '.loading model from',filename
        fmod = open(filename,'r')
        line = fmod.readline()
        d = 0
        while line:
            if (d+1)%self.dstep==0: print '..doc',d+1,'of',len(self.x)
            zdj = map(int,line.split(" "))
            for j in xrange(len(zdj)):
                self.flip(d, j, self.z[d][j], zdj[j])
            d += 1
            line = fmod.readline()
        fmod.close()
        
    ####################
    # display code
    ####################

    def showWord(self,index):
        """for debugging, show ith vocabulary word"""
        print 'word',index,':',self.vocab[index]

    def showDoc(self,d):
        """for debugging, show ith document"""
        print 'doc',d,':'
        for w in self.x[d]:
            print ' #',w,':',self.vocab[w]

    def heavyWords(self,k,numTop=10):
        """show the most likely words in the k-th topic"""
        def negWeightInTopicK(w) : return -self.phi(w,k)
        topicWeightedWordIds = range(len(self.vocab))
        topicWeightedWordIds.sort(key=negWeightInTopicK)
        return [(w,-negWeightInTopicK(w)) for w in topicWeightedWordIds[0:numTop]]

    def showHeavyWords(self,k,numTop=10):
        for (w,phi_wk) in self.heavyWords(k,numTop):
            print "%04d %0.5f %s" % (w,phi_wk,self.vocab[w])

    def heavyDocs(self,k,numTop=10):
        """show the the most likely docs in the k-th topic"""
        def negWeightInTopicK(d) : return -self.theta(d,k)
        topicWeightedDocIds = range(len(self.x))
        topicWeightedDocIds.sort(key=negWeightInTopicK)
        return [(d,-negWeightInTopicK(d)) for d in topicWeightedDocIds[0:numTop]]

    def showHeavyDocs(self,k,numTop=10):
        for (d,theta_dk) in self.heavyDocs(k,numTop):
            print "%04d %0.5f" % (d,theta_dk)

##############################################################################
# main program
##############################################################################

if __name__ == "__main__":
    #what to do
    op = 'run5'

    #common steps
    ap = lda()
    print 'loading data...'
    ap.loadDat('/usr0/wcohen/code/lda/ap/ap.dat')
    print 'loading vocab...'
    ap.loadVocab('/usr0/wcohen/code/lda/ap/vocab.txt')        

    #do something else
    if op == 'run5':
        print 'loading data...'
        ap.initGibbs()
        ap.runGibbs(5)
        ap.saveModel('ap20.mod')

    if op == 'run20':
        print 'loading data...'
        ap.initGibbs()
        ap.runGibbs(20)
        ap.saveModel('ap20.mod')

    if op == 'run30':
        print 'loading data...'
        ap.initGibbs()
        ap.runGibbs(30)
        ap.saveModel('ap20.mod')

    if op == 'load20':
        ap.loadModel('ap20.mod')
