package com.wcohen.ss;

import cern.colt.matrix.impl.AbstractFormatter;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.api.Tokenizer;
import java.util.Iterator;

/* loaded from: input_file:com/wcohen/ss/TFIDF.class */
public class TFIDF extends AbstractStatisticalTokenDistance {
    private UnitVector lastVector;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/wcohen/ss/TFIDF$UnitVector.class */
    public class UnitVector extends BagOfTokens {
        private final TFIDF this$0;

        public UnitVector(TFIDF tfidf, String str, Token[] tokenArr) {
            super(str, tokenArr);
            this.this$0 = tfidf;
            termFreq2TFIDF();
        }

        public UnitVector(TFIDF tfidf, BagOfTokens bagOfTokens) {
            this(tfidf, bagOfTokens.unwrap(), bagOfTokens.getTokens());
            termFreq2TFIDF();
        }

        private void termFreq2TFIDF() {
            double d = 0.0d;
            Iterator it = tokenIterator();
            while (it.hasNext()) {
                Token token = (Token) it.next();
                if (this.this$0.collectionSize > 0) {
                    double log = Math.log(getWeight(token) + 1.0d) * Math.log(this.this$0.collectionSize / (((Integer) this.this$0.documentFrequency.get(token)) == null ? 1.0d : r0.intValue()));
                    setWeight(token, log);
                    d += log * log;
                } else {
                    setWeight(token, 1.0d);
                    d += 1.0d;
                }
            }
            double sqrt = Math.sqrt(d);
            Iterator it2 = tokenIterator();
            while (it2.hasNext()) {
                Token token2 = (Token) it2.next();
                setWeight(token2, getWeight(token2) / sqrt);
            }
        }
    }

    public TFIDF(Tokenizer tokenizer) {
        super(tokenizer);
        this.lastVector = null;
    }

    public TFIDF() {
        this.lastVector = null;
    }

    @Override // com.wcohen.ss.AbstractStringDistance, com.wcohen.ss.api.StringDistance
    public double score(StringWrapper stringWrapper, StringWrapper stringWrapper2) {
        checkTrainingHasHappened(stringWrapper, stringWrapper2);
        UnitVector asUnitVector = asUnitVector(stringWrapper);
        UnitVector asUnitVector2 = asUnitVector(stringWrapper2);
        double d = 0.0d;
        Iterator it = asUnitVector.tokenIterator();
        while (it.hasNext()) {
            Token token = (Token) it.next();
            if (asUnitVector2.contains(token)) {
                d += asUnitVector.getWeight(token) * asUnitVector2.getWeight(token);
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public UnitVector asUnitVector(StringWrapper stringWrapper) {
        return stringWrapper instanceof UnitVector ? (UnitVector) stringWrapper : stringWrapper instanceof BagOfTokens ? new UnitVector(this, (BagOfTokens) stringWrapper) : new UnitVector(this, stringWrapper.unwrap(), this.tokenizer.tokenize(stringWrapper.unwrap()));
    }

    @Override // com.wcohen.ss.AbstractStringDistance, com.wcohen.ss.api.StringDistance
    public StringWrapper prepare(String str) {
        this.lastVector = new UnitVector(this, str, this.tokenizer.tokenize(str));
        return this.lastVector;
    }

    public Token[] getTokens() {
        return this.lastVector.getTokens();
    }

    public double getWeight(Token token) {
        return this.lastVector.getWeight(token);
    }

    @Override // com.wcohen.ss.AbstractStatisticalTokenDistance
    public int getDocumentFrequency(Token token) {
        return ((Integer) this.documentFrequency.get(token)).intValue();
    }

    public void setDocumentFrequency(Token token, int i) {
        this.documentFrequency.put(token, new Integer(i));
    }

    public int getCollectionSize() {
        return this.collectionSize;
    }

    public void setCollectionSize(int i) {
        this.collectionSize = i;
    }

    @Override // com.wcohen.ss.AbstractStringDistance, com.wcohen.ss.api.StringDistance
    public String explainScore(StringWrapper stringWrapper, StringWrapper stringWrapper2) {
        BagOfTokens bagOfTokens = (BagOfTokens) stringWrapper;
        BagOfTokens bagOfTokens2 = (BagOfTokens) stringWrapper2;
        StringBuffer stringBuffer = new StringBuffer("");
        PrintfFormat printfFormat = new PrintfFormat("%.3f");
        stringBuffer.append("Common tokens: ");
        Iterator it = bagOfTokens.tokenIterator();
        while (it.hasNext()) {
            Token token = (Token) it.next();
            if (bagOfTokens2.contains(token)) {
                stringBuffer.append(new StringBuffer().append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(token.getValue()).append(": ").toString());
                stringBuffer.append(printfFormat.sprintf(bagOfTokens.getWeight(token)));
                stringBuffer.append("*");
                stringBuffer.append(printfFormat.sprintf(bagOfTokens2.getWeight(token)));
            }
        }
        stringBuffer.append(new StringBuffer().append("\nscore = ").append(score(stringWrapper, stringWrapper2)).toString());
        return stringBuffer.toString();
    }

    public String toString() {
        return "[TFIDF]";
    }

    public static void main(String[] strArr) {
        doMain(new TFIDF(), strArr);
    }
}
