/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.rank.textsimilarity;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;

public class TextSimilarityRankFeaturePhaseRankCoordinatorContext
extends RankFeaturePhaseRankCoordinatorContext {
    protected final Client client;
    protected final String inferenceId;
    protected final String inferenceText;
    protected final Float minScore;

    public TextSimilarityRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize, Client client, String inferenceId, String inferenceText, Float minScore) {
        super(size, from, rankWindowSize);
        this.client = client;
        this.inferenceId = inferenceId;
        this.inferenceText = inferenceText;
        this.minScore = minScore;
    }

    protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
        ActionListener inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> {
            InferenceServiceResults results = r.getResults();
            assert (results instanceof RankedDocsResults);
            List rankedDocs = ((RankedDocsResults)results).getRankedDocs();
            if (rankedDocs.size() != featureDocs.length) {
                l.onFailure((Exception)new IllegalStateException("Reranker input document count and returned score count mismatch: [" + featureDocs.length + "] vs [" + rankedDocs.size() + "]"));
            } else {
                float[] scores = this.extractScoresFromRankedDocs(rankedDocs);
                l.onResponse((Object)scores);
            }
        });
        ActionListener topNListener = scoreListener.delegateFailureAndWrap((l, r) -> {
            TaskSettings patt1$temp;
            TaskSettings patt0$temp;
            Integer configuredTopN = null;
            if (!r.getEndpoints().isEmpty() && (patt0$temp = ((ModelConfigurations)r.getEndpoints().get(0)).getTaskSettings()) instanceof CohereRerankTaskSettings) {
                CohereRerankTaskSettings cohereTaskSettings = (CohereRerankTaskSettings)patt0$temp;
                configuredTopN = cohereTaskSettings.getTopNDocumentsOnly();
            } else if (!r.getEndpoints().isEmpty() && (patt1$temp = ((ModelConfigurations)r.getEndpoints().get(0)).getTaskSettings()) instanceof GoogleVertexAiRerankTaskSettings) {
                GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings = (GoogleVertexAiRerankTaskSettings)patt1$temp;
                configuredTopN = googleVertexAiTaskSettings.topN();
            }
            if (configuredTopN != null && configuredTopN < this.rankWindowSize) {
                l.onFailure((Exception)new IllegalArgumentException("Inference endpoint [" + this.inferenceId + "] is configured to return the top [" + configuredTopN + "] results, but rank_window_size is [" + this.rankWindowSize + "]. Reduce rank_window_size to be less than or equal to the configured top N value."));
                return;
            }
            if (featureDocs.length == 0) {
                inferenceListener.onResponse((Object)new InferenceAction.Response((InferenceServiceResults)new RankedDocsResults(List.of())));
            } else {
                List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
                InferenceAction.Request inferenceRequest = this.generateRequest(featureData);
                try {
                    this.client.execute((ActionType)InferenceAction.INSTANCE, (ActionRequest)inferenceRequest, inferenceListener);
                }
                finally {
                    inferenceRequest.decRef();
                }
            }
        });
        GetInferenceModelAction.Request getModelRequest = new GetInferenceModelAction.Request(this.inferenceId, TaskType.RERANK);
        this.client.execute((ActionType)GetInferenceModelAction.INSTANCE, (ActionRequest)getModelRequest, topNListener);
    }

    protected RankFeatureDoc[] preprocess(RankFeatureDoc[] originalDocs) {
        ArrayList<RankFeatureDoc> docs = new ArrayList<RankFeatureDoc>();
        for (RankFeatureDoc doc : originalDocs) {
            if (this.minScore != null && !(doc.score >= this.minScore.floatValue())) continue;
            doc.score = TextSimilarityRankFeaturePhaseRankCoordinatorContext.normalizeScore(doc.score);
            docs.add(doc);
        }
        docs.sort(RankDoc::compareTo);
        return docs.toArray(new RankFeatureDoc[0]);
    }

    protected InferenceAction.Request generateRequest(List<String> docFeatures) {
        return new InferenceAction.Request(TaskType.RERANK, this.inferenceId, this.inferenceText, docFeatures, Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, false);
    }

    private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
        float[] scores = new float[rankedDocs.size()];
        for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
            scores[rankedDoc.index()] = rankedDoc.relevanceScore();
        }
        return scores;
    }

    private static float normalizeScore(float score) {
        return Math.max(score, 0.0f) + Math.min((float)Math.exp(score), 1.0f);
    }
}

