package com.google.research.reflection.predictor;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Vector;

/* loaded from: classes.dex */
public class MulticlassPA {
    float aggressiveness_;
    public int num_classes_;
    int num_dimensions_;
    final float kEpsilon = 1.0E-4f;
    Vector<Vector<Float>> parameters_ = new Vector<>(1000);

    /* JADX INFO: Access modifiers changed from: package-private */
    public MulticlassPA() {
    }

    public MulticlassPA(int i, int i2, float f) {
        this.num_classes_ = i;
        this.num_dimensions_ = i2;
        this.aggressiveness_ = f;
        InitializeParameters();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void CHECK_GE(int i, int i2) {
        if (i < i2) {
            System.err.println(String.valueOf(i) + " >= " + i2);
            System.exit(1);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void CHECK_LT(int i, int i2) {
        if (i >= i2) {
            System.err.println(String.valueOf(i) + " >= " + i2);
            System.exit(1);
        }
    }

    private void DCHECK_GE(int i, int i2) {
        if (i < 0) {
            System.err.println(String.valueOf(i) + " < 0");
            System.exit(1);
        }
    }

    private void DCHECK_LT(int i, int i2) {
        if (i >= i2) {
            System.err.println(String.valueOf(i) + " >= " + i2);
            System.exit(1);
        }
    }

    public static float L2Norm(float[] fArr) {
        return (float) Math.sqrt(L2NormSquare(fArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static float L2NormSquare(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += fArr[i] * fArr[i];
        }
        return f;
    }

    void InitializeParameters() {
        this.parameters_.setSize(this.num_classes_);
        for (int i = 0; i < this.num_classes_; i++) {
            this.parameters_.set(i, new Vector<>(this.num_dimensions_));
            this.parameters_.get(i).setSize(this.num_dimensions_);
            for (int i2 = 0; i2 < this.num_dimensions_; i2++) {
                this.parameters_.get(i).set(i2, Float.valueOf(0.0f));
            }
        }
    }

    int PickAClassExcept(int i) {
        int random;
        do {
            random = (int) (Math.random() * this.num_classes_);
        } while (i == random);
        return random;
    }

    public float Score(float[] fArr, Vector<Float> vector) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += vector.get(i).floatValue() * fArr[i];
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float SparseL2NormSquare(Vector<NameValuePair> vector) {
        float f = 0.0f;
        for (int i = 0; i < vector.size(); i++) {
            f += vector.get(i).value * vector.get(i).value;
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float SparseScore(Vector<NameValuePair> vector, Vector<Float> vector2) {
        float f = 0.0f;
        for (int i = 0; i < vector.size(); i++) {
            float f2 = vector.get(i).value;
            f += f2 == 1.0f ? vector2.get(vector.get(i).index).floatValue() : vector2.get(vector.get(i).index).floatValue() * f2;
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float SparseTrainOneExample(Vector<NameValuePair> vector, int i) {
        CHECK_GE(i, 0);
        CHECK_LT(i, this.num_classes_);
        float SparseScore = SparseScore(vector, this.parameters_.get(i));
        int PickAClassExcept = PickAClassExcept(i);
        float SparseScore2 = (1.0f - SparseScore) + SparseScore(vector, this.parameters_.get(PickAClassExcept));
        if (SparseScore2 <= 0.0d) {
            return 0.0f;
        }
        float SparseL2NormSquare = SparseL2NormSquare(vector) * 2.0f;
        if (SparseL2NormSquare == 0.0d) {
            SparseL2NormSquare = 1.0E-4f;
        }
        float f = SparseScore2 / SparseL2NormSquare;
        if (f > this.aggressiveness_) {
            f = this.aggressiveness_;
        }
        for (int i2 = 0; i2 < vector.size(); i2++) {
            this.parameters_.get(i).set(vector.get(i2).index, Float.valueOf((vector.get(i2).value * f) + this.parameters_.get(i).get(vector.get(i2).index).floatValue()));
            this.parameters_.get(PickAClassExcept).set(vector.get(i2).index, Float.valueOf(this.parameters_.get(PickAClassExcept).get(vector.get(i2).index).floatValue() - (vector.get(i2).value * f)));
        }
        return SparseScore2;
    }

    public float TrainOneExample(float[] fArr, int i) {
        CHECK_GE(i, 0);
        CHECK_LT(i, this.num_classes_);
        float Score = Score(fArr, this.parameters_.get(i));
        int PickAClassExcept = PickAClassExcept(i);
        float Score2 = (1.0f - Score) + Score(fArr, this.parameters_.get(PickAClassExcept));
        if (Score2 <= 0.0d) {
            return 0.0f;
        }
        float L2NormSquare = L2NormSquare(fArr) * 2.0f;
        if (L2NormSquare == 0.0d) {
            L2NormSquare = 1.0E-4f;
        }
        float f = Score2 / L2NormSquare;
        if (f > this.aggressiveness_) {
            f = this.aggressiveness_;
        }
        for (int i2 = 0; i2 < fArr.length; i2++) {
            this.parameters_.get(i).set(i2, Float.valueOf((fArr[i2] * f) + this.parameters_.get(i).get(i2).floatValue()));
            this.parameters_.get(PickAClassExcept).set(i2, Float.valueOf(this.parameters_.get(PickAClassExcept).get(i2).floatValue() - (fArr[i2] * f)));
        }
        return Score2;
    }

    public float TrainOneExample(float[][] fArr, int i) {
        CHECK_GE(i, 0);
        CHECK_LT(i, this.num_classes_);
        float Score = Score(fArr[i], this.parameters_.get(i));
        int PickAClassExcept = PickAClassExcept(i);
        float Score2 = (1.0f - Score) + Score(fArr[PickAClassExcept], this.parameters_.get(PickAClassExcept));
        if (Score2 <= 0.0d) {
            return 0.0f;
        }
        float L2NormSquare = L2NormSquare(fArr[i]) + L2NormSquare(fArr[PickAClassExcept]);
        if (L2NormSquare == 0.0d) {
            L2NormSquare = 1.0E-4f;
        }
        float f = Score2 / L2NormSquare;
        if (f > this.aggressiveness_) {
            f = this.aggressiveness_;
        }
        for (int i2 = 0; i2 < fArr[i].length; i2++) {
            this.parameters_.get(i).set(i2, Float.valueOf((fArr[i][i2] * f) + this.parameters_.get(i).get(i2).floatValue()));
            this.parameters_.get(PickAClassExcept).set(i2, Float.valueOf(this.parameters_.get(PickAClassExcept).get(i2).floatValue() - (fArr[PickAClassExcept][i2] * f)));
        }
        return Score2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float aggressiveness() {
        return this.aggressiveness_;
    }

    public int dimensions() {
        return this.num_dimensions_;
    }

    public float mercer(float f) {
        return 0.5f + ((float) Math.pow(f, 5.0d));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int num_classes() {
        return this.num_classes_;
    }

    public Vector<Vector<Float>> parameters() {
        return this.parameters_;
    }

    public void read(DataInputStream dataInputStream) throws IOException {
        this.num_classes_ = dataInputStream.readInt();
        this.num_dimensions_ = dataInputStream.readInt();
        this.aggressiveness_ = dataInputStream.readFloat();
        this.parameters_ = new Vector<>(this.num_classes_);
        this.parameters_.setSize(this.num_classes_);
        for (int i = 0; i < this.num_classes_; i++) {
            Vector<Float> vector = new Vector<>(this.num_dimensions_);
            vector.setSize(this.num_dimensions_);
            this.parameters_.set(i, vector);
            for (int i2 = 0; i2 < this.num_dimensions_; i2++) {
                vector.set(i2, Float.valueOf(dataInputStream.readFloat()));
            }
        }
    }

    public void resize(int i) {
        if (i > this.num_dimensions_) {
            for (int i2 = 0; i2 < this.num_classes_; i2++) {
                this.parameters_.get(i2).setSize(i);
                for (int i3 = this.num_dimensions_; i3 < i; i3++) {
                    this.parameters_.get(i2).set(i3, Float.valueOf(0.0f));
                }
            }
            this.num_dimensions_ = i;
        }
    }

    public void resize(int i, int i2) {
        if (i2 > this.num_classes_) {
            this.parameters_.setSize(i2);
            for (int i3 = this.num_classes_; i3 < i2; i3++) {
                this.parameters_.set(i3, new Vector<>(this.num_dimensions_));
                this.parameters_.get(i3).setSize(this.num_dimensions_);
                for (int i4 = 0; i4 < this.num_dimensions_; i4++) {
                    this.parameters_.get(i3).set(i4, Float.valueOf(0.0f));
                }
            }
            this.num_classes_ = i2;
        }
        resize(i);
    }

    public void write(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeInt(this.num_classes_);
        dataOutputStream.writeInt(this.num_dimensions_);
        dataOutputStream.writeFloat(this.aggressiveness_);
        for (int i = 0; i < this.num_classes_; i++) {
            for (int i2 = 0; i2 < this.num_dimensions_; i2++) {
                dataOutputStream.writeFloat(this.parameters_.get(i).get(i2).floatValue());
            }
        }
    }
}
