#include  "defns.i"
#include  "extern.i"


    /*  Compute the gain from literals of the form R(A) or ~R(A).
	If better than the best so far, save.

	NB: The various counts have components with the following meanings:
		Pos	plus tuples
		Tot	all tuples
		T	obtained from R(A)
		F	obtained using ~R(A)
		Now	number of current tuples
		Orig	number of tuples in original training set
		New	referring to training set if use R(A) or ~R(A)  */

    /*  This version uses the same kind of pruning cutoffs as args.c  */


float  ComputeGain(R, A, LitBits, CanUse)
/*     -----------  */
    Relation R;
    Vars A;
    float LitBits;
    Boolean CanUse;
{
    int Instances = 1,				/* ground tuples 
						   when all vars bound */
	TInstances, FInstances,			/* no of them in/out 
						   wrt relation */
	TOInstances,				/* all non-F instances */
        NowTPos = 0, TPos = 0, TTot = 0,        /* params if use R(A) 
						   as literal */
	NowFPos = 0, NowFNeg = 0, NowFTot,	/* ditto if use ~R(A) */
	OrigTPos = 0, OrigFPos = 0,		/* no original pos tuples 
						   that would pass */
	OrigTTot = 0, OrigFTot = 0,		/* no original tuples 
						   that would pass */
	Best, N, i, j, UnboundVars = 0, Number();
    Var V, W,  Unbound[MAXARITY+1], New[MAXARITY*MAXARITY], 
        Previous[MAXARITY*MAXARITY];
    int Col, PossibleDuplicateVars = 0;
    Const PreviousVal, FirstBinding[MAXARITY+1];
    Tuple *TSP, Case;
    float Gain, PosGain, NegGain, MinUsefulGain, NowTPosThresh, NowFNegThresh,
	Worth(), Info();
    Boolean NewVar[MAXARITY+1], WeakPos, WeakNeg,  RuleOutT, RuleOutF,
	Determinate, IdenticalBindings, OK, ExcessTuplesT, ExcessTuplesF;
    Literal L;
    int il, jl;
    VarInfo Vl;


    /*  Table of constraints:

	TInstances + FInstances + OptInstances = Instances
	NewFTot/Pos = NowFTot/Pos
	NowTPos + NowFPos = Pos		*/

    N = R->Arity;

    /*  Compute number of ground tuples corresponding to args A  */

    memset(NewVar, true, MAXARITY+1);

    ForEach(i, 1, N)
    {
	if ( (V = A[i]) > MaxVar && NewVar[V] )
	{
	    Unbound[++UnboundVars] = i;
	    Instances *= R->TypeRef[i]->NValues;
	    NewVar[V] = false;

	    ForEach(W, 1, MaxVar)
	    {
		New[PossibleDuplicateVars] = i;
		Previous[PossibleDuplicateVars++] = W;
	    }
	}
    }

    ClearBits;

    IdenticalBindings = Determinate = DETERMINATE && UnboundVars > 0;

    /*  The minimum gain that would be of interest is just enough to give
	a literal a chance to be saved by the backup procedure or, if
	there are determinate literals, to reach the required fraction
	of the maximum possible gain  */

    RuleOutT = false;
    RuleOutF = ! NEGLITERALS;

    ExcessTuplesT = false;
    ExcessTuplesF = false;

    MinUsefulGain = NPossible < MAXPOSSLIT ? MINALTFRAC * BestLitGain :
		    Max(Possible[MAXPOSSLIT]->Gain, MINALTFRAC * BestLitGain);

    if ( NDeterminate && MinUsefulGain < DETERMINATE * MaxPossibleGain )
    {
	MinUsefulGain = DETERMINATE * MaxPossibleGain;
    }

    NowTPosThresh = MinUsefulGain / BaseInfo - 0.001;
    NowFNegThresh = (Pos + 1) * (pow(2.0, BaseInfo - MinUsefulGain / Pos) - 1)
                    + 0.001;

    TSP = TrainingSet;
    while ( Case = *TSP++ )
    {
	if ( Join(R->Pos, R->PosIndex, A, Case, N, false) )
	{
	    TInstances = NFound;

	    if ( ! TTot )
	    {
		memcpy(FirstBinding, Found[0], (N+1) * sizeof(Const));
	    }
	    else
	    for ( i = 1 ; IdenticalBindings && i <= UnboundVars ; i++ )
	    {
		V = Unbound[i];
		IdenticalBindings = FirstBinding[V] == Found[0][V];
	    }

	    for ( i = 0 ; Positive(Case) && i < PossibleDuplicateVars ; )
	    {
		Col = New[i];
		PreviousVal = Case[Previous[i]];
		OK = true;

		for ( j = 0 ; OK && j < NFound ; j++ )
		{
		    OK = Found[j][Col] == PreviousVal;
		}

		if ( OK )
		{
		    i++;
		}
		else
		{
		    PossibleDuplicateVars--;
		    for ( j = i ; j < PossibleDuplicateVars ; j++ )
		    {
			New[j] = New[j+1];
			Previous[j] = Previous[j+1];
		    }
		}
	    }
	}
	else
	{
	    TInstances = 0;
	}

	Determinate &= ( Positive(Case) ? TInstances == 1 : TInstances <= 1 );

	if ( TInstances )
	{
	    TTot += TInstances;
	    if ( Positive(Case) )
	    {
		NowTPos++;
		TPos += TInstances;
	    }

            if ( ! TestBit(Case[0]&Mask, TrueBit) )
            {
                SetBit(Case[0]&Mask, TrueBit);
                OrigTTot++;
                if ( Positive(Case) ) OrigTPos++;
            }

	    if(TTot>MAXTUPLES && !ExcessTuplesT )
	    {
	        ExcessTuplesT = true;
		VERBOSE(1)
		{
		    printf("\t");
		    PrintLiteral(R, true, A);
		    printf("\t tuple limit exceeded");
		}
		RuleOutT = true;
		if ( RuleOutF )
		{
		    VERBOSE(2) printf("\tabandoned");
		    VERBOSE(1) printf("\n");
		    return 0.0;
		}
		VERBOSE(1) printf("\n");
	    }
	}

	if ( ! TInstances )
	{
	    if ( Positive(Case) )
	    {
		NowFPos++;

		if ( Pos - NowFPos < NowTPosThresh && ! RuleOutT )
		{
		    RuleOutT = true;
		    if ( RuleOutF )
		    {
			VERBOSE(2)
			{
			    printf("\t");
			    PrintLiteral(R, true, A);
			    printf("\tTrue pos <= %d -- abandoned\n", 
				   Pos - NowFPos);
			}

			return 0.0;
		    }
		}
	    }
	    else
	    {
		NowFNeg++;

		if ( NowFNeg > NowFNegThresh && ! RuleOutF )
		{
		    RuleOutF = true;
		    if ( RuleOutT )
		    {
			VERBOSE(2)
			{
			    printf("\t");
			    PrintLiteral(R, true, A);
			    printf("\tFalse neg >= %d -- abandoned\n", 
				   NowFNeg);
			}

			return 0.0;
		    }
		}
	    }

            if ( ! TestBit(Case[0]&Mask, FalseBit) )
	    {
                SetBit(Case[0]&Mask, FalseBit);
                OrigFTot++;
                if ( Positive(Case) ) OrigFPos++;
	    }

	    if(NowFPos+NowFNeg>MAXTUPLES && !ExcessTuplesF )
	    {
	        ExcessTuplesF = true;
		VERBOSE(1)
		{
		    printf("\t");
		    PrintLiteral(R, false, A);
		    printf("\t tuple limit exceeded");
		}
		RuleOutF = true;
		if ( RuleOutT )
		{
		    VERBOSE(2) printf("\tabandoned");
		    VERBOSE(1) printf("\n");
		    return 0.0;
		}
		VERBOSE(1) printf("\n");
	    }
	}
    }

    NowFTot = NowFPos + NowFNeg;

    PosGain = Worth(NowTPos, TPos, TTot, UnboundVars);
    WeakPos = PosGain < 0.001 || NowFTot <= 0;

    NegGain = ( NEGLITERALS ) ? Worth(NowFPos, NowFPos, NowFTot, 0) : 0.0;
    WeakNeg = NegGain < 0.001 || NowFTot >= Tot;

    /*  Weak literal sequence check  */

    if ( ! Determinate && CanUse && WeakPos && WeakNeg 
	&& WeakLiterals >= PATIENCE )
    {
	VERBOSE(2)
	{
	    printf("\t");
	    PrintLiteral(R, true, A);
	    printf("\ttoo many weak literals\n");
	}

	return 0.0;
    }

    /*  Encoding length check  */

    if ( CanUse &&
	 ( PosGain > NegGain && (Best = OrigTPos) ||
	   NegGain > 0 && (Best = OrigFPos) ) &&
         Except(CycleTot, Best) < UsedSoFar + LitBits )
    {
	VERBOSE(2)
	{
	    printf("\t");
	    PrintLiteral(R, true, A);
	    printf("\tTrue %d, False %d, Covers %d: coding violation\n",
		NowTPos, NowFPos, Best);
	}

	return 0.0;
    }

    /*  Would the addition of this literal to the clause create the best
	compact clause so far? */

    if( ((OrigTPos==OrigTTot) && (OrigTPos>CompactClCover))
        || (NEGLITERALS && (OrigFPos==OrigFTot) && (OrigFPos>CompactClCover)) )
    {
        CompactClause = (Clause) pmalloc((NLit+2) * sizeof(Literal));
	if(NLit)
	{
	    for(il=0;il<NLit;il++)
	    {
	        L = (Literal) pmalloc(sizeof(struct _lit_rec));
		memcpy(L,NewClause[il],sizeof(struct _lit_rec));
		CompactClause[il] = L;
	    }
	}

	L = (Literal) pmalloc(sizeof(struct _lit_rec));    
	L->Rel  = R;
	L->Sign = ((OrigTPos==OrigTTot) && (OrigTPos>CompactClCover));
	L->Bits = LitBits;
	L->Args = (Vars) pmalloc(L->Rel->Arity + 1);
	memcpy(L->Args, A, L->Rel->Arity + 1);
	L->FloatingDet = false;
	if(L->Rel==Target)
	{
	    L->RefOrderedArg = (Ordering*) pmalloc(Target->Arity+1);
	    ForEach(il,1,Target->Arity) 
	    {
	        L->RefOrderedArg[il] = PartialOrder[L->Args[il]][il];
	    }
	}

	/* Need to unfloat associated determinate literals if any */

	ForEach(il,1,R->Arity)
	{
	    Vl = Variable[A[il]];
	    for(jl=0;jl<Vl->DetDeps;jl++)
	    {
	        CompactClause[Vl->DetLits[jl]]->FloatingDet = false;
	    }
	}
	
	CompactClause[NLit] = L;
	CompactClause[NLit+1] = Nil;

	CompactClCover = ((OrigTPos==OrigTTot)&&(OrigTPos>CompactClCover))
	                 ? OrigTPos : OrigFTot;

	CompactClNLit = NLit+1;

	VERBOSE(1)
	{
	    printf("Best clause so far, covering %d\n\t",
		   CompactClCover);
	    PrintClause(Target, CompactClause);
	}
    }

    VERBOSE(2)
    if ( PosGain > 0 || NegGain > 0 || Determinate || Verbosity >= 2 )
    {
	if ( Determinate ) putchar('#');
	if ( IdenticalBindings) putchar('X');
	if ( PossibleDuplicateVars ) printf(" %s=%s", 
                      Variable[A[New[0]]]->Name, Variable[Previous[0]]->Name);

	printf("\t");
	PrintLiteral(R, true, A);
	printf("\tTrue %d[%d,%d]: gain %.2f", NowTPos, TPos, TTot, PosGain);
	if ( NEGLITERALS )
	{
	    printf(";  False %d,%d: gain %.2f", NowFPos, NowFTot, NegGain);
	}
	newline;
    }

    Gain = Max(PosGain, NegGain);

    if ( CanUse )
    {
	if ( Determinate && ! IdenticalBindings &&
	     Gain < DETERMINATE * MaxPossibleGain &&
	     !ExcessTuplesT && !PossibleDuplicateVars &&
	     ProposeDeterminateLiteral(R, A, LitBits, UnboundVars) )
	{
	    return 0.001;
	}

	if ( PosGain > 1E-6 && !ExcessTuplesT && !PossibleDuplicateVars)
	{
	    ProposeLiteral(R, true, A, TTot, LitBits, OrigTPos, OrigTTot, 
			   PosGain);
	}

	if ( NegGain > 1E-6 && !ExcessTuplesF )
	{
	    ProposeLiteral(R, false, A, NowFTot, LitBits, OrigFPos, OrigFTot, 
			   NegGain);
	}
    }
    return Gain;
}



    /*  Compute aggregate gain from a test on relation R, tuple T.
	The Basic gain is the number of positive tuples * information
	gained regarding each; but there is a minor adjustment:
	  - a literal that has some positive tuples and no gain,
	    but introduces one or more unbound variables, is
	    given a slight gain  */

float Worth(N, P, T, UV)
/*    -----  */
    int N, P, T, UV;
{
    float G, TG, Info();

    TG = N * (G = BaseInfo - Info(P, T));

    if ( G < 1E-6 && N && UV )
    {
	return 0.0009 + UV * 0.0001;  /* very small notional gain */
    }
    else
    {
	return TG;
    }
    
}



float Info(N, T)
/*    ----  */
    int N, T;
{
    return Log2(T+1) - Log2(N+1);
}
