package edu.cmu.minorthird.classify.sequential;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.MathUtil;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/BeamSearcher.class */
public class BeamSearcher implements SequenceConstants, Serializable {
    private static final long serialVersionUID = 1;
    private static boolean OLD_VERSION = false;
    private static Logger log;
    private static final boolean DEBUG = false;
    private int historySize;
    private String[] possibleClassLabels;
    private Classifier classifier;
    private transient Instance[] instances;
    private transient String[] history;
    static Class class$edu$cmu$minorthird$classify$sequential$BeamSearcher;
    private final int CURRENT_SERIAL_VERSION = 1;
    private int beamSize = 10;
    private transient Beam beam = new Beam(this, null);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: edu.cmu.minorthird.classify.sequential.BeamSearcher$1, reason: invalid class name */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/BeamSearcher$1.class */
    public static class AnonymousClass1 {
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/BeamSearcher$Beam.class */
    public class Beam {
        private ArrayList list;
        private HashMap keyMap;
        private final BeamSearcher this$0;

        private Beam(BeamSearcher beamSearcher) {
            this.this$0 = beamSearcher;
            this.list = new ArrayList();
            this.keyMap = new HashMap();
        }

        public BeamEntry get(int i) {
            return (BeamEntry) this.list.get(i);
        }

        public void add(BeamEntry beamEntry) {
            BeamKey beamKey = new BeamKey(this.this$0, beamEntry);
            BeamEntry beamEntry2 = (BeamEntry) this.keyMap.get(beamKey);
            if (beamEntry2 == null || beamEntry2.score < beamEntry.score) {
                if (beamEntry2 != null) {
                    this.list.remove(beamEntry2);
                }
                this.list.add(beamEntry);
                this.keyMap.put(beamKey, beamEntry);
            }
        }

        public int size() {
            return this.list.size();
        }

        public void sort() {
            Collections.sort(this.list);
        }

        Beam(BeamSearcher beamSearcher, AnonymousClass1 anonymousClass1) {
            this(beamSearcher);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/BeamSearcher$BeamEntry.class */
    public class BeamEntry implements Comparable {
        public String[] labels;
        public double[] scores;
        public double score;
        private final BeamSearcher this$0;

        private BeamEntry(BeamSearcher beamSearcher) {
            this.this$0 = beamSearcher;
            this.labels = new String[0];
            this.scores = new double[0];
            this.score = 0.0d;
        }

        @Override // java.lang.Comparable
        public int compareTo(Object obj) {
            return MathUtil.sign(((BeamEntry) obj).score - this.score);
        }

        public ClassLabel toClassLabel(int i) {
            return new ClassLabel(this.labels[i], this.scores[i]);
        }

        public BeamEntry extend(String str, double d) {
            BeamEntry beamEntry = new BeamEntry(this.this$0);
            beamEntry.labels = new String[this.labels.length + 1];
            beamEntry.scores = new double[this.labels.length + 1];
            for (int i = 0; i < this.labels.length; i++) {
                beamEntry.labels[i] = this.labels[i];
                beamEntry.scores[i] = this.scores[i];
            }
            beamEntry.labels[this.labels.length] = str;
            beamEntry.scores[this.labels.length] = d;
            beamEntry.score = this.score + d;
            return beamEntry;
        }

        public Instance getBeamInstance(Instance instance) {
            fillHistory(this.this$0.history);
            return new InstanceFromSequence(instance, this.this$0.history);
        }

        public void fillHistory(String[] strArr) {
            InstanceFromSequence.fillHistory(strArr, this.labels, this.labels.length);
        }

        public String toString() {
            return new StringBuffer().append("[entry: ").append(this.labels).append(";").append(this.scores).append("; score:").append(this.score).append("]").toString();
        }

        BeamEntry(BeamSearcher beamSearcher, AnonymousClass1 anonymousClass1) {
            this(beamSearcher);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/sequential/BeamSearcher$BeamKey.class */
    public class BeamKey {
        private String[] keyHistory;
        private final BeamSearcher this$0;

        public BeamKey(BeamSearcher beamSearcher, BeamEntry beamEntry) {
            this.this$0 = beamSearcher;
            this.keyHistory = new String[this.this$0.historySize];
            beamEntry.fillHistory(this.keyHistory);
        }

        public int hashCode() {
            int i;
            int hashCode;
            int i2 = 73643674;
            for (int i3 = 0; i3 < this.keyHistory.length; i3++) {
                if (BeamSearcher.OLD_VERSION) {
                    i = i2;
                    hashCode = this.keyHistory.hashCode();
                } else {
                    i = i2;
                    hashCode = this.keyHistory[i3].hashCode();
                }
                i2 = i ^ hashCode;
            }
            return i2;
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof BeamKey)) {
                return false;
            }
            BeamKey beamKey = (BeamKey) obj;
            if (beamKey.keyHistory.length != this.keyHistory.length) {
                return false;
            }
            for (int i = 0; i < beamKey.keyHistory.length; i++) {
                if (!this.keyHistory[i].equals(beamKey.keyHistory[i])) {
                    return false;
                }
            }
            return true;
        }

        public String toString() {
            String str = "[Key ";
            for (int i = 0; i < this.keyHistory.length; i++) {
                str = new StringBuffer().append(str).append(this.keyHistory[i]).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString();
            }
            return new StringBuffer().append(str).append("]").toString();
        }
    }

    public BeamSearcher(Classifier classifier, int i, ExampleSchema exampleSchema) {
        this.classifier = classifier;
        this.historySize = i;
        this.possibleClassLabels = exampleSchema.validClassNames();
        if (this.possibleClassLabels.length < 2) {
            throw new IllegalArgumentException(new StringBuffer().append("possibleClassLabels.length=").append(this.possibleClassLabels.length).append(" <2 ???").toString());
        }
    }

    public int getMaxBeamSize() {
        return this.beamSize;
    }

    public void setMaxBeamSize(int i) {
        this.beamSize = i;
    }

    public ClassLabel[] bestLabelSequence(Instance[] instanceArr) {
        doSearch(instanceArr);
        return viterbi(0);
    }

    public static Instance getBeamInstance(Instance instance, int i) {
        String[] strArr = new String[i];
        InstanceFromSequence.fillHistory(strArr, new String[0], 0);
        return new InstanceFromSequence(instance, strArr);
    }

    public void doSearch(Instance[] instanceArr) {
        this.instances = instanceArr;
        if (this.possibleClassLabels.length < 2) {
            throw new IllegalStateException(new StringBuffer().append("possibleClassLabels.length=").append(this.possibleClassLabels.length).append(" <2 ???").toString());
        }
        this.history = new String[this.historySize];
        this.beam = new Beam(this, null);
        this.beam.add(new BeamEntry(this, null));
        for (int i = 0; i < this.instances.length; i++) {
            Beam beam = new Beam(this, null);
            for (int i2 = 0; i2 < Math.min(this.beam.size(), this.beamSize); i2++) {
                BeamEntry beamEntry = this.beam.get(i2);
                ClassLabel classification = this.classifier.classification(beamEntry.getBeamInstance(this.instances[i]));
                for (int i3 = 0; i3 < this.possibleClassLabels.length; i3++) {
                    beam.add(beamEntry.extend(this.possibleClassLabels[i3], classification.getWeight(this.possibleClassLabels[i3])));
                }
            }
            beam.sort();
            this.beam = beam;
        }
    }

    public void doSearch(Instance[] instanceArr, ClassLabel[] classLabelArr) {
        this.instances = instanceArr;
        if (this.possibleClassLabels.length < 2) {
            throw new IllegalStateException(new StringBuffer().append("possibleClassLabels.length=").append(this.possibleClassLabels.length).append(" <2 ???").toString());
        }
        this.history = new String[this.historySize];
        this.beam = new Beam(this, null);
        this.beam.add(new BeamEntry(this, null));
        for (int i = 0; i < this.instances.length; i++) {
            Beam beam = new Beam(this, null);
            for (int i2 = 0; i2 < Math.min(this.beam.size(), this.beamSize); i2++) {
                BeamEntry beamEntry = this.beam.get(i2);
                ClassLabel classification = this.classifier.classification(beamEntry.getBeamInstance(this.instances[i]));
                for (int i3 = 0; i3 < this.possibleClassLabels.length; i3++) {
                    if (classLabelArr.length < i + 1 || classLabelArr[i] == null || classLabelArr[i].bestClassName().equals(this.possibleClassLabels[i3])) {
                        beam.add(beamEntry.extend(this.possibleClassLabels[i3], classification.getWeight(this.possibleClassLabels[i3])));
                    }
                }
            }
            beam.sort();
            this.beam = beam;
        }
    }

    public int getNumberOfSolutionsFound() {
        return this.beam.size();
    }

    public ClassLabel[] viterbi(int i) {
        ClassLabel[] classLabelArr = new ClassLabel[this.instances.length];
        BeamEntry beamEntry = this.beam.get(i);
        for (int i2 = 0; i2 < this.instances.length; i2++) {
            classLabelArr[i2] = beamEntry.toClassLabel(i2);
        }
        return classLabelArr;
    }

    public float score(int i) {
        return (float) this.beam.get(i).score;
    }

    public String explain(Instance[] instanceArr) {
        StringBuffer stringBuffer = new StringBuffer("");
        doSearch(instanceArr);
        BeamEntry beamEntry = this.beam.get(0);
        BeamEntry beamEntry2 = new BeamEntry(this, null);
        for (int i = 0; i < instanceArr.length; i++) {
            stringBuffer.append(new StringBuffer().append("Classification for instance ").append(i).append(" is ").append(beamEntry.labels[i]).append(" (score ").append(beamEntry.scores[i]).append("):\n").toString());
            stringBuffer.append(this.classifier.explain(beamEntry2.getBeamInstance(instanceArr[i])));
            beamEntry2 = beamEntry2.extend(beamEntry.labels[i], beamEntry.scores[i]);
            stringBuffer.append(new StringBuffer().append("\nRunning total score: ").append(beamEntry2.score).append(AbstractFormatter.DEFAULT_SLICE_SEPARATOR).toString());
        }
        return stringBuffer.toString();
    }

    public Explanation getExplanation(Instance[] instanceArr) {
        doSearch(instanceArr);
        BeamEntry beamEntry = this.beam.get(0);
        BeamEntry beamEntry2 = new BeamEntry(this, null);
        Explanation.Node node = new Explanation.Node("BeamSearcher Classification");
        for (int i = 0; i < instanceArr.length; i++) {
            Explanation.Node node2 = new Explanation.Node(new StringBuffer().append("Classification for instance ").append(i).append(" is ").append(beamEntry.labels[i]).append(" (score ").append(beamEntry.scores[i]).append("):\n").toString());
            Explanation.Node topNode = this.classifier.getExplanation(instanceArr[i]).getTopNode();
            if (topNode == null) {
                topNode = new Explanation.Node(this.classifier.explain(beamEntry2.getBeamInstance(instanceArr[i])));
            }
            node2.add(topNode);
            beamEntry2 = beamEntry2.extend(beamEntry.labels[i], beamEntry.scores[i]);
            node2.add(new Explanation.Node(new StringBuffer().append("\nRunning total score: ").append(beamEntry2.score).append(AbstractFormatter.DEFAULT_SLICE_SEPARATOR).toString()));
            node.add(node2);
        }
        return new Explanation(node);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError().initCause(e);
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$sequential$BeamSearcher == null) {
            cls = class$("edu.cmu.minorthird.classify.sequential.BeamSearcher");
            class$edu$cmu$minorthird$classify$sequential$BeamSearcher = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$sequential$BeamSearcher;
        }
        log = Logger.getLogger(cls);
    }
}
