/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.memoryoptsearch.faiss;

import java.io.IOException;
import java.util.EnumMap;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.memoryoptsearch.faiss.FaissIndex;
import org.opensearch.knn.memoryoptsearch.faiss.FaissSection;
import org.opensearch.knn.memoryoptsearch.faiss.UnsupportedFaissIndexException;
import org.opensearch.knn.memoryoptsearch.faiss.reconstruct.FaissQuantizedValueReconstructor;
import org.opensearch.knn.memoryoptsearch.faiss.reconstruct.FaissQuantizedValueReconstructorFactory;
import org.opensearch.knn.memoryoptsearch.faiss.reconstruct.FaissQuantizerType;

public class FaissIndexScalarQuantizedFlat
extends FaissIndex {
    @Generated
    private static final Logger log = LogManager.getLogger(FaissIndexScalarQuantizedFlat.class);
    private static EnumMap<FaissQuantizerType, VectorEncoding> VECTOR_DATA_TYPES = new EnumMap<FaissQuantizerType, VectorEncoding>(Map.of(FaissQuantizerType.QT_8BIT_DIRECT_SIGNED, VectorEncoding.BYTE, FaissQuantizerType.QT_FP16, VectorEncoding.FLOAT32));
    public static final String IXSQ = "IxSQ";
    private FaissQuantizerType quantizerType;
    private FaissQuantizedValueReconstructor reconstructor;
    private RangeStat rangeStat;
    private float rangeStatArgument;
    private int dimension;
    private long oneVectorByteSize;
    private int oneVectorElementBits;
    private FaissSection trainedValues;
    private FaissSection flatVectors;
    private VectorEncoding vectorEncoding;

    public FaissIndexScalarQuantizedFlat() {
        super(IXSQ);
    }

    @Override
    protected void doLoad(IndexInput input) throws IOException {
        this.readCommonHeader(input);
        this.quantizerType = FaissQuantizerType.values()[input.readInt()];
        if (!VECTOR_DATA_TYPES.containsKey((Object)this.quantizerType)) {
            throw new UnsupportedFaissIndexException("Unsupported quantizer type: " + String.valueOf((Object)this.quantizerType));
        }
        this.vectorEncoding = VECTOR_DATA_TYPES.get((Object)this.quantizerType);
        this.rangeStat = RangeStat.values()[input.readInt()];
        float[] singleFloat = new float[1];
        input.readFloats(singleFloat, 0, 1);
        this.rangeStatArgument = singleFloat[0];
        this.dimension = Math.toIntExact(input.readLong());
        input.readLong();
        this.trainedValues = new FaissSection(input, 4);
        this.setDerivedSizes();
        this.flatVectors = new FaissSection(input, 1);
        this.reconstructor = FaissQuantizedValueReconstructorFactory.create(this.quantizerType, this.dimension, this.oneVectorElementBits);
    }

    @Override
    public VectorEncoding getVectorEncoding() {
        return this.vectorEncoding;
    }

    @Override
    public FloatVectorValues getFloatValues(IndexInput indexInput) {
        final class FloatVectorValuesImpl
        extends FloatVectorValues {
            final IndexInput indexInput;
            final byte[] bytesBuffer;
            final float[] floatBuffer;

            public float[] vectorValue(int internalVectorId) throws IOException {
                this.indexInput.seek(FaissIndexScalarQuantizedFlat.this.flatVectors.getBaseOffset() + (long)internalVectorId * FaissIndexScalarQuantizedFlat.this.oneVectorByteSize);
                this.indexInput.readBytes(this.bytesBuffer, 0, this.bytesBuffer.length);
                FaissIndexScalarQuantizedFlat.this.reconstructor.reconstruct(this.bytesBuffer, this.floatBuffer);
                return this.floatBuffer;
            }

            public int dimension() {
                return FaissIndexScalarQuantizedFlat.this.dimension;
            }

            public int size() {
                return FaissIndexScalarQuantizedFlat.this.totalNumberOfVectors;
            }

            public FloatVectorValuesImpl copy() {
                return new FloatVectorValuesImpl(this.indexInput.clone());
            }

            @Generated
            public FloatVectorValuesImpl(IndexInput indexInput) {
                this.bytesBuffer = new byte[(int)FaissIndexScalarQuantizedFlat.this.oneVectorByteSize];
                this.floatBuffer = new float[FaissIndexScalarQuantizedFlat.this.dimension];
                this.indexInput = indexInput;
            }
        }
        return new FloatVectorValuesImpl(indexInput);
    }

    @Override
    public ByteVectorValues getByteValues(IndexInput indexInput) {
        final class ByteVectorValuesImpl
        extends ByteVectorValues {
            final IndexInput indexInput;
            final byte[] buffer;

            public byte[] vectorValue(int internalVectorId) throws IOException {
                this.indexInput.seek(FaissIndexScalarQuantizedFlat.this.flatVectors.getBaseOffset() + (long)internalVectorId * FaissIndexScalarQuantizedFlat.this.oneVectorByteSize);
                this.indexInput.readBytes(this.buffer, 0, this.buffer.length);
                FaissIndexScalarQuantizedFlat.this.reconstructor.reconstruct(this.buffer, this.buffer);
                return this.buffer;
            }

            public int dimension() {
                return FaissIndexScalarQuantizedFlat.this.dimension;
            }

            public int size() {
                return FaissIndexScalarQuantizedFlat.this.totalNumberOfVectors;
            }

            public ByteVectorValues copy() {
                return new ByteVectorValuesImpl(this.indexInput.clone());
            }

            @Generated
            public ByteVectorValuesImpl(IndexInput indexInput) {
                this.buffer = new byte[(int)FaissIndexScalarQuantizedFlat.this.oneVectorByteSize];
                this.indexInput = indexInput;
            }
        }
        return new ByteVectorValuesImpl(indexInput);
    }

    @Override
    public String getIndexType() {
        return IXSQ;
    }

    private void setDerivedSizes() {
        switch (this.quantizerType) {
            case QT_8BIT: 
            case QT_8BIT_UNIFORM: 
            case QT_8BIT_DIRECT: 
            case QT_8BIT_DIRECT_SIGNED: {
                this.oneVectorByteSize = this.dimension;
                this.oneVectorElementBits = 8;
                break;
            }
            case QT_4BIT: 
            case QT_4BIT_UNIFORM: {
                this.oneVectorByteSize = (this.dimension + 1) / 2;
                this.oneVectorElementBits = 4;
                break;
            }
            case QT_6BIT: {
                this.oneVectorByteSize = (this.dimension * 6 + 7) / 8;
                this.oneVectorElementBits = 6;
                break;
            }
            case QT_FP16: 
            case QT_BF16: {
                this.oneVectorByteSize = this.dimension * 2;
                this.oneVectorElementBits = 16;
            }
        }
    }

    @Generated
    public FaissQuantizerType getQuantizerType() {
        return this.quantizerType;
    }

    @Generated
    public FaissQuantizedValueReconstructor getReconstructor() {
        return this.reconstructor;
    }

    @Generated
    public RangeStat getRangeStat() {
        return this.rangeStat;
    }

    @Generated
    public float getRangeStatArgument() {
        return this.rangeStatArgument;
    }

    @Override
    @Generated
    public int getDimension() {
        return this.dimension;
    }

    @Generated
    public long getOneVectorByteSize() {
        return this.oneVectorByteSize;
    }

    @Generated
    public int getOneVectorElementBits() {
        return this.oneVectorElementBits;
    }

    @Generated
    public FaissSection getTrainedValues() {
        return this.trainedValues;
    }

    @Generated
    public FaissSection getFlatVectors() {
        return this.flatVectors;
    }

    public static enum RangeStat {
        MIN_MAX,
        MEAN_STD,
        QUANTILES,
        OPTIM;

    }
}

