/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.ensemble;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXRef;

public final class VotingCombiner
implements EnsembleCombiner<Label> {
    private static final long serialVersionUID = 1L;

    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) {
        int numPredictions = predictions.size();
        int numUsed = 0;
        double weight = 1.0 / (double)numPredictions;
        double[] score = new double[outputInfo.size()];
        for (Prediction<Label> p : predictions) {
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            int n = outputInfo.getID(p.getOutput());
            score[n] = score[n] + weight;
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> predictionMap = new LinkedHashMap<String, Label>();
        for (int i = 0; i < score.length; ++i) {
            String name = ((Label)outputInfo.getOutput(i)).getLabel();
            Label label = new Label(name, score[i]);
            predictionMap.put(name, label);
            if (!(label.getScore() > maxScore)) continue;
            maxScore = label.getScore();
            maxLabel = label;
        }
        Example example = predictions.get(0).getExample();
        return new Prediction(maxLabel, predictionMap, numUsed, example, true);
    }

    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) {
        if (predictions.size() != weights.length) {
            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()=" + predictions.size() + ", weights.length=" + weights.length);
        }
        int numUsed = 0;
        double sum = 0.0;
        double[] score = new double[outputInfo.size()];
        for (int i = 0; i < weights.length; ++i) {
            Prediction<Label> p = predictions.get(i);
            if (numUsed < p.getNumActiveFeatures()) {
                numUsed = p.getNumActiveFeatures();
            }
            int n = outputInfo.getID(p.getOutput());
            score[n] = score[n] + (double)weights[i];
            sum += (double)weights[i];
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> predictionMap = new LinkedHashMap<String, Label>();
        for (int i = 0; i < score.length; ++i) {
            String name = ((Label)outputInfo.getOutput(i)).getLabel();
            Label label = new Label(name, score[i] / sum);
            predictionMap.put(name, label);
            if (!(label.getScore() > maxScore)) continue;
            maxScore = label.getScore();
            maxLabel = label;
        }
        Example example = predictions.get(0).getExample();
        return new Prediction(maxLabel, predictionMap, numUsed, example, true);
    }

    public String toString() {
        return "VotingCombiner()";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "EnsembleCombiner");
    }

    public ONNXNode exportCombiner(ONNXNode input) {
        HashMap<String, Object> attributes = new HashMap<String, Object>();
        attributes.put("axes", new int[]{2});
        attributes.put("keepdims", 0);
        return input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1)).apply(ONNXOperators.REDUCE_MEAN, attributes);
    }

    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
        ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0L, 1L});
        ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2L});
        ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, (ONNXRef)unsqueezeAxes);
        ONNXNode mulByWeights = input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1)).apply(ONNXOperators.MUL, (ONNXRef)unsqueezed);
        ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM);
        return mulByWeights.apply(ONNXOperators.REDUCE_SUM, (ONNXRef)sumAxes, Collections.singletonMap("keepdims", 0)).apply(ONNXOperators.DIV, (ONNXRef)weightSum);
    }
}

