/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.prometheus.storage.implementor;

import java.util.List;
import java.util.Optional;
import lombok.Generated;
import org.apache.commons.math3.util.Pair;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.planner.DefaultImplementor;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalRelation;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.prometheus.planner.logical.PrometheusLogicalMetricAgg;
import org.opensearch.sql.prometheus.planner.logical.PrometheusLogicalMetricScan;
import org.opensearch.sql.prometheus.storage.PrometheusMetricScan;
import org.opensearch.sql.prometheus.storage.PrometheusMetricTable;
import org.opensearch.sql.prometheus.storage.model.PrometheusResponseFieldNames;
import org.opensearch.sql.prometheus.storage.querybuilder.AggregationQueryBuilder;
import org.opensearch.sql.prometheus.storage.querybuilder.SeriesSelectionQueryBuilder;
import org.opensearch.sql.prometheus.storage.querybuilder.StepParameterResolver;
import org.opensearch.sql.prometheus.storage.querybuilder.TimeRangeParametersResolver;

public class PrometheusDefaultImplementor
extends DefaultImplementor<PrometheusMetricScan> {
    @Override
    public PhysicalPlan visitNode(LogicalPlan plan, PrometheusMetricScan context) {
        if (plan instanceof PrometheusLogicalMetricScan) {
            return this.visitIndexScan((PrometheusLogicalMetricScan)plan, context);
        }
        if (plan instanceof PrometheusLogicalMetricAgg) {
            return this.visitIndexAggregation((PrometheusLogicalMetricAgg)plan, context);
        }
        throw new IllegalStateException(StringUtils.format("unexpected plan node type %s", plan.getClass()));
    }

    public PhysicalPlan visitIndexScan(PrometheusLogicalMetricScan node, PrometheusMetricScan context) {
        String query = SeriesSelectionQueryBuilder.build(node.getMetricName(), node.getFilter());
        context.getRequest().setPromQl(query);
        this.setTimeRangeParameters(node.getFilter(), context);
        context.getRequest().setStep(StepParameterResolver.resolve(context.getRequest().getStartTime(), context.getRequest().getEndTime(), null));
        return context;
    }

    public PhysicalPlan visitIndexAggregation(PrometheusLogicalMetricAgg node, PrometheusMetricScan context) {
        this.setTimeRangeParameters(node.getFilter(), context);
        context.getRequest().setStep(StepParameterResolver.resolve(context.getRequest().getStartTime(), context.getRequest().getEndTime(), node.getGroupByList()));
        String step = context.getRequest().getStep();
        String seriesSelectionQuery = SeriesSelectionQueryBuilder.build(node.getMetricName(), node.getFilter());
        String aggregateQuery = AggregationQueryBuilder.build(node.getAggregatorList(), node.getGroupByList());
        String finalQuery = String.format(aggregateQuery, seriesSelectionQuery + "[" + step + "]");
        context.getRequest().setPromQl(finalQuery);
        this.setPrometheusResponseFieldNames(node, context);
        return context;
    }

    @Override
    public PhysicalPlan visitRelation(LogicalRelation node, PrometheusMetricScan context) {
        PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable)node.getTable();
        String query = SeriesSelectionQueryBuilder.build(node.getRelationName(), null);
        context.getRequest().setPromQl(query);
        this.setTimeRangeParameters(null, context);
        context.getRequest().setStep(StepParameterResolver.resolve(context.getRequest().getStartTime(), context.getRequest().getEndTime(), null));
        return context;
    }

    private void setTimeRangeParameters(Expression filter, PrometheusMetricScan context) {
        TimeRangeParametersResolver timeRangeParametersResolver = new TimeRangeParametersResolver();
        Pair<Long, Long> timeRange = timeRangeParametersResolver.resolve(filter);
        context.getRequest().setStartTime((Long)timeRange.getFirst());
        context.getRequest().setEndTime((Long)timeRange.getSecond());
    }

    private void setPrometheusResponseFieldNames(PrometheusLogicalMetricAgg node, PrometheusMetricScan context) {
        Optional<NamedExpression> spanExpression = this.getSpanExpression(node.getGroupByList());
        if (spanExpression.isEmpty()) {
            throw new RuntimeException("Prometheus Catalog doesn't support aggregations without span expression");
        }
        PrometheusResponseFieldNames prometheusResponseFieldNames = new PrometheusResponseFieldNames();
        prometheusResponseFieldNames.setValueFieldName(node.getAggregatorList().get(0).getName());
        prometheusResponseFieldNames.setValueType(node.getAggregatorList().get(0).type());
        prometheusResponseFieldNames.setTimestampFieldName(spanExpression.get().getName());
        prometheusResponseFieldNames.setGroupByList(node.getGroupByList());
        context.setPrometheusResponseFieldNames(prometheusResponseFieldNames);
    }

    private Optional<NamedExpression> getSpanExpression(List<NamedExpression> namedExpressionList) {
        if (namedExpressionList == null) {
            return Optional.empty();
        }
        return namedExpressionList.stream().filter(expression -> expression.getDelegated() instanceof SpanExpression).findFirst();
    }

    @Generated
    public PrometheusDefaultImplementor() {
    }
}

