/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.prediction;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLCommonsSettings;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class TransportPredictionTaskAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private TransportService transportService;
    private MLModelCacheHelper modelCacheHelper;
    private Client client;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private MLModelManager mlModelManager;
    private ModelAccessControlHelper modelAccessControlHelper;
    private volatile boolean enableAutomaticDeployment;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLModelCacheHelper modelCacheHelper, MLPredictTaskRunner mlPredictTaskRunner, ClusterService clusterService, Client client, NamedXContentRegistry xContentRegistry, MLModelManager mlModelManager, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings) {
        super("cluster:admin/opensearch/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mlPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = modelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.mlModelManager = mlModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.enableAutomaticDeployment = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> {
            this.enableAutomaticDeployment = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
        final MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest((ActionRequest)request);
        final String modelId = mlPredictionTaskRequest.getModelId();
        User user = mlPredictionTaskRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            mlPredictionTaskRequest.setUser(user);
        }
        final User userInfo = user;
        try (final ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            final ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
            MLModel cachedMlModel = this.modelCacheHelper.getModelInfo(modelId);
            ActionListener<MLModel> modelActionListener = new ActionListener<MLModel>(){

                public void onResponse(MLModel mlModel) {
                    context.restore();
                    TransportPredictionTaskAction.this.modelCacheHelper.setModelInfo(modelId, mlModel);
                    FunctionName functionName = mlModel.getAlgorithm();
                    if (FunctionName.isDLModel((FunctionName)functionName) && !TransportPredictionTaskAction.this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
                        throw new IllegalStateException("Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.");
                    }
                    mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
                    TransportPredictionTaskAction.this.modelAccessControlHelper.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), TransportPredictionTaskAction.this.client, (ActionListener<Boolean>)ActionListener.wrap(access -> {
                        if (!access.booleanValue()) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("User Doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
                        } else if (TransportPredictionTaskAction.this.modelCacheHelper.getIsModelEnabled(modelId) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getIsModelEnabled(modelId).booleanValue()) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN, new Object[0]));
                        } else if (FunctionName.isDLModel((FunctionName)functionName)) {
                            if (TransportPredictionTaskAction.this.modelCacheHelper.getRateLimiter(modelId) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getRateLimiter(modelId).request()) {
                                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            } else if (userInfo != null && TransportPredictionTaskAction.this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()) != null && !TransportPredictionTaskAction.this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()).request()) {
                                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                            } else {
                                TransportPredictionTaskAction.this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
                                TransportPredictionTaskAction.this.executePredict(mlPredictionTaskRequest, (ActionListener<MLTaskResponse>)wrappedListener, modelId);
                            }
                        } else {
                            TransportPredictionTaskAction.this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
                            TransportPredictionTaskAction.this.executePredict(mlPredictionTaskRequest, (ActionListener<MLTaskResponse>)wrappedListener, modelId);
                        }
                    }, e -> {
                        log.error("Failed to Validate Access for ModelId " + modelId, (Throwable)e);
                        if (e instanceof OpenSearchStatusException) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode((int)((OpenSearchStatusException)e).status().getStatus()), new Object[0]));
                        } else if (e instanceof MLResourceNotFoundException) {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND, new Object[0]));
                        } else if (e instanceof CircuitBreakingException) {
                            wrappedListener.onFailure(e);
                        } else {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("Failed to Validate Access for ModelId " + modelId, RestStatus.FORBIDDEN, new Object[0]));
                        }
                    }));
                }

                public void onFailure(Exception e) {
                    log.error("Failed to find model " + modelId, (Throwable)e);
                    wrappedListener.onFailure(e);
                }
            };
            if (cachedMlModel != null) {
                modelActionListener.onResponse((Object)cachedMlModel);
            } else {
                this.mlModelManager.getModel(modelId, modelActionListener);
            }
        }
    }

    private void executePredict(MLPredictionTaskRequest mlPredictionTaskRequest, ActionListener<MLTaskResponse> wrappedListener, String modelId) {
        String requestId = mlPredictionTaskRequest.getRequestID();
        log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
        long startTime = System.nanoTime();
        FunctionName functionName = this.modelCacheHelper.getOptionalFunctionName(modelId).orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm());
        this.mlPredictTaskRunner.run(functionName, mlPredictionTaskRequest, this.transportService, (ActionListener<MLTaskResponse>)ActionListener.runAfter(wrappedListener, () -> {
            long endTime = System.nanoTime();
            double durationInMs = (double)(endTime - startTime) / 1000000.0;
            this.modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
            this.modelCacheHelper.refreshLastAccessTime(modelId);
            log.debug("completed predict request " + requestId + " for model " + modelId);
        }));
    }

    public void validateInputSchema(String modelId, MLInput mlInput) {
        if (this.modelCacheHelper.getModelInterface(modelId) != null && this.modelCacheHelper.getModelInterface(modelId).get("input") != null) {
            String inputSchemaString = this.modelCacheHelper.getModelInterface(modelId).get("input");
            try {
                MLNodeUtils.validateSchema(inputSchemaString, mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
            }
            catch (Exception e) {
                throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]);
            }
        }
    }
}

