/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.util.onnx;

import ai.onnx.proto.OnnxMl;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;
import org.tribuo.util.onnx.ONNXShape;
import org.tribuo.util.onnx.ONNXUtils;

public final class ONNXContext {
    private final Map<String, Long> nameMap = new HashMap<String, Long>();
    private final OnnxMl.GraphProto.Builder protoBuilder = OnnxMl.GraphProto.newBuilder();

    public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperators op, List<T> inputs, List<String> outputs, Map<String, Object> attributes) {
        if (!inputs.stream().allMatch(n -> n.context == this)) {
            throw new IllegalArgumentException("All input nodes must belong to this ONNXContext");
        }
        OnnxMl.NodeProto opProto = op.build(this, (String[])inputs.stream().map(ONNXRef::getReference).toArray(String[]::new), (String[])outputs.stream().map(this::generateUniqueName).toArray(String[]::new), attributes);
        this.protoBuilder.addNode(opProto);
        return IntStream.range(0, outputs.size()).mapToObj(i -> new ONNXNode(this, opProto, (String)outputs.get(i), i)).collect(Collectors.toList());
    }

    public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> inputs, String outputName, Map<String, Object> attributes) {
        List<ONNXNode> opOutputs = this.operation(op, inputs, Collections.singletonList(outputName), attributes);
        if (((OnnxMl.NodeProto)opOutputs.get((int)0).backRef).getOutputList().size() > 1) {
            throw new IllegalStateException("Requested a single output from operation " + op.opName + " which produced " + ((OnnxMl.NodeProto)opOutputs.get((int)0).backRef).getOutputList().size() + " outputs");
        }
        return opOutputs.get(0);
    }

    public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> inputs, String outputName) {
        return this.operation(op, inputs, outputName, Collections.emptyMap());
    }

    public <LHS extends ONNXRef<?>, RHS extends ONNXRef<?>> LHS assignTo(RHS input, LHS output) {
        if (input.context != output.context || input.context != this) {
            throw new IllegalArgumentException("both input and output must both belong to this ONNXContext");
        }
        OnnxMl.NodeProto idNode = ONNXOperators.IDENTITY.build(this, input.getReference(), output.getReference());
        this.protoBuilder.addNode(idNode);
        return output;
    }

    public ONNXPlaceholder floatInput(String name, int featureDimension) {
        OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1L, featureDimension}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT);
        OnnxMl.ValueInfoProto inputValue = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName(name).build();
        this.protoBuilder.addInput(inputValue);
        return new ONNXPlaceholder(this, inputValue, name);
    }

    public ONNXPlaceholder floatInput(int featureDimension) {
        return this.floatInput("input", featureDimension);
    }

    public ONNXPlaceholder floatOutput(String name, int outputDimension) {
        OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1L, outputDimension}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT);
        OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName(name).build();
        this.protoBuilder.addOutput(outputValueProto);
        return new ONNXPlaceholder(this, outputValueProto, name);
    }

    public ONNXPlaceholder floatOutput(int outputDimension) {
        return this.floatOutput("output", outputDimension);
    }

    public ONNXInitializer floatTensor(String baseName, List<Integer> dims, Consumer<FloatBuffer> populate) {
        OnnxMl.TensorProto tens = ONNXUtils.floatTensorBuilder(this, baseName, dims, populate);
        this.protoBuilder.addInitializer(tens);
        return new ONNXInitializer(this, tens, baseName);
    }

    public ONNXInitializer array(String baseName, long[] parameters) {
        OnnxMl.TensorProto tens = ONNXUtils.arrayBuilder(this, baseName, parameters);
        this.protoBuilder.addInitializer(tens);
        return new ONNXInitializer(this, tens, baseName);
    }

    public ONNXInitializer array(String baseName, int[] parameters) {
        OnnxMl.TensorProto tens = ONNXUtils.arrayBuilder(this, baseName, parameters);
        this.protoBuilder.addInitializer(tens);
        return new ONNXInitializer(this, tens, baseName);
    }

    public ONNXInitializer array(String baseName, float[] parameters) {
        OnnxMl.TensorProto tens = ONNXUtils.arrayBuilder(this, baseName, parameters);
        this.protoBuilder.addInitializer(tens);
        return new ONNXInitializer(this, tens, baseName);
    }

    public ONNXInitializer array(String baseName, double[] parameters, boolean downcast) {
        OnnxMl.TensorProto tens = ONNXUtils.arrayBuilder(this, baseName, parameters, downcast);
        this.protoBuilder.addInitializer(tens);
        return new ONNXInitializer(this, tens, baseName);
    }

    public ONNXInitializer array(String baseName, double[] parameters) {
        return this.array(baseName, parameters, true);
    }

    public ONNXInitializer constant(String baseName, float value) {
        OnnxMl.TensorProto constant = OnnxMl.TensorProto.newBuilder().setName(this.generateUniqueName(baseName)).setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()).addFloatData(value).build();
        this.protoBuilder.addInitializer(constant);
        return new ONNXInitializer(this, constant, baseName);
    }

    public ONNXInitializer constant(String baseName, long value) {
        OnnxMl.TensorProto constant = OnnxMl.TensorProto.newBuilder().setName(this.generateUniqueName(baseName)).setDataType(OnnxMl.TensorProto.DataType.INT64.getNumber()).addInt64Data(value).build();
        this.protoBuilder.addInitializer(constant);
        return new ONNXInitializer(this, constant, baseName);
    }

    String generateUniqueName(String name) {
        long counter = this.nameMap.computeIfAbsent(name, k -> 0L);
        String newName = name + "_" + counter;
        this.nameMap.put(name, counter + 1L);
        return newName;
    }

    public void setName(String name) {
        this.protoBuilder.setName(name);
    }

    public OnnxMl.GraphProto buildGraph() {
        return this.protoBuilder.build();
    }
}

