package com.wcohen.ss;

import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.api.StringWrapperIterator;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.api.Tokenizer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/wcohen/ss/JensenShannonDistance.class */
public abstract class JensenShannonDistance extends AbstractTokenizedStringDistance {
    private Map<Token, Integer> backgroundFrequency;
    int totalTokens;
    private static final Integer ONE = new Integer(1);
    private static final Integer TWO = new Integer(2);
    private static final Integer THREE = new Integer(3);

    public JensenShannonDistance(Tokenizer tokenizer) {
        super(tokenizer);
        this.backgroundFrequency = new HashMap();
        this.totalTokens = 0;
    }

    public JensenShannonDistance() {
        this.backgroundFrequency = new HashMap();
        this.totalTokens = 0;
    }

    @Override // com.wcohen.ss.AbstractTokenizedStringDistance
    public final void train(StringWrapperIterator stringWrapperIterator) {
        while (stringWrapperIterator.hasNext()) {
            Iterator<Token> it = asBagOfTokens(stringWrapperIterator.nextStringWrapper()).tokenIterator();
            while (it.hasNext()) {
                Token next = it.next();
                this.totalTokens++;
                Integer num = this.backgroundFrequency.get(next);
                if (num == null) {
                    this.backgroundFrequency.put(next, ONE);
                } else if (num == ONE) {
                    this.backgroundFrequency.put(next, TWO);
                } else if (num == TWO) {
                    this.backgroundFrequency.put(next, THREE);
                } else {
                    this.backgroundFrequency.put(next, new Integer(num.intValue() + 1));
                }
            }
        }
    }

    @Override // com.wcohen.ss.AbstractStringDistance, com.wcohen.ss.api.StringDistance
    public final StringWrapper prepare(String str) {
        BagOfTokens bagOfTokens = new BagOfTokens(str, this.tokenizer.tokenize(str));
        double totalWeight = bagOfTokens.getTotalWeight();
        Iterator<Token> it = bagOfTokens.tokenIterator();
        while (it.hasNext()) {
            Token next = it.next();
            bagOfTokens.setWeight(next, smoothedProbability(next, bagOfTokens.getWeight(next), totalWeight));
        }
        return bagOfTokens;
    }

    protected abstract double smoothedProbability(Token token, double d, double d2);

    /* JADX INFO: Access modifiers changed from: protected */
    public double backgroundProb(Token token) {
        return (this.backgroundFrequency.get(token) == null ? CMAESOptimizer.DEFAULT_STOPFITNESS : r0.intValue()) / this.totalTokens;
    }

    @Override // com.wcohen.ss.AbstractStringDistance, com.wcohen.ss.api.StringDistance
    public final double score(StringWrapper stringWrapper, StringWrapper stringWrapper2) {
        BagOfTokens bagOfTokens = (BagOfTokens) stringWrapper;
        BagOfTokens bagOfTokens2 = (BagOfTokens) stringWrapper2;
        double d = 0.0d;
        Iterator<Token> it = bagOfTokens.tokenIterator();
        while (it.hasNext()) {
            Token next = it.next();
            if (bagOfTokens2.contains(next)) {
                double weight = bagOfTokens.getWeight(next);
                double weight2 = bagOfTokens2.getWeight(next);
                d -= (h(weight + weight2) - h(weight)) - h(weight2);
            }
        }
        return (0.5d * d) / Math.log(2.0d);
    }

    private double h(double d) {
        return (-d) * Math.log(d);
    }

    @Override // com.wcohen.ss.AbstractStringDistance, com.wcohen.ss.api.StringDistance
    public final String explainScore(StringWrapper stringWrapper, StringWrapper stringWrapper2) {
        StringBuilder sb = new StringBuilder();
        PrintfFormat printfFormat = new PrintfFormat("%.3f");
        BagOfTokens bagOfTokens = (BagOfTokens) stringWrapper;
        BagOfTokens bagOfTokens2 = (BagOfTokens) stringWrapper2;
        sb.append("Common tokens: ");
        Iterator<Token> it = bagOfTokens.tokenIterator();
        while (it.hasNext()) {
            Token next = it.next();
            if (bagOfTokens2.contains(next)) {
                double weight = bagOfTokens.getWeight(next);
                double weight2 = bagOfTokens2.getWeight(next);
                sb.append(" " + next.getValue() + ": ");
                sb.append(printfFormat.sprintf(weight));
                sb.append("*");
                sb.append(printfFormat.sprintf(weight2));
                sb.append(":delta=");
                sb.append(printfFormat.sprintf((h(weight + weight2) - h(weight)) - h(weight2)));
            }
        }
        sb.append("\nscore = " + score(stringWrapper, stringWrapper2));
        return sb.toString();
    }
}
