/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

public class MinMaxScoreNormalizationTechnique
implements ScoreNormalizationTechnique {
    public static final String TECHNIQUE_NAME = "min_max";
    private static final float MIN_SCORE = 0.001f;
    private static final float SINGLE_RESULT_SCORE = 1.0f;

    @Override
    public void normalize(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = queryTopDocs.stream().filter(Objects::nonNull).filter(topDocs -> topDocs.getTopDocs().size() > 0).findAny().get().getTopDocs().size();
        float[] minScoresPerSubquery = this.getMinScores(queryTopDocs, numOfSubqueries);
        float[] maxScoresPerSubquery = this.getMaxScores(queryTopDocs, numOfSubqueries);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = this.normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]);
                }
            }
        }
    }

    private float[] getMaxScores(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        float[] maxScores = new float[numOfSubqueries];
        Arrays.fill(maxScores, Float.MIN_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                maxScores[j] = Math.max(maxScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).max(Float::compare).orElse(Float.valueOf(Float.MIN_VALUE)).floatValue());
            }
        }
        return maxScores;
    }

    private float[] getMinScores(List<CompoundTopDocs> queryTopDocs, int numOfScores) {
        float[] minScores = new float[numOfScores];
        Arrays.fill(minScores, Float.MAX_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                minScores[j] = Math.min(minScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).min(Float::compare).orElse(Float.valueOf(Float.MAX_VALUE)).floatValue());
            }
        }
        return minScores;
    }

    private float normalizeSingleScore(float score, float minScore, float maxScore) {
        if (Floats.compare((float)maxScore, (float)minScore) == 0 && Floats.compare((float)maxScore, (float)score) == 0) {
            return 1.0f;
        }
        float normalizedScore = (score - minScore) / (maxScore - minScore);
        return normalizedScore == 0.0f ? 0.001f : normalizedScore;
    }

    @Generated
    public String toString() {
        return "MinMaxScoreNormalizationTechnique(TECHNIQUE_NAME=min_max)";
    }
}

