/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.function_calling;

import com.jayway.jsonpath.DocumentContext;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.function_calling.BedrockMessage;
import org.opensearch.ml.engine.function_calling.FunctionCalling;
import org.opensearch.ml.engine.function_calling.LLMMessage;

public class BedrockConverseFunctionCalling
implements FunctionCalling {
    @Generated
    private static final Logger log = LogManager.getLogger(BedrockConverseFunctionCalling.class);
    public static final String FINISH_REASON_PATH = "$.stopReason";
    public static final String FINISH_REASON = "tool_use";
    public static final String CALL_PATH = "$.output.message.content[*].toolUse";
    public static final String NAME = "name";
    public static final String INPUT = "input";
    public static final String ID_PATH = "toolUseId";
    public static final String TOOL_ERROR = "tool_error";
    public static final String BEDROCK_CONVERSE_TOOL_TEMPLATE = "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}";

    @Override
    public void configure(Map<String, String> params) {
        if (!params.containsKey("no_escape_params")) {
            params.put("no_escape_params", "_chat_history,_tools,_interactions,tool_configs");
        }
        params.put("llm_response_filter", "$.output.message.content[0].text");
        params.put("tool_template", BEDROCK_CONVERSE_TOOL_TEMPLATE);
        params.put("tool_calls_path", CALL_PATH);
        params.put("tool_calls.tool_name", NAME);
        params.put("tool_calls.tool_input", INPUT);
        params.put("tool_calls.id_path", ID_PATH);
        params.put("tool_configs", ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}");
        params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
        params.put("interaction_template.tool_response", "{\"role\":\"user\",\"content\":[{\"toolResult\":{\"toolUseId\":\"${_interactions.tool_call_id}\",\"content\":[{\"text\":\"${_interactions.tool_response}\"}]}}]}");
        params.put("chat_history_template.user_question", "{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}");
        params.put("chat_history_template.ai_response", "{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}");
        params.put("llm_finish_reason_path", FINISH_REASON_PATH);
        params.put("llm_finish_reason_tool_use", FINISH_REASON);
    }

    @Override
    public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
        String llmFinishReason;
        ArrayList<Map<String, String>> output = new ArrayList<Map<String, String>>();
        Map<String, ?> dataAsMap = ((ModelTensor)((ModelTensors)tmpModelTensorOutput.getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap();
        String llmResponseExcludePath = parameters.get("llm_response_exclude_path");
        if (llmResponseExcludePath != null) {
            dataAsMap = AgentUtils.removeJsonPath(dataAsMap, llmResponseExcludePath, true);
        }
        if (!(llmFinishReason = (String)JsonPath.read(dataAsMap, (String)FINISH_REASON_PATH, (Predicate[])new Predicate[0])).contentEquals(FINISH_REASON)) {
            return output;
        }
        List toolCalls = (List)JsonPath.read(dataAsMap, (String)CALL_PATH, (Predicate[])new Predicate[0]);
        if (CollectionUtils.isEmpty((Collection)toolCalls)) {
            return output;
        }
        for (Object call : toolCalls) {
            String toolName = (String)JsonPath.read(call, (String)NAME, (Predicate[])new Predicate[0]);
            String toolInput = StringUtils.toJson((Object)JsonPath.read(call, (String)INPUT, (Predicate[])new Predicate[0]));
            String toolCallId = (String)JsonPath.read(call, (String)ID_PATH, (Predicate[])new Predicate[0]);
            output.add(Map.of("tool_name", toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
        }
        return output;
    }

    @Override
    public List<LLMMessage> supply(List<Map<String, Object>> toolResults) {
        BedrockMessage toolMessage = new BedrockMessage();
        for (Map<String, Object> toolResult : toolResults) {
            String toolUseId = (String)toolResult.get("tool_call_id");
            if (toolUseId == null) continue;
            ToolResult result = new ToolResult();
            result.setToolUseId(toolUseId);
            result.getContent().add(toolResult.get("tool_result"));
            if (toolResult.containsKey(TOOL_ERROR)) {
                result.setStatus("error");
            }
            toolMessage.getContent().add(Map.of("toolResult", result));
        }
        return List.of(toolMessage);
    }

    @Override
    public Map<String, ?> filterToFirstToolCall(Map<String, ?> dataAsMap, Map<String, String> parameters) {
        try {
            List contentList = (List)JsonPath.read(dataAsMap, (String)"$.output.message.content", (Predicate[])new Predicate[0]);
            if (contentList == null || contentList.size() <= 1) {
                return dataAsMap;
            }
            ArrayList filteredContent = new ArrayList();
            ArrayList<String> allToolNames = new ArrayList<String>();
            String selectedToolName = null;
            boolean foundFirstToolUse = false;
            for (Object item : contentList) {
                if (item instanceof Map && ((Map)item).containsKey("toolUse")) {
                    Map toolUseMap = (Map)((Map)item).get("toolUse");
                    String toolName = toolUseMap != null ? String.valueOf(toolUseMap.get(NAME)) : "unknown";
                    allToolNames.add(toolName);
                    if (foundFirstToolUse) continue;
                    filteredContent.add(item);
                    selectedToolName = toolName;
                    foundFirstToolUse = true;
                    continue;
                }
                filteredContent.add(item);
            }
            if (!foundFirstToolUse) {
                return dataAsMap;
            }
            if (allToolNames.size() > 1) {
                log.info("LLM suggested {} tool(s): {}. Selected first tool: {}", (Object)allToolNames.size(), allToolNames, selectedToolName);
            }
            Map mutableCopy = (Map)StringUtils.gson.fromJson(StringUtils.toJson(dataAsMap), Map.class);
            DocumentContext context = JsonPath.parse((Object)mutableCopy);
            context.set("$.output.message.content", filteredContent, new Predicate[0]);
            return (Map)context.json();
        }
        catch (Exception e) {
            log.error("Failed to filter out to only first tool call", (Throwable)e);
            return dataAsMap;
        }
    }

    public static class ToolResult {
        private String toolUseId;
        private List<Object> content = new ArrayList<Object>();
        private String status;

        @Generated
        public ToolResult() {
        }

        @Generated
        public String getToolUseId() {
            return this.toolUseId;
        }

        @Generated
        public List<Object> getContent() {
            return this.content;
        }

        @Generated
        public String getStatus() {
            return this.status;
        }

        @Generated
        public void setToolUseId(String toolUseId) {
            this.toolUseId = toolUseId;
        }

        @Generated
        public void setContent(List<Object> content) {
            this.content = content;
        }

        @Generated
        public void setStatus(String status) {
            this.status = status;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ToolResult)) {
                return false;
            }
            ToolResult other = (ToolResult)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$toolUseId = this.getToolUseId();
            String other$toolUseId = other.getToolUseId();
            if (this$toolUseId == null ? other$toolUseId != null : !this$toolUseId.equals(other$toolUseId)) {
                return false;
            }
            List<Object> this$content = this.getContent();
            List<Object> other$content = other.getContent();
            if (this$content == null ? other$content != null : !((Object)this$content).equals(other$content)) {
                return false;
            }
            String this$status = this.getStatus();
            String other$status = other.getStatus();
            return !(this$status == null ? other$status != null : !this$status.equals(other$status));
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof ToolResult;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $toolUseId = this.getToolUseId();
            result = result * 59 + ($toolUseId == null ? 43 : $toolUseId.hashCode());
            List<Object> $content = this.getContent();
            result = result * 59 + ($content == null ? 43 : ((Object)$content).hashCode());
            String $status = this.getStatus();
            result = result * 59 + ($status == null ? 43 : $status.hashCode());
            return result;
        }

        @Generated
        public String toString() {
            return "BedrockConverseFunctionCalling.ToolResult(toolUseId=" + this.getToolUseId() + ", content=" + String.valueOf(this.getContent()) + ", status=" + this.getStatus() + ")";
        }
    }
}

