/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.memoryoptsearch.faiss;

import java.io.IOException;
import lombok.Generated;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.opensearch.knn.memoryoptsearch.faiss.AbstractFaissHNSWIndex;
import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSW;
import org.opensearch.knn.memoryoptsearch.faiss.FaissHNSWProvider;
import org.opensearch.knn.memoryoptsearch.faiss.FaissIndex;
import org.opensearch.knn.memoryoptsearch.faiss.MonotonicIntegerSequenceEncoder;
import org.opensearch.knn.memoryoptsearch.faiss.binary.FaissBinaryHnswIndex;
import org.opensearch.knn.memoryoptsearch.faiss.binary.FaissBinaryIndex;

public class FaissIdMapIndex
extends FaissBinaryIndex
implements FaissHNSWProvider {
    public static final String IXMP = "IxMp";
    public static final String IBMP = "IBMp";
    private FaissIndex nestedIndex;
    private FaissHNSWProvider hnswGetter;
    private DirectMonotonicReader idMappingReader;

    public FaissIdMapIndex(String indexType) {
        super(indexType);
    }

    @Override
    protected void doLoad(IndexInput input) throws IOException {
        if (this.indexType.equals(IXMP)) {
            this.readCommonHeader(input);
        } else {
            this.readBinaryCommonHeader(input);
        }
        FaissIndex nestedIndex = FaissIndex.load(input);
        if (!(nestedIndex instanceof AbstractFaissHNSWIndex) && !(nestedIndex instanceof FaissBinaryHnswIndex)) {
            throw new IllegalStateException("Invalid nested HNSW index type, got index type=" + nestedIndex.getIndexType());
        }
        this.nestedIndex = nestedIndex;
        this.hnswGetter = (FaissHNSWProvider)((Object)nestedIndex);
        int numElements = Math.toIntExact(input.readLong());
        this.idMappingReader = MonotonicIntegerSequenceEncoder.encode(numElements, input);
    }

    @Override
    public VectorEncoding getVectorEncoding() {
        return this.nestedIndex.getVectorEncoding();
    }

    @Override
    public FloatVectorValues getFloatValues(IndexInput indexInput) throws IOException {
        if (this.idMappingReader == null) {
            return this.nestedIndex.getFloatValues(indexInput);
        }
        return this.sparseFloatValues(indexInput);
    }

    @Override
    public ByteVectorValues getByteValues(IndexInput indexInput) throws IOException {
        if (this.idMappingReader == null) {
            return this.nestedIndex.getByteValues(indexInput);
        }
        return this.sparseByteValues(indexInput);
    }

    private ByteVectorValues sparseByteValues(IndexInput indexInput) throws IOException {
        ByteVectorValues vectorValues = this.nestedIndex.getByteValues(indexInput);
        class SparseByteVectorValuesImpl
        extends ByteVectorValues {
            private final ByteVectorValues vectorValues;

            public byte[] vectorValue(int internalVectorId) throws IOException {
                return this.vectorValues.vectorValue(internalVectorId);
            }

            public int dimension() {
                return this.vectorValues.dimension();
            }

            public int ordToDoc(int internalVectorId) {
                return (int)FaissIdMapIndex.this.idMappingReader.get((long)internalVectorId);
            }

            public Bits getAcceptOrds(Bits acceptDocs) {
                if (acceptDocs != null) {
                    final Bits internalBits = this.vectorValues.getAcceptOrds(acceptDocs);
                    return new Bits(){
                        final /* synthetic */ SparseByteVectorValuesImpl this$1;
                        {
                            this.this$1 = this$1;
                        }

                        public boolean get(int internalVectorId) {
                            return internalBits.get((int)this.this$1.FaissIdMapIndex.this.idMappingReader.get((long)internalVectorId));
                        }

                        public int length() {
                            return internalBits.length();
                        }
                    };
                }
                return null;
            }

            public int size() {
                return this.vectorValues.size();
            }

            public ByteVectorValues copy() throws IOException {
                return new SparseByteVectorValuesImpl(this.vectorValues.copy());
            }

            @Generated
            public SparseByteVectorValuesImpl(ByteVectorValues vectorValues) {
                this.vectorValues = vectorValues;
            }
        }
        return new SparseByteVectorValuesImpl(vectorValues);
    }

    private FloatVectorValues sparseFloatValues(IndexInput indexInput) throws IOException {
        FloatVectorValues vectorValues = this.nestedIndex.getFloatValues(indexInput);
        class SparseFloatVectorValuesImpl
        extends FloatVectorValues {
            private final FloatVectorValues vectorValues;

            public float[] vectorValue(int internalVectorId) throws IOException {
                return this.vectorValues.vectorValue(internalVectorId);
            }

            public int dimension() {
                return this.vectorValues.dimension();
            }

            public int ordToDoc(int internalVectorId) {
                return (int)FaissIdMapIndex.this.idMappingReader.get((long)internalVectorId);
            }

            public Bits getAcceptOrds(Bits acceptDocs) {
                if (acceptDocs != null) {
                    final Bits internalBits = this.vectorValues.getAcceptOrds(acceptDocs);
                    return new Bits(){
                        final /* synthetic */ SparseFloatVectorValuesImpl this$1;
                        {
                            this.this$1 = this$1;
                        }

                        public boolean get(int internalVectorId) {
                            return internalBits.get((int)this.this$1.FaissIdMapIndex.this.idMappingReader.get((long)internalVectorId));
                        }

                        public int length() {
                            return internalBits.length();
                        }
                    };
                }
                return null;
            }

            public int size() {
                return this.vectorValues.size();
            }

            public FloatVectorValues copy() throws IOException {
                return new SparseFloatVectorValuesImpl(this.vectorValues.copy());
            }

            @Generated
            public SparseFloatVectorValuesImpl(FloatVectorValues vectorValues) {
                this.vectorValues = vectorValues;
            }
        }
        return new SparseFloatVectorValuesImpl(vectorValues);
    }

    @Override
    public FaissHNSW getFaissHnsw() {
        return this.hnswGetter.getFaissHnsw();
    }

    @Generated
    public FaissIndex getNestedIndex() {
        return this.nestedIndex;
    }
}

