/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.util.HybridSearchSortUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

public class NormalizationProcessorWorkflow {
    @Generated
    private static final Logger log = LogManager.getLogger(NormalizationProcessorWorkflow.class);
    private final ScoreNormalizer scoreNormalizer;
    private final ScoreCombiner scoreCombiner;

    public void execute(List<QuerySearchResult> querySearchResults, Optional<FetchSearchResult> fetchSearchResultOptional, ScoreNormalizationTechnique normalizationTechnique, ScoreCombinationTechnique combinationTechnique) {
        List<Integer> unprocessedDocIds = this.unprocessedDocIds(querySearchResults);
        log.debug("Pre-process query results");
        List<CompoundTopDocs> queryTopDocs = this.getQueryTopDocs(querySearchResults);
        log.debug("Do score normalization");
        this.scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);
        CombineScoresDto combineScoresDTO = CombineScoresDto.builder().queryTopDocs(queryTopDocs).scoreCombinationTechnique(combinationTechnique).querySearchResults(querySearchResults).sort(HybridSearchSortUtil.evaluateSortCriteria(querySearchResults, queryTopDocs)).build();
        log.debug("Do score combination");
        this.scoreCombiner.combineScores(combineScoresDTO);
        log.debug("Post-process query results after score normalization and combination");
        this.updateOriginalQueryResults(combineScoresDTO);
        this.updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds);
    }

    private List<CompoundTopDocs> getQueryTopDocs(List<QuerySearchResult> querySearchResults) {
        List<CompoundTopDocs> queryTopDocs = querySearchResults.stream().filter(searchResult -> Objects.nonNull(searchResult.topDocs())).map(querySearchResult -> querySearchResult.topDocs().topDocs).map(CompoundTopDocs::new).collect(Collectors.toList());
        if (queryTopDocs.size() != querySearchResults.size()) {
            throw new IllegalStateException(String.format(Locale.ROOT, "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", querySearchResults.size(), queryTopDocs.size()));
        }
        return queryTopDocs;
    }

    private void updateOriginalQueryResults(CombineScoresDto combineScoresDTO) {
        List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
        List<CompoundTopDocs> queryTopDocs = this.getCompoundTopDocs(combineScoresDTO, querySearchResults);
        Sort sort = combineScoresDTO.getSort();
        for (int index = 0; index < querySearchResults.size(); ++index) {
            QuerySearchResult querySearchResult = querySearchResults.get(index);
            CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
            TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(this.buildTopDocs(updatedTopDocs, sort), this.maxScoreForShard(updatedTopDocs, sort != null));
            querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
        }
    }

    private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
        List<CompoundTopDocs> queryTopDocs = combineScoresDTO.getQueryTopDocs();
        if (querySearchResults.size() != queryTopDocs.size()) {
            throw new IllegalStateException(String.format(Locale.ROOT, "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", querySearchResults.size(), queryTopDocs.size()));
        }
        return queryTopDocs;
    }

    private float maxScoreForShard(CompoundTopDocs updatedTopDocs, boolean isSortEnabled) {
        if (updatedTopDocs.getTotalHits().value == 0L || updatedTopDocs.getScoreDocs().isEmpty()) {
            return ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND.floatValue();
        }
        if (isSortEnabled) {
            float maxScore = ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND.floatValue();
            for (ScoreDoc scoreDoc : updatedTopDocs.getScoreDocs()) {
                maxScore = Math.max(maxScore, scoreDoc.score);
            }
            return maxScore;
        }
        return updatedTopDocs.getScoreDocs().get((int)0).score;
    }

    private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
        if (sort != null) {
            return new TopFieldDocs(updatedTopDocs.getTotalHits(), (ScoreDoc[])updatedTopDocs.getScoreDocs().toArray(new FieldDoc[0]), sort.getSort());
        }
        return new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0]));
    }

    private void updateOriginalFetchResults(List<QuerySearchResult> querySearchResults, Optional<FetchSearchResult> fetchSearchResultOptional, List<Integer> docIds) {
        if (fetchSearchResultOptional.isEmpty()) {
            return;
        }
        FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get();
        boolean requestCache = Objects.nonNull(querySearchResults) && !querySearchResults.isEmpty() && Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache()) && querySearchResults.get(0).getShardSearchRequest().requestCache() != false;
        SearchHit[] searchHitArray = this.getSearchHits(docIds, fetchSearchResult, requestCache);
        HashMap<Integer, SearchHit> docIdToSearchHit = new HashMap<Integer, SearchHit>();
        for (int i = 0; i < searchHitArray.length; ++i) {
            int originalDocId = docIds.get(i);
            docIdToSearchHit.put(originalDocId, searchHitArray[i]);
        }
        QuerySearchResult querySearchResult = querySearchResults.get(0);
        TopDocs topDocs = querySearchResult.topDocs().topDocs;
        SearchHit[] updatedSearchHitArray = (SearchHit[])Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
            SearchHit searchHit = (SearchHit)docIdToSearchHit.get(scoreDoc.doc);
            searchHit.score(scoreDoc.score);
            return searchHit;
        }).toArray(SearchHit[]::new);
        SearchHits updatedSearchHits = new SearchHits(updatedSearchHitArray, querySearchResult.getTotalHits(), querySearchResult.getMaxScore());
        fetchSearchResult.hits(updatedSearchHits);
    }

    private SearchHit[] getSearchHits(List<Integer> docIds, FetchSearchResult fetchSearchResult, boolean requestCache) {
        SearchHits searchHits = fetchSearchResult.hits();
        SearchHit[] searchHitArray = searchHits.getHits();
        if (Objects.isNull(searchHitArray)) {
            throw new IllegalStateException("score normalization processor cannot produce final query result, fetch query phase returns empty results");
        }
        if (!requestCache && searchHitArray.length != docIds.size() || requestCache && docIds.size() < searchHitArray.length) {
            throw new IllegalStateException(String.format(Locale.ROOT, "score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]", searchHitArray.length, docIds.size()));
        }
        return searchHitArray;
    }

    private List<Integer> unprocessedDocIds(List<QuerySearchResult> querySearchResults) {
        List<Integer> docIds = querySearchResults.isEmpty() ? List.of() : Arrays.stream(querySearchResults.get((int)0).topDocs().topDocs.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList());
        return docIds;
    }

    @Generated
    public NormalizationProcessorWorkflow(ScoreNormalizer scoreNormalizer, ScoreCombiner scoreCombiner) {
        this.scoreNormalizer = scoreNormalizer;
        this.scoreCombiner = scoreCombiner;
    }
}

