00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef _PLSA_HPP
00018 #define _PLSA_HPP
00019 #include "common_headers.hpp"
00020 #include <cmath>
00021 #include <set>
00022 #include "Index.hpp"
00023 #include "FreqVector.hpp"
00024
00025
00026 class PLSA;
00027
00028
00029 typedef double (PLSA::*jointfuncType)(int, int);
00030
00031
00035
00036 class PLSA {
00037 public:
00038
00039
00041
00042 HashFreqVector **validate, int numIter,
00043 int numRestarts, double betastart,
00044 double betastop, double anneal, double betaMod);
00045
00047 PLSA(const Index &dbIndex, int testPercentage, int numCats, int numIter,
00048 int numRestarts, double betastart,
00049 double betastop, double anneal, double betaMod);
00051 PLSA(const Index &dbIndex);
00052 virtual ~PLSA();
00053
00055 void iterateWithRestarts();
00056
00059 double *get_p_z() const {return p_z_best;}
00061 double **get_p_w_z() const {return p_w_z_best;}
00063 double **get_p_d_z() const {return p_d_z_best;}
00065 double getProb(int d, int w) const ;
00066
00068
00070
00072
00075
00076
00077 private:
00078
00080
00082
00084
00086
00087
00089 HashFreqVector **data;
00091
00093
00094
00096 double startBeta, beta, betaMin;
00098 double betaModifier;
00100 double annealcue;
00102 int R;
00104 int numberOfIterations;
00106 int numberOfRestarts;
00108 double bestTestLL;
00110 double bestA;
00112 bool bestOnly;
00114 bool ownMem;
00116 double *p_z_current;
00118 double **p_w_z_current;
00120 double **p_d_z_current;
00121
00123 double *p_z_prev;
00125 double **p_w_z_prev;
00127 double **p_d_z_prev;
00128
00130 double *p_z_best;
00132 double **p_w_z_best;
00134 double **p_d_z_best;
00135
00136
00138
00140
00142
00144
00146
00147
00150 double getAverageLikelihood();
00153 double getAverageLikelihoodPrev();
00154
00156 double jointEstimate (int indexD, int indexW);
00158 double jointEstimateCurrent (int indexD, int indexW);
00160 double jointEstimateBest (int indexD, int indexW);
00163 double jointEstimateBeta (int indexD, int indexW);
00164
00166 void iterate();
00168 void initializeParameters();
00169
00172 double doLogLikelihood(jointfuncType, HashFreqVector **&myData);
00174 double logLikelihood();
00176 double validateDataLogLikelihood();
00178 double validateCurrentLogLikelihood();
00180 double bestDataLogLikelihood();
00182 double interleavedIterationEM();
00184 void selectTestTrain(int testPercent);
00186 void init();
00188 void initR();
00190 enum pType {P_Z = 0, P_W_Z = 1, P_D_Z = 2};
00191
00195
00197
00199
00201
00202 };
00203
00204 #endif