/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.statistics;

import calhoun.analysis.crf.statistics.BasicStats;
import calhoun.analysis.crf.statistics.GammaDistribution;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class MixtureOfGammas
implements Serializable {
    private static final long serialVersionUID = 6269582005717276381L;
    private static final Log log = LogFactory.getLog(MixtureOfGammas.class);
    private double pdist1;
    private double shape1;
    private double lambda1;
    private double shape2;
    private double lambda2;

    public MixtureOfGammas(double pdist1, double shape1, double lambda1, double shape2, double lambda2) {
        Assert.a(pdist1 >= 0.0);
        Assert.a(pdist1 <= 1.0);
        Assert.a(shape1 > 0.0);
        Assert.a(lambda1 > 0.0);
        Assert.a(shape2 > 0.0);
        Assert.a(lambda2 > 0.0);
        this.pdist1 = pdist1;
        this.shape1 = shape1;
        this.lambda1 = lambda1;
        this.shape2 = shape2;
        this.lambda2 = lambda2;
    }

    public MixtureOfGammas(double[] lengths) {
        this.setup(lengths, false);
        log.info((Object)this.summary());
    }

    public MixtureOfGammas(double[] lengths, boolean forceExponentialLength) {
        this.setup(lengths, forceExponentialLength);
        log.info((Object)this.summary());
    }

    public void setup(double[] lengths, boolean forceExponentialLength) {
        int nLengths = lengths.length;
        for (int j = 0; j < nLengths; ++j) {
            Assert.a(lengths[j] > 0.0);
        }
        Assert.a(nLengths >= 0);
        boolean exponentialDistribution = false;
        if (nLengths == 0) {
            log.warn((Object)"Train mixture gamma called with no length inputs; returning an exponential distn with mean 100");
            double mean = 100.0;
            this.pdist1 = 1.0;
            this.shape1 = 1.0;
            this.lambda1 = 1.0 / mean;
            this.shape2 = 1.0;
            this.lambda2 = 1.0;
            return;
        }
        double[] sortedLengths = lengths;
        Arrays.sort(sortedLengths);
        if (log.isDebugEnabled()) {
            log.debug((Object)("Mixture of Gammas trainer called to model these lengths: " + ColtUtil.format(sortedLengths)));
        }
        if (forceExponentialLength) {
            log.warn((Object)"you called a mixture of gammas model but set flag to force it to model as an exponential length distribution.");
            exponentialDistribution = true;
        }
        if (nLengths < 20) {
            log.warn((Object)"fewer than 20 lengths supplied for training; modeling with an exponential distribution instead of a mixture of Gammas");
            exponentialDistribution = true;
        }
        boolean allSame = true;
        double firstLength = lengths[0];
        for (int j = 0; j < lengths.length; ++j) {
            if (!(lengths[j] < 0.99 * firstLength) && !(lengths[j] > 1.01 * firstLength)) continue;
            allSame = false;
        }
        if (allSame) {
            log.warn((Object)"All the lengths we're asked to model are extremely close to the same value, suggesting the length is artifically fixed for some reason.  This is probably a mistake and some othe model of length is appropriate.  Will model as exponential length distribution.");
            exponentialDistribution = true;
        }
        if (exponentialDistribution) {
            double mean = BasicStats.meanDoubleArray(lengths);
            this.pdist1 = 1.0;
            this.shape1 = 1.0;
            this.lambda1 = 1.0 / mean;
            this.shape2 = 1.0;
            this.lambda2 = 1.0;
            return;
        }
        double median = BasicStats.medianDoubleArray(lengths);
        int len = nLengths + 4;
        double[] x = new double[len];
        double[] post = new double[len];
        for (int j = 0; j < nLengths; ++j) {
            x[j] = lengths[j];
        }
        x[nLengths] = 0.9 * median;
        post[nLengths] = 1.0;
        x[nLengths + 1] = 1.0 * median;
        post[nLengths + 1] = 1.0;
        x[nLengths + 2] = 1.0 * median;
        post[nLengths + 2] = 0.0;
        x[nLengths + 3] = 1.1 * median;
        post[nLengths + 3] = 0.0;
        this.pdist1 = 0.5;
        this.shape1 = 15.0;
        this.lambda1 = 0.25;
        this.shape2 = 5.0;
        this.lambda2 = 0.05;
        for (int iteration = 0; iteration < 40; ++iteration) {
            for (int j = 0; j < nLengths; ++j) {
                double p2;
                double p1 = this.pdist1 * GammaDistribution.gamma(this.shape1, this.lambda1, x[j]);
                if (!(p1 + (p2 = (1.0 - this.pdist1) * GammaDistribution.gamma(this.shape2, this.lambda2, x[j])) > 0.0)) {
                    Assert.a(false, "x[j]=" + x[j] + "  p1=" + p1 + "  p2=" + p2 + "  pdist1=" + this.pdist1 + "  shape1=" + this.shape1 + "  lambda1=" + this.lambda1 + "  shape2=" + this.shape2 + "  lambda2=" + this.lambda2);
                }
                post[j] = p1 / (p1 + p2);
            }
            double mean1 = 0.0;
            double mean2 = 0.0;
            double meanlog1 = 0.0;
            double meanlog2 = 0.0;
            for (int j = 0; j < len; ++j) {
                Assert.a(x[j] > 0.0);
                mean1 += post[j] * x[j];
                mean2 += (1.0 - post[j]) * x[j];
                meanlog1 += post[j] * Math.log(x[j]);
                meanlog2 += (1.0 - post[j]) * Math.log(x[j]);
            }
            this.pdist1 = BasicStats.meanDoubleArray(post);
            double[] plam1 = GammaDistribution.mleg(mean1 /= this.pdist1 * (double)len, meanlog1 /= this.pdist1 * (double)len);
            double[] plam2 = GammaDistribution.mleg(mean2 /= (1.0 - this.pdist1) * (double)len, meanlog2 /= (1.0 - this.pdist1) * (double)len);
            this.shape1 = plam1[0];
            this.lambda1 = plam1[1];
            this.shape2 = plam2[0];
            this.lambda2 = plam2[1];
        }
    }

    public double logEvaluate(double x) {
        double lret1 = GammaDistribution.lgamma(this.shape1, this.lambda1, x);
        double lret2 = GammaDistribution.lgamma(this.shape2, this.lambda2, x);
        if (this.pdist1 > 0.999) {
            return lret1;
        }
        if (this.pdist1 < 0.001) {
            return lret2;
        }
        double maxlog = Math.max(lret1, lret2);
        lret1 -= maxlog;
        lret2 -= maxlog;
        if (lret1 < -100.0) {
            return maxlog + Math.log(1.0 - this.pdist1) + lret2;
        }
        if (lret2 < -100.0) {
            return maxlog + Math.log(this.pdist1) + lret1;
        }
        double ret = maxlog + Math.log(this.pdist1 * Math.exp(lret1) + (1.0 - this.pdist1) * Math.exp(lret2));
        Assert.a(ret != Double.NEGATIVE_INFINITY && ret != Double.POSITIVE_INFINITY && !Double.isNaN(ret));
        return ret;
    }

    public double evaluate(double x) {
        double ret = this.pdist1 * GammaDistribution.gamma(this.shape1, this.lambda1, x);
        Assert.a((ret += (1.0 - this.pdist1) * GammaDistribution.gamma(this.shape2, this.lambda2, x)) != Double.NEGATIVE_INFINITY && ret != Double.POSITIVE_INFINITY && !Double.isNaN(ret));
        return ret;
    }

    public void summarize(PrintStream out) {
        out.println(this.summary());
    }

    private String summary() {
        String ret = "";
        ret = ret + "MIXTURE OF GAMMAS INFO: pr(dist1)=" + this.pdist1;
        ret = ret + "  shape1=" + this.shape1;
        ret = ret + "  rate1=" + this.lambda1;
        ret = ret + "  mean1=" + this.shape1 / this.lambda1;
        ret = ret + "  shape2=" + this.shape2;
        ret = ret + "  rate2=" + this.lambda2;
        ret = ret + "  mean2=" + this.shape2 / this.lambda2;
        return ret;
    }

    public double getMix() {
        return this.pdist1;
    }

    public double getMean() {
        double ret = this.pdist1 * this.shape1 / this.lambda1 + (1.0 - this.pdist1) * this.shape2 / this.lambda2;
        return ret;
    }
}

