package phase;

import blbutil.FloatArray;
import blbutil.FloatList;
import ints.IntArray;
import ints.IntList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import vcf.Markers;

/* loaded from: input_file:phase/PhaseBaum1.class */
public class PhaseBaum1 {
    private final PhaseData phaseData;
    private final boolean burnin;
    private final EstPhase estPhase;
    private final Markers markers;
    private final int nMarkers;
    private final List<int[]> refAlleles;
    private final byte[][][] mismatches;
    private final float pMismatch;
    private final float[] emProbs;
    private final int maxStates;
    private final BasicPhaseStates states;
    private final FloatList lrList;
    private int nStates;
    private final float[][] fwd;
    private final float[][] bwd;
    private final float[] sum;
    private final List<float[]> missProbs1;
    private final List<float[]> missProbs2;
    private final List<float[]> bwdHet1;
    private final List<float[]> bwdHet2;
    private boolean swapHaps = false;
    private int missIndex = -1;
    private int swapCnt = 0;
    private static final AtomicLong nSwaps;
    private static final AtomicLong nUnphHets;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static double getAndResetSwapRate() {
        double d = nSwaps.get() / nUnphHets.get();
        nSwaps.set(0L);
        nUnphHets.set(0L);
        return d;
    }

    public PhaseBaum1(PbwtPhaseIbs pbwtPhaseIbs) {
        this.phaseData = pbwtPhaseIbs.phaseData();
        this.burnin = this.phaseData.it() < this.phaseData.fpd().par().burnin();
        this.estPhase = this.phaseData.estPhase();
        this.markers = this.phaseData.fpd().stage1TargGT().markers();
        this.nMarkers = this.markers.size();
        this.maxStates = this.phaseData.fpd().par().phase_states();
        this.states = new BasicPhaseStates(pbwtPhaseIbs, this.maxStates);
        this.lrList = new FloatList(200);
        this.refAlleles = new ArrayList();
        this.mismatches = new byte[3][this.nMarkers][this.maxStates];
        this.pMismatch = this.phaseData.pMismatch();
        this.emProbs = new float[]{1.0f - this.pMismatch, this.pMismatch};
        this.fwd = new float[3][this.maxStates];
        this.bwd = new float[3][this.maxStates];
        this.sum = new float[3];
        this.missProbs1 = new ArrayList();
        this.missProbs2 = new ArrayList();
        this.bwdHet1 = new ArrayList();
        this.bwdHet2 = new ArrayList();
    }

    public int nTargSamples() {
        return this.phaseData.fpd().targGT().nSamples();
    }

    public void phase(int i) {
        SamplePhase samplePhase = this.estPhase.get(i);
        this.swapHaps = false;
        this.swapCnt = 0;
        int size = samplePhase.unphased().size();
        int size2 = samplePhase.missing().size();
        if (size2 > 0 || size > 0) {
            this.lrList.clear();
            ensureCapacity(size, size2);
            MarkerCluster markerCluster = new MarkerCluster(this.phaseData, i);
            this.missIndex = markerCluster.nMissingGTClusters();
            this.nStates = this.states.ibsStates(i, markerCluster, this.refAlleles, this.mismatches);
            bwdAlg(markerCluster);
            fwdAlg(samplePhase, markerCluster);
            updatePhase(i, samplePhase);
            this.estPhase.set(i, samplePhase);
        }
        nUnphHets.addAndGet(size);
        nSwaps.addAndGet(this.swapCnt);
    }

    private void ensureCapacity(int i, int i2) {
        if (this.refAlleles.size() < i2) {
            for (int size = this.refAlleles.size(); size < i2; size++) {
                this.refAlleles.add(new int[this.maxStates]);
                this.missProbs1.add(new float[this.maxStates]);
                this.missProbs2.add(new float[this.maxStates]);
            }
        }
        if (this.bwdHet1.size() < i) {
            for (int size2 = this.bwdHet1.size(); size2 < i; size2++) {
                this.bwdHet1.add(new float[this.maxStates]);
                this.bwdHet2.add(new float[this.maxStates]);
            }
        }
    }

    private void fwdAlg(SamplePhase samplePhase, MarkerCluster markerCluster) {
        IntArray unphClusters = markerCluster.unphClusters();
        Arrays.fill(this.fwd[0], 0, this.nStates, 1.0f / this.nStates);
        this.sum[0] = 1.0f;
        int i = 0;
        int size = unphClusters.size();
        for (int i2 = 0; i2 < size; i2++) {
            int i3 = unphClusters.get(i2);
            fwdAlg(samplePhase, markerCluster, i, i3);
            phaseHet(i2);
            i = i3;
        }
        if (i < markerCluster.nClusters()) {
            fwdAlg(samplePhase, markerCluster, i, markerCluster.nClusters());
        }
    }

    private void fwdAlg(SamplePhase samplePhase, MarkerCluster markerCluster, int i, int i2) {
        if (this.swapHaps) {
            swapHaps(samplePhase, markerCluster, i, i2);
        }
        System.arraycopy(this.fwd[0], 0, this.fwd[1], 0, this.nStates);
        System.arraycopy(this.fwd[0], 0, this.fwd[2], 0, this.nStates);
        float[] fArr = this.sum;
        float[] fArr2 = this.sum;
        float f = this.sum[0];
        fArr2[2] = f;
        fArr[1] = f;
        FloatArray pRecomb = markerCluster.pRecomb();
        for (int i3 = i; i3 < i2; i3++) {
            float f2 = pRecomb.get(i3);
            this.emProbs[1] = (markerCluster.clusterEnd(i3) - markerCluster.clusterStart(i3)) * this.pMismatch;
            this.emProbs[0] = 1.0f - this.emProbs[1];
            this.sum[0] = HmmUpdater.fwdUpdate(this.fwd[0], this.sum[0], f2, this.emProbs, this.mismatches[0][i3], this.nStates);
            this.sum[1] = HmmUpdater.fwdUpdate(this.fwd[1], this.sum[1], f2, this.emProbs, this.mismatches[1][i3], this.nStates);
            this.sum[2] = HmmUpdater.fwdUpdate(this.fwd[2], this.sum[2], f2, this.emProbs, this.mismatches[2][i3], this.nStates);
            if (markerCluster.clustHasMissingGT(i3)) {
                imputeAlleles(samplePhase, markerCluster, i3);
            }
        }
    }

    private void swapHaps(SamplePhase samplePhase, MarkerCluster markerCluster, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            byte[] bArr = this.mismatches[1][i3];
            this.mismatches[1][i3] = this.mismatches[2][i3];
            this.mismatches[2][i3] = bArr;
        }
        samplePhase.swapHaps(markerCluster.clusterStart(i), markerCluster.clusterEnd(i2 - 1));
    }

    private void imputeAlleles(SamplePhase samplePhase, MarkerCluster markerCluster, int i) {
        if (markerCluster.clustHasMissingGT(i)) {
            if (this.swapHaps) {
                float[] fArr = this.missProbs1.get(this.missIndex);
                this.missProbs1.set(this.missIndex, this.missProbs2.get(this.missIndex));
                this.missProbs2.set(this.missIndex, fArr);
            }
            float[] fArr2 = this.missProbs1.get(this.missIndex);
            float[] fArr3 = this.missProbs2.get(this.missIndex);
            int[] iArr = this.refAlleles.get(this.missIndex);
            for (int i2 = 0; i2 < this.nStates; i2++) {
                int i3 = i2;
                fArr2[i3] = fArr2[i3] * this.fwd[1][i2];
                int i4 = i2;
                fArr3[i4] = fArr3[i4] * this.fwd[2][i2];
            }
            if (!$assertionsDisabled && markerCluster.clusterEnd(i) - markerCluster.clusterStart(i) != 1) {
                throw new AssertionError();
            }
            imputeAlleles(samplePhase, markerCluster.clusterStart(i), fArr2, fArr3, iArr);
            this.missIndex++;
        }
    }

    private void imputeAlleles(SamplePhase samplePhase, int i, float[] fArr, float[] fArr2, int[] iArr) {
        int nAlleles = this.markers.marker(i).nAlleles();
        float[] fArr3 = new float[nAlleles];
        float[] fArr4 = new float[nAlleles];
        for (int i2 = 0; i2 < this.nStates; i2++) {
            int i3 = iArr[i2];
            fArr3[i3] = fArr3[i3] + fArr[i2];
            int i4 = iArr[i2];
            fArr4[i4] = fArr4[i4] + fArr2[i2];
        }
        int i5 = 0;
        int i6 = 0;
        for (int i7 = 1; i7 < nAlleles; i7++) {
            if (fArr3[i7] > fArr3[i5]) {
                i5 = i7;
            }
            if (fArr4[i7] > fArr4[i6]) {
                i6 = i7;
            }
        }
        samplePhase.setAllele1(i, i5);
        samplePhase.setAllele2(i, i6);
    }

    private void bwdAlg(MarkerCluster markerCluster) {
        System.nanoTime();
        IntArray unphClusters = markerCluster.unphClusters();
        Arrays.fill(this.bwd[0], 0, this.nStates, 1.0f / this.nStates);
        int nClusters = markerCluster.nClusters() - 1;
        if (markerCluster.clustHasMissingGT(nClusters)) {
            this.missIndex--;
            System.arraycopy(this.bwd[0], 0, this.missProbs1.get(this.missIndex), 0, this.nStates);
            System.arraycopy(this.bwd[0], 0, this.missProbs2.get(this.missIndex), 0, this.nStates);
        }
        for (int size = unphClusters.size() - 1; size >= 0; size--) {
            int i = unphClusters.get(size) - 1;
            if (!$assertionsDisabled && i < 0) {
                throw new AssertionError();
            }
            bwdAlg(markerCluster, i, nClusters);
            System.arraycopy(this.bwd[1], 0, this.bwdHet1.get(size), 0, this.nStates);
            System.arraycopy(this.bwd[2], 0, this.bwdHet2.get(size), 0, this.nStates);
            nClusters = i;
        }
        bwdAlg(markerCluster, 0, nClusters);
    }

    private void bwdAlg(MarkerCluster markerCluster, int i, int i2) {
        FloatArray pRecomb = markerCluster.pRecomb();
        System.arraycopy(this.bwd[0], 0, this.bwd[1], 0, this.nStates);
        System.arraycopy(this.bwd[0], 0, this.bwd[2], 0, this.nStates);
        for (int i3 = i2 - 1; i3 >= i; i3--) {
            int i4 = i3 + 1;
            float f = pRecomb.get(i4);
            this.emProbs[1] = (markerCluster.clusterEnd(i3) - markerCluster.clusterStart(i3)) * this.pMismatch;
            this.emProbs[0] = 1.0f - this.emProbs[1];
            HmmUpdater.bwdUpdate(this.bwd[0], f, this.emProbs, this.mismatches[0][i4], this.nStates);
            HmmUpdater.bwdUpdate(this.bwd[1], f, this.emProbs, this.mismatches[1][i4], this.nStates);
            HmmUpdater.bwdUpdate(this.bwd[2], f, this.emProbs, this.mismatches[2][i4], this.nStates);
            if (markerCluster.clustHasMissingGT(i3)) {
                this.missIndex--;
                System.arraycopy(this.bwd[1], 0, this.missProbs1.get(this.missIndex), 0, this.nStates);
                System.arraycopy(this.bwd[2], 0, this.missProbs2.get(this.missIndex), 0, this.nStates);
            }
        }
    }

    private void phaseHet(int i) {
        float[] fArr = this.bwdHet1.get(i);
        float[] fArr2 = this.bwdHet2.get(i);
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        float f4 = 0.0f;
        for (int i2 = 0; i2 < this.nStates; i2++) {
            f += this.fwd[1][i2] * fArr[i2];
            f2 += this.fwd[1][i2] * fArr2[i2];
            f3 += this.fwd[2][i2] * fArr[i2];
            f4 += this.fwd[2][i2] * fArr2[i2];
        }
        boolean z = this.swapHaps;
        float f5 = f * f4;
        float f6 = f2 * f3;
        this.swapHaps = f5 < f6;
        if (this.swapHaps != z) {
            this.swapCnt++;
        }
        this.lrList.add(this.swapHaps ? f6 / f5 : f5 / f6);
    }

    private void updatePhase(int i, SamplePhase samplePhase) {
        IntArray unphased = samplePhase.unphased();
        if (unphased.size() <= 0 || this.burnin) {
            return;
        }
        float leaveUnphasedProp = this.phaseData.leaveUnphasedProp(i);
        IntList intList = new IntList();
        float threshold = threshold(this.lrList, leaveUnphasedProp);
        int size = unphased.size();
        for (int i2 = 0; i2 < size; i2++) {
            if (this.lrList.get(i2) < threshold) {
                intList.add(unphased.get(i2));
            }
        }
        samplePhase.setUnphased(IntArray.create(intList, this.nMarkers));
    }

    private static float threshold(FloatList floatList, float f) {
        float[] array = floatList.toArray();
        Arrays.sort(array);
        int floor = (int) Math.floor((f * array.length) + 0.5f);
        return array[floor < array.length ? floor : array.length - 1];
    }

    static {
        $assertionsDisabled = !PhaseBaum1.class.desiredAssertionStatus();
        nSwaps = new AtomicLong(0L);
        nUnphHets = new AtomicLong(0L);
    }
}
