package phase;

import ints.IntArray;
import java.util.Optional;
import java.util.Random;
import vcf.GT;
import vcf.RefGT;
import vcf.XRefGT;

/* loaded from: input_file:phase/Stage2Baum.class */
public class Stage2Baum {
    private final FixedPhaseData fpd;
    private final PhaseData phaseData;
    private final HmmStateProbs stateProbs;
    private final int[] nStates = new int[2];
    private final int[][][] states;
    private final float[][][] probs;
    private final XRefGT hiFreqPhasedGT;
    private final GT unphTargGT;
    private final Optional<RefGT> refGT;
    private final int nTargHaps;
    private final int nStage1Markers;
    private final Stage2Haps stage2Haps;
    private final IntArray stage1To2;
    private final Random rand;

    public Stage2Baum(LowFreqPhaseIbs lowFreqPhaseIbs, Stage2Haps stage2Haps) {
        this.fpd = lowFreqPhaseIbs.phaseData().fpd();
        this.phaseData = lowFreqPhaseIbs.phaseData();
        this.nStage1Markers = this.fpd.stage1TargGT().nMarkers();
        this.stateProbs = new HmmStateProbs(lowFreqPhaseIbs);
        this.states = new int[2][this.nStage1Markers][this.stateProbs.maxStates()];
        this.probs = new float[2][this.nStage1Markers][this.stateProbs.maxStates()];
        this.hiFreqPhasedGT = this.phaseData.estPhase().phasedHaps();
        this.unphTargGT = this.fpd.targGT();
        this.refGT = this.fpd.refGT();
        this.nTargHaps = this.fpd.targGT().nHaps();
        this.stage2Haps = stage2Haps;
        this.stage1To2 = this.fpd.stage1To2();
        this.rand = new Random(this.phaseData.seed());
    }

    public int nTargSamples() {
        return this.hiFreqPhasedGT.nSamples();
    }

    public void phase(int i) {
        this.rand.setSeed(this.phaseData.seed() + i);
        int i2 = i << 1;
        this.nStates[0] = this.stateProbs.run(i2, this.states[0], this.probs[0]);
        this.nStates[1] = this.stateProbs.run(i2 | 1, this.states[1], this.probs[1]);
        int i3 = 0;
        for (int i4 = 0; i4 < this.nStage1Markers; i4++) {
            int i5 = this.stage1To2.get(i4);
            imputeInterval(i, i3, i5);
            i3 = i5 + 1;
        }
        imputeInterval(i, i3, this.unphTargGT.nMarkers());
    }

    private void imputeInterval(int i, int i2, int i3) {
        for (int i4 = i2; i4 < i3; i4++) {
            int allele1 = this.unphTargGT.allele1(i4, i);
            int allele2 = this.unphTargGT.allele2(i4, i);
            if (allele1 < 0 || allele2 < 0) {
                allele1 = imputeAllele(i4, 0);
                allele2 = imputeAllele(i4, 1);
            } else if (allele1 != allele2) {
                float[] unscaledAlProbs = unscaledAlProbs(i4, 0, allele1, allele2);
                float[] unscaledAlProbs2 = unscaledAlProbs(i4, 1, allele1, allele2);
                float f = unscaledAlProbs[allele1] * unscaledAlProbs2[allele2];
                float f2 = unscaledAlProbs[allele2] * unscaledAlProbs2[allele1];
                if (f < f2 || (f == f2 && this.rand.nextBoolean())) {
                    allele1 = allele2;
                    allele2 = allele1;
                }
            }
            this.stage2Haps.setPhasedGT(i4, i, allele1, allele2);
        }
    }

    private float[] unscaledAlProbs(int i, int i2, int i3, int i4) {
        float[] fArr = new float[this.unphTargGT.marker(i).nAlleles()];
        boolean isLowFreq = this.fpd.isLowFreq(i, i3);
        boolean isLowFreq2 = this.fpd.isLowFreq(i, i4);
        int prevStage1Marker = this.fpd.prevStage1Marker(i);
        int min = Math.min(prevStage1Marker + 1, this.nStage1Markers - 1);
        int[] iArr = this.states[i2][prevStage1Marker];
        float[] fArr2 = this.probs[i2][prevStage1Marker];
        float[] fArr3 = this.probs[i2][min];
        int i5 = this.nStates[i2];
        for (int i6 = 0; i6 < i5; i6++) {
            int i7 = iArr[i6];
            int allele = allele(i, i7);
            int allele2 = allele(i, i7 ^ 1);
            if (allele >= 0 && allele2 >= 0) {
                float prevStage1Wt = this.fpd.prevStage1Wt(i);
                float f = (prevStage1Wt * fArr2[i6]) + ((1.0f - prevStage1Wt) * fArr3[i6]);
                if (allele == allele2) {
                    fArr[allele] = fArr[allele] + f;
                } else {
                    boolean z = isLowFreq && (i3 == allele || i3 == allele2);
                    if (z ^ (isLowFreq2 && (i4 == allele || i4 == allele2))) {
                        if (z) {
                            fArr[i3] = fArr[i3] + f;
                        } else {
                            fArr[i4] = fArr[i4] + f;
                        }
                    }
                }
            }
        }
        return fArr;
    }

    private int imputeAllele(int i, int i2) {
        float[] fArr = new float[this.unphTargGT.marker(i).nAlleles()];
        int prevStage1Marker = this.fpd.prevStage1Marker(i);
        int min = Math.min(prevStage1Marker + 1, this.nStage1Markers - 1);
        int[] iArr = this.states[i2][prevStage1Marker];
        float[] fArr2 = this.probs[i2][prevStage1Marker];
        float[] fArr3 = this.probs[i2][min];
        int i3 = this.nStates[i2];
        for (int i4 = 0; i4 < i3; i4++) {
            float prevStage1Wt = this.fpd.prevStage1Wt(i);
            float f = (prevStage1Wt * fArr2[i4]) + ((1.0f - prevStage1Wt) * fArr3[i4]);
            int i5 = iArr[i4];
            int allele = allele(i, i5);
            int allele2 = allele(i, i5 ^ 1);
            if (allele >= 0 && allele2 >= 0) {
                if (allele == allele2 || i5 >= this.nTargHaps) {
                    fArr[allele] = fArr[allele] + f;
                } else {
                    boolean isLowFreq = this.fpd.isLowFreq(i, allele);
                    if (!(isLowFreq ^ this.fpd.isLowFreq(i, allele2))) {
                        fArr[allele] = (float) (fArr[allele] + (0.5d * f));
                        fArr[allele2] = (float) (fArr[allele2] + (0.5d * f));
                    } else if (isLowFreq) {
                        fArr[allele] = (float) (fArr[allele] + (0.55d * f));
                        fArr[allele2] = (float) (fArr[allele2] + (0.45d * f));
                    } else {
                        fArr[allele] = (float) (fArr[allele] + (0.45d * f));
                        fArr[allele2] = (float) (fArr[allele2] + (0.55d * f));
                    }
                }
            }
        }
        return maxIndex(fArr);
    }

    private int allele(int i, int i2) {
        return i2 < this.nTargHaps ? this.unphTargGT.allele(i, i2) : this.refGT.get().allele(i, i2 - this.nTargHaps);
    }

    private int maxIndex(float[] fArr) {
        int i = 0;
        for (int i2 = 1; i2 < fArr.length; i2++) {
            if (fArr[i2] > fArr[i]) {
                i = i2;
            }
        }
        return i;
    }
}
