From 677f534661634c74340f50723e55e241570d5a56 Mon Sep 17 00:00:00 2001 From: oleg Date: Wed, 15 May 2024 12:02:21 +0200 Subject: [PATCH] feat(AI Agent Node): Implement Tool calling agent (#9339) Signed-off-by: Oleg Ivaniv --- .../nodes/agents/Agent/Agent.node.ts | 124 ++++++++---- .../Agent/agents/ToolsAgent/description.ts | 43 ++++ .../agents/Agent/agents/ToolsAgent/execute.ts | 189 ++++++++++++++++++ .../agents/Agent/agents/ToolsAgent/prompt.ts | 1 + .../@n8n/nodes-langchain/utils/logWrapper.ts | 62 +++--- 5 files changed, 344 insertions(+), 75 deletions(-) create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/prompt.ts diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts index 9518728e35..f655ebd254 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts @@ -7,6 +7,7 @@ import type { INodeExecutionData, INodeType, INodeTypeDescription, + INodeProperties, } from 'n8n-workflow'; import { getTemplateNoticeField } from '../../../utils/sharedFields'; import { promptTypeOptions, textInput } from '../../../utils/descriptions'; @@ -20,11 +21,13 @@ import { reActAgentAgentProperties } from './agents/ReActAgent/description'; import { reActAgentAgentExecute } from './agents/ReActAgent/execute'; import { sqlAgentAgentProperties } from './agents/SqlAgent/description'; import { sqlAgentAgentExecute } from './agents/SqlAgent/execute'; +import { toolsAgentProperties } from './agents/ToolsAgent/description'; +import { toolsAgentExecute } from './agents/ToolsAgent/execute'; // Function used in the inputs expression to figure out which inputs to // display based on the agent type function getInputs( - agent: 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent', + agent: 'toolsAgent' | 'conversationalAgent' | 'openAiFunctionsAgent' | 'reActAgent' | 'sqlAgent', hasOutputParser?: boolean, ): Array { interface SpecialInput { @@ -92,6 +95,31 @@ function getInputs( type: NodeConnectionType.AiOutputParser, }, ]; + } else if (agent === 'toolsAgent') { + specialInputs = [ + { + type: NodeConnectionType.AiLanguageModel, + filter: { + nodes: [ + '@n8n/n8n-nodes-langchain.lmChatAnthropic', + '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', + '@n8n/n8n-nodes-langchain.lmChatMistralCloud', + '@n8n/n8n-nodes-langchain.lmChatOpenAi', + '@n8n/n8n-nodes-langchain.lmChatGroq', + ], + }, + }, + { + type: NodeConnectionType.AiMemory, + }, + { + type: NodeConnectionType.AiTool, + required: true, + }, + { + type: NodeConnectionType.AiOutputParser, + }, + ]; } else if (agent === 'openAiFunctionsAgent') { specialInputs = [ { @@ -157,16 +185,60 @@ function getInputs( return [NodeConnectionType.Main, ...getInputData(specialInputs)]; } +const agentTypeProperty: INodeProperties = { + displayName: 'Agent', + name: 'agent', + type: 'options', + noDataExpression: true, + options: [ + { + name: 'Conversational Agent', + value: 'conversationalAgent', + description: + 'Selects tools to accomplish its task and uses memory to recall previous conversations', + }, + { + name: 'OpenAI Functions Agent', + value: 'openAiFunctionsAgent', + description: + "Utilizes OpenAI's Function Calling feature to select the appropriate tool and arguments for execution", + }, + { + name: 'Plan and Execute Agent', + value: 'planAndExecuteAgent', + description: + 'Plan and execute agents accomplish an objective by first planning what to do, then executing the sub tasks', + }, + { + name: 'ReAct Agent', + value: 'reActAgent', + description: 'Strategically select tools to accomplish a given task', + }, + { + name: 'SQL Agent', + value: 'sqlAgent', + description: 'Answers questions about data in an SQL database', + }, + { + name: 'Tools Agent', + value: 'toolsAgent', + description: + 'Utilized unified Tool calling interface to select the appropriate tools and argument for execution', + }, + ], + default: '', +}; + export class Agent implements INodeType { description: INodeTypeDescription = { displayName: 'AI Agent', name: 'agent', icon: 'fa:robot', group: ['transform'], - version: [1, 1.1, 1.2, 1.3, 1.4, 1.5], + version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6], description: 'Generates an action plan and executes it. Can use external tools.', subtitle: - "={{ { conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}", + "={{ { toolsAgent: 'Tools Agent', conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}", defaults: { name: 'AI Agent', color: '#404040', @@ -225,43 +297,18 @@ export class Agent implements INodeType { }, }, }, + // Make Conversational Agent the default agent for versions 1.5 and below { - displayName: 'Agent', - name: 'agent', - type: 'options', - noDataExpression: true, - options: [ - { - name: 'Conversational Agent', - value: 'conversationalAgent', - description: - 'Selects tools to accomplish its task and uses memory to recall previous conversations', - }, - { - name: 'OpenAI Functions Agent', - value: 'openAiFunctionsAgent', - description: - "Utilizes OpenAI's Function Calling feature to select the appropriate tool and arguments for execution", - }, - { - name: 'Plan and Execute Agent', - value: 'planAndExecuteAgent', - description: - 'Plan and execute agents accomplish an objective by first planning what to do, then executing the sub tasks', - }, - { - name: 'ReAct Agent', - value: 'reActAgent', - description: 'Strategically select tools to accomplish a given task', - }, - { - name: 'SQL Agent', - value: 'sqlAgent', - description: 'Answers questions about data in an SQL database', - }, - ], + ...agentTypeProperty, + displayOptions: { show: { '@version': [{ _cnd: { lte: 1.5 } }] } }, default: 'conversationalAgent', }, + // Make Tools Agent the default agent for versions 1.6 and above + { + ...agentTypeProperty, + displayOptions: { show: { '@version': [{ _cnd: { gte: 1.6 } }] } }, + default: 'toolsAgent', + }, { ...promptTypeOptions, displayOptions: { @@ -307,6 +354,7 @@ export class Agent implements INodeType { }, }, + ...toolsAgentProperties, ...conversationalAgentProperties, ...openAiFunctionsAgentProperties, ...reActAgentAgentProperties, @@ -321,6 +369,8 @@ export class Agent implements INodeType { if (agentType === 'conversationalAgent') { return await conversationalAgentExecute.call(this, nodeVersion); + } else if (agentType === 'toolsAgent') { + return await toolsAgentExecute.call(this, nodeVersion); } else if (agentType === 'openAiFunctionsAgent') { return await openAiFunctionsAgentExecute.call(this, nodeVersion); } else if (agentType === 'reActAgent') { diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts new file mode 100644 index 0000000000..4597909f7f --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts @@ -0,0 +1,43 @@ +import type { INodeProperties } from 'n8n-workflow'; +import { SYSTEM_MESSAGE } from './prompt'; + +export const toolsAgentProperties: INodeProperties[] = [ + { + displayName: 'Options', + name: 'options', + type: 'collection', + displayOptions: { + show: { + agent: ['toolsAgent'], + }, + }, + default: {}, + placeholder: 'Add Option', + options: [ + { + displayName: 'System Message', + name: 'systemMessage', + type: 'string', + default: SYSTEM_MESSAGE, + description: 'The message that will be sent to the agent before the conversation starts', + typeOptions: { + rows: 6, + }, + }, + { + displayName: 'Max Iterations', + name: 'maxIterations', + type: 'number', + default: 10, + description: 'The maximum number of iterations the agent will run before stopping', + }, + { + displayName: 'Return Intermediate Steps', + name: 'returnIntermediateSteps', + type: 'boolean', + default: false, + description: 'Whether or not the output should include intermediate steps the agent took', + }, + ], + }, +]; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts new file mode 100644 index 0000000000..65265f704e --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts @@ -0,0 +1,189 @@ +import { NodeConnectionType, NodeOperationError } from 'n8n-workflow'; +import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; + +import type { AgentAction, AgentFinish, AgentStep } from 'langchain/agents'; +import { AgentExecutor, createToolCallingAgent } from 'langchain/agents'; +import type { BaseChatMemory } from '@langchain/community/memory/chat_memory'; +import { ChatPromptTemplate } from '@langchain/core/prompts'; +import { omit } from 'lodash'; +import type { Tool } from '@langchain/core/tools'; +import { DynamicStructuredTool } from '@langchain/core/tools'; +import { RunnableSequence } from '@langchain/core/runnables'; +import type { ZodObject } from 'zod'; +import { z } from 'zod'; +import type { BaseOutputParser, StructuredOutputParser } from '@langchain/core/output_parsers'; +import { OutputFixingParser } from 'langchain/output_parsers'; +import { + isChatInstance, + getPromptInputByType, + getOptionalOutputParsers, + getConnectedTools, +} from '../../../../../utils/helpers'; +import { SYSTEM_MESSAGE } from './prompt'; + +function getOutputParserSchema(outputParser: BaseOutputParser): ZodObject { + const parserType = outputParser.lc_namespace[outputParser.lc_namespace.length - 1]; + let schema: ZodObject; + + if (parserType === 'structured') { + // If the output parser is a structured output parser, we will use the schema from the parser + schema = (outputParser as StructuredOutputParser>).schema; + } else if (parserType === 'fix' && outputParser instanceof OutputFixingParser) { + // If the output parser is a fixing parser, we will use the schema from the connected structured output parser + schema = (outputParser.parser as StructuredOutputParser>).schema; + } else { + // If the output parser is not a structured output parser, we will use a fallback schema + schema = z.object({ text: z.string() }); + } + + return schema; +} + +export async function toolsAgentExecute( + this: IExecuteFunctions, + nodeVersion: number, +): Promise { + this.logger.verbose('Executing Tools Agent'); + const model = await this.getInputConnectionData(NodeConnectionType.AiLanguageModel, 0); + + if (!isChatInstance(model) || !model.bindTools) { + throw new NodeOperationError( + this.getNode(), + 'Tools Agent requires Chat Model which supports Tools calling', + ); + } + + const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as + | BaseChatMemory + | undefined; + + const tools = (await getConnectedTools(this, true)) as Array; + const outputParser = (await getOptionalOutputParsers(this))?.[0]; + let structuredOutputParserTool: DynamicStructuredTool | undefined; + + async function agentStepsParser( + steps: AgentFinish | AgentAction[], + ): Promise { + if (Array.isArray(steps)) { + const responseParserTool = steps.find((step) => step.tool === 'format_final_response'); + if (responseParserTool) { + const toolInput = responseParserTool?.toolInput; + const returnValues = (await outputParser.parse(toolInput as unknown as string)) as Record< + string, + unknown + >; + + return { + returnValues, + log: 'Final response formatted', + }; + } + } + + // If the steps are an AgentFinish and the outputParser is defined it must mean that the LLM didn't use `format_final_response` tool so we will parse the output manually + if (outputParser && typeof steps === 'object' && (steps as AgentFinish).returnValues) { + const finalResponse = (steps as AgentFinish).returnValues; + const returnValues = (await outputParser.parse(finalResponse as unknown as string)) as Record< + string, + unknown + >; + + return { + returnValues, + log: 'Final response formatted', + }; + } + return steps; + } + + if (outputParser) { + const schema = getOutputParserSchema(outputParser); + structuredOutputParserTool = new DynamicStructuredTool({ + schema, + name: 'format_final_response', + description: + 'Always use this tool for the final output to the user. It validates the output so only use it when you are sure the output is final.', + // We will not use the function here as we will use the parser to intercept & parse the output in the agentStepsParser + func: async () => '', + }); + + tools.push(structuredOutputParserTool); + } + + const options = this.getNodeParameter('options', 0, {}) as { + systemMessage?: string; + maxIterations?: number; + returnIntermediateSteps?: boolean; + }; + + const prompt = ChatPromptTemplate.fromMessages([ + ['system', `{system_message}${outputParser ? '\n\n{formatting_instructions}' : ''}`], + ['placeholder', '{chat_history}'], + ['human', '{input}'], + ['placeholder', '{agent_scratchpad}'], + ]); + + const agent = createToolCallingAgent({ + llm: model, + tools, + prompt, + streamRunnable: false, + }); + agent.streamRunnable = false; + + const runnableAgent = RunnableSequence.from<{ + steps: AgentStep[]; + }>([agent, agentStepsParser]); + + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + memory, + tools, + returnIntermediateSteps: options.returnIntermediateSteps === true, + maxIterations: options.maxIterations ?? 10, + }); + const returnData: INodeExecutionData[] = []; + + const items = this.getInputData(); + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + try { + const input = getPromptInputByType({ + ctx: this, + i: itemIndex, + inputKey: 'text', + promptTypeKey: 'promptType', + }); + + if (input === undefined) { + throw new NodeOperationError(this.getNode(), 'The ‘text parameter is empty.'); + } + + const response = await executor.invoke({ + input, + system_message: options.systemMessage ?? SYSTEM_MESSAGE, + formatting_instructions: + 'IMPORTANT: Always call `format_final_response` to format your final response!', //outputParser?.getFormatInstructions(), + }); + + returnData.push({ + json: omit( + response, + 'system_message', + 'formatting_instructions', + 'input', + 'chat_history', + 'agent_scratchpad', + ), + }); + } catch (error) { + if (this.continueOnFail()) { + returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); + continue; + } + + throw error; + } + } + + return await this.prepareOutputData(returnData); +} diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/prompt.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/prompt.ts new file mode 100644 index 0000000000..069a2629b5 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/prompt.ts @@ -0,0 +1 @@ +export const SYSTEM_MESSAGE = 'You are a helpful assistant'; diff --git a/packages/@n8n/nodes-langchain/utils/logWrapper.ts b/packages/@n8n/nodes-langchain/utils/logWrapper.ts index 1ce6924813..4cdec6fbfc 100644 --- a/packages/@n8n/nodes-langchain/utils/logWrapper.ts +++ b/packages/@n8n/nodes-langchain/utils/logWrapper.ts @@ -13,7 +13,6 @@ import type { Document } from '@langchain/core/documents'; import { TextSplitter } from 'langchain/text_splitter'; import { BaseChatMemory } from '@langchain/community/memory/chat_memory'; import { BaseRetriever } from '@langchain/core/retrievers'; -import type { FormatInstructionsOptions } from '@langchain/core/output_parsers'; import { BaseOutputParser, OutputParserException } from '@langchain/core/output_parsers'; import { isObject } from 'lodash'; import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base'; @@ -222,31 +221,7 @@ export function logWrapper( // ========== BaseOutputParser ========== if (originalInstance instanceof BaseOutputParser) { - if (prop === 'getFormatInstructions' && 'getFormatInstructions' in target) { - return (options?: FormatInstructionsOptions): string => { - connectionType = NodeConnectionType.AiOutputParser; - const { index } = executeFunctions.addInputData(connectionType, [ - [{ json: { action: 'getFormatInstructions' } }], - ]); - - // @ts-ignore - const response = callMethodSync.call(target, { - executeFunctions, - connectionType, - currentNodeRunIndex: index, - method: target[prop], - arguments: [options], - }) as string; - - executeFunctions.addOutputData(connectionType, index, [ - [{ json: { action: 'getFormatInstructions', response } }], - ]); - void logAiEvent(executeFunctions, 'n8n.ai.output.parser.get.instructions', { - response, - }); - return response; - }; - } else if (prop === 'parse' && 'parse' in target) { + if (prop === 'parse' && 'parse' in target) { return async (text: string | Record): Promise => { connectionType = NodeConnectionType.AiOutputParser; const stringifiedText = isObject(text) ? JSON.stringify(text) : text; @@ -254,19 +229,30 @@ export function logWrapper( [{ json: { action: 'parse', text: stringifiedText } }], ]); - const response = (await callMethodAsync.call(target, { - executeFunctions, - connectionType, - currentNodeRunIndex: index, - method: target[prop], - arguments: [stringifiedText], - })) as object; + try { + const response = (await callMethodAsync.call(target, { + executeFunctions, + connectionType, + currentNodeRunIndex: index, + method: target[prop], + arguments: [stringifiedText], + })) as object; - void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { text, response }); - executeFunctions.addOutputData(connectionType, index, [ - [{ json: { action: 'parse', response } }], - ]); - return response; + void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { text, response }); + executeFunctions.addOutputData(connectionType, index, [ + [{ json: { action: 'parse', response } }], + ]); + return response; + } catch (error) { + void logAiEvent(executeFunctions, 'n8n.ai.output.parser.parsed', { + text, + response: error.message ?? error, + }); + executeFunctions.addOutputData(connectionType, index, [ + [{ json: { action: 'parse', response: error.message ?? error } }], + ]); + throw error; + } }; } }