////////////////////////////////////////////////////////
// File  : reachInfo.cc   
// Desc. : Performs an explicit fix point analysis
//         on which grounded predicates can be reached 
//         from the initial state. Obs: this analysis 
//         is relaxed by ignoring the delete list
//         (otherwise it would be as complex as the
//          planning problem itself)
//         The algorithm was suggested by Edelkamp and
//         Helmert.
//
//         - Makes a depth computation of each fluent
//           used for a HSPr approximation (also suggested
//           Edelkamp & Helmert)
//  
// Author: Rune M. Jensen, CS, CMU
// Date  : 12/8/01
////////////////////////////////////////////////////////

#include <math.h>
#include <stream.h>
#include "common.hpp"
#include "set.hpp"
#include "pddl.hpp"
#include "numDomain.hpp"
#include "reachInfo.hpp"


////////////////////////////////////////////////////////
// Aux. functions
////////////////////////////////////////////////////////

// functions for translating facts to bit vectors 
inline int fact2id(vector<int> &fact,int &predShift, int &objShift) {

  int id = 0;
  for (int i=fact.size()-1; i > 0; i--)
    {
      id = id << objShift;
      id += fact[i]-1;
    }
  id = id << predShift;
  id += fact[0]-1;
  return id;
}



// convert a bitvector rep. of a fact to a fact
inline void id2fact(int &id, vector<int> &fact, 
		    int &objMask,int &predMask,
		    int &predShift, int &objShift,
		    vector<int> &predArity) {
  fact.clear();
  int id0 = id;
  int pred = id0 & predMask;
  fact.push_back(pred+1);
  id0 = id0 >> predShift;
  
  for (int i=1; i < predArity[pred]+1; i++) {
    fact.push_back( (id0 & objMask) + 1 );
    id0 = id0 >> objShift;
  }
}
    

// convert a bit projection rep. of a projection to vector rep of projection 
inline void id2proj(int &id, vector<int> &proj, 
		    int &objMask, int &objShift,
		    int &projArity) {
  proj.clear();
  int id0 = id;
  for (int i=0; i < projArity; i++) {
    proj.push_back( (id0 & objMask) + 1 );
    id0 = id0 >> objShift;
  }
}




//IN
// factlst : numerical rep. of fact list
// pred : predicate number we of interest
// arg : arg of pred we want to find domain for
// finds the domain of an argument of a predicate in init
set<int> domainOf(vector< vector<int> > &factLst,int pred, int arg) {
  set<int> res;
  
  for (int i=0; i<factLst.size(); i++)
    if ( (factLst[i])[0] == pred) res.insert((factLst[i])[arg]);
  
  return res;
}


			     
// returns the level of the predicate 
// = the max indices of a parameter in the support (zero if
// empty support)
int levelOf(vector<int> &pred) {
  int max = 0;
  for (int i=1; i<pred.size(); i++) 
    if (pred[i]<0) max = max > -pred[i] ? max : -pred[i];
  return max;
}


reachAction::reachAction(numAction &act, set<int> &staticPreds, 
			 numDomain &numDom) {

  name = act.name;
  paramNum = act.paramNum;
  

  
  ////////////////////////////////////////////////////////
  // make domain of parameters based on static predicates
  // in init (obs: this domain is relaxed when constrai-
  // ning static facts with more than 1 argument)
  ////////////////////////////////////////////////////////
  set<int> allObjs;
  for (int i=0; i<numDom.objNum; i++)
    allObjs.insert(i+1);

  // initialize parameter domain to all objects
  paramDom.resize(paramNum,allObjs);
  
  // go through precondition of action
  for (int i=0; i<act.pre.size(); i++) {
    // if pred is static
    int pred = (act.pre[i])[0];
    if (setMember(staticPreds,pred))
      // go through its arguments
      for (int j=1; j < act.pre[i].size(); j++) {
	// if arg refers to a parameter, constrain its domain
	int param = -(act.pre[i])[j]; 
	if (param)
	  {
	    set<int> s1 = domainOf(numDom.init,pred,j);
	    paramDom[param-1] = setIntersection(paramDom[param-1],s1);
	  }
      }
  }	      


  ///////////////////////////////////////////////////////
  // make leveling of precondition
  ///////////////////////////////////////////////////////
  pre.resize(paramNum+1); // level 0 is for grounded predicates
  for (int i=0; i<act.pre.size(); i++)
    // if predicate is not static with arity < 2
    if (!(setMember(staticPreds,(act.pre[i])[0]) && (act.pre[i].size() < 3)) ) {
	// then insert predicate at the level of its max parameter index in its support
	predsInPre.insert((act.pre[i])[0]);
	pre[levelOf(act.pre[i])].push_back(act.pre[i]);
    }
  
  // add not changed 
  add = act.add;
}
	


  
////////////////////////////////////////////////////////
// Print functions
////////////////////////////////////////////////////////

void printObjSet(set<int> s, numDomain &numDom) {
  for (set<int>::iterator si=s.begin(); si != s.end(); ++si)
    cout << numDom.obj[*si - 1] << " ";
}

void printPredSet(set<int> s,numDomain  &numDom) {
  for (set<int>::iterator si=s.begin(); si != s.end(); ++si)
    cout << numDom.pred[*si - 1] << " ";
}

void printProjSet(set<int> s,int objShift,int objMask,int projArity,numDomain &numDom) {
  vector<int> proj;
  
  for (set<int>::iterator si = s.begin(); si != s.end(); ++si) 
    {
      int id = *si;
      id2proj(id,proj,objMask,objShift,projArity);
      cout << "(";
      for (int i=0; i<proj.size(); i++) {
	cout << numDom.obj[proj[i]-1];
	if (i < proj.size()-1) cout << ",";
      }
      cout << ") ";
    }
}


void printFact(vector<int> &fact,numDomain &numDom) {

  cout << numDom.pred[fact[0]-1] << "(";
  for (int i=1; i<fact.size(); i++)
    {
      if (fact[i] < 0) 
	cout << fact[i];
      else
	cout << numDom.obj[fact[i]-1];
      if (i < fact.size()-1) cout << ",";
    }
  cout << ")";
}


void printFactSet2(set<int> s,int predShift, int objShift,
	     int predMask,int objMask, vector<int> predArity,
	     map<int,int> &depth,numDomain &numDom) {

  vector<int> fact;
  
  for (set<int>::iterator si = s.begin(); si != s.end(); ++si) 
    {
      int id = *si;
      id2fact(id,fact,objMask,predMask,
	      predShift,objShift,predArity);
      printFact(fact,numDom);
      cout << " depth = " << depth[id] << endl;
    }
}




void printFactSet(set<int> s,int predShift, int objShift,
		  int predMask,int objMask, vector<int> predArity,
		  numDomain &numDom) {

  vector<int> fact;
  
  for (set<int>::iterator si = s.begin(); si != s.end(); ++si) 
    {
      int id = *si;
      id2fact(id,fact,objMask,predMask,
	      predShift,objShift,predArity);
      printFact(fact,numDom);
      cout << " ";
    }
}





//IN
// pred   : ungrounded pred of action
// projid : projection of action (bitvector rep.)
// ...
//OUT
// bitvector rep. of instantiated fact 
int instantiate(vector<int> pred,int projId,int projArity,int objMask,int predMask,
		int predShift,int objShift) {

  vector<int> proj;
  id2proj(projId,proj,objMask,objShift,projArity);

  for (int i=1; i < pred.size(); i++)
    if (pred[i] < 0)
      pred[i] = proj[-pred[i]-1];
   
  return fact2id(pred,predShift,objShift);;
}






void reachAction::print(numDomain &numDom,int objShift,int objMask) {
  cout << name << endl;
  cout << "  ParamNum: " << paramNum << endl;
  cout << "    Domain:\n";
  for (int i=0; i<paramNum; i++) {
    cout << "      param" << i+1 << ": ";
    printObjSet(paramDom[i],numDom);
    cout << endl;
  }    
  cout << "  Preds in pre: ";
  printPredSet(predsInPre,numDom);
  cout << endl;
  for (int i=0; i<pre.size(); i++) {
    cout << "    Level " << i << " : "; 
    for (int j=0; j<pre[i].size(); j++) {
      printFact((pre[i])[j],numDom);
      cout << " ";
    }
    cout << endl;
  }

  cout << "  Add : ";
  for (int j=0; j<add.size(); j++) {
    printFact(add[j],numDom);
    cout << " ";
  }
  cout << endl;
  
  cout << "  Applied proj.: ";
  printProjSet(oldProjections,objShift,objMask,paramNum,numDom); 
  cout << endl;  
}

 


void reachInfo::print(numDomain &numDom) {
  cout << "ReachInfo Structure\n";
  cout << "PredShift : " << predShift << endl;
  cout << "ObjShift : "  << objShift  << endl;
  cout << "MaxArity : "  << maxArity  << endl;
  cout << "PredArity: \n";
  for (int i=0; i<predArity.size(); i++) 
    cout << "  pred" << i << ": " << predArity[i] << endl;
  
  cout << "Initfluents : ";
  printFactSet(initFluents,predShift,objShift,
	       predMask,objMask,predArity,numDom);
  cout << endl;
  
  cout << "InitStatFacts : ";
  printFactSet(initStaticFacts,predShift,objShift,
	       predMask,objMask,predArity,numDom);
  cout << endl;

  cout << "StaticPreds : ";
  printPredSet(staticPreds,numDom);
  cout << endl;  

  cout << "Action structures :\n";
  for (int i=0; i<act.size(); i++) {
    act[i].print(numDom,objShift,objMask);
    cout << endl;
  }
  
  cout << "Openfacts : ";
  printFactSet(openFacts,predShift,objShift,
	       predMask,objMask,predArity,numDom);
  cout << endl;
  
  cout << "ClosedFacts : ";
  printFactSet2(closedFacts,predShift,objShift,
	       predMask,objMask,predArity,depth,numDom);
  cout << endl;

}

   

////////////////////////////////////////////////////////
// Main reach structure 
// constructor
////////////////////////////////////////////////////////

void reachInfo::analyse(numDomain &numDom, double squizeFactor) {
  
  predShift = int(ceil(log(numDom.predNum)/log(2)));
  objShift  = numDom.objNum == 0 ? 0 : int(ceil(log(numDom.objNum)/log(2)));

  predMask = int(pow(2,predShift)) - 1;
  objMask =  int(pow(2,objShift)) - 1;

  maxArity = numDom.maxArity;
  predArity = numDom.predArity;
  
  if (maxArity*objShift + predShift > INTBITNUM) {
    cout << "reachability.cc:reachInfo::reachInfo: Bitrep too long\nexiting\n";
    exit(1);
  }

  // find static facts preds
  // init assume all preds are static and remove the preds
  // of any add or del list
  for (int i=0; i<numDom.predNum; i++)
    staticPreds.insert(i+1);
  for (int i=0; i<numDom.act.size(); i++) {
    for (int j=0; j<numDom.act[i].add.size(); j++)
      staticPreds.erase((numDom.act[i].add[j])[0]);
    for (int j=0; j<numDom.act[i].del.size(); j++)
      staticPreds.erase((numDom.act[i].del[j])[0]);
    }

  // find static facts and fluents in init
  for (int i=0; i < numDom.init.size(); i++)
    if (setMember(staticPreds,(numDom.init[i])[0]))
      initStaticFacts.insert(fact2id(numDom.init[i],predShift,objShift));
    else
      initFluents.insert(fact2id(numDom.init[i],predShift,objShift));
  
  // make action structures
  for (int i=0; i<numDom.act.size(); i++)
    act.push_back(reachAction(numDom.act[i],staticPreds,numDom));

  // make fixpoint analysis
  fixedPoint();

  // re-adjust fact depths in order to weaken the HSPr heuristic,
  // (if this is needed)
  range(squizeFactor);
}




////////////////////////////////////////////////////////
// projection struct used
// to instantiate parameters
// member functions
////////////////////////////////////////////////////////

inline projection::projection(vector< set<int> > &paramDom) {
  for (int i=0; i<paramDom.size(); i++) {
    pi.push_back(paramDom[i].begin());
    pBegin.push_back(paramDom[i].begin());
    pEnd.push_back(paramDom[i].end());
  }
}

//IN
// param : which param to alter (numbered 1,2,3)
//         projections are traversed depth first in accending
//         paramter order. Thus if parameter i is increased, parameter j > i
//         must be set to their begin value. If paramter i is at its max value
//         paramter i-1 is increased. 
//OUT
//         The parameter index of the smallest paramter
//         increased is returned or (-1 if no such number exists = no 
//         projections left)
// Programmer comments:
//  note that predicates, objects and paramters are refered in the numDom:
//  1,2,3 ... and not 0,1,2,3. We adopt the same reference at all level in the
//  code, except for the arrays of these entities.  
inline int projection::next(int &param) {

  int p=param;
  pi[p-1]++;
  while (pi[p-1] == pEnd[p-1]) {
    p--;
    if (p == 0) return -1;
    pi[p-1]++;
  }

  // initialize following paramters to first element
  for (int i = p; i<pi.size(); i++)
    pi[i] = pBegin[i];
  
  return p;
}

// pre at least one parameter
// mk bitvector id of projection
// Form: [ParamN,ParamN-1,Param1]
inline int projection::id(int &objShift) {

  int res = *(pi[pi.size()-1]) - 1;
  for (int i=pi.size()-2; i >= 0; i--) {
    res = res << objShift;
    res += *(pi[i]) - 1;
  }
  return res;
}
  
    
// make bit vector id of fact
inline int projection::instantiate(vector<int> &fact,int &predShift,int &objShift) {

  int id = 0;
  for (int i=fact.size()-1; i > 0; i--)
    {
      id = id << objShift;
      if (fact[i] > 0)
	id += fact[i]-1;
      else
	id += *(pi[-fact[i]-1]) - 1;
    }
  id = id << predShift;
  id += fact[0]-1;
  return id;
}




////////////////////////////////////////////////////////
// reachability analysis function
// MAIN FUNCTION
//
// Obs: parameter domains are intialized to all objects
//      but then constrained according to static
//      "typing preds" in the precondition
////////////////////////////////////////////////////////
void reachInfo::fixedPoint() {

  
  // initialize open and closed list
  closedFacts = initStaticFacts;
  openFacts = initFluents;
  for (set<int>::iterator si = initFluents.begin(); si != initFluents.end(); ++si) 
    openList.push_back(*si); // I have to keep open facts in a list to guarantee that actions are applied in a BFS manner
                             // however I also track a set of open facts in order to make computations on the set faster 

  // set depth to 0 of all static preds
  for (set<int>::iterator si = initStaticFacts.begin(); si != initStaticFacts.end(); ++si)
    depth[*si] = 0;

  // fire all actions with empty precondition,
  // add the produced facts to openFacts, and
  // remove the action from the action vector
  for (int i=0; i<act.size(); i++) 
    if (act[i].predsInPre.empty())      
      // add fully instantiated preds in addlist
      for (int pred = 0; pred < act[i].add.size(); pred++)
	{
	  int fid = fact2id(act[i].add[pred],predShift,objShift);
	  // only add fact to openFacts/list if its not already a closed fact and open facts
	  if (!setMember(closedFacts,fid) && !setMember(openFacts,fid))
	    {
	      openFacts.insert(fid);
	      openList.push_back(fid);
	      depth[fid] = 0;
	    }
	}
  
  
  ////////////////////////////////////////////////////////
  // Main loop
  ////////////////////////////////////////////////////////
  while (!openList.empty())
    {
      int curFactId; 

      // get next fact from openFacts
      curFactId = openList.front();
      openList.pop_front();
      openFacts.erase(curFactId);
      
      // set current depth
      int curDepth = depth[curFactId];

      // translate id to fact (called current fact)
      vector<int> curFact;
      id2fact(curFactId,curFact,objMask,predMask,predShift,
	      objShift,predArity);

	
      // if current fact is not in closedFacts
      if (!setMember(closedFacts,curFactId))  
	{
	  closedFacts.insert(curFactId);
	  
	  // go through actions
	  for (int i = 0; i < act.size(); i++)
	    // if current fact is a predicate in the precondition 
	    // of the action then try to match current fact to predicate in precondition
	    if (setMember(act[i].predsInPre,curFact[0]))
	      {
		// try to find a predicate in the precondition 
		// matching the current fact
		vector< set<int> > curParamDom;
		int curFactLevel;
		int curFactPos;
		bool matchOk = false;
		// go through all levels
		for (int l = 0; l < act[i].pre.size(); l++)
		  {
		    // go through all predicates at level l
		    for (int n=0; n < act[i].pre[l].size(); n++)
		      if (((act[i].pre[l])[n])[0] == curFact[0])
			{ 
			  // predicate found match its arguments
			  matchOk = true;
			  for (int a = 1; a < (act[i].pre[l])[n].size(); a++) 
			    {
			      if (((act[i].pre[l])[n])[a] < 0)
				// argument is a parameter
				if (!setMember(act[i].paramDom[-(((act[i].pre[l])[n])[a])-1],curFact[a]))
				  {
				    matchOk = false;
				    break;
				  }
			      if (((act[i].pre[l])[n])[a] > 0 && ((act[i].pre[l])[n])[a] != curFact[a]) 
				{
				  matchOk = false;
				  break;
				}
			    }			  
			  
			  // constrain parameter domain if match found
			  if (matchOk)
			    {
			      // init paramdom
			      curParamDom = act[i].paramDom;
			      
			    // record position of matched predicate
			      curFactLevel = l;  
			      curFactPos = n;			    			    
			      
			      for (int a = 1; a < (act[i].pre[l])[n].size(); a++)
				if (((act[i].pre[l])[n])[a] < 0) {
				  curParamDom[-(((act[i].pre[l])[n])[a])-1].clear();
				  curParamDom[-(((act[i].pre[l])[n])[a])-1].insert(curFact[a]);
				}
			      break;
			    }
			}
		    if (matchOk) break;
		  }
		
		// if current fact matched check if level 0 is satisfied
		if (matchOk)
		  {
		    bool Oklevel0 = true;
		    for (int j = 0; j<act[i].pre[0].size(); j++) {
		      int f = fact2id((act[i].pre[0])[j],predShift,objShift);
		      if (!setMember(closedFacts,f)) {
			Oklevel0 = false;
			break;
		      }
		    }
		    
		    
		    ////////////////////////////////////////////////////////
		    // Produce Projections if level 0 is satisfied 		   
		    ////////////////////////////////////////////////////////
		    if (Oklevel0)
		      {
			// remove predicate of current fact from precondition
			// to reduce cost of checking the precondition
			vector<int> curPred = (act[i].pre[curFactLevel])[curFactPos];
			vector< vector<int> >::iterator vi = (act[i].pre[curFactLevel]).begin();
			for (int q = 0; q < curFactPos; q++) ++vi;
			act[i].pre[curFactLevel].erase(vi);
			
			if (act[i].paramNum > 0)
			  {
			    //////////////////////////////
			    // Case 1: the action has a  
			    // non-empty set of parameters
			    //////////////////////////////
			    
			    projection proj(curParamDom);
			    
			    // go through all projections
			    int fstParam = 1;
			    while (!proj.empty())
			      {
				bool levelOk;
				for (int level = fstParam; level < act[i].pre.size(); level++)
				  {
				    levelOk = true; 
				    for (int pred = 0; pred < act[i].pre[level].size(); pred++)
				      // check each fact in level
				      if (!setMember(closedFacts,
						     proj.instantiate((act[i].pre[level])[pred],
								      predShift,objShift))) 
					{
					  levelOk = false;
					  break;
					}
				    
				    if (!levelOk) 
				      {
					// change projection and continue at resulting level
					// the level is incremented by the for loop just after this 
					// adjustment, thus it level must be one lower than the correct
					// value
					level = proj.next(level) - 1;
					if (level < 0) break;
				      }
				  }
				
				if (levelOk)
				  {
				    // we went successfully through all levels
				    // of precondition apply projection
				    int projId = proj.id(objShift);
				    
				    if (!setMember(act[i].oldProjections,projId))
				      {
					// projection has never been applied before
					// apply it
					act[i].oldProjections.insert(projId);
					
					for (int pred = 0; pred < act[i].add.size(); pred++)
					  {
					    int fid = proj.instantiate(act[i].add[pred],
								       predShift,objShift);
					    // only add fact to openFacts/list if its not already
					    // a closed fact or an open fact
					    if (!setMember(closedFacts,fid) && !setMember(openFacts,fid))
					      {
						openFacts.insert(fid);
						openList.push_back(fid);
						depth[fid] = curDepth + 1;
					      }
					  }
				      }
				    // change projection and continue at resulting level
				    fstParam = proj.next(act[i].paramNum);			    
				  }
			      }				
			  }
			else
			  {
			    //////////////////////////////
			    // Case 2 parameter list
			    // of action is empry
			    //////////////////////////////
			    
			    // since level 0 is satisfied, the precondition is
			    // satisfied: add fully instantiated preds of addlist
			    for (int pred = 0; pred < act[i].add.size(); pred++)
			      {
				int fid = fact2id(act[i].add[pred],predShift,objShift);
				// only add fact to openFacts/list if its not already
				// a closed fact or open fact
				if (!setMember(closedFacts,fid) && !setMember(openFacts,fid))
				  {
				    openFacts.insert(fid);
				    openList.push_back(fid);
				    depth[fid] = curDepth + 1;
				  }
			      }
			}
			
			// reinsert the predicate matched by current fact
			// in the precondition
			act[i].pre[curFactLevel].push_back(curPred);
		      }
		  }
	      }
	}
    }
}

	      



void reachInfo::range(double squizeFactor) {

  map<int,int> newDepths;
  
  // adjust the depths into the new ranges
  for (map<int,int>::iterator mpi = depth.begin(); mpi != depth.end(); ++mpi)
    {
      int factId = mpi->first;
      int oldDepth = mpi->second;
      
      newDepths[factId] = int( squizeFactor * double(oldDepth) );
    }

  depth = newDepths;
    
}
