/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetAction;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.MultiGetAction;
import org.opensearch.action.get.MultiGetResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.TextInferenceRequest;
import org.opensearch.neuralsearch.processor.optimization.InferenceFilter;
import org.opensearch.neuralsearch.processor.optimization.TextEmbeddingInferenceFilter;
import org.opensearch.neuralsearch.sparse.common.SparseFieldUtils;
import org.opensearch.neuralsearch.stats.events.EventStatName;
import org.opensearch.neuralsearch.stats.events.EventStatsManager;
import org.opensearch.neuralsearch.util.TokenWeightUtil;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;
import org.opensearch.transport.client.OpenSearchClient;

public final class SparseEncodingProcessor
extends InferenceProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseEncodingProcessor.class);
    public static final String TYPE = "sparse_encoding";
    public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
    private final OpenSearchClient openSearchClient;
    private final boolean skipExisting;
    private final TextEmbeddingInferenceFilter textEmbeddingInferenceFilter;
    private static final AsymmetricTextEmbeddingParameters TOKEN_ID_PARAMETER = AsymmetricTextEmbeddingParameters.builder().sparseEmbeddingFormat(SparseEmbeddingFormat.TOKEN_ID).build();
    private final PruneType pruneType;
    private final float pruneRatio;
    private final ClusterService clusterService;

    public SparseEncodingProcessor(String tag, String description, int batchSize, String modelId, Map<String, Object> fieldMap, boolean skipExisting, TextEmbeddingInferenceFilter textEmbeddingInferenceFilter, PruneType pruneType, float pruneRatio, OpenSearchClient openSearchClient, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description, batchSize, "sparse_encoding", "sparse_encoding", modelId, fieldMap, clientAccessor, environment, clusterService);
        this.pruneType = pruneType;
        this.pruneRatio = pruneRatio;
        this.skipExisting = skipExisting;
        this.textEmbeddingInferenceFilter = textEmbeddingInferenceFilter;
        this.openSearchClient = openSearchClient;
        this.clusterService = clusterService;
    }

    @Override
    public void doExecute(IngestDocument ingestDocument, Map<String, Object> processMap, List<String> inferenceList, BiConsumer<IngestDocument, Exception> handler) {
        EventStatsManager.increment(EventStatName.SPARSE_ENCODING_PROCESSOR_EXECUTIONS);
        if (!this.skipExisting) {
            this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, this.pruneType, this.pruneRatio, handler);
            return;
        }
        EventStatsManager.increment(EventStatName.SKIP_EXISTING_EXECUTIONS);
        Object index = ingestDocument.getSourceAndMetadata().get("_index");
        Object id = ingestDocument.getSourceAndMetadata().get("_id");
        if (Objects.isNull(index) || Objects.isNull(id)) {
            this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, this.pruneType, this.pruneRatio, handler);
            return;
        }
        this.openSearchClient.execute((ActionType)GetAction.INSTANCE, (ActionRequest)new GetRequest(index.toString(), id.toString()), ActionListener.wrap(response -> {
            Map existingDocument = response.getSourceAsMap();
            if (existingDocument == null || existingDocument.isEmpty()) {
                this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, this.pruneType, this.pruneRatio, handler);
                return;
            }
            Map<String, Object> filteredProcessMap = this.textEmbeddingInferenceFilter.filterAndCopyExistingEmbeddings(existingDocument, ingestDocument.getSourceAndMetadata(), processMap);
            List<String> filteredInferenceList = this.createInferenceList(filteredProcessMap).stream().filter(Objects::nonNull).collect(Collectors.toList());
            if (filteredInferenceList.isEmpty()) {
                handler.accept(ingestDocument, null);
            } else {
                this.generateAndSetMapInference(ingestDocument, filteredProcessMap, filteredInferenceList, this.pruneType, this.pruneRatio, handler);
            }
        }, e -> handler.accept((IngestDocument)null, (Exception)e)));
    }

    @Override
    protected void doSubBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, List<String> inferenceList, List<InferenceProcessor.DataForInference> dataForInferences, Consumer<List<IngestDocumentWrapper>> handler) {
        if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
            handler.accept(ingestDocumentWrappers);
            return;
        }
        Object indexObj = ingestDocumentWrappers.getFirst().getIngestDocument().getSourceAndMetadata().get("_index");
        String index = indexObj.toString();
        Set<String> sparseAnnFields = SparseFieldUtils.getSparseAnnFields(index, this.clusterService);
        if (sparseAnnFields.isEmpty()) {
            super.doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
            return;
        }
        SplitDataResponse splitDataResponse = this.splitData(dataForInferences, sparseAnnFields);
        AtomicInteger counter = new AtomicInteger(0);
        if (splitDataResponse.getTokenIdDataForInference().isEmpty()) {
            super.doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
            return;
        }
        counter.incrementAndGet();
        if (!splitDataResponse.getWordDataForInference().isEmpty()) {
            counter.incrementAndGet();
        }
        List<Exception> exceptions = Collections.synchronizedList(new ArrayList());
        if (!splitDataResponse.getTokenIdDataForInference().isEmpty()) {
            this.doBatchExecuteWithType(splitDataResponse.getTokenIdResponseInferenceList(), splitDataResponse.getTokenIdDataForInference(), SparseEmbeddingFormat.TOKEN_ID, this.getCountDownBatchDataHandler(counter, ingestDocumentWrappers, exceptions, handler));
        }
        if (!splitDataResponse.wordDataForInference.isEmpty()) {
            this.doBatchExecuteWithType(splitDataResponse.getWordResponseInferenceList(), splitDataResponse.getWordDataForInference(), SparseEmbeddingFormat.WORD, this.getCountDownBatchDataHandler(counter, ingestDocumentWrappers, exceptions, handler));
        }
    }

    private void doBatchExecuteWithType(List<String> inferenceList, List<InferenceProcessor.DataForInference> dataForInferences, SparseEmbeddingFormat format, BiConsumer<List<InferenceProcessor.DataForInference>, Exception> handler) {
        Tuple<List<String>, Map<Integer, Integer>> sortedResult = this.sortByLengthAndReturnOriginalOrder(inferenceList);
        List sortedInferenceList = (List)sortedResult.v1();
        Map originalOrder = (Map)sortedResult.v2();
        AsymmetricTextEmbeddingParameters parameters = format == SparseEmbeddingFormat.TOKEN_ID ? TOKEN_ID_PARAMETER : null;
        this.mlCommonsClientAccessor.inferenceSentencesWithMapResult((TextInferenceRequest)((TextInferenceRequest.TextInferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)TextInferenceRequest.builder().modelId(this.modelId)).inputTexts(sortedInferenceList)).build(), (MLAlgoParams)parameters, ActionListener.wrap(resultMaps -> {
            List<Map> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps).stream().map(vector -> PruneUtils.pruneSparseVector(this.pruneType, this.pruneRatio, vector)).toList();
            this.batchExecuteHandler(sparseVectors, dataForInferences, originalOrder);
            handler.accept(dataForInferences, null);
        }, exception -> handler.accept(dataForInferences, (Exception)exception)));
    }

    private SplitDataResponse splitData(List<InferenceProcessor.DataForInference> dataForInferences, Set<String> sparseAnnFields) {
        SplitDataResponse splitDataResponse = new SplitDataResponse();
        for (InferenceProcessor.DataForInference dataForInference : dataForInferences) {
            HashMap<String, Object> tokenIdProcessMap = new HashMap<String, Object>();
            HashMap<String, Object> wordProcessMap = new HashMap<String, Object>();
            for (Map.Entry<String, Object> entry : dataForInference.getProcessMap().entrySet()) {
                if (this.isSparseAnnField(sparseAnnFields, entry.getKey())) {
                    tokenIdProcessMap.put(entry.getKey(), entry.getValue());
                    continue;
                }
                wordProcessMap.put(entry.getKey(), entry.getValue());
            }
            List<String> tokenIdList = this.createInferenceList(tokenIdProcessMap);
            List<String> wordList = this.createInferenceList(wordProcessMap);
            splitDataResponse.getTokenIdResponseInferenceList().addAll(tokenIdList);
            splitDataResponse.getWordResponseInferenceList().addAll(wordList);
            if (!tokenIdList.isEmpty()) {
                splitDataResponse.getTokenIdDataForInference().add(new InferenceProcessor.DataForInference(dataForInference.getIngestDocumentWrapper(), tokenIdProcessMap, tokenIdList));
            }
            if (wordList.isEmpty()) continue;
            splitDataResponse.getWordDataForInference().add(new InferenceProcessor.DataForInference(dataForInference.getIngestDocumentWrapper(), wordProcessMap, wordList));
        }
        return splitDataResponse;
    }

    @Override
    public void doBatchExecute(List<String> inferenceList, Consumer<List<?>> handler, Consumer<Exception> onException) {
        this.mlCommonsClientAccessor.inferenceSentencesWithMapResult((TextInferenceRequest)((TextInferenceRequest.TextInferenceRequestBuilder)((TextInferenceRequest.TextInferenceRequestBuilder)TextInferenceRequest.builder().modelId(this.modelId)).inputTexts(inferenceList)).build(), null, ActionListener.wrap(resultMaps -> {
            List<Map> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps).stream().map(vector -> PruneUtils.pruneSparseVector(this.pruneType, this.pruneRatio, vector)).toList();
            handler.accept(sparseVectors);
        }, onException));
    }

    @Override
    public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
        EventStatsManager.increment(EventStatName.SPARSE_ENCODING_PROCESSOR_EXECUTIONS);
        try {
            if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
                handler.accept(ingestDocumentWrappers);
                return;
            }
            List<InferenceProcessor.DataForInference> dataForInferences = this.getDataForInference(ingestDocumentWrappers);
            List<String> inferenceList = this.constructInferenceTexts(dataForInferences);
            if (inferenceList.isEmpty()) {
                handler.accept(ingestDocumentWrappers);
                return;
            }
            if (!this.skipExisting) {
                this.doSubBatchExecute(ingestDocumentWrappers, inferenceList, dataForInferences, handler);
                return;
            }
            this.openSearchClient.execute((ActionType)MultiGetAction.INSTANCE, (ActionRequest)this.buildMultiGetRequest(dataForInferences), ActionListener.wrap(response -> this.reuseOrGenerateEmbedding((MultiGetResponse)response, ingestDocumentWrappers, inferenceList, dataForInferences, handler, (InferenceFilter)this.textEmbeddingInferenceFilter), e -> this.updateWithExceptions(this.getIngestDocumentWrappers(dataForInferences), handler, (Exception)e)));
        }
        catch (Exception e2) {
            this.updateWithExceptions(ingestDocumentWrappers, handler, e2);
        }
    }

    private void generateAndSetMapInference(IngestDocument ingestDocument, Map<String, Object> processMap, List<String> inferenceList, PruneType pruneType, float pruneRatio, BiConsumer<IngestDocument, Exception> handler) {
        List<String> updatedInferenceList;
        Object indexObj = ingestDocument.getSourceAndMetadata().get("_index");
        String index = indexObj == null ? null : indexObj.toString();
        Set<String> sparseAnnFields = SparseFieldUtils.getSparseAnnFields(index, this.clusterService);
        HashMap<String, Object> tokenIdProcessMap = new HashMap<String, Object>();
        HashMap<String, Object> wordProcessMap = new HashMap<String, Object>();
        for (Map.Entry<String, Object> entry : processMap.entrySet()) {
            if (this.isSparseAnnField(sparseAnnFields, entry.getKey())) {
                tokenIdProcessMap.put(entry.getKey(), entry.getValue());
                continue;
            }
            wordProcessMap.put(entry.getKey(), entry.getValue());
        }
        AtomicInteger counter = new AtomicInteger(0);
        if (tokenIdProcessMap.isEmpty()) {
            this.generateAndSetMapInference(ingestDocument, processMap, inferenceList, pruneType, pruneRatio, null, handler);
            return;
        }
        counter.incrementAndGet();
        if (!wordProcessMap.isEmpty()) {
            counter.incrementAndGet();
        }
        if (counter.get() == 0) {
            handler.accept(ingestDocument, null);
            return;
        }
        List<Exception> exceptions = Collections.synchronizedList(new ArrayList());
        if (!tokenIdProcessMap.isEmpty()) {
            updatedInferenceList = this.createInferenceList(tokenIdProcessMap);
            this.generateAndSetMapInference(ingestDocument, tokenIdProcessMap, updatedInferenceList, pruneType, pruneRatio, (MLAlgoParams)TOKEN_ID_PARAMETER, this.getCountDownHandler(counter, exceptions, handler));
        }
        if (!wordProcessMap.isEmpty()) {
            updatedInferenceList = this.createInferenceList(wordProcessMap);
            this.generateAndSetMapInference(ingestDocument, wordProcessMap, updatedInferenceList, pruneType, pruneRatio, null, this.getCountDownHandler(counter, exceptions, handler));
        }
    }

    private BiConsumer<IngestDocument, Exception> getCountDownHandler(AtomicInteger counter, List<Exception> exceptions, BiConsumer<IngestDocument, Exception> originalHandler) {
        return (ingestDocument, e) -> {
            if (e != null) {
                exceptions.add((Exception)e);
            }
            if (counter.decrementAndGet() == 0) {
                if (!exceptions.isEmpty()) {
                    originalHandler.accept(null, (Exception)exceptions.getFirst());
                } else {
                    originalHandler.accept((IngestDocument)ingestDocument, (Exception)null);
                }
            }
        };
    }

    private BiConsumer<List<InferenceProcessor.DataForInference>, Exception> getCountDownBatchDataHandler(AtomicInteger counter, List<IngestDocumentWrapper> ingestDocumentWrappers, List<Exception> exceptions, Consumer<List<IngestDocumentWrapper>> handler) {
        return (dataForInferences, e) -> {
            if (e != null) {
                exceptions.add((Exception)e);
            }
            if (counter.decrementAndGet() == 0) {
                if (!exceptions.isEmpty()) {
                    this.updateWithExceptions(ingestDocumentWrappers, handler, (Exception)exceptions.getFirst());
                } else {
                    handler.accept(ingestDocumentWrappers);
                }
            }
        };
    }

    private boolean isSparseAnnField(Set<String> sparseAnnFields, String field) {
        int nestedDotIndex = field.indexOf(46);
        return nestedDotIndex == -1 && sparseAnnFields.contains(field);
    }

    @Generated
    public PruneType getPruneType() {
        return this.pruneType;
    }

    @Generated
    public float getPruneRatio() {
        return this.pruneRatio;
    }

    private static class SplitDataResponse {
        private List<String> tokenIdResponseInferenceList = new ArrayList<String>();
        private List<String> wordResponseInferenceList = new ArrayList<String>();
        private List<InferenceProcessor.DataForInference> tokenIdDataForInference = new ArrayList<InferenceProcessor.DataForInference>();
        private List<InferenceProcessor.DataForInference> wordDataForInference = new ArrayList<InferenceProcessor.DataForInference>();

        @Generated
        public SplitDataResponse() {
        }

        @Generated
        public List<String> getTokenIdResponseInferenceList() {
            return this.tokenIdResponseInferenceList;
        }

        @Generated
        public List<String> getWordResponseInferenceList() {
            return this.wordResponseInferenceList;
        }

        @Generated
        public List<InferenceProcessor.DataForInference> getTokenIdDataForInference() {
            return this.tokenIdDataForInference;
        }

        @Generated
        public List<InferenceProcessor.DataForInference> getWordDataForInference() {
            return this.wordDataForInference;
        }

        @Generated
        public void setTokenIdResponseInferenceList(List<String> tokenIdResponseInferenceList) {
            this.tokenIdResponseInferenceList = tokenIdResponseInferenceList;
        }

        @Generated
        public void setWordResponseInferenceList(List<String> wordResponseInferenceList) {
            this.wordResponseInferenceList = wordResponseInferenceList;
        }

        @Generated
        public void setTokenIdDataForInference(List<InferenceProcessor.DataForInference> tokenIdDataForInference) {
            this.tokenIdDataForInference = tokenIdDataForInference;
        }

        @Generated
        public void setWordDataForInference(List<InferenceProcessor.DataForInference> wordDataForInference) {
            this.wordDataForInference = wordDataForInference;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SplitDataResponse)) {
                return false;
            }
            SplitDataResponse other = (SplitDataResponse)o;
            if (!other.canEqual(this)) {
                return false;
            }
            List<String> this$tokenIdResponseInferenceList = this.getTokenIdResponseInferenceList();
            List<String> other$tokenIdResponseInferenceList = other.getTokenIdResponseInferenceList();
            if (this$tokenIdResponseInferenceList == null ? other$tokenIdResponseInferenceList != null : !((Object)this$tokenIdResponseInferenceList).equals(other$tokenIdResponseInferenceList)) {
                return false;
            }
            List<String> this$wordResponseInferenceList = this.getWordResponseInferenceList();
            List<String> other$wordResponseInferenceList = other.getWordResponseInferenceList();
            if (this$wordResponseInferenceList == null ? other$wordResponseInferenceList != null : !((Object)this$wordResponseInferenceList).equals(other$wordResponseInferenceList)) {
                return false;
            }
            List<InferenceProcessor.DataForInference> this$tokenIdDataForInference = this.getTokenIdDataForInference();
            List<InferenceProcessor.DataForInference> other$tokenIdDataForInference = other.getTokenIdDataForInference();
            if (this$tokenIdDataForInference == null ? other$tokenIdDataForInference != null : !((Object)this$tokenIdDataForInference).equals(other$tokenIdDataForInference)) {
                return false;
            }
            List<InferenceProcessor.DataForInference> this$wordDataForInference = this.getWordDataForInference();
            List<InferenceProcessor.DataForInference> other$wordDataForInference = other.getWordDataForInference();
            return !(this$wordDataForInference == null ? other$wordDataForInference != null : !((Object)this$wordDataForInference).equals(other$wordDataForInference));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof SplitDataResponse;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            List<String> $tokenIdResponseInferenceList = this.getTokenIdResponseInferenceList();
            result = result * 59 + ($tokenIdResponseInferenceList == null ? 43 : ((Object)$tokenIdResponseInferenceList).hashCode());
            List<String> $wordResponseInferenceList = this.getWordResponseInferenceList();
            result = result * 59 + ($wordResponseInferenceList == null ? 43 : ((Object)$wordResponseInferenceList).hashCode());
            List<InferenceProcessor.DataForInference> $tokenIdDataForInference = this.getTokenIdDataForInference();
            result = result * 59 + ($tokenIdDataForInference == null ? 43 : ((Object)$tokenIdDataForInference).hashCode());
            List<InferenceProcessor.DataForInference> $wordDataForInference = this.getWordDataForInference();
            result = result * 59 + ($wordDataForInference == null ? 43 : ((Object)$wordDataForInference).hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "SparseEncodingProcessor.SplitDataResponse(tokenIdResponseInferenceList=" + String.valueOf(this.getTokenIdResponseInferenceList()) + ", wordResponseInferenceList=" + String.valueOf(this.getWordResponseInferenceList()) + ", tokenIdDataForInference=" + String.valueOf(this.getTokenIdDataForInference()) + ", wordDataForInference=" + String.valueOf(this.getWordDataForInference()) + ")";
        }
    }
}

