/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.judgments;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.StepListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
import org.opensearch.searchrelevance.common.MLConstants;
import org.opensearch.searchrelevance.dao.JudgmentCacheDao;
import org.opensearch.searchrelevance.dao.QuerySetDao;
import org.opensearch.searchrelevance.dao.SearchConfigurationDao;
import org.opensearch.searchrelevance.exception.SearchRelevanceException;
import org.opensearch.searchrelevance.executors.LlmJudgmentTaskManager;
import org.opensearch.searchrelevance.judgments.BaseJudgmentsProcessor;
import org.opensearch.searchrelevance.judgments.JudgmentDataTransformer;
import org.opensearch.searchrelevance.ml.ChunkResult;
import org.opensearch.searchrelevance.ml.MLAccessor;
import org.opensearch.searchrelevance.model.JudgmentCache;
import org.opensearch.searchrelevance.model.JudgmentType;
import org.opensearch.searchrelevance.model.QuerySet;
import org.opensearch.searchrelevance.model.SearchConfiguration;
import org.opensearch.searchrelevance.model.builder.SearchRequestBuilder;
import org.opensearch.searchrelevance.stats.events.EventStatName;
import org.opensearch.searchrelevance.stats.events.EventStatsManager;
import org.opensearch.searchrelevance.utils.ParserUtils;
import org.opensearch.searchrelevance.utils.TimeUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.Client;

public class LlmJudgmentsProcessor
implements BaseJudgmentsProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(LlmJudgmentsProcessor.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final MLAccessor mlAccessor;
    private final QuerySetDao querySetDao;
    private final SearchConfigurationDao searchConfigurationDao;
    private final JudgmentCacheDao judgmentCacheDao;
    private final Client client;
    private final ThreadPool threadPool;
    private final LlmJudgmentTaskManager taskManager;

    @Inject
    public LlmJudgmentsProcessor(MLAccessor mlAccessor, QuerySetDao querySetDao, SearchConfigurationDao searchConfigurationDao, JudgmentCacheDao judgmentCacheDao, Client client, ThreadPool threadPool) {
        this.mlAccessor = mlAccessor;
        this.querySetDao = querySetDao;
        this.searchConfigurationDao = searchConfigurationDao;
        this.judgmentCacheDao = judgmentCacheDao;
        this.client = client;
        this.threadPool = threadPool;
        this.taskManager = new LlmJudgmentTaskManager(threadPool);
    }

    @Override
    public JudgmentType getJudgmentType() {
        return JudgmentType.LLM_JUDGMENT;
    }

    @Override
    public void generateJudgmentRating(Map<String, Object> metadata, ActionListener<List<Map<String, Object>>> listener) {
        this.threadPool.executor("generic").execute(() -> this.generateJudgmentRatingInternal(metadata, listener));
    }

    private void generateJudgmentRatingInternal(Map<String, Object> metadata, ActionListener<List<Map<String, Object>>> listener) {
        try {
            EventStatsManager.increment(EventStatName.LLM_JUDGMENT_RATING_GENERATIONS);
            String querySetId = (String)metadata.get("querySetId");
            List searchConfigurationList = (List)metadata.get("searchConfigurationList");
            int size = (Integer)metadata.get("size");
            String modelId = (String)metadata.get("modelId");
            int tokenLimit = (Integer)metadata.get("tokenLimit");
            List contextFields = (List)metadata.get("contextFields");
            boolean ignoreFailure = (Boolean)metadata.get("ignoreFailure");
            QuerySet querySet = this.querySetDao.getQuerySetSync(querySetId);
            List<SearchConfiguration> searchConfigurations = searchConfigurationList.stream().map(id -> this.searchConfigurationDao.getSearchConfigurationSync((String)id)).collect(Collectors.toList());
            this.generateLLMJudgmentsAsync(modelId, size, tokenLimit, contextFields, querySet, searchConfigurations, ignoreFailure, listener);
        }
        catch (Exception e) {
            log.error("Failed to generate LLM judgments", (Throwable)e);
            listener.onFailure((Exception)((Object)new SearchRelevanceException("Failed to generate LLM judgments", e, RestStatus.INTERNAL_SERVER_ERROR)));
        }
    }

    private void generateLLMJudgmentsAsync(String modelId, int size, int tokenLimit, List<String> contextFields, QuerySet querySet, List<SearchConfiguration> searchConfigurations, boolean ignoreFailure, ActionListener<List<Map<String, Object>>> listener) {
        List queryTextWithReferences = querySet.querySetQueries().stream().map(e -> e.queryText()).collect(Collectors.toList());
        int totalQueries = queryTextWithReferences.size();
        log.info("Starting LLM judgment generation for {} total queries", (Object)totalQueries);
        StepListener cacheIndexListener = new StepListener();
        this.judgmentCacheDao.createIndexIfAbsent((StepListener<Void>)cacheIndexListener);
        cacheIndexListener.whenComplete(indexResult -> {
            log.debug("Judgment cache index creation completed, proceeding with task scheduling");
            this.taskManager.scheduleTasksAsync(queryTextWithReferences, queryTextWithReference -> {
                try {
                    return this.processQueryTextAsync(modelId, size, tokenLimit, contextFields, searchConfigurations, (String)queryTextWithReference, ignoreFailure);
                }
                catch (Exception e) {
                    if (ignoreFailure) {
                        log.warn("Query processing failed, returning empty result for: {}", queryTextWithReference, (Object)e);
                        return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, Map.of());
                    }
                    log.error("Query processing failed for: {}", queryTextWithReference, (Object)e);
                    throw new RuntimeException("Query processing failed: " + queryTextWithReference, e);
                }
            }, ignoreFailure, (ActionListener<List<Map<String, Object>>>)ActionListener.wrap(results -> {
                int processedQueries = results.size();
                int successQueries = (int)results.stream().mapToLong(result -> {
                    List ratings = (List)result.get("ratings");
                    return ratings != null && !ratings.isEmpty() ? 1L : 0L;
                }).sum();
                int failureQueries = processedQueries - successQueries;
                log.info("LLM judgment generation completed - Total: {}, Processed: {}, Success: {}, Failure: {}", (Object)totalQueries, (Object)processedQueries, (Object)successQueries, (Object)failureQueries);
                log.info("Calling final listener.onResponse with {} results", (Object)results.size());
                listener.onResponse(results);
            }, error -> {
                log.error("LLM judgment generation failed - Total: {}, All failed", (Object)totalQueries, error);
                listener.onFailure(error);
            }));
        }, indexError -> {
            log.warn("Failed to create judgment cache index, proceeding without cache optimization", (Throwable)indexError);
            this.taskManager.scheduleTasksAsync(queryTextWithReferences, queryTextWithReference -> {
                try {
                    return this.processQueryTextAsync(modelId, size, tokenLimit, contextFields, searchConfigurations, (String)queryTextWithReference, ignoreFailure);
                }
                catch (Exception e) {
                    if (ignoreFailure) {
                        log.warn("Query processing failed, returning empty result for: {}", queryTextWithReference, (Object)e);
                        return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, Map.of());
                    }
                    log.error("Query processing failed for: {}", queryTextWithReference, (Object)e);
                    throw new RuntimeException("Query processing failed: " + queryTextWithReference, e);
                }
            }, ignoreFailure, (ActionListener<List<Map<String, Object>>>)ActionListener.wrap(results -> {
                int processedQueries = results.size();
                int successQueries = (int)results.stream().mapToLong(result -> {
                    List ratings = (List)result.get("ratings");
                    return ratings != null && !ratings.isEmpty() ? 1L : 0L;
                }).sum();
                int failureQueries = processedQueries - successQueries;
                log.info("LLM judgment generation completed - Total: {}, Processed: {}, Success: {}, Failure: {}", (Object)totalQueries, (Object)processedQueries, (Object)successQueries, (Object)failureQueries);
                log.info("Calling final listener.onResponse with {} results", (Object)results.size());
                listener.onResponse(results);
            }, error -> {
                log.error("LLM judgment generation failed - Total: {}, All failed", (Object)totalQueries, error);
                listener.onFailure(error);
            }));
        });
    }

    private Map<String, Object> processQueryTextAsync(String modelId, int size, int tokenLimit, List<String> contextFields, List<SearchConfiguration> searchConfigurations, String queryTextWithReference, boolean ignoreFailure) {
        log.info("Processing query text judgment: {}", (Object)queryTextWithReference);
        ConcurrentHashMap<String, SearchHit> allHits = new ConcurrentHashMap<String, SearchHit>();
        ConcurrentHashMap<String, String> docIdToScore = new ConcurrentHashMap<String, String>();
        String queryText = queryTextWithReference.split("#", 2)[0];
        try {
            this.processSearchConfigurationsAsync(searchConfigurations, queryText, size, allHits, ignoreFailure);
            ArrayList<String> docIds = new ArrayList<String>(allHits.keySet());
            String index = searchConfigurations.get(0).index();
            List<String> unprocessedDocIds = this.deduplicateFromCache(index, queryTextWithReference, contextFields, docIds, docIdToScore, ignoreFailure);
            if (!unprocessedDocIds.isEmpty()) {
                this.processWithLLM(modelId, queryTextWithReference, tokenLimit, contextFields, unprocessedDocIds, allHits, index, docIdToScore);
            }
            Map<String, Object> result = JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, docIdToScore);
            log.debug("Query processing completed for: {} with {} ratings", (Object)queryTextWithReference, (Object)docIdToScore.size());
            return result;
        }
        catch (Exception e) {
            log.warn("Query processing failed for: {} with {} ratings collected. Error: {}", (Object)queryTextWithReference, (Object)docIdToScore.size(), (Object)e.getMessage(), (Object)e);
            return JudgmentDataTransformer.createJudgmentResult(queryTextWithReference, docIdToScore);
        }
    }

    private void processSearchConfigurationsAsync(List<SearchConfiguration> searchConfigurations, String queryText, int size, ConcurrentMap<String, SearchHit> allHits, boolean ignoreFailure) throws Exception {
        List<CompletableFuture> searchFutures = searchConfigurations.stream().map(config -> {
            CompletableFuture future = new CompletableFuture();
            SearchRequest searchRequest = SearchRequestBuilder.buildSearchRequest(config.index(), config.query(), queryText, config.searchPipeline(), size);
            this.client.search(searchRequest, ActionListener.wrap(future::complete, future::completeExceptionally));
            return ((CompletableFuture)future.thenAccept(response -> {
                if (response.getHits().getTotalHits().value() > 0L) {
                    for (SearchHit hit : response.getHits().getHits()) {
                        allHits.put(hit.getId(), hit);
                    }
                    log.debug("Collected {} hits from index: {}", (Object)response.getHits().getHits().length, (Object)config.index());
                }
            })).exceptionally(e -> {
                log.warn("Search failed for index: {}, continuing with other searches", (Object)config.index(), e);
                return null;
            });
        }).toList();
        CompletableFuture.allOf(searchFutures.toArray(new CompletableFuture[0])).join();
        log.info("Search phase completed. Total hits collected: {}", (Object)allHits.size());
    }

    private List<String> deduplicateFromCache(String index, String queryTextWithReference, List<String> contextFields, List<String> docIds, ConcurrentMap<String, String> docIdToScore, boolean ignoreFailure) throws Exception {
        List processedDocIds = Collections.synchronizedList(new ArrayList());
        AtomicBoolean hasFailure = new AtomicBoolean(false);
        List<CompletableFuture> cacheFutures = docIds.stream().map(docId -> {
            String compositeKey = ParserUtils.combinedIndexAndDocId(index, docId);
            CompletableFuture future = new CompletableFuture();
            this.judgmentCacheDao.getJudgmentCache(queryTextWithReference, compositeKey, contextFields, (ActionListener<SearchResponse>)ActionListener.wrap(future::complete, future::completeExceptionally));
            return ((CompletableFuture)future.thenAccept(response -> {
                if (response.getHits().getTotalHits().value() > 0L) {
                    SearchHit hit = response.getHits().getHits()[0];
                    Map source = hit.getSourceAsMap();
                    String rating = (String)source.get("rating");
                    log.debug("Found cached judgment for docId: {}, rating: {}", docId, (Object)rating);
                    docIdToScore.put((String)docId, rating);
                    processedDocIds.add(docId);
                }
            })).exceptionally(e -> {
                log.debug("Cache lookup failed for docId: {} - continuing without cache", docId);
                return null;
            });
        }).toList();
        CompletableFuture.allOf(cacheFutures.toArray(new CompletableFuture[0])).join();
        List<String> unprocessedDocIds = docIds.stream().filter(docId -> !processedDocIds.contains(docId)).collect(Collectors.toList());
        log.info("Cache deduplication completed. Cached: {}, Unprocessed: {}", (Object)processedDocIds.size(), (Object)unprocessedDocIds.size());
        return unprocessedDocIds;
    }

    private void processWithLLM(String modelId, String queryTextWithReference, int tokenLimit, List<String> contextFields, List<String> unprocessedDocIds, ConcurrentMap<String, SearchHit> allHits, String index, ConcurrentMap<String, String> docIdToScore) throws Exception {
        HashMap<String, String> unionHits = new HashMap<String, String>();
        for (String docId : unprocessedDocIds) {
            SearchHit hit = (SearchHit)allHits.get(docId);
            String compositeKey = ParserUtils.combinedIndexAndDocId(index, docId);
            String contextSource = this.getContextSource(hit, contextFields);
            unionHits.put(compositeKey, contextSource);
        }
        log.info("Processing {} uncached docs with LLM", (Object)unionHits.size());
        PlainActionFuture llmFuture = PlainActionFuture.newFuture();
        this.generateLLMJudgmentForQueryText(modelId, queryTextWithReference, tokenLimit, contextFields, unionHits, new HashMap<String, String>(), (ActionListener<Map<String, String>>)llmFuture);
        Map llmResults = (Map)llmFuture.actionGet();
        docIdToScore.putAll(llmResults);
        log.info("LLM processing completed. Generated {} ratings", (Object)llmResults.size());
    }

    private void generateLLMJudgmentForQueryText(final String modelId, final String queryTextWithReference, int tokenLimit, final List<String> contextFields, Map<String, String> unprocessedUnionHits, Map<String, String> docIdToRating, final ActionListener<Map<String, String>> listener) {
        log.debug("calculating LLM evaluation with modelId: {} and unprocessed unionHits: {}", (Object)modelId, unprocessedUnionHits);
        log.debug("processed docIdToRating before llm evaluation: {}", docIdToRating);
        if (unprocessedUnionHits.isEmpty()) {
            log.info("All hits found in cache, returning cached results for query: {}", (Object)queryTextWithReference);
            listener.onResponse(docIdToRating);
            return;
        }
        String[] queryTextRefArr = queryTextWithReference.split("#");
        String queryText = queryTextRefArr[0];
        String referenceAnswer = queryTextRefArr.length > 1 ? queryTextWithReference.split("#", 2)[1] : null;
        final ConcurrentHashMap<String, String> processedRatings = new ConcurrentHashMap<String, String>(docIdToRating);
        final ConcurrentHashMap combinedResponses = new ConcurrentHashMap();
        final AtomicBoolean hasFailure = new AtomicBoolean(false);
        this.mlAccessor.predict(modelId, tokenLimit, queryText, referenceAnswer, unprocessedUnionHits, new ActionListener<ChunkResult>(){
            final /* synthetic */ LlmJudgmentsProcessor this$0;
            {
                this.this$0 = this$0;
            }

            public void onResponse(ChunkResult chunkResult) {
                try {
                    Map<Integer, String> succeededChunks = chunkResult.getSucceededChunks();
                    for (Map.Entry<Integer, String> entry : succeededChunks.entrySet()) {
                        Integer chunkIndex = entry.getKey();
                        if (combinedResponses.containsKey(chunkIndex)) continue;
                        log.debug("response before sanitization: {}", (Object)entry.getValue());
                        String sanitizedResponse = MLConstants.sanitizeLLMResponse(entry.getValue());
                        log.debug("response after sanitization: {}", (Object)sanitizedResponse);
                        List scores = (List)OBJECT_MAPPER.readValue(sanitizedResponse, (TypeReference)new TypeReference<List<Map<String, Object>>>(this){});
                        combinedResponses.put(chunkIndex, scores);
                    }
                    this.this$0.logFailedChunks(chunkResult);
                    if (chunkResult.isLastChunk() && !hasFailure.get()) {
                        log.info("Processing final results for query: {}. Successful chunks: {}, Failed chunks: {}", (Object)queryTextWithReference, (Object)chunkResult.getSuccessfulChunksCount(), (Object)chunkResult.getFailedChunksCount());
                        for (List ratings : combinedResponses.values()) {
                            for (Map rating : ratings) {
                                String compositeKey = (String)rating.get("id");
                                Double ratingScore = ((Number)rating.get("rating_score")).doubleValue();
                                String docId = ParserUtils.getDocIdFromCompositeKey(compositeKey);
                                processedRatings.put(docId, ratingScore.toString());
                                this.this$0.updateJudgmentCache(compositeKey, queryTextWithReference, contextFields, ratingScore.toString(), modelId);
                            }
                        }
                        listener.onResponse((Object)processedRatings);
                    }
                }
                catch (Exception e) {
                    this.handleProcessingError(e, chunkResult.isLastChunk());
                }
            }

            public void onFailure(Exception e) {
                this.handleProcessingError(e, true);
            }

            private void handleProcessingError(Exception e, boolean isLastChunk) {
                if (!hasFailure.getAndSet(true)) {
                    log.error("Failed to process chunk response", (Throwable)e);
                    listener.onFailure((Exception)((Object)new SearchRelevanceException("Failed to process chunk response", e, RestStatus.INTERNAL_SERVER_ERROR)));
                }
            }
        });
    }

    private void updateJudgmentCache(String compositeKey, String queryText, List<String> contextFields, String rating, String modelId) {
        try {
            JudgmentCache judgmentCache = new JudgmentCache(ParserUtils.generateUniqueId(queryText, compositeKey, contextFields), TimeUtils.getTimestamp(), queryText, compositeKey, contextFields, rating, modelId);
            StepListener createIndexStep = new StepListener();
            this.judgmentCacheDao.createIndexIfAbsent((StepListener<Void>)createIndexStep);
            createIndexStep.whenComplete(v -> this.judgmentCacheDao.upsertJudgmentCache(judgmentCache, ActionListener.wrap(response -> log.debug("Successfully processed judgment cache for queryText: {} and compositeKey: {}, contextFields: {}", (Object)queryText, (Object)compositeKey, (Object)contextFields), e -> log.warn("Failed to process judgment cache for queryText: {} and compositeKey: {}, contextFields: {} - continuing without cache", (Object)queryText, (Object)compositeKey, (Object)contextFields))), e -> log.warn("Failed to create judgment cache index for queryText: {} and compositeKey: {}, contextFields: {} - continuing without cache", (Object)queryText, (Object)compositeKey, (Object)contextFields));
        }
        catch (Exception e2) {
            log.warn("Cache operation failed for queryText: {} - continuing without cache", (Object)queryText);
        }
    }

    private void logFailedChunks(ChunkResult chunkResult) {
        chunkResult.getFailedChunks().forEach((index, error) -> log.warn("Chunk {} failed: {}", index, error));
    }

    private String getContextSource(SearchHit hit, List<String> contextFields) {
        try {
            if (contextFields != null && !contextFields.isEmpty()) {
                HashMap filteredSource = new HashMap();
                Map sourceAsMap = hit.getSourceAsMap();
                for (String field : contextFields) {
                    if (!sourceAsMap.containsKey(field)) continue;
                    filteredSource.put(field, sourceAsMap.get(field));
                }
                return OBJECT_MAPPER.writeValueAsString(filteredSource);
            }
            return hit.getSourceAsString();
        }
        catch (JsonProcessingException e) {
            log.error("Failed to process context source for hit: {}", (Object)hit.getId(), (Object)e);
            throw new RuntimeException("Failed to process context source", e);
        }
    }
}

