/*
 * Decompiled with CFR 0.152.
 */
package org.trustyai.arrowconverters;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.channels.SeekableByteChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.arrow.vector.util.Text;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;

public class ArrowConverters {
    private ArrowConverters() {
        throw new IllegalStateException("Utility class");
    }

    public static List<PredictionOutput> convertFieldVectorstoPO(List<FieldVector> fvs) {
        int colCount = fvs.size();
        int rowCount = fvs.get(0).getValueCount();
        Output[][] outputBuffer = new Output[rowCount][colCount];
        for (int col = 0; col < fvs.size(); ++col) {
            Float8Vector castv;
            FieldVector fv = fvs.get(col);
            int destinationCol = col;
            if (fv.getMinorType() == Types.MinorType.FLOAT8) {
                castv = (Float8Vector)fv;
                IntStream.range(0, rowCount).forEach(row -> {
                    outputBuffer[row][destinationCol] = new Output(fv.getName(), Type.NUMBER, new Value((Object)castv.get(row)), 1.0);
                });
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.FLOAT4) {
                castv = (Float4Vector)fv;
                IntStream.range(0, rowCount).forEach(arg_0 -> ArrowConverters.lambda$convertFieldVectorstoPO$1(outputBuffer, destinationCol, fv, (Float4Vector)castv, arg_0));
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.INT) {
                castv = (IntVector)fv;
                IntStream.range(0, rowCount).forEach(arg_0 -> ArrowConverters.lambda$convertFieldVectorstoPO$2(outputBuffer, destinationCol, fv, (IntVector)castv, arg_0));
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.BIGINT) {
                castv = (BigIntVector)fv;
                IntStream.range(0, rowCount).forEach(arg_0 -> ArrowConverters.lambda$convertFieldVectorstoPO$3(outputBuffer, destinationCol, fv, (BigIntVector)castv, arg_0));
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.BIT) {
                castv = (BitVector)fv;
                IntStream.range(0, rowCount).forEach(arg_0 -> ArrowConverters.lambda$convertFieldVectorstoPO$4(outputBuffer, destinationCol, fv, (BitVector)castv, arg_0));
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.VARCHAR) {
                castv = (VarCharVector)fv;
                IntStream.range(0, rowCount).forEach(arg_0 -> ArrowConverters.lambda$convertFieldVectorstoPO$5(outputBuffer, destinationCol, fv, (VarCharVector)castv, arg_0));
                continue;
            }
            throw new IllegalArgumentException(String.format("FieldVector Type %s currently unsupported", fv.getMinorType()));
        }
        ArrayList<PredictionOutput> converted = new ArrayList<PredictionOutput>();
        for (int i = 0; i < rowCount; ++i) {
            converted.add(new PredictionOutput(Arrays.asList(outputBuffer[i])));
        }
        return converted;
    }

    public static Schema generatePrototypePISchema(PredictionInput paradigm) {
        ArrayList<Field> fields = new ArrayList<Field>();
        for (Feature f : paradigm.getFeatures()) {
            if (f.getType() == Type.NUMBER) {
                Field doubleField = new Field(f.getName(), FieldType.nullable((ArrowType)new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null);
                fields.add(doubleField);
                continue;
            }
            if (f.getType() == Type.BOOLEAN) {
                Field boolField = new Field(f.getName(), FieldType.nullable((ArrowType)new ArrowType.Bool()), null);
                fields.add(boolField);
                continue;
            }
            if (f.getType() == Type.TEXT || f.getType() == Type.CATEGORICAL) {
                Field textField = new Field(f.getName(), FieldType.nullable((ArrowType)new ArrowType.Utf8()), null);
                fields.add(textField);
                continue;
            }
            throw new IllegalArgumentException(String.format("Output type %s currently unsupported", f.getType()));
        }
        return new Schema(fields, null);
    }

    public static VectorSchemaRoot convertPItoVSR(List<PredictionInput> inputs, Schema sourceSchema, RootAllocator allocator) {
        int nrows = inputs.size();
        VectorSchemaRoot sourceRoot = VectorSchemaRoot.create((Schema)sourceSchema, (BufferAllocator)allocator);
        List fvs = sourceRoot.getFieldVectors();
        for (int col = 0; col < fvs.size(); ++col) {
            Float8Vector castv;
            FieldVector fv = (FieldVector)fvs.get(col);
            int finalCol = col;
            if (fv.getMinorType() == Types.MinorType.FLOAT8) {
                castv = (Float8Vector)fv;
                castv.allocateNew(nrows);
                IntStream.range(0, nrows).forEach(row -> castv.setSafe(row, ((Double)((Feature)((PredictionInput)inputs.get(row)).getFeatures().get(finalCol)).getValue().getUnderlyingObject()).doubleValue()));
                castv.setValueCount(nrows);
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.BIT) {
                castv = (BitVector)fv;
                castv.allocateNew(nrows);
                IntStream.range(0, nrows).forEach(arg_0 -> ArrowConverters.lambda$convertPItoVSR$7((BitVector)castv, inputs, finalCol, arg_0));
                castv.setValueCount(nrows);
                continue;
            }
            if (fv.getMinorType() == Types.MinorType.VARCHAR) {
                castv = (VarCharVector)fv;
                castv.allocateNew(nrows);
                IntStream.range(0, nrows).forEach(arg_0 -> ArrowConverters.lambda$convertPItoVSR$8((VarCharVector)castv, inputs, finalCol, arg_0));
                castv.setValueCount(nrows);
                continue;
            }
            throw new IllegalArgumentException(String.format("Output type %s currently unsupported, but this error should never arise", fv.getMinorType()));
        }
        sourceRoot.setRowCount(nrows);
        return sourceRoot;
    }

    public static List<PredictionOutput> read(byte[] resultBuffer, RootAllocator allocator) {
        List<PredictionOutput> list;
        byte[] byteArray = resultBuffer;
        ArrowFileReader reader = new ArrowFileReader((SeekableByteChannel)new ByteArrayReadableSeekableByteChannel(byteArray), (BufferAllocator)allocator);
        try {
            reader.loadRecordBatch((ArrowBlock)reader.getRecordBlocks().get(0));
            VectorSchemaRoot vsr = reader.getVectorSchemaRoot();
            list = ArrowConverters.convertFieldVectorstoPO(vsr.getFieldVectors());
        }
        catch (Throwable throwable) {
            try {
                try {
                    reader.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        reader.close();
        return list;
    }

    public static byte[] write(VectorSchemaRoot vsr) {
        try {
            ByteArrayOutputStream out = new ByteArrayOutputStream();
            ArrowFileWriter writer = new ArrowFileWriter(vsr, null, Channels.newChannel(out));
            writer.start();
            writer.writeBatch();
            writer.end();
            return out.toByteArray();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private static /* synthetic */ void lambda$convertPItoVSR$8(VarCharVector castv, List inputs, int finalCol, int row) {
        castv.setSafe(row, new Text(((Feature)((PredictionInput)inputs.get(row)).getFeatures().get(finalCol)).getValue().asString()));
    }

    private static /* synthetic */ void lambda$convertPItoVSR$7(BitVector castv, List inputs, int finalCol, int row) {
        castv.setSafe(row, (int)((Feature)((PredictionInput)inputs.get(row)).getFeatures().get(finalCol)).getValue().asNumber());
    }

    private static /* synthetic */ void lambda$convertFieldVectorstoPO$5(Output[][] outputBuffer, int destinationCol, FieldVector fv, VarCharVector castv, int row) {
        outputBuffer[row][destinationCol] = new Output(fv.getName(), Type.TEXT, new Value((Object)castv.get(row)), 1.0);
    }

    private static /* synthetic */ void lambda$convertFieldVectorstoPO$4(Output[][] outputBuffer, int destinationCol, FieldVector fv, BitVector castv, int row) {
        outputBuffer[row][destinationCol] = new Output(fv.getName(), Type.BOOLEAN, new Value((Object)(castv.get(row) == 1 ? 1 : 0)), 1.0);
    }

    private static /* synthetic */ void lambda$convertFieldVectorstoPO$3(Output[][] outputBuffer, int destinationCol, FieldVector fv, BigIntVector castv, int row) {
        outputBuffer[row][destinationCol] = new Output(fv.getName(), Type.NUMBER, new Value((Object)castv.get(row)), 1.0);
    }

    private static /* synthetic */ void lambda$convertFieldVectorstoPO$2(Output[][] outputBuffer, int destinationCol, FieldVector fv, IntVector castv, int row) {
        outputBuffer[row][destinationCol] = new Output(fv.getName(), Type.NUMBER, new Value((Object)castv.get(row)), 1.0);
    }

    private static /* synthetic */ void lambda$convertFieldVectorstoPO$1(Output[][] outputBuffer, int destinationCol, FieldVector fv, Float4Vector castv, int row) {
        outputBuffer[row][destinationCol] = new Output(fv.getName(), Type.NUMBER, new Value((Object)Float.valueOf(castv.get(row))), 1.0);
    }
}

