package deepboof.impl.forward.standard;

import deepboof.forward.SpatialBatchNorm;
import deepboof.tensors.Tensor_F32;

/* loaded from: classes2.dex */
public class SpatialBatchNorm_F32 extends FunctionBatchNorm_F32 implements SpatialBatchNorm<Tensor_F32> {
    public SpatialBatchNorm_F32(boolean z) {
        super(z);
    }

    @Override // deepboof.impl.forward.standard.FunctionBatchNorm_F32, deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F32 tensor_F32, Tensor_F32 tensor_F322) {
        int length = tensor_F32.length(1);
        int length2 = tensor_F32.length(2) * tensor_F32.length(3);
        int i = tensor_F32.startIndex;
        int i2 = tensor_F322.startIndex;
        if (!hasGammaBeta()) {
            for (int i3 = 0; i3 < this.miniBatchSize; i3++) {
                int i4 = 0;
                int i5 = this.params.startIndex;
                while (i4 < length) {
                    int i6 = i5 + 1;
                    float f = this.params.d[i5];
                    i5 = i6 + 1;
                    float f2 = this.params.d[i6];
                    int i7 = i + length2;
                    int i8 = i2;
                    int i9 = i;
                    while (i9 < i7) {
                        tensor_F322.d[i8] = (tensor_F32.d[i9] - f) * f2;
                        i8++;
                        i9++;
                    }
                    i4++;
                    i2 = i8;
                    i = i9;
                }
            }
            return;
        }
        for (int i10 = 0; i10 < this.miniBatchSize; i10++) {
            int i11 = 0;
            int i12 = this.params.startIndex;
            while (i11 < length) {
                int i13 = i12 + 1;
                float f3 = this.params.d[i12];
                int i14 = i13 + 1;
                float f4 = this.params.d[i13];
                int i15 = i14 + 1;
                float f5 = this.params.d[i14];
                i12 = i15 + 1;
                float f6 = this.params.d[i15];
                int i16 = i + length2;
                int i17 = i2;
                int i18 = i;
                while (i18 < i16) {
                    tensor_F322.d[i17] = ((tensor_F32.d[i18] - f3) * f5 * f4) + f6;
                    i17++;
                    i18++;
                }
                i11++;
                i2 = i17;
                i = i18;
            }
        }
    }

    @Override // deepboof.impl.forward.standard.FunctionBatchNorm_F32, deepboof.impl.forward.standard.BaseFunction
    public void _initialize() {
        if (this.shapeInput.length != 3) {
            throw new IllegalArgumentException("Expected 3 DOF in a spatial shape (C,W,H)");
        }
        this.shapeOutput = (int[]) this.shapeInput.clone();
        int[] iArr = new int[2];
        iArr[0] = this.shapeInput[0];
        iArr[1] = this.requiresGammaBeta ? 4 : 2;
        this.shapeParameters.add(iArr);
        this.params.reshape(iArr);
    }
}
