/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.rules;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.OpenSearchRuleConfig;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.opensearch.planner.rules.ImmutableDedupPushdownRule;
import org.opensearch.sql.opensearch.planner.rules.InterruptibleRelRule;
import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan;
import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan;
import shaded.com.google.common.collect.ImmutableList;

@Value.Enclosing
public class DedupPushdownRule
extends InterruptibleRelRule<Config> {
    private static final Logger LOG = LogManager.getLogger();

    protected DedupPushdownRule(Config config) {
        super(config);
    }

    @Override
    protected void onMatchImpl(RelOptRuleCall call) {
        LogicalProject finalProject = (LogicalProject)call.rel(0);
        LogicalFilter numOfDedupFilter = (LogicalFilter)call.rel(1);
        LogicalProject projectWithWindow = (LogicalProject)call.rel(2);
        if (call.rels.length != 5) {
            throw new AssertionError((Object)String.format("The length of rels should be %s but got %s", this.operands.size(), call.rels.length));
        }
        CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(4);
        this.apply(call, finalProject, numOfDedupFilter, projectWithWindow, scan);
    }

    protected void apply(RelOptRuleCall call, LogicalProject finalProject, LogicalFilter numOfDedupFilter, LogicalProject projectWithWindow, CalciteLogicalIndexScan scan) {
        List<RexWindow> windows = PlanUtils.getRexWindowFromProject(projectWithWindow);
        if (windows.size() != 1) {
            return;
        }
        ImmutableList dedupColumns = windows.get((int)0).partitionKeys;
        if (dedupColumns.stream().filter(rex -> rex.isA(SqlKind.INPUT_REF)).anyMatch(rex -> rex.getType().getSqlTypeName() == SqlTypeName.MAP)) {
            LOG.debug("Cannot pushdown the dedup since the dedup fields contains MAP type");
            return;
        }
        if (projectWithWindow.getProjects().stream().filter(rex -> !rex.isA(SqlKind.ROW_NUMBER)).filter(Predicate.not(((List)dedupColumns)::contains)).anyMatch(rex -> !rex.isA(SqlKind.INPUT_REF))) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Cannot pushdown the dedup since the final outputs contain a column which is not included in table schema");
            }
            return;
        }
        List<RexNode> rexCallsExceptWindow = projectWithWindow.getProjects().stream().filter(rex -> !rex.isA(SqlKind.ROW_NUMBER)).filter(rex -> rex instanceof RexCall).toList();
        if (!rexCallsExceptWindow.isEmpty() && DedupPushdownRule.dedupColumnsContainRexCall(rexCallsExceptWindow, (List<RexNode>)dedupColumns)) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Cannot pushdown the dedup since the dedup columns contain RexCall");
            }
            return;
        }
        assert (numOfDedupFilter.getCondition().isA(SqlKind.LESS_THAN_OR_EQUAL));
        RexLiteral literal = (RexLiteral)((RexCall)numOfDedupFilter.getCondition()).getOperands().getLast();
        Integer dedupNumer = (Integer)literal.getValueAs(Integer.class);
        RelBuilder relBuilder = call.builder();
        relBuilder.push((RelNode)scan);
        ArrayList<RexNode> mergedRexList = new ArrayList<RexNode>();
        ArrayList<String> mergedFieldNames = new ArrayList<String>();
        ImmutableList builderFields = relBuilder.fields();
        List projectFields = projectWithWindow.getProjects();
        List builderFieldNames = relBuilder.peek().getRowType().getFieldNames();
        List projectFieldNames = projectWithWindow.getRowType().getFieldNames();
        for (RexNode field : builderFields) {
            mergedRexList.add(field);
            int projectIndex = projectFields.indexOf(field);
            if (projectIndex >= 0) {
                mergedFieldNames.add((String)projectFieldNames.get(projectIndex));
                continue;
            }
            mergedFieldNames.add((String)builderFieldNames.get(builderFields.indexOf(field)));
        }
        for (RexNode field : projectFields) {
            if (field.isA(SqlKind.ROW_NUMBER) || builderFields.contains(field)) continue;
            mergedRexList.add(field);
            mergedFieldNames.add(field.toString());
        }
        relBuilder.project(mergedRexList, mergedFieldNames, true);
        LogicalProject baseline = (LogicalProject)relBuilder.peek();
        Mapping mappingForDedupColumns = PlanUtils.mapping((List<RexNode>)dedupColumns, relBuilder.peek().getRowType());
        ArrayList<RexInputRef> reordered = new ArrayList<RexInputRef>(PlanUtils.getInputRefs((List<RexNode>)dedupColumns));
        baseline.getProjects().stream().filter(Predicate.not(((List)dedupColumns)::contains)).forEach(reordered::add);
        relBuilder.project(reordered);
        LogicalProject childProject = (LogicalProject)relBuilder.peek();
        List newDedupColumns = RexUtil.apply((Mappings.TargetMapping)mappingForDedupColumns, (Iterable)dedupColumns);
        relBuilder.aggregate(relBuilder.groupKey((Iterable)newDedupColumns), new RelBuilder.AggCall[]{relBuilder.literalAgg((Object)dedupNumer)});
        PlanUtils.addIgnoreNullBucketHintToAggregate(relBuilder);
        LogicalAggregate aggregate = (LogicalAggregate)relBuilder.build();
        CalciteLogicalIndexScan newScan = (CalciteLogicalIndexScan)scan.pushDownAggregate((Aggregate)aggregate, (Project)childProject);
        if (newScan != null) {
            call.transformTo((RelNode)newScan.copyWithNewSchema(finalProject.getRowType()));
        }
    }

    private static boolean dedupColumnsContainRexCall(List<RexNode> calls, List<RexNode> dedupColumns) {
        List dedupColumnIndicesFromCall = PlanUtils.getSelectColumns(calls).stream().distinct().toList();
        List dedupColumnsIndicesFromPartitionKeys = PlanUtils.getSelectColumns(dedupColumns).stream().distinct().toList();
        return dedupColumnsIndicesFromPartitionKeys.stream().anyMatch(dedupColumnIndicesFromCall::contains);
    }

    @Value.Immutable
    public static interface Config
    extends OpenSearchRuleConfig {
        public static final Config DEFAULT = ImmutableDedupPushdownRule.Config.builder().build().withDescription("Dedup-to-Aggregate").withOperandSupplier(b0 -> b0.operand(LogicalProject.class).predicate(Predicate.not(PlanUtils::containsRowNumberDedup)).oneInput(b1 -> b1.operand(LogicalFilter.class).predicate(Config::validDedupNumberChecker).oneInput(b2 -> b2.operand(LogicalProject.class).predicate(PlanUtils::containsRowNumberDedup).oneInput(b3 -> b3.operand(LogicalFilter.class).predicate(PlanUtils::mayBeFilterFromBucketNonNull).oneInput(b4 -> b4.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed).and(AbstractCalciteIndexScan::isProjectPushed)).noInputs())))));
        public static final Config DEDUP_EXPR = ImmutableDedupPushdownRule.Config.builder().build().withDescription("DedupWithExpression-to-Aggregate").withOperandSupplier(b0 -> b0.operand(LogicalProject.class).predicate(Predicate.not(PlanUtils::containsRowNumberDedup)).oneInput(b1 -> b1.operand(LogicalFilter.class).predicate(Config::validDedupNumberChecker).oneInput(b2 -> b2.operand(LogicalProject.class).predicate(PlanUtils::containsRowNumberDedup).oneInput(b3 -> b3.operand(LogicalFilter.class).predicate(Config::isNotNull).oneInput(b4 -> b4.operand(LogicalProject.class).predicate(PlanUtils::containsRexCall).oneInput(b5 -> b5.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed).and(AbstractCalciteIndexScan::isProjectPushed)).noInputs()))))));

        default public DedupPushdownRule toRule() {
            return new DedupPushdownRule(this);
        }

        private static boolean validDedupNumberChecker(LogicalFilter filter) {
            return filter.getCondition().isA(SqlKind.LESS_THAN_OR_EQUAL) && PlanUtils.containsRowNumberDedup((RelNode)filter);
        }

        private static boolean isNotNull(LogicalFilter filter) {
            return filter.getCondition().isA(SqlKind.IS_NOT_NULL);
        }
    }
}

