/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.flowframework.workflow;

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchParseException;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Nullable;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLMemorySpec;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;

public class RegisterAgentStep
implements WorkflowStep {
    private static final Logger logger = LogManager.getLogger(RegisterAgentStep.class);
    private MachineLearningNodeClient mlClient;
    private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
    public static final String NAME = "register_agent";
    public static final String MODEL_ID = "model_id";

    public RegisterAgentStep(MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler) {
        this.mlClient = mlClient;
        this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler;
    }

    @Override
    public PlainActionFuture<WorkflowData> execute(final String currentNodeId, final WorkflowData currentNodeInputs, Map<String, WorkflowData> outputs, Map<String, String> previousNodeInputs, Map<String, String> params, final String tenantId) {
        String workflowId = currentNodeInputs.getWorkflowId();
        final PlainActionFuture registerAgentModelFuture = PlainActionFuture.newFuture();
        ActionListener<MLRegisterAgentResponse> actionListener = new ActionListener<MLRegisterAgentResponse>(this){
            final /* synthetic */ RegisterAgentStep this$0;
            {
                this.this$0 = this$0;
            }

            public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) {
                logger.info("Agent registration successful for the agent {}", (Object)mlRegisterAgentResponse.getAgentId());
                this.this$0.flowFrameworkIndicesHandler.addResourceToStateIndex(currentNodeInputs, currentNodeId, this.this$0.getName(), mlRegisterAgentResponse.getAgentId(), tenantId, (ActionListener<WorkflowData>)registerAgentModelFuture);
            }

            public void onFailure(Exception ex) {
                Exception e = WorkflowStepException.getSafeException(ex);
                String errorMessage = e == null ? "Failed to register the agent" : e.getMessage();
                logger.error(errorMessage, (Throwable)e);
                registerAgentModelFuture.onFailure((Exception)((Object)new WorkflowStepException(errorMessage, ExceptionsHelper.status((Throwable)e))));
            }
        };
        Set<String> requiredKeys = Set.of("name", "type");
        Set<String> optionalKeys = Set.of("description", "llm", "tools", "tools_order", "parameters", "memory", "created_time", "last_updated_time", "app_type");
        try {
            Instant createdTime;
            Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(requiredKeys, optionalKeys, currentNodeInputs, outputs, previousNodeInputs, params);
            String type = (String)inputs.get("type");
            String name = (String)inputs.get("name");
            String description = (String)inputs.get("description");
            String llmField = (String)inputs.get("llm");
            String[] toolsOrder = (String[])inputs.get("tools_order");
            List<MLToolSpec> toolsList = this.getTools(toolsOrder, previousNodeInputs, outputs);
            Object parameters = inputs.get("parameters");
            Map parametersMap = parameters == null ? Collections.emptyMap() : ParseUtils.getStringToStringMap(parameters, "parameters");
            MLMemorySpec memory = this.getMLMemorySpec(inputs.get("memory"));
            Instant lastUpdateTime = createdTime = Instant.now();
            String appType = (String)inputs.get("app_type");
            String llmModelId = null;
            HashMap<String, String> llmParameters = new HashMap<String, String>();
            if (llmField != null) {
                try {
                    Map<String, Object> llmFieldMap = this.getParseFieldMap(llmField);
                    llmModelId = (String)llmFieldMap.get(MODEL_ID);
                    Object llmParams = llmFieldMap.get("parameters");
                    if (llmParams != null) {
                        this.validateLLMParametersMap(llmParams);
                        Map llmParamsMap = (Map)llmParams;
                        llmParameters.putAll(llmParamsMap);
                    }
                }
                catch (IllegalArgumentException ex) {
                    String errorMessage = "Failed to parse llm field: " + ex.getMessage();
                    logger.error(errorMessage, (Throwable)ex);
                    registerAgentModelFuture.onFailure((Exception)((Object)new WorkflowStepException(ex.getMessage(), RestStatus.BAD_REQUEST)));
                    return registerAgentModelFuture;
                }
            }
            if (llmModelId == null) {
                llmModelId = this.getLlmModelId(previousNodeInputs, outputs);
            }
            LLMSpec llmSpec = this.getLLMSpec(llmModelId, llmParameters, workflowId, currentNodeId);
            MLAgent.MLAgentBuilder builder = MLAgent.builder().name(name);
            if (description != null) {
                builder.description(description);
            }
            if (memory != null) {
                builder.memory(memory);
            }
            if (llmSpec != null) {
                builder.llm(llmSpec);
            }
            builder.type(type).tools(toolsList).parameters(parametersMap).createdTime(createdTime).lastUpdateTime(lastUpdateTime).appType(appType).tenantId(tenantId);
            MLAgent mlAgent = builder.build();
            this.mlClient.registerAgent(mlAgent, (ActionListener)actionListener);
        }
        catch (FlowFrameworkException e) {
            registerAgentModelFuture.onFailure((Exception)((Object)e));
        }
        return registerAgentModelFuture;
    }

    @Override
    public String getName() {
        return NAME;
    }

    private List<MLToolSpec> getTools(@Nullable String[] tools, Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
        ArrayList<MLToolSpec> mlToolSpecList = new ArrayList<MLToolSpec>();
        List previousNodes = previousNodeInputs.entrySet().stream().filter(e -> "tools".equals(e.getValue())).map(Map.Entry::getKey).collect(Collectors.toList());
        List<Object> sortedNodes = tools == null ? new ArrayList() : Arrays.asList(tools);
        previousNodes.removeAll(sortedNodes);
        sortedNodes.addAll(previousNodes);
        sortedNodes.forEach(node -> {
            WorkflowData previousNodeOutput = (WorkflowData)outputs.get(node);
            if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey("tools")) {
                MLToolSpec mlToolSpec = (MLToolSpec)previousNodeOutput.getContent().get("tools");
                logger.info("Tool added {}", (Object)mlToolSpec.getType());
                mlToolSpecList.add(mlToolSpec);
            }
        });
        return mlToolSpecList;
    }

    private String getLlmModelId(Map<String, String> previousNodeInputs, Map<String, WorkflowData> outputs) {
        Object modelId;
        WorkflowData previousNodeOutput;
        Optional<String> previousNode = previousNodeInputs.entrySet().stream().filter(e -> MODEL_ID.equals(e.getValue())).map(Map.Entry::getKey).findFirst();
        if (previousNode.isPresent() && (previousNodeOutput = outputs.get(previousNode.get())) != null && (modelId = previousNodeOutput.getContent().getOrDefault(MODEL_ID, previousNodeOutput.getContent().get(MODEL_ID))) != null) {
            return modelId.toString();
        }
        return null;
    }

    private LLMSpec getLLMSpec(String llmModelId, Map<String, String> llmParameters, String workflowId, String currentNodeId) {
        if (llmModelId == null) {
            return null;
        }
        LLMSpec.LLMSpecBuilder builder = LLMSpec.builder();
        builder.modelId(llmModelId);
        if (llmParameters != null) {
            builder.parameters(llmParameters);
        }
        return builder.build();
    }

    private MLMemorySpec getMLMemorySpec(Object mlMemory) {
        if (mlMemory == null) {
            return null;
        }
        Map map = (Map)mlMemory;
        String type = null;
        String sessionId = null;
        Integer windowSize = null;
        type = (String)map.get("type");
        if (type == null) {
            throw new IllegalArgumentException("agent name is null");
        }
        sessionId = (String)map.get("session_id");
        windowSize = (Integer)map.get("window_size");
        MLMemorySpec.MLMemorySpecBuilder builder = MLMemorySpec.builder();
        builder.type(type);
        if (sessionId != null) {
            builder.sessionId(sessionId);
        }
        if (windowSize != null) {
            builder.windowSize(windowSize);
        }
        return builder.build();
    }

    private Map<String, Object> getParseFieldMap(String llmFieldMapString) throws OpenSearchParseException {
        BytesArray llmFieldBytes = new BytesArray(llmFieldMapString.getBytes(StandardCharsets.UTF_8));
        return (Map)XContentHelper.convertToMap((BytesReference)llmFieldBytes, (boolean)false, (MediaType)MediaTypeRegistry.JSON).v2();
    }

    private void validateLLMParametersMap(Object llmParams) {
        String errorMessage = "llm field [parameters] must be a string to string map";
        if (!(llmParams instanceof Map)) {
            throw new IllegalArgumentException(errorMessage);
        }
        Map llmParamsMap = (Map)llmParams;
        for (Map.Entry entry : llmParamsMap.entrySet()) {
            if (entry.getValue() instanceof String) continue;
            throw new IllegalArgumentException(errorMessage);
        }
    }
}

