Main Page   Namespace List   Class Hierarchy   Alphabetical List   Compound List   File List   Namespace Members   Compound Members   File Members   Related Pages  

PLSA.hpp

Go to the documentation of this file.
00001 /*==========================================================================
00002  * Copyright (c) 2003 University of Massachusetts.  All Rights Reserved.
00003  *
00004  * Use of the Lemur Toolkit for Language Modeling and Information Retrieval
00005  * is subject to the terms of the software license set forth in the LICENSE
00006  * file included with this software, and also available at
00007  * http://www.lemurproject.org/license.html
00008  *
00009  *==========================================================================
00010 */
00011 /*
00012   Probabilistic Latent Semantic Analysis 
00013   Java Reference implementation from Andrew Schein and Alexandrin Popescul
00014   Author: dmf 2/2003
00015  */
00016 
00017 #ifndef _PLSA_HPP
00018 #define _PLSA_HPP
00019 #include "common_headers.hpp"
00020 #include <cmath>
00021 #include <set>
00022 #include "Index.hpp"
00023 #include "FreqVector.hpp"
00024 
00025 //forward declaration.
00026 class PLSA;
00027 
00028 // Needed to declare member function parameter (not sure why...).
00029 typedef double (PLSA::*jointfuncType)(int, int);
00030 
00031 
00035 // Need both transient and permanent versions (store model components).
00036 class PLSA  {
00037 public:
00038   // need to pass in many parameters. Use sensible defaults...
00039   // document arguments.
00041   PLSA(const Index &dbIndex, int numCats, HashFreqVector **train, 
00042        HashFreqVector **validate, int numIter, 
00043        int numRestarts, double betastart, 
00044        double betastop, double anneal, double betaMod);
00045 
00047   PLSA(const Index &dbIndex, int testPercentage, int numCats, int numIter, 
00048        int numRestarts, double betastart, 
00049        double betastop, double anneal, double betaMod);
00051   PLSA(const Index &dbIndex);
00052   virtual ~PLSA();
00053     
00055   void iterateWithRestarts();
00056 
00059   double *get_p_z() const   {return p_z_best;}
00061   double **get_p_w_z() const {return p_w_z_best;}
00063   double **get_p_d_z() const {return p_d_z_best;}
00065   double getProb(int d, int w) const ;
00066   // get the matrix sizes
00068   int numWords() const {return sizeW;}
00070   int numDocs() const {return sizeD;}
00072   int numCats() const {return sizeZ;}
00075   bool readArrays();
00076 
00077 private:
00078   // attributes
00080   const Index &ind;
00082   int sizeZ;
00084   int sizeD;
00086   int sizeW;
00087 
00089   HashFreqVector **data; // passed in
00091   HashFreqVector **testData; // passed in
00093   set<int, less<int> > *invIndex; // constructed
00094   
00096   double startBeta, beta, betaMin;
00098   double betaModifier;
00100   double annealcue;
00102   int R;
00104   int numberOfIterations;
00106   int numberOfRestarts;
00108   double bestTestLL;
00110   double bestA;
00112   bool bestOnly;
00114   bool ownMem;
00116   double *p_z_current;
00118   double **p_w_z_current;
00120   double **p_d_z_current;
00121 
00123   double *p_z_prev;
00125   double **p_w_z_prev;
00127   double **p_d_z_prev;
00128 
00130   double *p_z_best;
00132   double **p_w_z_best;
00134   double **p_d_z_best;
00135 
00136   // methods
00138   void setPrevToCurrent();
00140   void setCurrentToBest();
00142   void setBestToCurrent();
00144   void setBestToPrev();
00146   void setPrevToBest();
00147 
00150   double getAverageLikelihood();
00153   double getAverageLikelihoodPrev();
00154 
00156   double jointEstimate (int indexD, int indexW);
00158   double jointEstimateCurrent (int indexD, int indexW);
00160   double jointEstimateBest (int indexD, int indexW);
00163   double jointEstimateBeta (int indexD, int indexW);
00164   
00166   void iterate();
00168   void initializeParameters();
00169 
00172   double doLogLikelihood(jointfuncType, HashFreqVector **&myData);
00174   double logLikelihood();
00176   double validateDataLogLikelihood();
00178   double validateCurrentLogLikelihood();
00180   double bestDataLogLikelihood();
00182   double interleavedIterationEM();
00184   void selectTestTrain(int testPercent);
00186   void init();
00188   void initR();
00190   enum pType {P_Z = 0, P_W_Z = 1, P_D_Z = 2};
00191 
00195   //  bool readArrays();
00197   void writeArrays();  
00199   bool readArray(ifstream& infile, enum pType which);
00201   void writeArray(ofstream& ofile, enum pType which);
00202 };
00203 
00204 #endif /* _PLSA_HPP */

Generated on Wed Nov 3 12:59:01 2004 for Lemur Toolkit by doxygen1.2.18