/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.search.processor.mmr;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.index.KNNVectorSimilarityFunction;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.search.processor.mmr.MMRRerankContext;
import org.opensearch.knn.search.processor.mmr.MMRUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.ProcessorGenerationContext;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.pipeline.SystemGeneratedProcessor;
import org.opensearch.search.profile.SearchProfileShardResults;

public class MMRRerankProcessor
implements SearchResponseProcessor,
SystemGeneratedProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(MMRRerankProcessor.class);
    public static final String TYPE = "mmr_rerank";
    public static final String DESCRIPTION = "This is a system generated processor that will rerank the response basedon Maximal Marginal Relevance.";
    private final String tag;
    private final boolean ignoreFailure;

    public SearchResponse processResponse(SearchRequest request, SearchResponse response) {
        throw new UnsupportedOperationException(String.format(Locale.ROOT, "Should not try to use %s to process a search response without PipelineProcessingContext.", TYPE));
    }

    public SearchResponse processResponse(SearchRequest request, SearchResponse searchResponse, PipelineProcessingContext requestContext) throws IOException {
        long startNanos = System.nanoTime();
        if (this.isEmptyResponse(searchResponse)) {
            return searchResponse;
        }
        MMRRerankContext mmrContext = this.requireMMRContext(requestContext);
        KNNVectorSimilarityFunction similarityFunction = mmrContext.getSpaceType().getKnnVectorSimilarityFunction();
        int originalQuerySize = mmrContext.getOriginalQuerySize();
        float diversity = mmrContext.getDiversity().floatValue();
        boolean isFloatVector = VectorDataType.FLOAT.equals((Object)mmrContext.getVectorDataType());
        ArrayList<SearchHit> candidates = new ArrayList<SearchHit>(List.of(searchResponse.getHits().getHits()));
        Map<String, Object> docVectors = this.extractVectors(candidates, mmrContext.getVectorFieldPath(), mmrContext.getIndexToVectorFieldPathMap(), isFloatVector);
        List<SearchHit> selected = this.selectHitsWithMMR(candidates, docVectors, similarityFunction, diversity, originalQuerySize, isFloatVector);
        this.applyFetchSourceFilterIfNeeded(selected, mmrContext);
        float maxSelectedScore = selected.stream().map(SearchHit::getScore).max(Float::compare).orElse(Float.valueOf(Float.NEGATIVE_INFINITY)).floatValue();
        SearchHits newHits = new SearchHits(selected.toArray(new SearchHit[0]), searchResponse.getHits().getTotalHits(), maxSelectedScore, searchResponse.getHits().getSortFields(), searchResponse.getHits().getCollapseField(), searchResponse.getHits().getCollapseValues());
        SearchResponseSections newSections = new SearchResponseSections(newHits, searchResponse.getAggregations(), searchResponse.getSuggest(), searchResponse.isTimedOut(), searchResponse.isTerminatedEarly(), new SearchProfileShardResults(searchResponse.getProfileResults()), searchResponse.getNumReducePhases(), searchResponse.getInternalResponse().getSearchExtBuilders());
        long elapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNanos);
        log.debug("MMR rerank took: {} ms", (Object)elapsedMillis);
        return new SearchResponse(newSections, searchResponse.getScrollId(), searchResponse.getTotalShards(), searchResponse.getSuccessfulShards(), searchResponse.getSkippedShards(), searchResponse.getTook().millis(), searchResponse.getPhaseTook(), searchResponse.getShardFailures(), searchResponse.getClusters(), searchResponse.pointInTimeId());
    }

    private boolean isEmptyResponse(SearchResponse response) {
        return response == null || response.getHits() == null || response.getHits().getHits() == null || response.getHits().getHits().length == 0;
    }

    private MMRRerankContext requireMMRContext(PipelineProcessingContext requestContext) {
        Object attr = requestContext.getAttribute("mmr.rerank_context");
        if (attr == null) {
            throw new IllegalStateException("MMR rerank context cannot be null");
        }
        MMRRerankContext ctx = (MMRRerankContext)attr;
        if (ctx.getSpaceType() == null) {
            throw new IllegalStateException("Space type in MMR rerank context cannot be null");
        }
        if (ctx.getOriginalQuerySize() == null) {
            throw new IllegalStateException("Original query size in MMR rerank context cannot be null");
        }
        if (ctx.getDiversity() == null) {
            throw new IllegalStateException("Diversity in MMR rerank context cannot be null");
        }
        if (ctx.getVectorDataType() == null) {
            throw new IllegalStateException("Vector data type in MMR rerank context cannot be null");
        }
        return ctx;
    }

    private Map<String, Object> extractVectors(List<SearchHit> hits, String defaultVectorFieldPath, Map<String, String> indexToVectorFieldPathMap, boolean isFloatVector) {
        ConcurrentHashMap<String, Object> vectors = new ConcurrentHashMap<String, Object>();
        for (SearchHit hit : hits) {
            String overridePath;
            String vectorPath = defaultVectorFieldPath;
            if (indexToVectorFieldPathMap != null && (overridePath = indexToVectorFieldPathMap.get(hit.getIndex())) != null && !overridePath.isBlank()) {
                vectorPath = overridePath;
            }
            Object embedding = MMRUtil.extractVectorFromHit(hit.getSourceAsMap(), vectorPath, hit.getId(), isFloatVector);
            vectors.put(hit.getId(), embedding);
        }
        return vectors;
    }

    private List<SearchHit> selectHitsWithMMR(List<SearchHit> candidates, Map<String, Object> docVectors, KNNVectorSimilarityFunction similarityFunction, float diversity, int targetSize, boolean isFloatVector) {
        ArrayList<SearchHit> selected = new ArrayList<SearchHit>();
        HashMap<String, Float> simCache = new HashMap<String, Float>();
        while (selected.size() < targetSize && !candidates.isEmpty()) {
            Pair bestCandidate = null;
            double bestScore = Double.NEGATIVE_INFINITY;
            for (SearchHit candidate : candidates) {
                String candidateId = candidate.getId();
                float maxSimToSelected = 0.0f;
                for (SearchHit sel : selected) {
                    String selId = sel.getId();
                    String key = this.cacheKey(candidateId, selId);
                    String symKey = this.cacheKey(selId, candidateId);
                    float sim = simCache.computeIfAbsent(key, k -> {
                        if (isFloatVector) {
                            return Float.valueOf(similarityFunction.compare((float[])docVectors.get(candidateId), (float[])docVectors.get(selId)));
                        }
                        return Float.valueOf(similarityFunction.compare((byte[])docVectors.get(candidateId), (byte[])docVectors.get(selId)));
                    }).floatValue();
                    simCache.putIfAbsent(symKey, Float.valueOf(sim));
                    maxSimToSelected = Math.max(maxSimToSelected, sim);
                }
                double score = (1.0f - diversity) * candidate.getScore() - diversity * maxSimToSelected;
                if (!(score > bestScore)) continue;
                bestScore = score;
                bestCandidate = Pair.of((Object)candidate, (Object)score);
            }
            if (bestCandidate == null) continue;
            SearchHit bestHit = (SearchHit)bestCandidate.getLeft();
            selected.add(bestHit);
            candidates.remove(bestHit);
        }
        return selected;
    }

    private void applyFetchSourceFilterIfNeeded(List<SearchHit> hits, MMRRerankContext mmrContext) throws IOException {
        FetchSourceContext fetchSourceContext = mmrContext.getOriginalFetchSourceContext();
        if (fetchSourceContext == null) {
            return;
        }
        if (!fetchSourceContext.fetchSource()) {
            for (SearchHit hit : hits) {
                hit.sourceRef(null);
            }
            return;
        }
        Function filter = fetchSourceContext.getFilter();
        for (SearchHit hit : hits) {
            Map filtered = (Map)filter.apply(hit.getSourceAsMap());
            hit.sourceRef(BytesReference.bytes((XContentBuilder)XContentFactory.jsonBuilder().map(filtered)));
        }
    }

    private String cacheKey(String id1, String id2) {
        return String.join((CharSequence)":", id1, id2);
    }

    public SystemGeneratedProcessor.ExecutionStage getExecutionStage() {
        return SystemGeneratedProcessor.ExecutionStage.PRE_USER_DEFINED;
    }

    public String getType() {
        return TYPE;
    }

    public String getTag() {
        return this.tag;
    }

    public String getDescription() {
        return DESCRIPTION;
    }

    public boolean isIgnoreFailure() {
        return this.ignoreFailure;
    }

    @Generated
    public MMRRerankProcessor(String tag, boolean ignoreFailure) {
        this.tag = tag;
        this.ignoreFailure = ignoreFailure;
    }

    public static class MMRRerankProcessorFactory
    implements SystemGeneratedProcessor.SystemGeneratedFactory<SearchResponseProcessor> {
        public static final String TYPE = "mmr_rerank_factory";

        public boolean shouldGenerate(ProcessorGenerationContext context) {
            return MMRUtil.shouldGenerateMMRProcessor(context);
        }

        public SearchResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String tag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) throws Exception {
            return new MMRRerankProcessor(tag, ignoreFailure);
        }
    }
}

