/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.features.supporting;

import calhoun.analysis.crf.statistics.BasicStats;
import calhoun.util.Assert;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class MaxentMotifModel {
    private static final long serialVersionUID = -7659288739348604129L;
    private static final Log log = LogFactory.getLog(MaxentMotifModel.class);
    boolean debug = log.isDebugEnabled();

    public static double[] trainMaxentDistributionUsingAllPairwiseConstraints(List<int[]> motifExamples, int span, int nIter, double pseudocount) {
        int nExamples = motifExamples.size();
        if (nExamples == 0) {
            log.warn((Object)"Warning -- attempting to train a maxent distribution without any examples; the flat distribution will eventually be returned.");
        }
        for (int j = 0; j < nExamples; ++j) {
            int[] motif = motifExamples.get(j);
            Assert.a(motif.length == span);
            for (int k = 0; k < span; ++k) {
                Assert.a(motif[k] >= 0 && motif[k] < 4);
            }
        }
        List<Constraint> motifConstraints = MaxentMotifModel.makeAllPairwiseConstraints(motifExamples, span, pseudocount);
        log.debug((Object)("The numebr of motifConstraints is " + motifConstraints.size()));
        double[] ret = MaxentMotifModel.trainMaxentDistribution(motifConstraints, span, nIter);
        return ret;
    }

    private static List<Constraint> makeAllPairwiseConstraints(List<int[]> motifExamples, int motifLen, double pseudocount) {
        int nMotif = motifExamples.size();
        Assert.a(motifLen > 1);
        for (int i = 1; i < nMotif; ++i) {
            Assert.a(motifExamples.get(i).length == motifLen);
        }
        ArrayList<Constraint> ret = new ArrayList<Constraint>();
        for (int pos1 = 0; pos1 < motifLen - 1; ++pos1) {
            for (int pos2 = pos1 + 1; pos2 < motifLen; ++pos2) {
                int i;
                double[] counts = new double[16];
                for (int i2 = 0; i2 < 16; ++i2) {
                    counts[i2] = 0.0;
                }
                for (int j = 0; j < nMotif; ++j) {
                    int hash;
                    int n = hash = 4 * motifExamples.get(j)[pos1] + motifExamples.get(j)[pos2];
                    counts[n] = counts[n] + 1.0;
                }
                double total = 0.0;
                for (i = 0; i < 16; ++i) {
                    total += counts[i];
                }
                i = 0;
                while (i < 16) {
                    int n = i++;
                    counts[n] = counts[n] / total;
                }
                Constraint c = new Constraint(motifLen, pos1, pos2, motifExamples, pseudocount);
                ret.add(c);
            }
        }
        return ret;
    }

    private static double[] trainMaxentDistribution(List<Constraint> motifConstraints, int span, int nIter) {
        int hSize = 1;
        for (int j = 0; j < span; ++j) {
            hSize *= 4;
        }
        double[] ret = new double[hSize];
        for (int j = 0; j < hSize; ++j) {
            ret[j] = 1.0 / (double)hSize;
        }
        int nCon = motifConstraints.size();
        if (nCon == 0) {
            log.warn((Object)"Warning -- no constraints, returning maximum entropy distribution");
            return ret;
        }
        for (int iter = 0; iter < nIter; ++iter) {
            int cNum = (int)((double)nCon * Math.random());
            log.debug((Object)("Enforcing constrain number " + cNum + " which is " + motifConstraints.get(cNum).stringSummary()));
            ret = motifConstraints.get(cNum).enforce(ret);
        }
        return ret;
    }

    private static class Constraint {
        int span = -1;
        int pos1;
        int pos2;
        double[] prob;
        int msize;
        int size;
        static int[] sixteen;
        static int[] newsixteen;
        static int[] many;
        static int[] newmany;

        public Constraint(int newspan, int pos1, int pos2, List<int[]> motifExamples, double pseudocount) {
            if (this.span != newspan) {
                this.span = newspan;
                this.msize = 1;
                for (int j = 0; j < this.span - 2; ++j) {
                    this.msize *= 4;
                }
                this.size = 16 * this.msize;
                sixteen = new int[16];
                newsixteen = new int[16];
                many = new int[this.msize];
                newmany = new int[this.msize];
            }
            this.pos1 = pos1;
            this.pos2 = pos2;
            Assert.a(0 <= pos1);
            Assert.a(pos1 < pos2);
            Assert.a(pos2 < this.span);
            Assert.a(2 <= this.span);
            this.train(motifExamples, pseudocount);
        }

        private void train(List<int[]> motifExamples, double pseudocount) {
            int i;
            int i2;
            int nMotif = motifExamples.size();
            for (i2 = 0; i2 < nMotif; ++i2) {
                Assert.a(motifExamples.get(i2).length == this.span);
            }
            this.prob = new double[16];
            for (i2 = 0; i2 < 16; ++i2) {
                this.prob[i2] = pseudocount;
            }
            for (int j = 0; j < nMotif; ++j) {
                int hash;
                int n = hash = 4 * motifExamples.get(j)[this.pos1] + motifExamples.get(j)[this.pos2];
                this.prob[n] = this.prob[n] + 1.0;
            }
            double total = 0.0;
            for (i = 0; i < 16; ++i) {
                total += this.prob[i];
            }
            i = 0;
            while (i < 16) {
                int n = i++;
                this.prob[n] = this.prob[n] / total;
            }
            double sum = BasicStats.sumDoubleArray(this.prob);
            Assert.a(sum > 0.999 && sum < 1.001);
        }

        public double[] enforce(double[] pp) {
            double[] qq = pp;
            Constraint.many[0] = 0;
            int nM = 1;
            Constraint.sixteen[0] = 0;
            int nS = 1;
            for (int pos = 0; pos < this.span; ++pos) {
                int temp;
                int j;
                if (pos == this.pos1 || pos == this.pos2) {
                    for (j = 0; j < nS; ++j) {
                        Constraint.newsixteen[4 * j] = temp = 4 * sixteen[j];
                        Constraint.newsixteen[4 * j + 1] = temp + 1;
                        Constraint.newsixteen[4 * j + 2] = temp + 2;
                        Constraint.newsixteen[4 * j + 3] = temp + 3;
                    }
                    nS *= 4;
                    for (j = 0; j < nS; ++j) {
                        Constraint.sixteen[j] = newsixteen[j];
                    }
                    j = 0;
                    while (j < nM) {
                        int n = j++;
                        many[n] = many[n] * 4;
                    }
                    continue;
                }
                j = 0;
                while (j < nS) {
                    int n = j++;
                    sixteen[n] = sixteen[n] * 4;
                }
                for (j = 0; j < nM; ++j) {
                    Constraint.newmany[4 * j] = temp = 4 * many[j];
                    Constraint.newmany[4 * j + 1] = temp + 1;
                    Constraint.newmany[4 * j + 2] = temp + 2;
                    Constraint.newmany[4 * j + 3] = temp + 3;
                }
                nM *= 4;
                for (j = 0; j < nM; ++j) {
                    Constraint.many[j] = newmany[j];
                }
            }
            Assert.a(nS == 16);
            Assert.a(nM == this.msize);
            double changeneeded = 0.0;
            for (int z = 0; z < 16; ++z) {
                int base = sixteen[z];
                double total = 0.0;
                for (int t = 0; t < nM; ++t) {
                    total += qq[base + many[t]];
                }
                if (!(this.prob[z] > 0.0)) continue;
                Assert.a(total > 0.0);
                changeneeded += Math.abs(this.prob[z] - total);
                double mult = this.prob[z] / total;
                for (int t = 0; t < nM; ++t) {
                    int temp = base + many[t];
                    qq[temp] = qq[temp] * mult;
                }
            }
            System.out.println("amount of change needed to enforce constraint was " + changeneeded);
            return qq;
        }

        public String stringSummary() {
            String ret = "constraint_pos1=" + this.pos1 + "_pos2=" + this.pos2 + "_span=" + this.span;
            return ret;
        }
    }
}

