/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.action.search;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.TopDocs;
import org.opensearch.action.search.ArraySearchPhaseResults;
import org.opensearch.action.search.SearchPhaseController;
import org.opensearch.action.search.SearchProgressListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchShard;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.tasks.TaskCancelledException;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.query.QuerySearchResult;

public class QueryPhaseResultConsumer
extends ArraySearchPhaseResults<SearchPhaseResult>
implements Releasable {
    private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class);
    private final Executor executor;
    private final CircuitBreaker circuitBreaker;
    private final SearchPhaseController controller;
    private final SearchProgressListener progressListener;
    private final InternalAggregation.ReduceContextBuilder aggReduceContextBuilder;
    private final NamedWriteableRegistry namedWriteableRegistry;
    private final int topNSize;
    private final boolean hasTopDocs;
    private final boolean hasAggs;
    private final boolean performFinalReduce;
    final PendingReduces pendingReduces;
    private final Consumer<Exception> cancelTaskOnFailure;
    private final BooleanSupplier isTaskCancelled;

    public QueryPhaseResultConsumer(SearchRequest request, Executor executor, CircuitBreaker circuitBreaker, SearchPhaseController controller, SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, int expectedResultSize, Consumer<Exception> cancelTaskOnFailure) {
        this(request, executor, circuitBreaker, controller, progressListener, namedWriteableRegistry, expectedResultSize, cancelTaskOnFailure, () -> false);
    }

    public QueryPhaseResultConsumer(SearchRequest request, Executor executor, CircuitBreaker circuitBreaker, SearchPhaseController controller, SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, int expectedResultSize, Consumer<Exception> cancelTaskOnFailure, BooleanSupplier isTaskCancelled) {
        super(expectedResultSize);
        this.executor = executor;
        this.circuitBreaker = circuitBreaker;
        this.controller = controller;
        this.progressListener = progressListener;
        this.aggReduceContextBuilder = controller.getReduceContext(request);
        this.namedWriteableRegistry = namedWriteableRegistry;
        this.topNSize = SearchPhaseController.getTopDocsSize(request);
        this.performFinalReduce = request.isFinalReduce();
        this.cancelTaskOnFailure = cancelTaskOnFailure;
        SearchSourceBuilder source = request.source();
        this.hasTopDocs = source == null || source.size() != 0;
        this.hasAggs = source != null && source.aggregations() != null;
        int batchReduceSize = this.getBatchReduceSize(request.getBatchedReduceSize(), expectedResultSize);
        this.pendingReduces = new PendingReduces(batchReduceSize, request.resolveTrackTotalHitsUpTo());
        this.isTaskCancelled = isTaskCancelled;
    }

    int getBatchReduceSize(int requestBatchedReduceSize, int minBatchReduceSize) {
        return this.hasAggs || this.hasTopDocs ? Math.min(requestBatchedReduceSize, minBatchReduceSize) : minBatchReduceSize;
    }

    public void close() {
        Releasables.close((Releasable)this.pendingReduces);
    }

    @Override
    public void consumeResult(SearchPhaseResult result, Runnable next) {
        super.consumeResult(result, () -> {});
        QuerySearchResult querySearchResult = result.queryResult();
        this.progressListener.notifyQueryResult(querySearchResult.getShardIndex());
        this.pendingReduces.consume(querySearchResult, next);
    }

    @Override
    public SearchPhaseController.ReducedQueryPhase reduce() throws Exception {
        if (this.pendingReduces.hasPendingReduceTask()) {
            throw new AssertionError((Object)"partial reduce in-flight");
        }
        this.checkCancellation();
        if (this.pendingReduces.hasFailure()) {
            throw this.pendingReduces.failure.get();
        }
        this.pendingReduces.sortBuffer();
        SearchPhaseController.TopDocsStats topDocsStats = this.pendingReduces.consumeTopDocsStats();
        List<TopDocs> topDocsList = this.pendingReduces.consumeTopDocs();
        List<InternalAggregations> aggsList = this.pendingReduces.consumeAggs();
        long breakerSize = this.pendingReduces.circuitBreakerBytes;
        if (this.hasAggs) {
            breakerSize = this.pendingReduces.addEstimateAndMaybeBreak(this.pendingReduces.estimateRamBytesUsedForReduce(breakerSize));
        }
        SearchPhaseController.ReducedQueryPhase reducePhase = this.controller.reducedQueryPhase(this.results.asList(), aggsList, topDocsList, topDocsStats, this.pendingReduces.numReducePhases, false, this.aggReduceContextBuilder, this.performFinalReduce);
        if (this.hasAggs) {
            long finalSize = reducePhase.aggregations.getSerializedSize() - breakerSize;
            this.pendingReduces.addWithoutBreaking(finalSize);
            logger.trace("aggs final reduction [{}] max [{}]", (Object)this.pendingReduces.aggsCurrentBufferSize, (Object)this.pendingReduces.maxAggsCurrentBufferSize);
        }
        this.progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(this.results.asList()), reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases);
        return reducePhase;
    }

    private ReduceResult partialReduce(QuerySearchResult[] toConsume, List<SearchShard> emptyResults, SearchPhaseController.TopDocsStats topDocsStats, ReduceResult lastReduceResult, int numReducePhases) {
        InternalAggregations newAggs;
        TopDocs newTopDocs;
        this.checkCancellation();
        if (this.pendingReduces.hasFailure()) {
            return lastReduceResult;
        }
        Arrays.sort(toConsume, Comparator.comparingInt(SearchPhaseResult::getShardIndex));
        for (QuerySearchResult result : toConsume) {
            topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
        }
        if (this.hasTopDocs) {
            ArrayList<TopDocs> topDocsList = new ArrayList<TopDocs>();
            if (lastReduceResult != null) {
                topDocsList.add(lastReduceResult.reducedTopDocs);
            }
            for (QuerySearchResult result : toConsume) {
                TopDocsAndMaxScore topDocs = result.consumeTopDocs();
                SearchPhaseController.setShardIndex(topDocs.topDocs, result.getShardIndex());
                topDocsList.add(topDocs.topDocs);
            }
            newTopDocs = SearchPhaseController.mergeTopDocs(topDocsList, this.topNSize, 0);
        } else {
            newTopDocs = null;
        }
        if (this.hasAggs) {
            ArrayList<InternalAggregations> aggsList = new ArrayList<InternalAggregations>();
            if (lastReduceResult != null) {
                aggsList.add(lastReduceResult.reducedAggs);
            }
            for (QuerySearchResult result : toConsume) {
                aggsList.add(result.consumeAggs().expand());
            }
            newAggs = InternalAggregations.topLevelReduce(aggsList, this.aggReduceContextBuilder.forPartialReduction());
        } else {
            newAggs = null;
        }
        ArrayList<SearchShard> processedShards = new ArrayList<SearchShard>(emptyResults);
        if (lastReduceResult != null) {
            processedShards.addAll(lastReduceResult.processedShards);
        }
        for (QuerySearchResult result : toConsume) {
            SearchShardTarget target = result.getSearchShardTarget();
            processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
        }
        this.progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
        long serializedSize = this.hasAggs ? newAggs.getSerializedSize() : 0L;
        return new ReduceResult(processedShards, newTopDocs, newAggs, this.hasAggs ? serializedSize : 0L);
    }

    private void checkCancellation() {
        if (this.isTaskCancelled.getAsBoolean()) {
            this.pendingReduces.onFailure((Exception)new TaskCancelledException("request has been terminated"));
        }
    }

    public int getNumReducePhases() {
        return this.pendingReduces.numReducePhases;
    }

    class PendingReduces
    implements Releasable {
        private final int batchReduceSize;
        private final List<QuerySearchResult> buffer = new ArrayList<QuerySearchResult>();
        private final List<SearchShard> emptyResults = new ArrayList<SearchShard>();
        private volatile long circuitBreakerBytes;
        private volatile long aggsCurrentBufferSize;
        private volatile long maxAggsCurrentBufferSize = 0L;
        private final ArrayDeque<ReduceTask> queue = new ArrayDeque();
        private final AtomicReference<ReduceTask> runningTask = new AtomicReference();
        private final AtomicReference<Exception> failure = new AtomicReference();
        private final SearchPhaseController.TopDocsStats topDocsStats;
        private volatile ReduceResult reduceResult;
        private volatile boolean hasPartialReduce;
        private volatile int numReducePhases;

        PendingReduces(int batchReduceSize, int trackTotalHitsUpTo) {
            this.batchReduceSize = batchReduceSize;
            this.topDocsStats = new SearchPhaseController.TopDocsStats(trackTotalHitsUpTo);
        }

        public synchronized void close() {
            assert (!this.hasPendingReduceTask()) : "cannot close with partial reduce in-flight";
            if (this.hasFailure()) {
                assert (this.circuitBreakerBytes == 0L);
                return;
            }
            assert (this.circuitBreakerBytes >= 0L);
            QueryPhaseResultConsumer.this.circuitBreaker.addWithoutBreaking(-this.circuitBreakerBytes);
            this.circuitBreakerBytes = 0L;
        }

        private boolean hasFailure() {
            return this.failure.get() != null;
        }

        private boolean hasPendingReduceTask() {
            return !this.queue.isEmpty() || this.runningTask.get() != null;
        }

        private void sortBuffer() {
            if (this.buffer.size() > 0) {
                Collections.sort(this.buffer, Comparator.comparingInt(SearchPhaseResult::getShardIndex));
            }
        }

        private synchronized long addWithoutBreaking(long size) {
            if (this.hasFailure()) {
                return this.circuitBreakerBytes;
            }
            QueryPhaseResultConsumer.this.circuitBreaker.addWithoutBreaking(size);
            this.circuitBreakerBytes += size;
            this.maxAggsCurrentBufferSize = Math.max(this.maxAggsCurrentBufferSize, this.circuitBreakerBytes);
            return this.circuitBreakerBytes;
        }

        private synchronized long addEstimateAndMaybeBreak(long estimatedSize) {
            if (this.hasFailure()) {
                return this.circuitBreakerBytes;
            }
            QueryPhaseResultConsumer.this.circuitBreaker.addEstimateBytesAndMaybeBreak(estimatedSize, "<reduce_aggs>");
            this.circuitBreakerBytes += estimatedSize;
            this.maxAggsCurrentBufferSize = Math.max(this.maxAggsCurrentBufferSize, this.circuitBreakerBytes);
            return this.circuitBreakerBytes;
        }

        private synchronized void resetCircuitBreaker() {
            if (this.circuitBreakerBytes > 0L) {
                QueryPhaseResultConsumer.this.circuitBreaker.addWithoutBreaking(-this.circuitBreakerBytes);
                this.circuitBreakerBytes = 0L;
            }
        }

        private long ramBytesUsedQueryResult(QuerySearchResult result) {
            if (!QueryPhaseResultConsumer.this.hasAggs) {
                return 0L;
            }
            return result.aggregations().asSerialized((Writeable.Reader<InternalAggregations>)((Writeable.Reader)InternalAggregations::readFrom), QueryPhaseResultConsumer.this.namedWriteableRegistry).ramBytesUsed();
        }

        private long estimateRamBytesUsedForReduce(long size) {
            return Math.round(0.5 * (double)size);
        }

        void consume(QuerySearchResult result, Runnable callback) {
            QueryPhaseResultConsumer.this.checkCancellation();
            if (this.consumeResult(result, callback)) {
                callback.run();
            }
        }

        private synchronized boolean consumeResult(QuerySearchResult result, Runnable callback) {
            int size;
            if (this.hasFailure()) {
                result.consumeAll();
                return true;
            }
            if (result.isNull()) {
                SearchShardTarget target = result.getSearchShardTarget();
                this.emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
                return true;
            }
            if (QueryPhaseResultConsumer.this.hasAggs) {
                long aggsSize = this.ramBytesUsedQueryResult(result);
                try {
                    this.addEstimateAndMaybeBreak(aggsSize);
                    this.aggsCurrentBufferSize += aggsSize;
                }
                catch (CircuitBreakingException e) {
                    this.onFailure((Exception)((Object)e));
                    return true;
                }
            }
            if ((size = this.buffer.size() + (this.hasPartialReduce ? 1 : 0)) >= this.batchReduceSize) {
                this.hasPartialReduce = true;
                QuerySearchResult[] clone = (QuerySearchResult[])this.buffer.toArray(QuerySearchResult[]::new);
                ReduceTask task = new ReduceTask(clone, this.aggsCurrentBufferSize, new ArrayList<SearchShard>(this.emptyResults), callback);
                this.aggsCurrentBufferSize = 0L;
                this.buffer.clear();
                this.emptyResults.clear();
                this.queue.add(task);
                this.tryExecuteNext();
                this.buffer.add(result);
                return false;
            }
            this.buffer.add(result);
            return true;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void tryExecuteNext() {
            ReduceTask task;
            PendingReduces pendingReduces = this;
            synchronized (pendingReduces) {
                if (this.hasFailure()) {
                    return;
                }
                if (this.queue.isEmpty() || this.runningTask.get() != null) {
                    return;
                }
                task = this.queue.poll();
                this.runningTask.compareAndSet(null, task);
            }
            QueryPhaseResultConsumer.this.executor.execute(new AbstractRunnable(this){
                final /* synthetic */ PendingReduces this$1;
                {
                    this.this$1 = this$1;
                }

                @Override
                protected void doRun() {
                    ReduceResult newReduceResult;
                    long estimateRamBytesUsedForReduce;
                    ReduceResult thisReduceResult = this.this$1.reduceResult;
                    long estimatedTotalSize = (thisReduceResult != null ? thisReduceResult.estimatedSize : 0L) + task.aggsBufferSize;
                    try {
                        QuerySearchResult[] toConsume = task.consumeBuffer();
                        if (toConsume == null) {
                            this.this$1.onAfterReduce(task, null, 0L);
                            return;
                        }
                        estimateRamBytesUsedForReduce = this.this$1.estimateRamBytesUsedForReduce(estimatedTotalSize);
                        this.this$1.addEstimateAndMaybeBreak(estimateRamBytesUsedForReduce);
                        ++this.this$1.numReducePhases;
                        newReduceResult = this.this$1.QueryPhaseResultConsumer.this.partialReduce(toConsume, task.emptyResults, this.this$1.topDocsStats, thisReduceResult, this.this$1.numReducePhases);
                    }
                    catch (Exception t) {
                        this.this$1.onFailure(t);
                        return;
                    }
                    this.this$1.onAfterReduce(task, newReduceResult, estimatedTotalSize += estimateRamBytesUsedForReduce);
                }

                @Override
                public void onFailure(Exception exc) {
                    this.this$1.onFailure(exc);
                }
            });
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void onAfterReduce(ReduceTask task, ReduceResult newResult, long estimatedSize) {
            if (newResult != null) {
                PendingReduces pendingReduces = this;
                synchronized (pendingReduces) {
                    if (this.hasFailure()) {
                        return;
                    }
                    this.runningTask.compareAndSet(task, null);
                    this.reduceResult = newResult;
                    if (QueryPhaseResultConsumer.this.hasAggs) {
                        long newSize = this.reduceResult.estimatedSize - estimatedSize;
                        this.addWithoutBreaking(newSize);
                        logger.trace("aggs partial reduction [{}->{}] max [{}]", (Object)estimatedSize, (Object)this.reduceResult.estimatedSize, (Object)this.maxAggsCurrentBufferSize);
                    }
                }
            }
            task.consumeListener();
            QueryPhaseResultConsumer.this.executor.execute(this::tryExecuteNext);
        }

        private synchronized void onFailure(Exception exc) {
            if (this.hasFailure()) {
                assert (this.circuitBreakerBytes == 0L);
                return;
            }
            assert (this.circuitBreakerBytes >= 0L);
            this.resetCircuitBreaker();
            this.failure.compareAndSet(null, exc);
            this.clearReduceTaskQueue();
            QueryPhaseResultConsumer.this.cancelTaskOnFailure.accept(exc);
        }

        private synchronized void clearReduceTaskQueue() {
            ReduceTask task = this.runningTask.get();
            this.runningTask.compareAndSet(task, null);
            ArrayList<ReduceTask> toCancels = new ArrayList<ReduceTask>();
            if (task != null) {
                toCancels.add(task);
            }
            toCancels.addAll(this.queue);
            this.queue.clear();
            this.reduceResult = null;
            for (ReduceTask toCancel : toCancels) {
                toCancel.cancel();
            }
        }

        private synchronized SearchPhaseController.TopDocsStats consumeTopDocsStats() {
            for (QuerySearchResult result : this.buffer) {
                this.topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly());
            }
            return this.topDocsStats;
        }

        private synchronized List<TopDocs> consumeTopDocs() {
            if (!QueryPhaseResultConsumer.this.hasTopDocs) {
                return Collections.emptyList();
            }
            ArrayList<TopDocs> topDocsList = new ArrayList<TopDocs>();
            if (this.reduceResult != null) {
                topDocsList.add(this.reduceResult.reducedTopDocs);
            }
            for (QuerySearchResult result : this.buffer) {
                TopDocsAndMaxScore topDocs = result.consumeTopDocs();
                SearchPhaseController.setShardIndex(topDocs.topDocs, result.getShardIndex());
                topDocsList.add(topDocs.topDocs);
            }
            return topDocsList;
        }

        private synchronized List<InternalAggregations> consumeAggs() {
            if (!QueryPhaseResultConsumer.this.hasAggs) {
                return Collections.emptyList();
            }
            ArrayList<InternalAggregations> aggsList = new ArrayList<InternalAggregations>();
            if (this.reduceResult != null) {
                aggsList.add(this.reduceResult.reducedAggs);
            }
            for (QuerySearchResult result : this.buffer) {
                aggsList.add(result.consumeAggs().expand());
            }
            return aggsList;
        }
    }

    private record ReduceResult(List<SearchShard> processedShards, TopDocs reducedTopDocs, InternalAggregations reducedAggs, long estimatedSize) {
    }

    private static class ReduceTask {
        private final List<SearchShard> emptyResults;
        private QuerySearchResult[] buffer;
        private final long aggsBufferSize;
        private Runnable next;

        private ReduceTask(QuerySearchResult[] buffer, long aggsBufferSize, List<SearchShard> emptyResults, Runnable next) {
            this.buffer = buffer;
            this.aggsBufferSize = aggsBufferSize;
            this.emptyResults = emptyResults;
            this.next = next;
        }

        public synchronized QuerySearchResult[] consumeBuffer() {
            QuerySearchResult[] toRet = this.buffer;
            this.buffer = null;
            return toRet;
        }

        public void consumeListener() {
            if (this.next != null) {
                this.next.run();
                this.next = null;
            }
        }

        public synchronized void cancel() {
            this.consumeBuffer();
            this.consumeListener();
        }
    }
}

