From ff156930c5f1b75da59bc27b424925a5535cd908 Mon Sep 17 00:00:00 2001 From: Benjamin Schroth <68321970+schrothbn@users.noreply.github.com> Date: Tue, 13 May 2025 13:58:38 +0200 Subject: [PATCH] feat: Optimise langchain calls in batching mode (#15243) --- .../nodes/agents/Agent/Agent.node.ts | 510 ++---------------- .../nodes/agents/Agent/V1/AgentV1.node.ts | 460 ++++++++++++++++ .../nodes/agents/Agent/V2/AgentV2.node.ts | 169 ++++++ .../Agent/agents/ToolsAgent/V1/description.ts | 19 + .../Agent/agents/ToolsAgent/V1/execute.ts | 138 +++++ .../Agent/agents/ToolsAgent/V2/description.ts | 16 + .../Agent/agents/ToolsAgent/V2/execute.ts | 161 ++++++ .../ToolsAgent/{execute.ts => common.ts} | 142 +---- .../Agent/agents/ToolsAgent/description.ts | 52 -- .../agents/Agent/agents/ToolsAgent/options.ts | 38 ++ .../test/ToolsAgent/ToolsAgentV1.test.ts | 159 ++++++ .../test/ToolsAgent/ToolsAgentV2.test.ts | 223 ++++++++ .../commons.test.ts} | 2 +- .../nodes/chains/ChainLLM/ChainLlm.node.ts | 163 +++--- .../nodes/chains/ChainLLM/methods/config.ts | 7 +- .../chains/ChainLLM/methods/processItem.ts | 54 ++ .../ChainLLM/test/ChainLlm.node.test.ts | 144 ++++- .../ChainRetrievalQA/ChainRetrievalQa.node.ts | 205 +++---- .../chains/ChainRetrievalQA/constants.ts | 20 + .../chains/ChainRetrievalQA/processItem.ts | 100 ++++ .../test/ChainRetrievalQa.node.test.ts | 124 ++++- .../ChainSummarization.node.ts | 3 +- .../V2/ChainSummarizationV2.node.ts | 165 ++---- .../ChainSummarization/V2/processItem.ts | 107 ++++ .../InformationExtractor.node.ts | 94 ++-- .../chains/InformationExtractor/constants.ts | 3 + .../InformationExtractor/processItem.ts | 39 ++ .../test/InformationExtraction.node.test.ts | 140 ++++- .../SentimentAnalysis.node.ts | 343 ++++++++---- .../TextClassifier/TextClassifier.node.ts | 152 ++++-- .../nodes/chains/TextClassifier/constants.ts | 2 + .../chains/TextClassifier/processItem.ts | 57 ++ .../nodes-langchain/utils/sharedFields.ts | 35 +- .../src/stores/workflows.store.test.ts | 61 ++- .../editor-ui/src/stores/workflows.store.ts | 10 +- 35 files changed, 2946 insertions(+), 1171 deletions(-) create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/description.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/description.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts rename packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/{execute.ts => common.ts} (74%) delete mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/options.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV1.test.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts rename packages/@n8n/nodes-langchain/nodes/agents/Agent/test/{ToolsAgent.test.ts => ToolsAgent/commons.test.ts} (99%) create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/constants.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/processItem.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/processItem.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/constants.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/processItem.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/constants.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/processItem.ts diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts index fb80f166e6..f86ad3f9e1 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/Agent.node.ts @@ -1,475 +1,49 @@ -import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; -import type { - INodeInputConfiguration, - INodeInputFilter, - IExecuteFunctions, - INodeExecutionData, - INodeType, - INodeTypeDescription, - INodeProperties, - NodeConnectionType, -} from 'n8n-workflow'; +import type { INodeTypeBaseDescription, IVersionedNodeType } from 'n8n-workflow'; +import { VersionedNodeType } from 'n8n-workflow'; -import { promptTypeOptions, textFromPreviousNode, textInput } from '@utils/descriptions'; +import { AgentV1 } from './V1/AgentV1.node'; +import { AgentV2 } from './V2/AgentV2.node'; -import { conversationalAgentProperties } from './agents/ConversationalAgent/description'; -import { conversationalAgentExecute } from './agents/ConversationalAgent/execute'; -import { openAiFunctionsAgentProperties } from './agents/OpenAiFunctionsAgent/description'; -import { openAiFunctionsAgentExecute } from './agents/OpenAiFunctionsAgent/execute'; -import { planAndExecuteAgentProperties } from './agents/PlanAndExecuteAgent/description'; -import { planAndExecuteAgentExecute } from './agents/PlanAndExecuteAgent/execute'; -import { reActAgentAgentProperties } from './agents/ReActAgent/description'; -import { reActAgentAgentExecute } from './agents/ReActAgent/execute'; -import { sqlAgentAgentProperties } from './agents/SqlAgent/description'; -import { sqlAgentAgentExecute } from './agents/SqlAgent/execute'; -import { toolsAgentProperties } from './agents/ToolsAgent/description'; -import { toolsAgentExecute } from './agents/ToolsAgent/execute'; - -// Function used in the inputs expression to figure out which inputs to -// display based on the agent type -function getInputs( - agent: - | 'toolsAgent' - | 'conversationalAgent' - | 'openAiFunctionsAgent' - | 'planAndExecuteAgent' - | 'reActAgent' - | 'sqlAgent', - hasOutputParser?: boolean, -): Array { - interface SpecialInput { - type: NodeConnectionType; - filter?: INodeInputFilter; - required?: boolean; - } - - const getInputData = ( - inputs: SpecialInput[], - ): Array => { - const displayNames: { [key: string]: string } = { - ai_languageModel: 'Model', - ai_memory: 'Memory', - ai_tool: 'Tool', - ai_outputParser: 'Output Parser', +export class Agent extends VersionedNodeType { + constructor() { + const baseDescription: INodeTypeBaseDescription = { + displayName: 'AI Agent', + name: 'agent', + icon: 'fa:robot', + iconColor: 'black', + group: ['transform'], + description: 'Generates an action plan and executes it. Can use external tools.', + codex: { + alias: ['LangChain', 'Chat', 'Conversational', 'Plan and Execute', 'ReAct', 'Tools'], + categories: ['AI'], + subcategories: { + AI: ['Agents', 'Root Nodes'], + }, + resources: { + primaryDocumentation: [ + { + url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.agent/', + }, + ], + }, + }, + defaultVersion: 2, }; - return inputs.map(({ type, filter }) => { - const isModelType = type === ('ai_languageModel' as NodeConnectionType); - let displayName = type in displayNames ? displayNames[type] : undefined; - if ( - isModelType && - ['openAiFunctionsAgent', 'toolsAgent', 'conversationalAgent'].includes(agent) - ) { - displayName = 'Chat Model'; - } - const input: INodeInputConfiguration = { - type, - displayName, - required: isModelType, - maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes( - type as NodeConnectionType, - ) - ? 1 - : undefined, - }; + const nodeVersions: IVersionedNodeType['nodeVersions'] = { + 1: new AgentV1(baseDescription), + 1.1: new AgentV1(baseDescription), + 1.2: new AgentV1(baseDescription), + 1.3: new AgentV1(baseDescription), + 1.4: new AgentV1(baseDescription), + 1.5: new AgentV1(baseDescription), + 1.6: new AgentV1(baseDescription), + 1.7: new AgentV1(baseDescription), + 1.8: new AgentV1(baseDescription), + 1.9: new AgentV1(baseDescription), + 2: new AgentV2(baseDescription), + }; - if (filter) { - input.filter = filter; - } - - return input; - }); - }; - - let specialInputs: SpecialInput[] = []; - - if (agent === 'conversationalAgent') { - specialInputs = [ - { - type: 'ai_languageModel', - filter: { - nodes: [ - '@n8n/n8n-nodes-langchain.lmChatAnthropic', - '@n8n/n8n-nodes-langchain.lmChatAwsBedrock', - '@n8n/n8n-nodes-langchain.lmChatGroq', - '@n8n/n8n-nodes-langchain.lmChatOllama', - '@n8n/n8n-nodes-langchain.lmChatOpenAi', - '@n8n/n8n-nodes-langchain.lmChatGoogleGemini', - '@n8n/n8n-nodes-langchain.lmChatGoogleVertex', - '@n8n/n8n-nodes-langchain.lmChatMistralCloud', - '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', - '@n8n/n8n-nodes-langchain.lmChatDeepSeek', - '@n8n/n8n-nodes-langchain.lmChatOpenRouter', - '@n8n/n8n-nodes-langchain.lmChatXAiGrok', - ], - }, - }, - { - type: 'ai_memory', - }, - { - type: 'ai_tool', - }, - { - type: 'ai_outputParser', - }, - ]; - } else if (agent === 'toolsAgent') { - specialInputs = [ - { - type: 'ai_languageModel', - filter: { - nodes: [ - '@n8n/n8n-nodes-langchain.lmChatAnthropic', - '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', - '@n8n/n8n-nodes-langchain.lmChatAwsBedrock', - '@n8n/n8n-nodes-langchain.lmChatMistralCloud', - '@n8n/n8n-nodes-langchain.lmChatOllama', - '@n8n/n8n-nodes-langchain.lmChatOpenAi', - '@n8n/n8n-nodes-langchain.lmChatGroq', - '@n8n/n8n-nodes-langchain.lmChatGoogleVertex', - '@n8n/n8n-nodes-langchain.lmChatGoogleGemini', - '@n8n/n8n-nodes-langchain.lmChatDeepSeek', - '@n8n/n8n-nodes-langchain.lmChatOpenRouter', - '@n8n/n8n-nodes-langchain.lmChatXAiGrok', - ], - }, - }, - { - type: 'ai_memory', - }, - { - type: 'ai_tool', - required: true, - }, - { - type: 'ai_outputParser', - }, - ]; - } else if (agent === 'openAiFunctionsAgent') { - specialInputs = [ - { - type: 'ai_languageModel', - filter: { - nodes: [ - '@n8n/n8n-nodes-langchain.lmChatOpenAi', - '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', - ], - }, - }, - { - type: 'ai_memory', - }, - { - type: 'ai_tool', - required: true, - }, - { - type: 'ai_outputParser', - }, - ]; - } else if (agent === 'reActAgent') { - specialInputs = [ - { - type: 'ai_languageModel', - }, - { - type: 'ai_tool', - }, - { - type: 'ai_outputParser', - }, - ]; - } else if (agent === 'sqlAgent') { - specialInputs = [ - { - type: 'ai_languageModel', - }, - { - type: 'ai_memory', - }, - ]; - } else if (agent === 'planAndExecuteAgent') { - specialInputs = [ - { - type: 'ai_languageModel', - }, - { - type: 'ai_tool', - }, - { - type: 'ai_outputParser', - }, - ]; - } - - if (hasOutputParser === false) { - specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser'); - } - return ['main', ...getInputData(specialInputs)]; -} - -const agentTypeProperty: INodeProperties = { - displayName: 'Agent', - name: 'agent', - type: 'options', - noDataExpression: true, - // eslint-disable-next-line n8n-nodes-base/node-param-options-type-unsorted-items - options: [ - { - name: 'Tools Agent', - value: 'toolsAgent', - description: - 'Utilizes structured tool schemas for precise and reliable tool selection and execution. Recommended for complex tasks requiring accurate and consistent tool usage, but only usable with models that support tool calling.', - }, - { - name: 'Conversational Agent', - value: 'conversationalAgent', - description: - 'Describes tools in the system prompt and parses JSON responses for tool calls. More flexible but potentially less reliable than the Tools Agent. Suitable for simpler interactions or with models not supporting structured schemas.', - }, - { - name: 'OpenAI Functions Agent', - value: 'openAiFunctionsAgent', - description: - "Leverages OpenAI's function calling capabilities to precisely select and execute tools. Excellent for tasks requiring structured outputs when working with OpenAI models.", - }, - { - name: 'Plan and Execute Agent', - value: 'planAndExecuteAgent', - description: - 'Creates a high-level plan for complex tasks and then executes each step. Suitable for multi-stage problems or when a strategic approach is needed.', - }, - { - name: 'ReAct Agent', - value: 'reActAgent', - description: - 'Combines reasoning and action in an iterative process. Effective for tasks that require careful analysis and step-by-step problem-solving.', - }, - { - name: 'SQL Agent', - value: 'sqlAgent', - description: - 'Specializes in interacting with SQL databases. Ideal for data analysis tasks, generating queries, or extracting insights from structured data.', - }, - ], - default: '', -}; - -export class Agent implements INodeType { - description: INodeTypeDescription = { - displayName: 'AI Agent', - name: 'agent', - icon: 'fa:robot', - iconColor: 'black', - group: ['transform'], - version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9], - description: 'Generates an action plan and executes it. Can use external tools.', - subtitle: - "={{ { toolsAgent: 'Tools Agent', conversationalAgent: 'Conversational Agent', openAiFunctionsAgent: 'OpenAI Functions Agent', reActAgent: 'ReAct Agent', sqlAgent: 'SQL Agent', planAndExecuteAgent: 'Plan and Execute Agent' }[$parameter.agent] }}", - defaults: { - name: 'AI Agent', - color: '#404040', - }, - codex: { - alias: ['LangChain', 'Chat', 'Conversational', 'Plan and Execute', 'ReAct', 'Tools'], - categories: ['AI'], - subcategories: { - AI: ['Agents', 'Root Nodes'], - }, - resources: { - primaryDocumentation: [ - { - url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.agent/', - }, - ], - }, - }, - inputs: `={{ - ((agent, hasOutputParser) => { - ${getInputs.toString()}; - return getInputs(agent, hasOutputParser) - })($parameter.agent, $parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true) - }}`, - outputs: [NodeConnectionTypes.Main], - credentials: [ - { - // eslint-disable-next-line n8n-nodes-base/node-class-description-credentials-name-unsuffixed - name: 'mySql', - required: true, - testedBy: 'mysqlConnectionTest', - displayOptions: { - show: { - agent: ['sqlAgent'], - '/dataSource': ['mysql'], - }, - }, - }, - { - name: 'postgres', - required: true, - displayOptions: { - show: { - agent: ['sqlAgent'], - '/dataSource': ['postgres'], - }, - }, - }, - ], - properties: [ - { - displayName: - 'Tip: Get a feel for agents with our quick tutorial or see an example of how this node works', - name: 'notice_tip', - type: 'notice', - default: '', - displayOptions: { - show: { - agent: ['conversationalAgent', 'toolsAgent'], - }, - }, - }, - { - displayName: - "This node is using Agent that has been deprecated. Please switch to using 'Tools Agent' instead.", - name: 'deprecated', - type: 'notice', - default: '', - displayOptions: { - show: { - agent: [ - 'conversationalAgent', - 'openAiFunctionsAgent', - 'planAndExecuteAgent', - 'reActAgent', - 'sqlAgent', - ], - }, - }, - }, - // Make Conversational Agent the default agent for versions 1.5 and below - { - ...agentTypeProperty, - options: agentTypeProperty?.options?.filter( - (o) => 'value' in o && o.value !== 'toolsAgent', - ), - displayOptions: { show: { '@version': [{ _cnd: { lte: 1.5 } }] } }, - default: 'conversationalAgent', - }, - // Make Tools Agent the default agent for versions 1.6 and 1.7 - { - ...agentTypeProperty, - displayOptions: { show: { '@version': [{ _cnd: { between: { from: 1.6, to: 1.7 } } }] } }, - default: 'toolsAgent', - }, - // Make Tools Agent the only agent option for versions 1.8 and above - { - ...agentTypeProperty, - type: 'hidden', - displayOptions: { show: { '@version': [{ _cnd: { gte: 1.8 } }] } }, - default: 'toolsAgent', - }, - { - ...promptTypeOptions, - displayOptions: { - hide: { - '@version': [{ _cnd: { lte: 1.2 } }], - agent: ['sqlAgent'], - }, - }, - }, - { - ...textFromPreviousNode, - displayOptions: { - show: { promptType: ['auto'], '@version': [{ _cnd: { gte: 1.7 } }] }, - // SQL Agent has data source and credentials parameters so we need to include this input there manually - // to preserve the order - hide: { - agent: ['sqlAgent'], - }, - }, - }, - { - ...textInput, - displayOptions: { - show: { - promptType: ['define'], - }, - hide: { - agent: ['sqlAgent'], - }, - }, - }, - { - displayName: 'For more reliable structured output parsing, consider using the Tools agent', - name: 'notice', - type: 'notice', - default: '', - displayOptions: { - show: { - hasOutputParser: [true], - agent: [ - 'conversationalAgent', - 'reActAgent', - 'planAndExecuteAgent', - 'openAiFunctionsAgent', - ], - }, - }, - }, - { - displayName: 'Require Specific Output Format', - name: 'hasOutputParser', - type: 'boolean', - default: false, - noDataExpression: true, - displayOptions: { - hide: { - '@version': [{ _cnd: { lte: 1.2 } }], - agent: ['sqlAgent'], - }, - }, - }, - { - displayName: `Connect an output parser on the canvas to specify the output format you require`, - name: 'notice', - type: 'notice', - default: '', - displayOptions: { - show: { - hasOutputParser: [true], - agent: ['toolsAgent'], - }, - }, - }, - - ...toolsAgentProperties, - ...conversationalAgentProperties, - ...openAiFunctionsAgentProperties, - ...reActAgentAgentProperties, - ...sqlAgentAgentProperties, - ...planAndExecuteAgentProperties, - ], - }; - - async execute(this: IExecuteFunctions): Promise { - const agentType = this.getNodeParameter('agent', 0, '') as string; - const nodeVersion = this.getNode().typeVersion; - - if (agentType === 'conversationalAgent') { - return await conversationalAgentExecute.call(this, nodeVersion); - } else if (agentType === 'toolsAgent') { - return await toolsAgentExecute.call(this); - } else if (agentType === 'openAiFunctionsAgent') { - return await openAiFunctionsAgentExecute.call(this, nodeVersion); - } else if (agentType === 'reActAgent') { - return await reActAgentAgentExecute.call(this, nodeVersion); - } else if (agentType === 'sqlAgent') { - return await sqlAgentAgentExecute.call(this); - } else if (agentType === 'planAndExecuteAgent') { - return await planAndExecuteAgentExecute.call(this, nodeVersion); - } - - throw new NodeOperationError(this.getNode(), `The agent type "${agentType}" is not supported`); + super(nodeVersions, baseDescription); } } diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts new file mode 100644 index 0000000000..669e36b109 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts @@ -0,0 +1,460 @@ +import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; +import type { + INodeInputConfiguration, + INodeInputFilter, + IExecuteFunctions, + INodeExecutionData, + INodeType, + INodeTypeDescription, + INodeProperties, + NodeConnectionType, + INodeTypeBaseDescription, +} from 'n8n-workflow'; + +import { promptTypeOptions, textFromPreviousNode, textInput } from '@utils/descriptions'; + +import { conversationalAgentProperties } from '../agents/ConversationalAgent/description'; +import { conversationalAgentExecute } from '../agents/ConversationalAgent/execute'; +import { openAiFunctionsAgentProperties } from '../agents/OpenAiFunctionsAgent/description'; +import { openAiFunctionsAgentExecute } from '../agents/OpenAiFunctionsAgent/execute'; +import { planAndExecuteAgentProperties } from '../agents/PlanAndExecuteAgent/description'; +import { planAndExecuteAgentExecute } from '../agents/PlanAndExecuteAgent/execute'; +import { reActAgentAgentProperties } from '../agents/ReActAgent/description'; +import { reActAgentAgentExecute } from '../agents/ReActAgent/execute'; +import { sqlAgentAgentProperties } from '../agents/SqlAgent/description'; +import { sqlAgentAgentExecute } from '../agents/SqlAgent/execute'; +import { toolsAgentProperties } from '../agents/ToolsAgent/V1/description'; +import { toolsAgentExecute } from '../agents/ToolsAgent/V1/execute'; + +// Function used in the inputs expression to figure out which inputs to +// display based on the agent type +function getInputs( + agent: + | 'toolsAgent' + | 'conversationalAgent' + | 'openAiFunctionsAgent' + | 'planAndExecuteAgent' + | 'reActAgent' + | 'sqlAgent', + hasOutputParser?: boolean, +): Array { + interface SpecialInput { + type: NodeConnectionType; + filter?: INodeInputFilter; + required?: boolean; + } + + const getInputData = ( + inputs: SpecialInput[], + ): Array => { + const displayNames: { [key: string]: string } = { + ai_languageModel: 'Model', + ai_memory: 'Memory', + ai_tool: 'Tool', + ai_outputParser: 'Output Parser', + }; + + return inputs.map(({ type, filter }) => { + const isModelType = type === ('ai_languageModel' as NodeConnectionType); + let displayName = type in displayNames ? displayNames[type] : undefined; + if ( + isModelType && + ['openAiFunctionsAgent', 'toolsAgent', 'conversationalAgent'].includes(agent) + ) { + displayName = 'Chat Model'; + } + const input: INodeInputConfiguration = { + type, + displayName, + required: isModelType, + maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes( + type as NodeConnectionType, + ) + ? 1 + : undefined, + }; + + if (filter) { + input.filter = filter; + } + + return input; + }); + }; + + let specialInputs: SpecialInput[] = []; + + if (agent === 'conversationalAgent') { + specialInputs = [ + { + type: 'ai_languageModel', + filter: { + nodes: [ + '@n8n/n8n-nodes-langchain.lmChatAnthropic', + '@n8n/n8n-nodes-langchain.lmChatAwsBedrock', + '@n8n/n8n-nodes-langchain.lmChatGroq', + '@n8n/n8n-nodes-langchain.lmChatOllama', + '@n8n/n8n-nodes-langchain.lmChatOpenAi', + '@n8n/n8n-nodes-langchain.lmChatGoogleGemini', + '@n8n/n8n-nodes-langchain.lmChatGoogleVertex', + '@n8n/n8n-nodes-langchain.lmChatMistralCloud', + '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', + '@n8n/n8n-nodes-langchain.lmChatDeepSeek', + '@n8n/n8n-nodes-langchain.lmChatOpenRouter', + '@n8n/n8n-nodes-langchain.lmChatXAiGrok', + ], + }, + }, + { + type: 'ai_memory', + }, + { + type: 'ai_tool', + }, + { + type: 'ai_outputParser', + }, + ]; + } else if (agent === 'toolsAgent') { + specialInputs = [ + { + type: 'ai_languageModel', + filter: { + nodes: [ + '@n8n/n8n-nodes-langchain.lmChatAnthropic', + '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', + '@n8n/n8n-nodes-langchain.lmChatAwsBedrock', + '@n8n/n8n-nodes-langchain.lmChatMistralCloud', + '@n8n/n8n-nodes-langchain.lmChatOllama', + '@n8n/n8n-nodes-langchain.lmChatOpenAi', + '@n8n/n8n-nodes-langchain.lmChatGroq', + '@n8n/n8n-nodes-langchain.lmChatGoogleVertex', + '@n8n/n8n-nodes-langchain.lmChatGoogleGemini', + '@n8n/n8n-nodes-langchain.lmChatDeepSeek', + '@n8n/n8n-nodes-langchain.lmChatOpenRouter', + '@n8n/n8n-nodes-langchain.lmChatXAiGrok', + ], + }, + }, + { + type: 'ai_memory', + }, + { + type: 'ai_tool', + required: true, + }, + { + type: 'ai_outputParser', + }, + ]; + } else if (agent === 'openAiFunctionsAgent') { + specialInputs = [ + { + type: 'ai_languageModel', + filter: { + nodes: [ + '@n8n/n8n-nodes-langchain.lmChatOpenAi', + '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', + ], + }, + }, + { + type: 'ai_memory', + }, + { + type: 'ai_tool', + required: true, + }, + { + type: 'ai_outputParser', + }, + ]; + } else if (agent === 'reActAgent') { + specialInputs = [ + { + type: 'ai_languageModel', + }, + { + type: 'ai_tool', + }, + { + type: 'ai_outputParser', + }, + ]; + } else if (agent === 'sqlAgent') { + specialInputs = [ + { + type: 'ai_languageModel', + }, + { + type: 'ai_memory', + }, + ]; + } else if (agent === 'planAndExecuteAgent') { + specialInputs = [ + { + type: 'ai_languageModel', + }, + { + type: 'ai_tool', + }, + { + type: 'ai_outputParser', + }, + ]; + } + + if (hasOutputParser === false) { + specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser'); + } + return ['main', ...getInputData(specialInputs)]; +} + +const agentTypeProperty: INodeProperties = { + displayName: 'Agent', + name: 'agent', + type: 'options', + noDataExpression: true, + // eslint-disable-next-line n8n-nodes-base/node-param-options-type-unsorted-items + options: [ + { + name: 'Tools Agent', + value: 'toolsAgent', + description: + 'Utilizes structured tool schemas for precise and reliable tool selection and execution. Recommended for complex tasks requiring accurate and consistent tool usage, but only usable with models that support tool calling.', + }, + { + name: 'Conversational Agent', + value: 'conversationalAgent', + description: + 'Describes tools in the system prompt and parses JSON responses for tool calls. More flexible but potentially less reliable than the Tools Agent. Suitable for simpler interactions or with models not supporting structured schemas.', + }, + { + name: 'OpenAI Functions Agent', + value: 'openAiFunctionsAgent', + description: + "Leverages OpenAI's function calling capabilities to precisely select and execute tools. Excellent for tasks requiring structured outputs when working with OpenAI models.", + }, + { + name: 'Plan and Execute Agent', + value: 'planAndExecuteAgent', + description: + 'Creates a high-level plan for complex tasks and then executes each step. Suitable for multi-stage problems or when a strategic approach is needed.', + }, + { + name: 'ReAct Agent', + value: 'reActAgent', + description: + 'Combines reasoning and action in an iterative process. Effective for tasks that require careful analysis and step-by-step problem-solving.', + }, + { + name: 'SQL Agent', + value: 'sqlAgent', + description: + 'Specializes in interacting with SQL databases. Ideal for data analysis tasks, generating queries, or extracting insights from structured data.', + }, + ], + default: '', +}; + +export class AgentV1 implements INodeType { + description: INodeTypeDescription; + + constructor(baseDescription: INodeTypeBaseDescription) { + this.description = { + version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9], + ...baseDescription, + defaults: { + name: 'AI Agent', + color: '#404040', + }, + inputs: `={{ + ((agent, hasOutputParser) => { + ${getInputs.toString()}; + return getInputs(agent, hasOutputParser) + })($parameter.agent, $parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true) + }}`, + outputs: [NodeConnectionTypes.Main], + credentials: [ + { + // eslint-disable-next-line n8n-nodes-base/node-class-description-credentials-name-unsuffixed + name: 'mySql', + required: true, + testedBy: 'mysqlConnectionTest', + displayOptions: { + show: { + agent: ['sqlAgent'], + '/dataSource': ['mysql'], + }, + }, + }, + { + name: 'postgres', + required: true, + displayOptions: { + show: { + agent: ['sqlAgent'], + '/dataSource': ['postgres'], + }, + }, + }, + ], + properties: [ + { + displayName: + 'Tip: Get a feel for agents with our quick tutorial or see an example of how this node works', + name: 'notice_tip', + type: 'notice', + default: '', + displayOptions: { + show: { + agent: ['conversationalAgent', 'toolsAgent'], + }, + }, + }, + { + displayName: + "This node is using Agent that has been deprecated. Please switch to using 'Tools Agent' instead.", + name: 'deprecated', + type: 'notice', + default: '', + displayOptions: { + show: { + agent: [ + 'conversationalAgent', + 'openAiFunctionsAgent', + 'planAndExecuteAgent', + 'reActAgent', + 'sqlAgent', + ], + }, + }, + }, + // Make Conversational Agent the default agent for versions 1.5 and below + { + ...agentTypeProperty, + options: agentTypeProperty?.options?.filter( + (o) => 'value' in o && o.value !== 'toolsAgent', + ), + displayOptions: { show: { '@version': [{ _cnd: { lte: 1.5 } }] } }, + default: 'conversationalAgent', + }, + // Make Tools Agent the default agent for versions 1.6 and 1.7 + { + ...agentTypeProperty, + displayOptions: { show: { '@version': [{ _cnd: { between: { from: 1.6, to: 1.7 } } }] } }, + default: 'toolsAgent', + }, + // Make Tools Agent the only agent option for versions 1.8 and above + { + ...agentTypeProperty, + type: 'hidden', + displayOptions: { show: { '@version': [{ _cnd: { gte: 1.8 } }] } }, + default: 'toolsAgent', + }, + { + ...promptTypeOptions, + displayOptions: { + hide: { + '@version': [{ _cnd: { lte: 1.2 } }], + agent: ['sqlAgent'], + }, + }, + }, + { + ...textFromPreviousNode, + displayOptions: { + show: { promptType: ['auto'], '@version': [{ _cnd: { gte: 1.7 } }] }, + // SQL Agent has data source and credentials parameters so we need to include this input there manually + // to preserve the order + hide: { + agent: ['sqlAgent'], + }, + }, + }, + { + ...textInput, + displayOptions: { + show: { + promptType: ['define'], + }, + hide: { + agent: ['sqlAgent'], + }, + }, + }, + { + displayName: + 'For more reliable structured output parsing, consider using the Tools agent', + name: 'notice', + type: 'notice', + default: '', + displayOptions: { + show: { + hasOutputParser: [true], + agent: [ + 'conversationalAgent', + 'reActAgent', + 'planAndExecuteAgent', + 'openAiFunctionsAgent', + ], + }, + }, + }, + { + displayName: 'Require Specific Output Format', + name: 'hasOutputParser', + type: 'boolean', + default: false, + noDataExpression: true, + displayOptions: { + hide: { + '@version': [{ _cnd: { lte: 1.2 } }], + agent: ['sqlAgent'], + }, + }, + }, + { + displayName: `Connect an output parser on the canvas to specify the output format you require`, + name: 'notice', + type: 'notice', + default: '', + displayOptions: { + show: { + hasOutputParser: [true], + agent: ['toolsAgent'], + }, + }, + }, + + ...toolsAgentProperties, + ...conversationalAgentProperties, + ...openAiFunctionsAgentProperties, + ...reActAgentAgentProperties, + ...sqlAgentAgentProperties, + ...planAndExecuteAgentProperties, + ], + }; + } + + async execute(this: IExecuteFunctions): Promise { + const agentType = this.getNodeParameter('agent', 0, '') as string; + const nodeVersion = this.getNode().typeVersion; + + if (agentType === 'conversationalAgent') { + return await conversationalAgentExecute.call(this, nodeVersion); + } else if (agentType === 'toolsAgent') { + return await toolsAgentExecute.call(this); + } else if (agentType === 'openAiFunctionsAgent') { + return await openAiFunctionsAgentExecute.call(this, nodeVersion); + } else if (agentType === 'reActAgent') { + return await reActAgentAgentExecute.call(this, nodeVersion); + } else if (agentType === 'sqlAgent') { + return await sqlAgentAgentExecute.call(this); + } else if (agentType === 'planAndExecuteAgent') { + return await planAndExecuteAgentExecute.call(this, nodeVersion); + } + + throw new NodeOperationError(this.getNode(), `The agent type "${agentType}" is not supported`); + } +} diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts new file mode 100644 index 0000000000..f54e46e68d --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts @@ -0,0 +1,169 @@ +import { NodeConnectionTypes } from 'n8n-workflow'; +import type { + INodeInputConfiguration, + INodeInputFilter, + IExecuteFunctions, + INodeExecutionData, + INodeType, + INodeTypeDescription, + NodeConnectionType, + INodeTypeBaseDescription, +} from 'n8n-workflow'; + +import { promptTypeOptions, textFromPreviousNode, textInput } from '@utils/descriptions'; + +import { toolsAgentProperties } from '../agents/ToolsAgent/V2/description'; +import { toolsAgentExecute } from '../agents/ToolsAgent/V2/execute'; + +// Function used in the inputs expression to figure out which inputs to +// display based on the agent type +function getInputs(hasOutputParser?: boolean): Array { + interface SpecialInput { + type: NodeConnectionType; + filter?: INodeInputFilter; + required?: boolean; + } + + const getInputData = ( + inputs: SpecialInput[], + ): Array => { + const displayNames: { [key: string]: string } = { + ai_languageModel: 'Model', + ai_memory: 'Memory', + ai_tool: 'Tool', + ai_outputParser: 'Output Parser', + }; + + return inputs.map(({ type, filter }) => { + const isModelType = type === 'ai_languageModel'; + let displayName = type in displayNames ? displayNames[type] : undefined; + if (isModelType) { + displayName = 'Chat Model'; + } + const input: INodeInputConfiguration = { + type, + displayName, + required: isModelType, + maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes( + type as NodeConnectionType, + ) + ? 1 + : undefined, + }; + + if (filter) { + input.filter = filter; + } + + return input; + }); + }; + + let specialInputs: SpecialInput[] = [ + { + type: 'ai_languageModel', + filter: { + nodes: [ + '@n8n/n8n-nodes-langchain.lmChatAnthropic', + '@n8n/n8n-nodes-langchain.lmChatAzureOpenAi', + '@n8n/n8n-nodes-langchain.lmChatAwsBedrock', + '@n8n/n8n-nodes-langchain.lmChatMistralCloud', + '@n8n/n8n-nodes-langchain.lmChatOllama', + '@n8n/n8n-nodes-langchain.lmChatOpenAi', + '@n8n/n8n-nodes-langchain.lmChatGroq', + '@n8n/n8n-nodes-langchain.lmChatGoogleVertex', + '@n8n/n8n-nodes-langchain.lmChatGoogleGemini', + '@n8n/n8n-nodes-langchain.lmChatDeepSeek', + '@n8n/n8n-nodes-langchain.lmChatOpenRouter', + '@n8n/n8n-nodes-langchain.lmChatXAiGrok', + ], + }, + }, + { + type: 'ai_memory', + }, + { + type: 'ai_tool', + required: true, + }, + { + type: 'ai_outputParser', + }, + ]; + + if (hasOutputParser === false) { + specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser'); + } + return ['main', ...getInputData(specialInputs)]; +} + +export class AgentV2 implements INodeType { + description: INodeTypeDescription; + + constructor(baseDescription: INodeTypeBaseDescription) { + this.description = { + ...baseDescription, + version: 2, + defaults: { + name: 'AI Agent', + color: '#404040', + }, + inputs: `={{ + ((hasOutputParser) => { + ${getInputs.toString()}; + return getInputs(hasOutputParser) + })($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true) + }}`, + outputs: [NodeConnectionTypes.Main], + properties: [ + { + displayName: + 'Tip: Get a feel for agents with our quick tutorial or see an example of how this node works', + name: 'notice_tip', + type: 'notice', + default: '', + }, + promptTypeOptions, + { + ...textFromPreviousNode, + displayOptions: { + show: { + promptType: ['auto'], + }, + }, + }, + { + ...textInput, + displayOptions: { + show: { + promptType: ['define'], + }, + }, + }, + { + displayName: 'Require Specific Output Format', + name: 'hasOutputParser', + type: 'boolean', + default: false, + noDataExpression: true, + }, + { + displayName: `Connect an output parser on the canvas to specify the output format you require`, + name: 'notice', + type: 'notice', + default: '', + displayOptions: { + show: { + hasOutputParser: [true], + }, + }, + }, + ...toolsAgentProperties, + ], + }; + } + + async execute(this: IExecuteFunctions): Promise { + return await toolsAgentExecute.call(this); + } +} diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/description.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/description.ts new file mode 100644 index 0000000000..a7116eefb2 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/description.ts @@ -0,0 +1,19 @@ +import type { INodeProperties } from 'n8n-workflow'; + +import { commonOptions } from '../options'; + +export const toolsAgentProperties: INodeProperties[] = [ + { + displayName: 'Options', + name: 'options', + type: 'collection', + displayOptions: { + show: { + agent: ['toolsAgent'], + }, + }, + default: {}, + placeholder: 'Add Option', + options: [...commonOptions], + }, +]; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts new file mode 100644 index 0000000000..a86f464f65 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts @@ -0,0 +1,138 @@ +import { RunnableSequence } from '@langchain/core/runnables'; +import { AgentExecutor, createToolCallingAgent } from 'langchain/agents'; +import { omit } from 'lodash'; +import { jsonParse, NodeOperationError } from 'n8n-workflow'; +import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; + +import { getPromptInputByType } from '@utils/helpers'; +import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser'; + +import { + fixEmptyContentMessage, + getAgentStepsParser, + getChatModel, + getOptionalMemory, + getTools, + prepareMessages, + preparePrompt, +} from '../common'; +import { SYSTEM_MESSAGE } from '../prompt'; + +/* ----------------------------------------------------------- + Main Executor Function +----------------------------------------------------------- */ +/** + * The main executor method for the Tools Agent. + * + * This function retrieves necessary components (model, memory, tools), prepares the prompt, + * creates the agent, and processes each input item. The error handling for each item is also + * managed here based on the node's continueOnFail setting. + * + * @returns The array of execution data for all processed items + */ +export async function toolsAgentExecute(this: IExecuteFunctions): Promise { + this.logger.debug('Executing Tools Agent'); + + const returnData: INodeExecutionData[] = []; + const items = this.getInputData(); + const outputParser = await getOptionalOutputParser(this); + const tools = await getTools(this, outputParser); + + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + try { + const model = await getChatModel(this); + const memory = await getOptionalMemory(this); + + const input = getPromptInputByType({ + ctx: this, + i: itemIndex, + inputKey: 'text', + promptTypeKey: 'promptType', + }); + if (input === undefined) { + throw new NodeOperationError(this.getNode(), 'The “text” parameter is empty.'); + } + + const options = this.getNodeParameter('options', itemIndex, {}) as { + systemMessage?: string; + maxIterations?: number; + returnIntermediateSteps?: boolean; + passthroughBinaryImages?: boolean; + }; + + // Prepare the prompt messages and prompt template. + const messages = await prepareMessages(this, itemIndex, { + systemMessage: options.systemMessage, + passthroughBinaryImages: options.passthroughBinaryImages ?? true, + outputParser, + }); + const prompt = preparePrompt(messages); + + // Create the base agent that calls tools. + const agent = createToolCallingAgent({ + llm: model, + tools, + prompt, + streamRunnable: false, + }); + agent.streamRunnable = false; + // Wrap the agent with parsers and fixes. + const runnableAgent = RunnableSequence.from([ + agent, + getAgentStepsParser(outputParser, memory), + fixEmptyContentMessage, + ]); + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + memory, + tools, + returnIntermediateSteps: options.returnIntermediateSteps === true, + maxIterations: options.maxIterations ?? 10, + }); + + // Invoke the executor with the given input and system message. + const response = await executor.invoke( + { + input, + system_message: options.systemMessage ?? SYSTEM_MESSAGE, + formatting_instructions: + 'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.', + }, + { signal: this.getExecutionCancelSignal() }, + ); + + // If memory and outputParser are connected, parse the output. + if (memory && outputParser) { + const parsedOutput = jsonParse<{ output: Record }>( + response.output as string, + ); + response.output = parsedOutput?.output ?? parsedOutput; + } + + // Omit internal keys before returning the result. + const itemResult = { + json: omit( + response, + 'system_message', + 'formatting_instructions', + 'input', + 'chat_history', + 'agent_scratchpad', + ), + }; + + returnData.push(itemResult); + } catch (error) { + if (this.continueOnFail()) { + returnData.push({ + json: { error: error.message }, + pairedItem: { item: itemIndex }, + }); + continue; + } + throw error; + } + } + + return [returnData]; +} diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/description.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/description.ts new file mode 100644 index 0000000000..d023ce7467 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/description.ts @@ -0,0 +1,16 @@ +import type { INodeProperties } from 'n8n-workflow'; + +import { getBatchingOptionFields } from '@utils/sharedFields'; + +import { commonOptions } from '../options'; + +export const toolsAgentProperties: INodeProperties[] = [ + { + displayName: 'Options', + name: 'options', + type: 'collection', + default: {}, + placeholder: 'Add Option', + options: [...commonOptions, getBatchingOptionFields(undefined, 1)], + }, +]; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts new file mode 100644 index 0000000000..1fc5aeefa7 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts @@ -0,0 +1,161 @@ +import type { ChatPromptTemplate } from '@langchain/core/prompts'; +import { RunnableSequence } from '@langchain/core/runnables'; +import { AgentExecutor, createToolCallingAgent } from 'langchain/agents'; +import { omit } from 'lodash'; +import { jsonParse, NodeOperationError, sleep } from 'n8n-workflow'; +import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; + +import { getPromptInputByType } from '@utils/helpers'; +import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser'; + +import { + fixEmptyContentMessage, + getAgentStepsParser, + getChatModel, + getOptionalMemory, + getTools, + prepareMessages, + preparePrompt, +} from '../common'; +import { SYSTEM_MESSAGE } from '../prompt'; + +/* ----------------------------------------------------------- + Main Executor Function +----------------------------------------------------------- */ +/** + * The main executor method for the Tools Agent. + * + * This function retrieves necessary components (model, memory, tools), prepares the prompt, + * creates the agent, and processes each input item. The error handling for each item is also + * managed here based on the node's continueOnFail setting. + * + * @returns The array of execution data for all processed items + */ +export async function toolsAgentExecute(this: IExecuteFunctions): Promise { + this.logger.debug('Executing Tools Agent V2'); + + const returnData: INodeExecutionData[] = []; + const items = this.getInputData(); + const outputParser = await getOptionalOutputParser(this); + const tools = await getTools(this, outputParser); + const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 1) as number; + const delayBetweenBatches = this.getNodeParameter( + 'options.batching.delayBetweenBatches', + 0, + 0, + ) as number; + const memory = await getOptionalMemory(this); + const model = await getChatModel(this); + + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (_item, batchItemIndex) => { + const itemIndex = i + batchItemIndex; + + const input = getPromptInputByType({ + ctx: this, + i: itemIndex, + inputKey: 'text', + promptTypeKey: 'promptType', + }); + if (input === undefined) { + throw new NodeOperationError(this.getNode(), 'The “text” parameter is empty.'); + } + + const options = this.getNodeParameter('options', itemIndex, {}) as { + systemMessage?: string; + maxIterations?: number; + returnIntermediateSteps?: boolean; + passthroughBinaryImages?: boolean; + }; + + // Prepare the prompt messages and prompt template. + const messages = await prepareMessages(this, itemIndex, { + systemMessage: options.systemMessage, + passthroughBinaryImages: options.passthroughBinaryImages ?? true, + outputParser, + }); + const prompt: ChatPromptTemplate = preparePrompt(messages); + + // Create the base agent that calls tools. + const agent = createToolCallingAgent({ + llm: model, + tools, + prompt, + streamRunnable: false, + }); + agent.streamRunnable = false; + // Wrap the agent with parsers and fixes. + const runnableAgent = RunnableSequence.from([ + agent, + getAgentStepsParser(outputParser, memory), + fixEmptyContentMessage, + ]); + const executor = AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + memory, + tools, + returnIntermediateSteps: options.returnIntermediateSteps === true, + maxIterations: options.maxIterations ?? 10, + }); + + // Invoke the executor with the given input and system message. + return await executor.invoke( + { + input, + system_message: options.systemMessage ?? SYSTEM_MESSAGE, + formatting_instructions: + 'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.', + }, + { signal: this.getExecutionCancelSignal() }, + ); + }); + + const batchResults = await Promise.allSettled(batchPromises); + + batchResults.forEach((result, index) => { + const itemIndex = i + index; + if (result.status === 'rejected') { + const error = result.reason as Error; + if (this.continueOnFail()) { + returnData.push({ + json: { error: error.message }, + pairedItem: { item: itemIndex }, + }); + return; + } else { + throw new NodeOperationError(this.getNode(), error); + } + } + const response = result.value; + // If memory and outputParser are connected, parse the output. + if (memory && outputParser) { + const parsedOutput = jsonParse<{ output: Record }>( + response.output as string, + ); + response.output = parsedOutput?.output ?? parsedOutput; + } + + // Omit internal keys before returning the result. + const itemResult = { + json: omit( + response, + 'system_message', + 'formatting_instructions', + 'input', + 'chat_history', + 'agent_scratchpad', + ), + pairedItem: { item: itemIndex }, + }; + + returnData.push(itemResult); + }); + + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); + } + } + + return [returnData]; +} diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/common.ts similarity index 74% rename from packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts rename to packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/common.ts index 826059ab78..d438224b74 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/execute.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/common.ts @@ -1,29 +1,18 @@ -import type { BaseChatMemory } from '@langchain/community/memory/chat_memory'; import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { HumanMessage } from '@langchain/core/messages'; import type { BaseMessage } from '@langchain/core/messages'; -import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts'; -import { ChatPromptTemplate } from '@langchain/core/prompts'; -import { RunnableSequence } from '@langchain/core/runnables'; -import type { Tool } from '@langchain/core/tools'; -import { DynamicStructuredTool } from '@langchain/core/tools'; +import { ChatPromptTemplate, type BaseMessagePromptTemplateLike } from '@langchain/core/prompts'; import type { AgentAction, AgentFinish } from 'langchain/agents'; -import { AgentExecutor, createToolCallingAgent } from 'langchain/agents'; import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser'; -import { omit } from 'lodash'; +import type { BaseChatMemory } from 'langchain/memory'; +import { DynamicStructuredTool, type Tool } from 'langchain/tools'; import { BINARY_ENCODING, jsonParse, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; -import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; +import type { IExecuteFunctions } from 'n8n-workflow'; import type { ZodObject } from 'zod'; import { z } from 'zod'; -import { isChatInstance, getPromptInputByType, getConnectedTools } from '@utils/helpers'; -import { - getOptionalOutputParser, - type N8nOutputParser, -} from '@utils/output_parsers/N8nOutputParser'; - -import { SYSTEM_MESSAGE } from './prompt'; - +import { isChatInstance, getConnectedTools } from '@utils/helpers'; +import { type N8nOutputParser } from '@utils/output_parsers/N8nOutputParser'; /* ----------------------------------------------------------- Output Parser Helper ----------------------------------------------------------- */ @@ -387,122 +376,3 @@ export async function prepareMessages( export function preparePrompt(messages: BaseMessagePromptTemplateLike[]): ChatPromptTemplate { return ChatPromptTemplate.fromMessages(messages); } - -/* ----------------------------------------------------------- - Main Executor Function ------------------------------------------------------------ */ -/** - * The main executor method for the Tools Agent. - * - * This function retrieves necessary components (model, memory, tools), prepares the prompt, - * creates the agent, and processes each input item. The error handling for each item is also - * managed here based on the node's continueOnFail setting. - * - * @returns The array of execution data for all processed items - */ -export async function toolsAgentExecute(this: IExecuteFunctions): Promise { - this.logger.debug('Executing Tools Agent'); - - const returnData: INodeExecutionData[] = []; - const items = this.getInputData(); - const outputParser = await getOptionalOutputParser(this); - const tools = await getTools(this, outputParser); - - for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { - try { - const model = await getChatModel(this); - const memory = await getOptionalMemory(this); - - const input = getPromptInputByType({ - ctx: this, - i: itemIndex, - inputKey: 'text', - promptTypeKey: 'promptType', - }); - if (input === undefined) { - throw new NodeOperationError(this.getNode(), 'The “text” parameter is empty.'); - } - - const options = this.getNodeParameter('options', itemIndex, {}) as { - systemMessage?: string; - maxIterations?: number; - returnIntermediateSteps?: boolean; - passthroughBinaryImages?: boolean; - }; - - // Prepare the prompt messages and prompt template. - const messages = await prepareMessages(this, itemIndex, { - systemMessage: options.systemMessage, - passthroughBinaryImages: options.passthroughBinaryImages ?? true, - outputParser, - }); - const prompt = preparePrompt(messages); - - // Create the base agent that calls tools. - const agent = createToolCallingAgent({ - llm: model, - tools, - prompt, - streamRunnable: false, - }); - agent.streamRunnable = false; - // Wrap the agent with parsers and fixes. - const runnableAgent = RunnableSequence.from([ - agent, - getAgentStepsParser(outputParser, memory), - fixEmptyContentMessage, - ]); - const executor = AgentExecutor.fromAgentAndTools({ - agent: runnableAgent, - memory, - tools, - returnIntermediateSteps: options.returnIntermediateSteps === true, - maxIterations: options.maxIterations ?? 10, - }); - - // Invoke the executor with the given input and system message. - const response = await executor.invoke( - { - input, - system_message: options.systemMessage ?? SYSTEM_MESSAGE, - formatting_instructions: - 'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.', - }, - { signal: this.getExecutionCancelSignal() }, - ); - - // If memory and outputParser are connected, parse the output. - if (memory && outputParser) { - const parsedOutput = jsonParse<{ output: Record }>( - response.output as string, - ); - response.output = parsedOutput?.output ?? parsedOutput; - } - - // Omit internal keys before returning the result. - const itemResult = { - json: omit( - response, - 'system_message', - 'formatting_instructions', - 'input', - 'chat_history', - 'agent_scratchpad', - ), - }; - - returnData.push(itemResult); - } catch (error) { - if (this.continueOnFail()) { - returnData.push({ - json: { error: error.message }, - pairedItem: { item: itemIndex }, - }); - continue; - } - throw error; - } - } - - return [returnData]; -} diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts deleted file mode 100644 index 06b64a91de..0000000000 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/description.ts +++ /dev/null @@ -1,52 +0,0 @@ -import type { INodeProperties } from 'n8n-workflow'; - -import { SYSTEM_MESSAGE } from './prompt'; - -export const toolsAgentProperties: INodeProperties[] = [ - { - displayName: 'Options', - name: 'options', - type: 'collection', - displayOptions: { - show: { - agent: ['toolsAgent'], - }, - }, - default: {}, - placeholder: 'Add Option', - options: [ - { - displayName: 'System Message', - name: 'systemMessage', - type: 'string', - default: SYSTEM_MESSAGE, - description: 'The message that will be sent to the agent before the conversation starts', - typeOptions: { - rows: 6, - }, - }, - { - displayName: 'Max Iterations', - name: 'maxIterations', - type: 'number', - default: 10, - description: 'The maximum number of iterations the agent will run before stopping', - }, - { - displayName: 'Return Intermediate Steps', - name: 'returnIntermediateSteps', - type: 'boolean', - default: false, - description: 'Whether or not the output should include intermediate steps the agent took', - }, - { - displayName: 'Automatically Passthrough Binary Images', - name: 'passthroughBinaryImages', - type: 'boolean', - default: true, - description: - 'Whether or not binary images should be automatically passed through to the agent as image type messages', - }, - ], - }, -]; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/options.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/options.ts new file mode 100644 index 0000000000..520678b11e --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/options.ts @@ -0,0 +1,38 @@ +import type { INodeProperties } from 'n8n-workflow'; + +import { SYSTEM_MESSAGE } from './prompt'; + +export const commonOptions: INodeProperties[] = [ + { + displayName: 'System Message', + name: 'systemMessage', + type: 'string', + default: SYSTEM_MESSAGE, + description: 'The message that will be sent to the agent before the conversation starts', + typeOptions: { + rows: 6, + }, + }, + { + displayName: 'Max Iterations', + name: 'maxIterations', + type: 'number', + default: 10, + description: 'The maximum number of iterations the agent will run before stopping', + }, + { + displayName: 'Return Intermediate Steps', + name: 'returnIntermediateSteps', + type: 'boolean', + default: false, + description: 'Whether or not the output should include intermediate steps the agent took', + }, + { + displayName: 'Automatically Passthrough Binary Images', + name: 'passthroughBinaryImages', + type: 'boolean', + default: true, + description: + 'Whether or not binary images should be automatically passed through to the agent as image type messages', + }, +]; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV1.test.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV1.test.ts new file mode 100644 index 0000000000..93783de568 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV1.test.ts @@ -0,0 +1,159 @@ +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { mock } from 'jest-mock-extended'; +import { AgentExecutor } from 'langchain/agents'; +import type { Tool } from 'langchain/tools'; +import type { IExecuteFunctions, INode } from 'n8n-workflow'; + +import * as helpers from '../../../../../utils/helpers'; +import { toolsAgentExecute } from '../../agents/ToolsAgent/V1/execute'; + +const mockHelpers = mock(); +const mockContext = mock({ helpers: mockHelpers }); + +beforeEach(() => jest.resetAllMocks()); + +describe('toolsAgentExecute', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockContext.logger = { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }; + }); + + it('should process items', async () => { + const mockNode = mock(); + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + // Mock getNodeParameter to return default values + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'text') return 'test input'; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 1' }) }) + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 2' }) }), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + const result = await toolsAgentExecute.call(mockContext); + + expect(mockExecutor.invoke).toHaveBeenCalledTimes(2); + expect(result[0]).toHaveLength(2); + expect(result[0][0].json).toEqual({ output: { text: 'success 1' } }); + expect(result[0][1].json).toEqual({ output: { text: 'success 2' } }); + }); + + it('should handle errors when continueOnFail is true', async () => { + const mockNode = mock(); + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'text') return 'test input'; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + mockContext.continueOnFail.mockReturnValue(true); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: '{ "text": "success" }' }) + .mockRejectedValueOnce(new Error('Test error')), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + const result = await toolsAgentExecute.call(mockContext); + + expect(result[0]).toHaveLength(2); + expect(result[0][0].json).toEqual({ output: { text: 'success' } }); + expect(result[0][1].json).toEqual({ error: 'Test error' }); + }); + + it('should throw error in when continueOnFail is false', async () => { + const mockNode = mock(); + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'text') return 'test input'; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + mockContext.continueOnFail.mockReturnValue(false); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success' }) }) + .mockRejectedValueOnce(new Error('Test error')), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + await expect(toolsAgentExecute.call(mockContext)).rejects.toThrow('Test error'); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts new file mode 100644 index 0000000000..de8c0521e7 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts @@ -0,0 +1,223 @@ +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { mock } from 'jest-mock-extended'; +import { AgentExecutor } from 'langchain/agents'; +import type { Tool } from 'langchain/tools'; +import type { IExecuteFunctions, INode } from 'n8n-workflow'; + +import * as helpers from '../../../../../utils/helpers'; +import { toolsAgentExecute } from '../../agents/ToolsAgent/V2/execute'; + +const mockHelpers = mock(); +const mockContext = mock({ helpers: mockHelpers }); + +beforeEach(() => jest.resetAllMocks()); + +describe('toolsAgentExecute', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockContext.logger = { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }; + }); + + it('should process items sequentially when batchSize is not set', async () => { + const mockNode = mock(); + mockNode.typeVersion = 2; + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + // Mock getNodeParameter to return default values + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'text') return 'test input'; + if (param === 'options.batching.batchSize') return defaultValue; + if (param === 'options.batching.delayBetweenBatches') return defaultValue; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 1' }) }) + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 2' }) }), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + const result = await toolsAgentExecute.call(mockContext); + + expect(mockExecutor.invoke).toHaveBeenCalledTimes(2); + expect(result[0]).toHaveLength(2); + expect(result[0][0].json).toEqual({ output: { text: 'success 1' } }); + expect(result[0][1].json).toEqual({ output: { text: 'success 2' } }); + }); + + it('should process items in parallel within batches when batchSize > 1', async () => { + const mockNode = mock(); + mockNode.typeVersion = 2; + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + { json: { text: 'test input 3' } }, + { json: { text: 'test input 4' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'options.batching.batchSize') return 2; + if (param === 'options.batching.delayBetweenBatches') return 100; + if (param === 'text') return 'test input'; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 1' }) }) + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 2' }) }) + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 3' }) }) + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success 4' }) }), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + const result = await toolsAgentExecute.call(mockContext); + + expect(mockExecutor.invoke).toHaveBeenCalledTimes(4); // Each item is processed individually + expect(result[0]).toHaveLength(4); + + expect(result[0][0].json).toEqual({ output: { text: 'success 1' } }); + expect(result[0][1].json).toEqual({ output: { text: 'success 2' } }); + expect(result[0][2].json).toEqual({ output: { text: 'success 3' } }); + expect(result[0][3].json).toEqual({ output: { text: 'success 4' } }); + }); + + it('should handle errors in batch processing when continueOnFail is true', async () => { + const mockNode = mock(); + mockNode.typeVersion = 2; + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'options.batching.batchSize') return 2; + if (param === 'options.batching.delayBetweenBatches') return 0; + if (param === 'text') return 'test input'; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + mockContext.continueOnFail.mockReturnValue(true); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: '{ "text": "success" }' }) + .mockRejectedValueOnce(new Error('Test error')), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + const result = await toolsAgentExecute.call(mockContext); + + expect(result[0]).toHaveLength(2); + expect(result[0][0].json).toEqual({ output: { text: 'success' } }); + expect(result[0][1].json).toEqual({ error: 'Test error' }); + }); + + it('should throw error in batch processing when continueOnFail is false', async () => { + const mockNode = mock(); + mockNode.typeVersion = 2; + mockContext.getNode.mockReturnValue(mockNode); + mockContext.getInputData.mockReturnValue([ + { json: { text: 'test input 1' } }, + { json: { text: 'test input 2' } }, + ]); + + const mockModel = mock(); + mockModel.bindTools = jest.fn(); + mockModel.lc_namespace = ['chat_models']; + mockContext.getInputConnectionData.mockResolvedValue(mockModel); + + const mockTools = [mock()]; + jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue(mockTools); + + mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { + if (param === 'options.batching.batchSize') return 2; + if (param === 'options.batching.delayBetweenBatches') return 0; + if (param === 'text') return 'test input'; + if (param === 'options') + return { + systemMessage: 'You are a helpful assistant', + maxIterations: 10, + returnIntermediateSteps: false, + passthroughBinaryImages: true, + }; + return defaultValue; + }); + + mockContext.continueOnFail.mockReturnValue(false); + + const mockExecutor = { + invoke: jest + .fn() + .mockResolvedValueOnce({ output: JSON.stringify({ text: 'success' }) }) + .mockRejectedValueOnce(new Error('Test error')), + }; + + jest.spyOn(AgentExecutor, 'fromAgentAndTools').mockReturnValue(mockExecutor as any); + + await expect(toolsAgentExecute.call(mockContext)).rejects.toThrow('Test error'); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent.test.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts similarity index 99% rename from packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent.test.ts rename to packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts index 338e605b22..d9768dcb43 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts @@ -23,7 +23,7 @@ import { prepareMessages, preparePrompt, getTools, -} from '../agents/ToolsAgent/execute'; +} from '../../agents/ToolsAgent/common'; function getFakeOutputParser(returnSchema?: ZodType): N8nOutputParser { const fakeOutputParser = mock(); diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts index ce982e896d..66d102e405 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts @@ -1,23 +1,16 @@ -import type { BaseLanguageModel } from '@langchain/core/language_models/base'; import type { IExecuteFunctions, INodeExecutionData, INodeType, INodeTypeDescription, } from 'n8n-workflow'; -import { NodeApiError, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; +import { NodeApiError, NodeConnectionTypes, NodeOperationError, sleep } from 'n8n-workflow'; -import { getPromptInputByType } from '@utils/helpers'; import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser'; // Import from centralized module -import { - executeChain, - formatResponse, - getInputs, - nodeProperties, - type MessageTemplate, -} from './methods'; +import { formatResponse, getInputs, nodeProperties } from './methods'; +import { processItem } from './methods/processItem'; import { getCustomErrorMessage as getCustomOpenAiErrorMessage, isOpenAiError, @@ -34,7 +27,7 @@ export class ChainLlm implements INodeType { icon: 'fa:link', iconColor: 'black', group: ['transform'], - version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6], + version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7], description: 'A simple chain to prompt a large language model', defaults: { name: 'Basic LLM Chain', @@ -67,83 +60,97 @@ export class ChainLlm implements INodeType { this.logger.debug('Executing Basic LLM Chain'); const items = this.getInputData(); const returnData: INodeExecutionData[] = []; + const outputParser = await getOptionalOutputParser(this); + // If the node version is 1.6(and LLM is using `response_format: json_object`) or higher or an output parser is configured, + // we unwrap the response and return the object directly as JSON + const shouldUnwrapObjects = this.getNode().typeVersion >= 1.6 || !!outputParser; - // Process each input item - for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { - try { - // Get the language model - const llm = (await this.getInputConnectionData( - NodeConnectionTypes.AiLanguageModel, - 0, - )) as BaseLanguageModel; + const batchSize = this.getNodeParameter('batching.batchSize', 0, 5) as number; + const delayBetweenBatches = this.getNodeParameter( + 'batching.delayBetweenBatches', + 0, + 0, + ) as number; - // Get output parser if configured - const outputParser = await getOptionalOutputParser(this); - - // Get user prompt based on node version - let prompt: string; - - if (this.getNode().typeVersion <= 1.3) { - prompt = this.getNodeParameter('prompt', itemIndex) as string; - } else { - prompt = getPromptInputByType({ - ctx: this, - i: itemIndex, - inputKey: 'text', - promptTypeKey: 'promptType', - }); - } - - // Validate prompt - if (prompt === undefined) { - throw new NodeOperationError(this.getNode(), "The 'prompt' parameter is empty."); - } - - // Get chat messages if configured - const messages = this.getNodeParameter( - 'messages.messageValues', - itemIndex, - [], - ) as MessageTemplate[]; - - // Execute the chain - const responses = await executeChain({ - context: this, - itemIndex, - query: prompt, - llm, - outputParser, - messages, + if (this.getNode().typeVersion >= 1.7 && batchSize > 1) { + // Process items in batches + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (_item, batchItemIndex) => { + return await processItem(this, i + batchItemIndex); }); - // If the node version is 1.6(and LLM is using `response_format: json_object`) or higher or an output parser is configured, - // we unwrap the response and return the object directly as JSON - const shouldUnwrapObjects = this.getNode().typeVersion >= 1.6 || !!outputParser; - // Process each response and add to return data - responses.forEach((response) => { - returnData.push({ - json: formatResponse(response, shouldUnwrapObjects), + const batchResults = await Promise.allSettled(batchPromises); + + batchResults.forEach((promiseResult, batchItemIndex) => { + const itemIndex = i + batchItemIndex; + if (promiseResult.status === 'rejected') { + const error = promiseResult.reason as Error; + // Handle OpenAI specific rate limit errors + if (error instanceof NodeApiError && isOpenAiError(error.cause)) { + const openAiErrorCode: string | undefined = (error.cause as any).error?.code; + if (openAiErrorCode) { + const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode); + if (customMessage) { + error.message = customMessage; + } + } + } + + if (this.continueOnFail()) { + returnData.push({ + json: { error: error.message }, + pairedItem: { item: itemIndex }, + }); + return; + } + throw new NodeOperationError(this.getNode(), error); + } + + const responses = promiseResult.value; + responses.forEach((response: unknown) => { + returnData.push({ + json: formatResponse(response, shouldUnwrapObjects), + }); }); }); - } catch (error) { - // Handle OpenAI specific rate limit errors - if (error instanceof NodeApiError && isOpenAiError(error.cause)) { - const openAiErrorCode: string | undefined = (error.cause as any).error?.code; - if (openAiErrorCode) { - const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode); - if (customMessage) { - error.message = customMessage; + + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); + } + } + } else { + // Process each input item + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + try { + const responses = await processItem(this, itemIndex); + + // Process each response and add to return data + responses.forEach((response) => { + returnData.push({ + json: formatResponse(response, shouldUnwrapObjects), + }); + }); + } catch (error) { + // Handle OpenAI specific rate limit errors + if (error instanceof NodeApiError && isOpenAiError(error.cause)) { + const openAiErrorCode: string | undefined = (error.cause as any).error?.code; + if (openAiErrorCode) { + const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode); + if (customMessage) { + error.message = customMessage; + } } } - } - // Continue on failure if configured - if (this.continueOnFail()) { - returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); - continue; - } + // Continue on failure if configured + if (this.continueOnFail()) { + returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); + continue; + } - throw error; + throw error; + } } } diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts index 098bc8a8e1..cf9d08495a 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts @@ -7,7 +7,7 @@ import type { IDataObject, INodeInputConfiguration, INodeProperties } from 'n8n- import { NodeConnectionTypes } from 'n8n-workflow'; import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions'; -import { getTemplateNoticeField } from '@utils/sharedFields'; +import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields'; /** * Dynamic input configuration generation based on node parameters @@ -259,6 +259,11 @@ export const nodeProperties: INodeProperties[] = [ }, ], }, + getBatchingOptionFields({ + show: { + '@version': [{ _cnd: { gte: 1.7 } }], + }, + }), { displayName: `Connect an output parser on the canvas to specify the output format you require`, name: 'notice', diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts new file mode 100644 index 0000000000..bde406edb5 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts @@ -0,0 +1,54 @@ +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; + +import { getPromptInputByType } from '@utils/helpers'; +import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser'; + +import { executeChain } from './chainExecutor'; +import { type MessageTemplate } from './types'; + +export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => { + const llm = (await ctx.getInputConnectionData( + NodeConnectionTypes.AiLanguageModel, + 0, + )) as BaseLanguageModel; + + // Get output parser if configured + const outputParser = await getOptionalOutputParser(ctx); + + // Get user prompt based on node version + let prompt: string; + + if (ctx.getNode().typeVersion <= 1.3) { + prompt = ctx.getNodeParameter('prompt', itemIndex) as string; + } else { + prompt = getPromptInputByType({ + ctx, + i: itemIndex, + inputKey: 'text', + promptTypeKey: 'promptType', + }); + } + + // Validate prompt + if (prompt === undefined) { + throw new NodeOperationError(ctx.getNode(), "The 'prompt' parameter is empty."); + } + + // Get chat messages if configured + const messages = ctx.getNodeParameter( + 'messages.messageValues', + itemIndex, + [], + ) as MessageTemplate[]; + + // Execute the chain + return await executeChain({ + context: ctx, + itemIndex, + query: prompt, + llm, + outputParser, + messages, + }); +}; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts index d7916a0275..9069e47611 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts @@ -3,7 +3,7 @@ import { FakeChatModel } from '@langchain/core/utils/testing'; import { mock } from 'jest-mock-extended'; import type { IExecuteFunctions, INode } from 'n8n-workflow'; -import { NodeConnectionTypes } from 'n8n-workflow'; +import { NodeApiError, NodeConnectionTypes } from 'n8n-workflow'; import * as helperModule from '@utils/helpers'; import * as outputParserModule from '@utils/output_parsers/N8nOutputParser'; @@ -191,6 +191,148 @@ describe('ChainLlm Node', () => { expect(result[0]).toHaveLength(2); }); + describe('batching (version 1.7+)', () => { + beforeEach(() => { + mockExecuteFunction.getNode.mockReturnValue({ + name: 'Chain LLM', + typeVersion: 1.7, + parameters: {}, + } as INode); + }); + + it('should process items in batches with default settings', async () => { + mockExecuteFunction.getInputData.mockReturnValue([ + { json: { item: 1 } }, + { json: { item: 2 } }, + { json: { item: 3 } }, + ]); + + mockExecuteFunction.getNodeParameter.mockImplementation( + (param, _itemIndex, defaultValue) => { + if (param === 'messages.messageValues') return []; + return defaultValue; + }, + ); + + (helperModule.getPromptInputByType as jest.Mock) + .mockReturnValueOnce('Test prompt 1') + .mockReturnValueOnce('Test prompt 2') + .mockReturnValueOnce('Test prompt 3'); + + (executeChainModule.executeChain as jest.Mock) + .mockResolvedValueOnce(['Response 1']) + .mockResolvedValueOnce(['Response 2']) + .mockResolvedValueOnce(['Response 3']); + + const result = await node.execute.call(mockExecuteFunction); + + expect(executeChainModule.executeChain).toHaveBeenCalledTimes(3); + expect(result[0]).toHaveLength(3); + }); + + it('should process items in smaller batches', async () => { + mockExecuteFunction.getInputData.mockReturnValue([ + { json: { item: 1 } }, + { json: { item: 2 } }, + { json: { item: 3 } }, + { json: { item: 4 } }, + ]); + + mockExecuteFunction.getNodeParameter.mockImplementation( + (param, _itemIndex, defaultValue) => { + if (param === 'batching.batchSize') return 2; + if (param === 'batching.delayBetweenBatches') return 0; + if (param === 'messages.messageValues') return []; + return defaultValue; + }, + ); + + (helperModule.getPromptInputByType as jest.Mock) + .mockReturnValueOnce('Test prompt 1') + .mockReturnValueOnce('Test prompt 2') + .mockReturnValueOnce('Test prompt 3') + .mockReturnValueOnce('Test prompt 4'); + + (executeChainModule.executeChain as jest.Mock) + .mockResolvedValueOnce(['Response 1']) + .mockResolvedValueOnce(['Response 2']) + .mockResolvedValueOnce(['Response 3']) + .mockResolvedValueOnce(['Response 4']); + + const result = await node.execute.call(mockExecuteFunction); + + expect(executeChainModule.executeChain).toHaveBeenCalledTimes(4); + expect(result[0]).toHaveLength(4); + }); + + it('should handle errors in batches with continueOnFail', async () => { + mockExecuteFunction.getInputData.mockReturnValue([ + { json: { item: 1 } }, + { json: { item: 2 } }, + ]); + + mockExecuteFunction.getNodeParameter.mockImplementation( + (param, _itemIndex, defaultValue) => { + if (param === 'batching.batchSize') return 2; + if (param === 'batching.delayBetweenBatches') return 0; + if (param === 'messages.messageValues') return []; + return defaultValue; + }, + ); + + mockExecuteFunction.continueOnFail.mockReturnValue(true); + + (helperModule.getPromptInputByType as jest.Mock) + .mockReturnValueOnce('Test prompt 1') + .mockReturnValueOnce('Test prompt 2'); + + (executeChainModule.executeChain as jest.Mock) + .mockResolvedValueOnce(['Response 1']) + .mockRejectedValueOnce(new Error('Test error')); + + const result = await node.execute.call(mockExecuteFunction); + + expect(result[0]).toHaveLength(2); + expect(result[0][1].json).toEqual({ error: 'Test error' }); + }); + + it('should handle OpenAI rate limit errors in batches', async () => { + mockExecuteFunction.getInputData.mockReturnValue([ + { json: { item: 1 } }, + { json: { item: 2 } }, + ]); + + mockExecuteFunction.getNodeParameter.mockImplementation( + (param, _itemIndex, defaultValue) => { + if (param === 'batching.batchSize') return 2; + if (param === 'batching.delayBetweenBatches') return 0; + if (param === 'messages.messageValues') return []; + return defaultValue; + }, + ); + + mockExecuteFunction.continueOnFail.mockReturnValue(true); + + (helperModule.getPromptInputByType as jest.Mock) + .mockReturnValueOnce('Test prompt 1') + .mockReturnValueOnce('Test prompt 2'); + + const openAiError = new NodeApiError(mockExecuteFunction.getNode(), { + message: 'Rate limit exceeded', + cause: { error: { code: 'rate_limit_exceeded' } }, + }); + + (executeChainModule.executeChain as jest.Mock) + .mockResolvedValueOnce(['Response 1']) + .mockRejectedValueOnce(openAiError); + + const result = await node.execute.call(mockExecuteFunction); + + expect(result[0]).toHaveLength(2); + expect(result[0][1].json).toEqual({ error: expect.stringContaining('Rate limit') }); + }); + }); + it('should unwrap object responses when node version is 1.6 or higher', async () => { mockExecuteFunction.getNode.mockReturnValue({ name: 'Chain LLM', diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts index 4640603f69..ba135b5796 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/ChainRetrievalQa.node.ts @@ -1,16 +1,5 @@ -import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { NodeConnectionTypes, parseErrorMetadata, sleep } from 'n8n-workflow'; import { - ChatPromptTemplate, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, - PromptTemplate, -} from '@langchain/core/prompts'; -import type { BaseRetriever } from '@langchain/core/retrievers'; -import { createStuffDocumentsChain } from 'langchain/chains/combine_documents'; -import { createRetrievalChain } from 'langchain/chains/retrieval'; -import { NodeConnectionTypes, NodeOperationError, parseErrorMetadata } from 'n8n-workflow'; -import { - type INodeProperties, type IExecuteFunctions, type INodeExecutionData, type INodeType, @@ -18,28 +7,10 @@ import { } from 'n8n-workflow'; import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions'; -import { getPromptInputByType, isChatInstance } from '@utils/helpers'; -import { getTemplateNoticeField } from '@utils/sharedFields'; -import { getTracingConfig } from '@utils/tracing'; +import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields'; -const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. -If you don't know the answer, just say that you don't know, don't try to make up an answer. ----------------- -Context: {context}`; - -// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template. -const LEGACY_INPUT_TEMPLATE_KEY = 'question'; -const INPUT_TEMPLATE_KEY = 'input'; - -const systemPromptOption: INodeProperties = { - displayName: 'System Prompt Template', - name: 'systemPromptTemplate', - type: 'string', - default: SYSTEM_PROMPT_TEMPLATE, - typeOptions: { - rows: 6, - }, -}; +import { INPUT_TEMPLATE_KEY, LEGACY_INPUT_TEMPLATE_KEY, systemPromptOption } from './constants'; +import { processItem } from './processItem'; export class ChainRetrievalQa implements INodeType { description: INodeTypeDescription = { @@ -48,7 +19,7 @@ export class ChainRetrievalQa implements INodeType { icon: 'fa:link', iconColor: 'black', group: ['transform'], - version: [1, 1.1, 1.2, 1.3, 1.4, 1.5], + version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6], description: 'Answer questions about retrieved documents', defaults: { name: 'Question and Answer Chain', @@ -177,6 +148,11 @@ export class ChainRetrievalQa implements INodeType { }, }, }, + getBatchingOptionFields({ + show: { + '@version': [{ _cnd: { gte: 1.6 } }], + }, + }), ], }, ], @@ -187,109 +163,78 @@ export class ChainRetrievalQa implements INodeType { const items = this.getInputData(); const returnData: INodeExecutionData[] = []; + const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number; + const delayBetweenBatches = this.getNodeParameter( + 'options.batching.delayBetweenBatches', + 0, + 0, + ) as number; - // Run for each item - for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { - try { - const model = (await this.getInputConnectionData( - NodeConnectionTypes.AiLanguageModel, - 0, - )) as BaseLanguageModel; - - const retriever = (await this.getInputConnectionData( - NodeConnectionTypes.AiRetriever, - 0, - )) as BaseRetriever; - - let query; - - if (this.getNode().typeVersion <= 1.2) { - query = this.getNodeParameter('query', itemIndex) as string; - } else { - query = getPromptInputByType({ - ctx: this, - i: itemIndex, - inputKey: 'text', - promptTypeKey: 'promptType', - }); - } - - if (query === undefined) { - throw new NodeOperationError(this.getNode(), 'The ‘query‘ parameter is empty.'); - } - - const options = this.getNodeParameter('options', itemIndex, {}) as { - systemPromptTemplate?: string; - }; - - let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE; - - // Replace legacy input template key for versions 1.4 and below - if (this.getNode().typeVersion < 1.5) { - templateText = templateText.replace( - `{${LEGACY_INPUT_TEMPLATE_KEY}}`, - `{${INPUT_TEMPLATE_KEY}}`, - ); - } - - // Create prompt template based on model type and user configuration - let promptTemplate; - if (isChatInstance(model)) { - // For chat models, create a chat prompt template with system and human messages - const messages = [ - SystemMessagePromptTemplate.fromTemplate(templateText), - HumanMessagePromptTemplate.fromTemplate('{input}'), - ]; - promptTemplate = ChatPromptTemplate.fromMessages(messages); - } else { - // For non-chat models, create a text prompt template with Question/Answer format - const questionSuffix = - options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : ''; - - promptTemplate = new PromptTemplate({ - template: templateText + questionSuffix, - inputVariables: ['context', 'input'], - }); - } - - // Create the document chain that combines the retrieved documents - const combineDocsChain = await createStuffDocumentsChain({ - llm: model, - prompt: promptTemplate, + if (this.getNode().typeVersion >= 1.6 && batchSize >= 1) { + // Run in batches + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (_item, batchItemIndex) => { + return await processItem(this, i + batchItemIndex); }); - // Create the retrieval chain that handles the retrieval and then passes to the combine docs chain - const retrievalChain = await createRetrievalChain({ - combineDocsChain, - retriever, + const batchResults = await Promise.allSettled(batchPromises); + + batchResults.forEach((response, index) => { + if (response.status === 'rejected') { + const error = response.reason; + if (this.continueOnFail()) { + const metadata = parseErrorMetadata(error); + returnData.push({ + json: { error: error.message }, + pairedItem: { item: index }, + metadata, + }); + return; + } else { + throw error; + } + } + const output = response.value; + const answer = output.answer as string; + if (this.getNode().typeVersion >= 1.5) { + returnData.push({ json: { response: answer } }); + } else { + // Legacy format for versions 1.4 and below is { text: string } + returnData.push({ json: { response: { text: answer } } }); + } }); - // Execute the chain with tracing config - const tracingConfig = getTracingConfig(this); - const response = await retrievalChain - .withConfig(tracingConfig) - .invoke({ input: query }, { signal: this.getExecutionCancelSignal() }); - - // Get the answer from the response - const answer: string = response.answer; - if (this.getNode().typeVersion >= 1.5) { - returnData.push({ json: { response: answer } }); - } else { - // Legacy format for versions 1.4 and below is { text: string } - returnData.push({ json: { response: { text: answer } } }); - } - } catch (error) { - if (this.continueOnFail()) { - const metadata = parseErrorMetadata(error); - returnData.push({ - json: { error: error.message }, - pairedItem: { item: itemIndex }, - metadata, - }); - continue; + // Add delay between batches if not the last batch + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); } + } + } else { + // Run for each item + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + try { + const response = await processItem(this, itemIndex); + const answer = response.answer as string; + if (this.getNode().typeVersion >= 1.5) { + returnData.push({ json: { response: answer } }); + } else { + // Legacy format for versions 1.4 and below is { text: string } + returnData.push({ json: { response: { text: answer } } }); + } + } catch (error) { + if (this.continueOnFail()) { + const metadata = parseErrorMetadata(error); + returnData.push({ + json: { error: error.message }, + pairedItem: { item: itemIndex }, + metadata, + }); + continue; + } - throw error; + throw error; + } } } return [returnData]; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/constants.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/constants.ts new file mode 100644 index 0000000000..a45797e046 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/constants.ts @@ -0,0 +1,20 @@ +import type { INodeProperties } from 'n8n-workflow'; + +export const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. +If you don't know the answer, just say that you don't know, don't try to make up an answer. +---------------- +Context: {context}`; + +// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template. +export const LEGACY_INPUT_TEMPLATE_KEY = 'question'; +export const INPUT_TEMPLATE_KEY = 'input'; + +export const systemPromptOption: INodeProperties = { + displayName: 'System Prompt Template', + name: 'systemPromptTemplate', + type: 'string', + default: SYSTEM_PROMPT_TEMPLATE, + typeOptions: { + rows: 6, + }, +}; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/processItem.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/processItem.ts new file mode 100644 index 0000000000..f3ec34eef4 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/processItem.ts @@ -0,0 +1,100 @@ +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, + SystemMessagePromptTemplate, +} from '@langchain/core/prompts'; +import type { BaseRetriever } from '@langchain/core/retrievers'; +import { createStuffDocumentsChain } from 'langchain/chains/combine_documents'; +import { createRetrievalChain } from 'langchain/chains/retrieval'; +import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; + +import { getPromptInputByType, isChatInstance } from '@utils/helpers'; +import { getTracingConfig } from '@utils/tracing'; + +import { INPUT_TEMPLATE_KEY, LEGACY_INPUT_TEMPLATE_KEY, SYSTEM_PROMPT_TEMPLATE } from './constants'; + +export const processItem = async ( + ctx: IExecuteFunctions, + itemIndex: number, +): Promise> => { + const model = (await ctx.getInputConnectionData( + NodeConnectionTypes.AiLanguageModel, + 0, + )) as BaseLanguageModel; + + const retriever = (await ctx.getInputConnectionData( + NodeConnectionTypes.AiRetriever, + 0, + )) as BaseRetriever; + + let query; + + if (ctx.getNode().typeVersion <= 1.2) { + query = ctx.getNodeParameter('query', itemIndex) as string; + } else { + query = getPromptInputByType({ + ctx, + i: itemIndex, + inputKey: 'text', + promptTypeKey: 'promptType', + }); + } + + if (query === undefined) { + throw new NodeOperationError(ctx.getNode(), 'The ‘query‘ parameter is empty.'); + } + + const options = ctx.getNodeParameter('options', itemIndex, {}) as { + systemPromptTemplate?: string; + }; + + let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE; + + // Replace legacy input template key for versions 1.4 and below + if (ctx.getNode().typeVersion < 1.5) { + templateText = templateText.replace( + `{${LEGACY_INPUT_TEMPLATE_KEY}}`, + `{${INPUT_TEMPLATE_KEY}}`, + ); + } + + // Create prompt template based on model type and user configuration + let promptTemplate; + if (isChatInstance(model)) { + // For chat models, create a chat prompt template with system and human messages + const messages = [ + SystemMessagePromptTemplate.fromTemplate(templateText), + HumanMessagePromptTemplate.fromTemplate('{input}'), + ]; + promptTemplate = ChatPromptTemplate.fromMessages(messages); + } else { + // For non-chat models, create a text prompt template with Question/Answer format + const questionSuffix = + options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : ''; + + promptTemplate = new PromptTemplate({ + template: templateText + questionSuffix, + inputVariables: ['context', 'input'], + }); + } + + // Create the document chain that combines the retrieved documents + const combineDocsChain = await createStuffDocumentsChain({ + llm: model, + prompt: promptTemplate, + }); + + // Create the retrieval chain that handles the retrieval and then passes to the combine docs chain + const retrievalChain = await createRetrievalChain({ + combineDocsChain, + retriever, + }); + + // Execute the chain with tracing config + const tracingConfig = getTracingConfig(ctx); + return await retrievalChain + .withConfig(tracingConfig) + .invoke({ input: query }, { signal: ctx.getExecutionCancelSignal() }); +}; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/test/ChainRetrievalQa.node.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/test/ChainRetrievalQa.node.test.ts index 36e118464b..0a43f68725 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/test/ChainRetrievalQa.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainRetrievalQA/test/ChainRetrievalQa.node.test.ts @@ -71,7 +71,7 @@ describe('ChainRetrievalQa', () => { node = new ChainRetrievalQa(); }); - it.each([1.3, 1.4, 1.5])( + it.each([1.3, 1.4, 1.5, 1.6])( 'should process a query using a chat model (version %s)', async (version) => { // Mock a chat model that returns a predefined answer @@ -103,7 +103,7 @@ describe('ChainRetrievalQa', () => { }, ); - it.each([1.3, 1.4, 1.5])( + it.each([1.3, 1.4, 1.5, 1.6])( 'should process a query using a text completion model (version %s)', async (version) => { // Mock a text completion model that returns a predefined answer @@ -143,7 +143,7 @@ describe('ChainRetrievalQa', () => { }, ); - it.each([1.3, 1.4, 1.5])( + it.each([1.3, 1.4, 1.5, 1.6])( 'should use a custom system prompt if provided (version %s)', async (version) => { const customSystemPrompt = `You are a geography expert. Use the following context to answer the question. @@ -177,7 +177,7 @@ describe('ChainRetrievalQa', () => { }, ); - it.each([1.3, 1.4, 1.5])( + it.each([1.3, 1.4, 1.5, 1.6])( 'should throw an error if the query is undefined (version %s)', async (version) => { const mockChatModel = new FakeChatModel({}); @@ -196,7 +196,7 @@ describe('ChainRetrievalQa', () => { }, ); - it.each([1.3, 1.4, 1.5])( + it.each([1.3, 1.4, 1.5, 1.6])( 'should add error to json if continueOnFail is true (version %s)', async (version) => { // Create a model that will throw an error @@ -226,4 +226,118 @@ describe('ChainRetrievalQa', () => { expect(result[0][0].json.error).toContain('Model error'); }, ); + + it('should process items in batches', async () => { + const mockChatModel = new FakeLLM({ response: 'Paris is the capital of France.' }); + const items = [ + { json: { input: 'What is the capital of France?' } }, + { json: { input: 'What is the capital of France?' } }, + { json: { input: 'What is the capital of France?' } }, + ]; + + const execMock = createExecuteFunctionsMock( + { + promptType: 'define', + text: '={{ $json.input }}', + options: { + batching: { + batchSize: 2, + delayBetweenBatches: 0, + }, + }, + }, + mockChatModel, + fakeRetriever, + 1.6, + ); + + execMock.getInputData = () => items; + + const result = await node.execute.call(execMock); + + expect(result).toHaveLength(1); + expect(result[0]).toHaveLength(3); + result[0].forEach((item) => { + expect(item.json.response).toBeDefined(); + }); + + expect(result[0][0].json.response).toContain('Paris is the capital of France.'); + expect(result[0][1].json.response).toContain('Paris is the capital of France.'); + expect(result[0][2].json.response).toContain('Paris is the capital of France.'); + }); + + it('should handle errors in batches with continueOnFail', async () => { + class ErrorLLM extends FakeLLM { + async _call(): Promise { + throw new UnexpectedError('Model error'); + } + } + + const errorModel = new ErrorLLM({}); + const items = [ + { json: { input: 'What is the capital of France?' } }, + { json: { input: 'What is the population of Paris?' } }, + ]; + + const execMock = createExecuteFunctionsMock( + { + promptType: 'define', + text: '={{ $json.input }}', + options: { + batching: { + batchSize: 2, + delayBetweenBatches: 0, + }, + }, + }, + errorModel, + fakeRetriever, + 1.6, + ); + + execMock.getInputData = () => items; + execMock.continueOnFail = () => true; + + const result = await node.execute.call(execMock); + + expect(result).toHaveLength(1); + expect(result[0]).toHaveLength(2); + result[0].forEach((item) => { + expect(item.json.error).toContain('Model error'); + }); + }); + + it('should respect delay between batches', async () => { + const mockChatModel = new FakeChatModel({}); + const items = [ + { json: { input: 'What is the capital of France?' } }, + { json: { input: 'What is the population of Paris?' } }, + { json: { input: 'What is France known for?' } }, + ]; + + const execMock = createExecuteFunctionsMock( + { + promptType: 'define', + text: '={{ $json.input }}', + options: { + batching: { + batchSize: 2, + delayBetweenBatches: 100, + }, + }, + }, + mockChatModel, + fakeRetriever, + 1.6, + ); + + execMock.getInputData = () => items; + + const startTime = Date.now(); + await node.execute.call(execMock); + const endTime = Date.now(); + + // Should take at least 100ms due to delay between batches + expect(endTime - startTime).toBeGreaterThanOrEqual(100); + }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/ChainSummarization.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/ChainSummarization.node.ts index 9c97190952..684d27c754 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/ChainSummarization.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/ChainSummarization.node.ts @@ -27,12 +27,13 @@ export class ChainSummarization extends VersionedNodeType { ], }, }, - defaultVersion: 2, + defaultVersion: 2.1, }; const nodeVersions: IVersionedNodeType['nodeVersions'] = { 1: new ChainSummarizationV1(baseDescription), 2: new ChainSummarizationV2(baseDescription), + 2.1: new ChainSummarizationV2(baseDescription), }; super(nodeVersions, baseDescription); diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/ChainSummarizationV2.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/ChainSummarizationV2.node.ts index 7cb2bed603..a8d58f0223 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/ChainSummarizationV2.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/ChainSummarizationV2.node.ts @@ -1,8 +1,3 @@ -import type { Document } from '@langchain/core/documents'; -import type { BaseLanguageModel } from '@langchain/core/language_models/base'; -import type { TextSplitter } from '@langchain/textsplitters'; -import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'; -import { loadSummarizationChain } from 'langchain/chains'; import type { INodeTypeBaseDescription, IExecuteFunctions, @@ -12,14 +7,11 @@ import type { IDataObject, INodeInputConfiguration, } from 'n8n-workflow'; -import { NodeConnectionTypes } from 'n8n-workflow'; +import { NodeConnectionTypes, sleep } from 'n8n-workflow'; -import { N8nBinaryLoader } from '@utils/N8nBinaryLoader'; -import { N8nJsonLoader } from '@utils/N8nJsonLoader'; -import { getTemplateNoticeField } from '@utils/sharedFields'; -import { getTracingConfig } from '@utils/tracing'; +import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields'; -import { getChainPromptsArgs } from '../helpers'; +import { processItem } from './processItem'; import { REFINE_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE } from '../prompt'; function getInputs(parameters: IDataObject) { @@ -63,7 +55,7 @@ export class ChainSummarizationV2 implements INodeType { constructor(baseDescription: INodeTypeBaseDescription) { this.description = { ...baseDescription, - version: [2], + version: [2, 2.1], defaults: { name: 'Summarization Chain', color: '#909298', @@ -306,6 +298,11 @@ export class ChainSummarizationV2 implements INodeType { }, ], }, + getBatchingOptionFields({ + show: { + '@version': [{ _cnd: { gte: 2.1 } }], + }, + }), ], }, ], @@ -325,108 +322,64 @@ export class ChainSummarizationV2 implements INodeType { const items = this.getInputData(); const returnData: INodeExecutionData[] = []; - for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { - try { - const model = (await this.getInputConnectionData( - NodeConnectionTypes.AiLanguageModel, - 0, - )) as BaseLanguageModel; + const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number; + const delayBetweenBatches = this.getNodeParameter( + 'options.batching.delayBetweenBatches', + 0, + 0, + ) as number; - const summarizationMethodAndPrompts = this.getNodeParameter( - 'options.summarizationMethodAndPrompts.values', - itemIndex, - {}, - ) as { - prompt?: string; - refineQuestionPrompt?: string; - refinePrompt?: string; - summarizationMethod: 'map_reduce' | 'stuff' | 'refine'; - combineMapPrompt?: string; - }; + if (this.getNode().typeVersion >= 2.1 && batchSize > 1) { + // Batch processing + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (item, batchItemIndex) => { + const itemIndex = i + batchItemIndex; + return await processItem(this, itemIndex, item, operationMode, chunkingMode); + }); - const chainArgs = getChainPromptsArgs( - summarizationMethodAndPrompts.summarizationMethod ?? 'map_reduce', - summarizationMethodAndPrompts, - ); - - const chain = loadSummarizationChain(model, chainArgs); - const item = items[itemIndex]; - - let processedDocuments: Document[]; - - // Use dedicated document loader input to load documents - if (operationMode === 'documentLoader') { - const documentInput = (await this.getInputConnectionData( - NodeConnectionTypes.AiDocument, - 0, - )) as N8nJsonLoader | Array>>; - - const isN8nLoader = - documentInput instanceof N8nJsonLoader || documentInput instanceof N8nBinaryLoader; - - processedDocuments = isN8nLoader - ? await documentInput.processItem(item, itemIndex) - : documentInput; - - const response = await chain.withConfig(getTracingConfig(this)).invoke({ - input_documents: processedDocuments, - }); - - returnData.push({ json: { response } }); - } - - // Take the input and use binary or json loader - if (['nodeInputJson', 'nodeInputBinary'].includes(operationMode)) { - let textSplitter: TextSplitter | undefined; - - switch (chunkingMode) { - // In simple mode we use recursive character splitter with default settings - case 'simple': - const chunkSize = this.getNodeParameter('chunkSize', itemIndex, 1000) as number; - const chunkOverlap = this.getNodeParameter('chunkOverlap', itemIndex, 200) as number; - - textSplitter = new RecursiveCharacterTextSplitter({ chunkOverlap, chunkSize }); - break; - - // In advanced mode user can connect text splitter node so we just retrieve it - case 'advanced': - textSplitter = (await this.getInputConnectionData( - NodeConnectionTypes.AiTextSplitter, - 0, - )) as TextSplitter | undefined; - break; - default: - break; - } - - let processor: N8nJsonLoader | N8nBinaryLoader; - if (operationMode === 'nodeInputBinary') { - const binaryDataKey = this.getNodeParameter( - 'options.binaryDataKey', - itemIndex, - 'data', - ) as string; - processor = new N8nBinaryLoader(this, 'options.', binaryDataKey, textSplitter); + const batchResults = await Promise.allSettled(batchPromises); + batchResults.forEach((response, index) => { + if (response.status === 'rejected') { + const error = response.reason as Error; + if (this.continueOnFail()) { + returnData.push({ + json: { error: error.message }, + pairedItem: { item: i + index }, + }); + } else { + throw error; + } } else { - processor = new N8nJsonLoader(this, 'options.', textSplitter); + const output = response.value; + returnData.push({ json: { output } }); } + }); - const processedItem = await processor.processItem(item, itemIndex); - const response = await chain.invoke( - { - input_documents: processedItem, - }, - { signal: this.getExecutionCancelSignal() }, + // Add delay between batches if not the last batch + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); + } + } + } else { + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + try { + const response = await processItem( + this, + itemIndex, + items[itemIndex], + operationMode, + chunkingMode, ); returnData.push({ json: { response } }); - } - } catch (error) { - if (this.continueOnFail()) { - returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); - continue; - } + } catch (error) { + if (this.continueOnFail()) { + returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); + continue; + } - throw error; + throw error; + } } } diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/processItem.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/processItem.ts new file mode 100644 index 0000000000..e335deb1a5 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainSummarization/V2/processItem.ts @@ -0,0 +1,107 @@ +import type { Document } from '@langchain/core/documents'; +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import type { ChainValues } from '@langchain/core/utils/types'; +import { RecursiveCharacterTextSplitter, type TextSplitter } from '@langchain/textsplitters'; +import { loadSummarizationChain } from 'langchain/chains'; +import { type IExecuteFunctions, type INodeExecutionData, NodeConnectionTypes } from 'n8n-workflow'; + +import { N8nBinaryLoader } from '@utils/N8nBinaryLoader'; +import { N8nJsonLoader } from '@utils/N8nJsonLoader'; +import { getTracingConfig } from '@utils/tracing'; + +import { getChainPromptsArgs } from '../helpers'; + +export async function processItem( + ctx: IExecuteFunctions, + itemIndex: number, + item: INodeExecutionData, + operationMode: string, + chunkingMode: 'simple' | 'advanced' | 'none', +): Promise { + const model = (await ctx.getInputConnectionData( + NodeConnectionTypes.AiLanguageModel, + 0, + )) as BaseLanguageModel; + + const summarizationMethodAndPrompts = ctx.getNodeParameter( + 'options.summarizationMethodAndPrompts.values', + itemIndex, + {}, + ) as { + prompt?: string; + refineQuestionPrompt?: string; + refinePrompt?: string; + summarizationMethod: 'map_reduce' | 'stuff' | 'refine'; + combineMapPrompt?: string; + }; + + const chainArgs = getChainPromptsArgs( + summarizationMethodAndPrompts.summarizationMethod ?? 'map_reduce', + summarizationMethodAndPrompts, + ); + + const chain = loadSummarizationChain(model, chainArgs); + + let processedDocuments: Document[]; + + // Use dedicated document loader input to load documents + if (operationMode === 'documentLoader') { + const documentInput = (await ctx.getInputConnectionData(NodeConnectionTypes.AiDocument, 0)) as + | N8nJsonLoader + | Array>>; + + const isN8nLoader = + documentInput instanceof N8nJsonLoader || documentInput instanceof N8nBinaryLoader; + + processedDocuments = isN8nLoader + ? await documentInput.processItem(item, itemIndex) + : documentInput; + + return await chain.withConfig(getTracingConfig(ctx)).invoke({ + input_documents: processedDocuments, + }); + } else if (['nodeInputJson', 'nodeInputBinary'].includes(operationMode)) { + // Take the input and use binary or json loader + let textSplitter: TextSplitter | undefined; + + switch (chunkingMode) { + // In simple mode we use recursive character splitter with default settings + case 'simple': + const chunkSize = ctx.getNodeParameter('chunkSize', itemIndex, 1000) as number; + const chunkOverlap = ctx.getNodeParameter('chunkOverlap', itemIndex, 200) as number; + + textSplitter = new RecursiveCharacterTextSplitter({ chunkOverlap, chunkSize }); + break; + + // In advanced mode user can connect text splitter node so we just retrieve it + case 'advanced': + textSplitter = (await ctx.getInputConnectionData(NodeConnectionTypes.AiTextSplitter, 0)) as + | TextSplitter + | undefined; + break; + default: + break; + } + + let processor: N8nJsonLoader | N8nBinaryLoader; + if (operationMode === 'nodeInputBinary') { + const binaryDataKey = ctx.getNodeParameter( + 'options.binaryDataKey', + itemIndex, + 'data', + ) as string; + processor = new N8nBinaryLoader(ctx, 'options.', binaryDataKey, textSplitter); + } else { + processor = new N8nJsonLoader(ctx, 'options.', textSplitter); + } + + const processedItem = await processor.processItem(item, itemIndex); + return await chain.invoke( + { + input_documents: processedItem, + }, + { signal: ctx.getExecutionCancelSignal() }, + ); + } + return undefined; +} diff --git a/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/InformationExtractor.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/InformationExtractor.node.ts index 182e488782..c4b618a284 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/InformationExtractor.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/InformationExtractor.node.ts @@ -1,9 +1,7 @@ import type { BaseLanguageModel } from '@langchain/core/language_models/base'; -import { HumanMessage } from '@langchain/core/messages'; -import { ChatPromptTemplate, SystemMessagePromptTemplate } from '@langchain/core/prompts'; import type { JSONSchema7 } from 'json-schema'; import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers'; -import { jsonParse, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; +import { jsonParse, NodeConnectionTypes, NodeOperationError, sleep } from 'n8n-workflow'; import type { INodeType, INodeTypeDescription, @@ -15,15 +13,13 @@ import type { z } from 'zod'; import { inputSchemaField, jsonSchemaExampleField, schemaTypeField } from '@utils/descriptions'; import { convertJsonSchemaToZod, generateSchema } from '@utils/schemaParsing'; -import { getTracingConfig } from '@utils/tracing'; +import { getBatchingOptionFields } from '@utils/sharedFields'; +import { SYSTEM_PROMPT_TEMPLATE } from './constants'; import { makeZodSchemaFromAttributes } from './helpers'; +import { processItem } from './processItem'; import type { AttributeDefinition } from './types'; -const SYSTEM_PROMPT_TEMPLATE = `You are an expert extraction algorithm. -Only extract relevant information from the text. -If you do not know the value of an attribute asked to extract, you may omit the attribute's value.`; - export class InformationExtractor implements INodeType { description: INodeTypeDescription = { displayName: 'Information Extractor', @@ -31,7 +27,7 @@ export class InformationExtractor implements INodeType { icon: 'fa:project-diagram', iconColor: 'black', group: ['transform'], - version: 1, + version: [1, 1.1], description: 'Extract information from text in a structured format', codex: { alias: ['NER', 'parse', 'parsing', 'JSON', 'data extraction', 'structured'], @@ -213,6 +209,11 @@ export class InformationExtractor implements INodeType { rows: 6, }, }, + getBatchingOptionFields({ + show: { + '@version': [{ _cnd: { gte: 1.1 } }], + }, + }), ], }, ], @@ -265,38 +266,59 @@ export class InformationExtractor implements INodeType { } const resultData: INodeExecutionData[] = []; - for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { - const input = this.getNodeParameter('text', itemIndex) as string; - const inputPrompt = new HumanMessage(input); + const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number; + const delayBetweenBatches = this.getNodeParameter( + 'options.batching.delayBetweenBatches', + 0, + 0, + ) as number; + if (this.getNode().typeVersion >= 1.1 && batchSize >= 1) { + // Batch processing + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (_item, batchItemIndex) => { + const itemIndex = i + batchItemIndex; + return await processItem(this, itemIndex, llm, parser); + }); - const options = this.getNodeParameter('options', itemIndex, {}) as { - systemPromptTemplate?: string; - }; + const batchResults = await Promise.allSettled(batchPromises); - const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( - `${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE} -{format_instructions}`, - ); + batchResults.forEach((response, index) => { + if (response.status === 'rejected') { + const error = response.reason as Error; + if (this.continueOnFail()) { + resultData.push({ + json: { error: error.message }, + pairedItem: { item: i + index }, + }); + return; + } else { + throw new NodeOperationError(this.getNode(), error.message); + } + } + const output = response.value; + resultData.push({ json: { output } }); + }); - const messages = [ - await systemPromptTemplate.format({ - format_instructions: parser.getFormatInstructions(), - }), - inputPrompt, - ]; - const prompt = ChatPromptTemplate.fromMessages(messages); - const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); - - try { - const output = await chain.invoke(messages); - resultData.push({ json: { output } }); - } catch (error) { - if (this.continueOnFail()) { - resultData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); - continue; + // Add delay between batches if not the last batch + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); } + } + } else { + // Sequential processing + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + try { + const output = await processItem(this, itemIndex, llm, parser); + resultData.push({ json: { output } }); + } catch (error) { + if (this.continueOnFail()) { + resultData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } }); + continue; + } - throw error; + throw error; + } } } diff --git a/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/constants.ts b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/constants.ts new file mode 100644 index 0000000000..32fb14dfa4 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/constants.ts @@ -0,0 +1,3 @@ +export const SYSTEM_PROMPT_TEMPLATE = `You are an expert extraction algorithm. +Only extract relevant information from the text. +If you do not know the value of an attribute asked to extract, you may omit the attribute's value.`; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/processItem.ts b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/processItem.ts new file mode 100644 index 0000000000..a7156e7518 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/processItem.ts @@ -0,0 +1,39 @@ +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { HumanMessage } from '@langchain/core/messages'; +import { ChatPromptTemplate, SystemMessagePromptTemplate } from '@langchain/core/prompts'; +import type { OutputFixingParser } from 'langchain/output_parsers'; +import type { IExecuteFunctions } from 'n8n-workflow'; + +import { getTracingConfig } from '@utils/tracing'; + +import { SYSTEM_PROMPT_TEMPLATE } from './constants'; + +export async function processItem( + ctx: IExecuteFunctions, + itemIndex: number, + llm: BaseLanguageModel, + parser: OutputFixingParser, +) { + const input = ctx.getNodeParameter('text', itemIndex) as string; + const inputPrompt = new HumanMessage(input); + + const options = ctx.getNodeParameter('options', itemIndex, {}) as { + systemPromptTemplate?: string; + }; + + const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( + `${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE} +{format_instructions}`, + ); + + const messages = [ + await systemPromptTemplate.format({ + format_instructions: parser.getFormatInstructions(), + }), + inputPrompt, + ]; + const prompt = ChatPromptTemplate.fromMessages(messages); + const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(ctx)); + + return await chain.invoke(messages); +} diff --git a/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/test/InformationExtraction.node.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/test/InformationExtraction.node.test.ts index 0444ba3cef..ef216c05cd 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/test/InformationExtraction.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/InformationExtractor/test/InformationExtraction.node.test.ts @@ -41,7 +41,11 @@ function formatFakeLlmResponse(object: Record) { return `\`\`\`json\n${JSON.stringify(object, null, 2)}\n\`\`\``; } -const createExecuteFunctionsMock = (parameters: IDataObject, fakeLlm: BaseLanguageModel) => { +const createExecuteFunctionsMock = ( + parameters: IDataObject, + fakeLlm: BaseLanguageModel, + inputData = [{ json: {} }], +) => { const nodeParameters = parameters; return { @@ -49,13 +53,15 @@ const createExecuteFunctionsMock = (parameters: IDataObject, fakeLlm: BaseLangua return get(nodeParameters, parameter); }, getNode() { - return {}; + return { + typeVersion: 1.1, + }; }, getInputConnectionData() { return fakeLlm; }, getInputData() { - return [{ json: {} }]; + return inputData; }, getWorkflow() { return { @@ -215,4 +221,132 @@ describe('InformationExtractor', () => { expect(response).toEqual([[{ json: { output: { name: 'John', age: 30 } } }]]); }); }); + + describe('Batch Processing', () => { + it('should process multiple items in batches', async () => { + const node = new InformationExtractor(); + const inputData = [ + { json: { text: 'John is 30 years old' } }, + { json: { text: 'Alice is 25 years old' } }, + { json: { text: 'Bob is 40 years old' } }, + ]; + + const response = await node.execute.call( + createExecuteFunctionsMock( + { + text: 'John is 30 years old', + attributes: { + attributes: mockPersonAttributes, + }, + options: { + batching: { + batchSize: 2, + delayBetweenBatches: 0, + }, + }, + schemaType: 'fromAttributes', + }, + new FakeListChatModel({ + responses: [ + formatFakeLlmResponse({ name: 'John', age: 30 }), + formatFakeLlmResponse({ name: 'Alice', age: 25 }), + formatFakeLlmResponse({ name: 'Bob', age: 40 }), + ], + }), + inputData, + ), + ); + + expect(response).toEqual([ + [ + { json: { output: { name: 'John', age: 30 } } }, + { json: { output: { name: 'Alice', age: 25 } } }, + { json: { output: { name: 'Bob', age: 40 } } }, + ], + ]); + }); + + it('should handle errors in batch processing', async () => { + const node = new InformationExtractor(); + const inputData = [ + { json: { text: 'John is 30 years old' } }, + { json: { text: 'Invalid text' } }, + { json: { text: 'Bob is 40 years old' } }, + ]; + + const mockExecuteFunctions = createExecuteFunctionsMock( + { + text: 'John is 30 years old', + attributes: { + attributes: mockPersonAttributesRequired, + }, + options: { + batching: { + batchSize: 2, + delayBetweenBatches: 0, + }, + }, + schemaType: 'fromAttributes', + }, + new FakeListChatModel({ + responses: [ + formatFakeLlmResponse({ name: 'John', age: 30 }), + formatFakeLlmResponse({ name: 'Invalid' }), // Missing required age + formatFakeLlmResponse({ name: 'Invalid' }), // Missing required age on retry + formatFakeLlmResponse({ name: 'Bob', age: 40 }), + ], + }), + inputData, + ); + + mockExecuteFunctions.continueOnFail = () => true; + + const response = await node.execute.call(mockExecuteFunctions); + + //expect(response).toBe({}); + expect(response[0]).toHaveLength(3); + expect(response[0][0]).toEqual({ json: { output: { name: 'John', age: 30 } } }); + expect(response[0][1]).toEqual({ + json: { error: expect.stringContaining('Failed to parse') }, + pairedItem: { item: 1 }, + }); + expect(response[0][2]).toEqual({ json: { output: { name: 'Bob', age: 40 } } }); + }); + + it('should throw error if batch processing fails and continueOnFail is false', async () => { + const node = new InformationExtractor(); + const inputData = [ + { json: { text: 'John is 30 years old' } }, + { json: { text: 'Invalid text' } }, + { json: { text: 'Bob is 40 years old' } }, + ]; + + const mockExecuteFunctions = createExecuteFunctionsMock( + { + text: 'John is 30 years old', + attributes: { + attributes: mockPersonAttributesRequired, + }, + options: { + batching: { + batchSize: 2, + delayBetweenBatches: 0, + }, + }, + schemaType: 'fromAttributes', + }, + new FakeListChatModel({ + responses: [ + formatFakeLlmResponse({ name: 'John', age: 30 }), + formatFakeLlmResponse({ name: 'Invalid' }), // Missing required age + formatFakeLlmResponse({ name: 'Invalid' }), // Missing required age on retry + formatFakeLlmResponse({ name: 'Bob', age: 40 }), + ], + }), + inputData, + ); + + await expect(node.execute.call(mockExecuteFunctions)).rejects.toThrow('Failed to parse'); + }); + }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/chains/SentimentAnalysis/SentimentAnalysis.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/SentimentAnalysis/SentimentAnalysis.node.ts index ac232eeee4..139cf6acc2 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/SentimentAnalysis/SentimentAnalysis.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/SentimentAnalysis/SentimentAnalysis.node.ts @@ -2,7 +2,7 @@ import type { BaseLanguageModel } from '@langchain/core/language_models/base'; import { HumanMessage } from '@langchain/core/messages'; import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts'; import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers'; -import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; +import { NodeConnectionTypes, NodeOperationError, sleep } from 'n8n-workflow'; import type { IDataObject, IExecuteFunctions, @@ -13,6 +13,7 @@ import type { } from 'n8n-workflow'; import { z } from 'zod'; +import { getBatchingOptionFields } from '@utils/sharedFields'; import { getTracingConfig } from '@utils/tracing'; const DEFAULT_SYSTEM_PROMPT_TEMPLATE = @@ -35,7 +36,7 @@ export class SentimentAnalysis implements INodeType { icon: 'fa:balance-scale-left', iconColor: 'black', group: ['transform'], - version: 1, + version: [1, 1.1], description: 'Analyze the sentiment of your text', codex: { categories: ['AI'], @@ -131,6 +132,11 @@ export class SentimentAnalysis implements INodeType { description: 'Whether to enable auto-fixing (may trigger an additional LLM call if output is broken)', }, + getBatchingOptionFields({ + show: { + '@version': [{ _cnd: { gte: 1.1 } }], + }, + }), ], }, ], @@ -146,110 +152,265 @@ export class SentimentAnalysis implements INodeType { const returnData: INodeExecutionData[][] = []; - for (let i = 0; i < items.length; i++) { - try { - const sentimentCategories = this.getNodeParameter( - 'options.categories', - i, - DEFAULT_CATEGORIES, - ) as string; + const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number; + const delayBetweenBatches = this.getNodeParameter( + 'options.batching.delayBetweenBatches', + 0, + 0, + ) as number; - const categories = sentimentCategories - .split(',') - .map((cat) => cat.trim()) - .filter(Boolean); + if (this.getNode().typeVersion >= 1.1 && batchSize > 1) { + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (_item, batchItemIndex) => { + const itemIndex = i + batchItemIndex; + const sentimentCategories = this.getNodeParameter( + 'options.categories', + itemIndex, + DEFAULT_CATEGORIES, + ) as string; - if (categories.length === 0) { - throw new NodeOperationError(this.getNode(), 'No sentiment categories provided', { - itemIndex: i, + const categories = sentimentCategories + .split(',') + .map((cat) => cat.trim()) + .filter(Boolean); + + if (categories.length === 0) { + return { + result: null, + itemIndex, + error: new NodeOperationError(this.getNode(), 'No sentiment categories provided', { + itemIndex, + }), + }; + } + + // Initialize returnData with empty arrays for each category + if (returnData.length === 0) { + returnData.push(...Array.from({ length: categories.length }, () => [])); + } + + const options = this.getNodeParameter('options', itemIndex, {}) as { + systemPromptTemplate?: string; + includeDetailedResults?: boolean; + enableAutoFixing?: boolean; + }; + + const schema = z.object({ + sentiment: z.enum(categories as [string, ...string[]]), + strength: z + .number() + .min(0) + .max(1) + .describe('Strength score for sentiment in relation to the category'), + confidence: z.number().min(0).max(1), }); - } - // Initialize returnData with empty arrays for each category - if (returnData.length === 0) { - returnData.push(...Array.from({ length: categories.length }, () => [])); - } + const structuredParser = StructuredOutputParser.fromZodSchema(schema); - const options = this.getNodeParameter('options', i, {}) as { - systemPromptTemplate?: string; - includeDetailedResults?: boolean; - enableAutoFixing?: boolean; - }; + const parser = options.enableAutoFixing + ? OutputFixingParser.fromLLM(llm, structuredParser) + : structuredParser; - const schema = z.object({ - sentiment: z.enum(categories as [string, ...string[]]), - strength: z - .number() - .min(0) - .max(1) - .describe('Strength score for sentiment in relation to the category'), - confidence: z.number().min(0).max(1), - }); - - const structuredParser = StructuredOutputParser.fromZodSchema(schema); - - const parser = options.enableAutoFixing - ? OutputFixingParser.fromLLM(llm, structuredParser) - : structuredParser; - - const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( - `${options.systemPromptTemplate ?? DEFAULT_SYSTEM_PROMPT_TEMPLATE} - {format_instructions}`, - ); - - const input = this.getNodeParameter('inputText', i) as string; - const inputPrompt = new HumanMessage(input); - const messages = [ - await systemPromptTemplate.format({ - categories: sentimentCategories, - format_instructions: parser.getFormatInstructions(), - }), - inputPrompt, - ]; - - const prompt = ChatPromptTemplate.fromMessages(messages); - const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); - - try { - const output = await chain.invoke(messages); - const sentimentIndex = categories.findIndex( - (s) => s.toLowerCase() === output.sentiment.toLowerCase(), + const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( + `${options.systemPromptTemplate ?? DEFAULT_SYSTEM_PROMPT_TEMPLATE} + {format_instructions}`, ); - if (sentimentIndex !== -1) { - const resultItem = { ...items[i] }; - const sentimentAnalysis: IDataObject = { - category: output.sentiment, - }; - if (options.includeDetailedResults) { - sentimentAnalysis.strength = output.strength; - sentimentAnalysis.confidence = output.confidence; + const input = this.getNodeParameter('inputText', itemIndex) as string; + const inputPrompt = new HumanMessage(input); + const messages = [ + await systemPromptTemplate.format({ + categories: sentimentCategories, + format_instructions: parser.getFormatInstructions(), + }), + inputPrompt, + ]; + + const prompt = ChatPromptTemplate.fromMessages(messages); + const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); + + try { + const output = await chain.invoke(messages); + const sentimentIndex = categories.findIndex( + (s) => s.toLowerCase() === output.sentiment.toLowerCase(), + ); + + if (sentimentIndex !== -1) { + const resultItem = { ...items[itemIndex] }; + const sentimentAnalysis: IDataObject = { + category: output.sentiment, + }; + if (options.includeDetailedResults) { + sentimentAnalysis.strength = output.strength; + sentimentAnalysis.confidence = output.confidence; + } + resultItem.json = { + ...resultItem.json, + sentimentAnalysis, + }; + + return { + result: { + resultItem, + sentimentIndex, + }, + itemIndex, + }; } - resultItem.json = { - ...resultItem.json, - sentimentAnalysis, + + return { + result: {}, + itemIndex, }; + } catch (error) { + return { + result: null, + itemIndex, + error: new NodeOperationError( + this.getNode(), + 'Error during parsing of LLM output, please check your LLM model and configuration', + { + itemIndex, + }, + ), + }; + } + }); + const batchResults = await Promise.all(batchPromises); + + batchResults.forEach(({ result, itemIndex, error }) => { + if (error) { + if (this.continueOnFail()) { + const executionErrorData = this.helpers.constructExecutionMetaData( + this.helpers.returnJsonArray({ error: error.message }), + { itemData: { item: itemIndex } }, + ); + + returnData[0].push(...executionErrorData); + return; + } else { + throw error; + } + } else if (result.resultItem && result.sentimentIndex) { + const sentimentIndex = result.sentimentIndex; + const resultItem = result.resultItem; returnData[sentimentIndex].push(resultItem); } - } catch (error) { - throw new NodeOperationError( - this.getNode(), - 'Error during parsing of LLM output, please check your LLM model and configuration', - { + }); + + // Add delay between batches if not the last batch + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); + } + } + } else { + // Sequential Processing + for (let i = 0; i < items.length; i++) { + try { + const sentimentCategories = this.getNodeParameter( + 'options.categories', + i, + DEFAULT_CATEGORIES, + ) as string; + + const categories = sentimentCategories + .split(',') + .map((cat) => cat.trim()) + .filter(Boolean); + + if (categories.length === 0) { + throw new NodeOperationError(this.getNode(), 'No sentiment categories provided', { itemIndex: i, - }, + }); + } + + // Initialize returnData with empty arrays for each category + if (returnData.length === 0) { + returnData.push(...Array.from({ length: categories.length }, () => [])); + } + + const options = this.getNodeParameter('options', i, {}) as { + systemPromptTemplate?: string; + includeDetailedResults?: boolean; + enableAutoFixing?: boolean; + }; + + const schema = z.object({ + sentiment: z.enum(categories as [string, ...string[]]), + strength: z + .number() + .min(0) + .max(1) + .describe('Strength score for sentiment in relation to the category'), + confidence: z.number().min(0).max(1), + }); + + const structuredParser = StructuredOutputParser.fromZodSchema(schema); + + const parser = options.enableAutoFixing + ? OutputFixingParser.fromLLM(llm, structuredParser) + : structuredParser; + + const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( + `${options.systemPromptTemplate ?? DEFAULT_SYSTEM_PROMPT_TEMPLATE} + {format_instructions}`, ); + + const input = this.getNodeParameter('inputText', i) as string; + const inputPrompt = new HumanMessage(input); + const messages = [ + await systemPromptTemplate.format({ + categories: sentimentCategories, + format_instructions: parser.getFormatInstructions(), + }), + inputPrompt, + ]; + + const prompt = ChatPromptTemplate.fromMessages(messages); + const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); + + try { + const output = await chain.invoke(messages); + const sentimentIndex = categories.findIndex( + (s) => s.toLowerCase() === output.sentiment.toLowerCase(), + ); + + if (sentimentIndex !== -1) { + const resultItem = { ...items[i] }; + const sentimentAnalysis: IDataObject = { + category: output.sentiment, + }; + if (options.includeDetailedResults) { + sentimentAnalysis.strength = output.strength; + sentimentAnalysis.confidence = output.confidence; + } + resultItem.json = { + ...resultItem.json, + sentimentAnalysis, + }; + returnData[sentimentIndex].push(resultItem); + } + } catch (error) { + throw new NodeOperationError( + this.getNode(), + 'Error during parsing of LLM output, please check your LLM model and configuration', + { + itemIndex: i, + }, + ); + } + } catch (error) { + if (this.continueOnFail()) { + const executionErrorData = this.helpers.constructExecutionMetaData( + this.helpers.returnJsonArray({ error: error.message }), + { itemData: { item: i } }, + ); + returnData[0].push(...executionErrorData); + continue; + } + throw error; } - } catch (error) { - if (this.continueOnFail()) { - const executionErrorData = this.helpers.constructExecutionMetaData( - this.helpers.returnJsonArray({ error: error.message }), - { itemData: { item: i } }, - ); - returnData[0].push(...executionErrorData); - continue; - } - throw error; } } return returnData; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts index 204e164370..656b3c829d 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts @@ -1,8 +1,6 @@ import type { BaseLanguageModel } from '@langchain/core/language_models/base'; -import { HumanMessage } from '@langchain/core/messages'; -import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts'; import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers'; -import { NodeOperationError, NodeConnectionTypes } from 'n8n-workflow'; +import { NodeOperationError, NodeConnectionTypes, sleep } from 'n8n-workflow'; import type { IDataObject, IExecuteFunctions, @@ -13,7 +11,9 @@ import type { } from 'n8n-workflow'; import { z } from 'zod'; -import { getTracingConfig } from '@utils/tracing'; +import { getBatchingOptionFields } from '@utils/sharedFields'; + +import { processItem } from './processItem'; const SYSTEM_PROMPT_TEMPLATE = "Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json."; @@ -35,7 +35,7 @@ export class TextClassifier implements INodeType { icon: 'fa:tags', iconColor: 'black', group: ['transform'], - version: 1, + version: [1, 1.1], description: 'Classify your text into distinct categories', codex: { categories: ['AI'], @@ -158,6 +158,11 @@ export class TextClassifier implements INodeType { description: 'Whether to enable auto-fixing (may trigger an additional LLM call if output is broken)', }, + getBatchingOptionFields({ + show: { + '@version': [{ _cnd: { gte: 1.1 } }], + }, + }), ], }, ], @@ -165,6 +170,12 @@ export class TextClassifier implements INodeType { async execute(this: IExecuteFunctions): Promise { const items = this.getInputData(); + const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number; + const delayBetweenBatches = this.getNodeParameter( + 'options.batching.delayBetweenBatches', + 0, + 0, + ) as number; const llm = (await this.getInputConnectionData( NodeConnectionTypes.AiLanguageModel, @@ -223,68 +234,93 @@ export class TextClassifier implements INodeType { { length: categories.length + (fallback === 'other' ? 1 : 0) }, (_) => [], ); - for (let itemIdx = 0; itemIdx < items.length; itemIdx++) { - const item = items[itemIdx]; - item.pairedItem = { item: itemIdx }; - const input = this.getNodeParameter('inputText', itemIdx) as string; - if (input === undefined || input === null) { - if (this.continueOnFail()) { - returnData[0].push({ - json: { error: 'Text to classify is not defined' }, - pairedItem: { item: itemIdx }, - }); - continue; - } else { - throw new NodeOperationError( - this.getNode(), - `Text to classify for item ${itemIdx} is not defined`, + if (this.getNode().typeVersion >= 1.1 && batchSize > 1) { + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchPromises = batch.map(async (_item, batchItemIndex) => { + const itemIndex = i + batchItemIndex; + const item = items[itemIndex]; + item.pairedItem = { item: itemIndex }; + + return await processItem( + this, + itemIndex, + item, + llm, + parser, + categories, + multiClassPrompt, + fallbackPrompt, ); + }); + + const batchResults = await Promise.allSettled(batchPromises); + + batchResults.forEach((response, batchItemIndex) => { + const index = i + batchItemIndex; + if (response.status === 'rejected') { + const error = response.reason as Error; + if (this.continueOnFail()) { + returnData[0].push({ + json: { error: error.message }, + pairedItem: { item: index }, + }); + return; + } else { + throw new NodeOperationError(this.getNode(), error.message); + } + } else { + const output = response.value; + const item = items[index]; + + categories.forEach((cat, idx) => { + if (output[cat.category]) returnData[idx].push(item); + }); + + if (fallback === 'other' && output.fallback) + returnData[returnData.length - 1].push(item); + } + }); + + // Add delay between batches if not the last batch + if (i + batchSize < items.length && delayBetweenBatches > 0) { + await sleep(delayBetweenBatches); } } + } else { + for (let itemIndex = 0; itemIndex < items.length; itemIndex++) { + const item = items[itemIndex]; + item.pairedItem = { item: itemIndex }; - const inputPrompt = new HumanMessage(input); + try { + const output = await processItem( + this, + itemIndex, + item, + llm, + parser, + categories, + multiClassPrompt, + fallbackPrompt, + ); - const systemPromptTemplateOpt = this.getNodeParameter( - 'options.systemPromptTemplate', - itemIdx, - SYSTEM_PROMPT_TEMPLATE, - ) as string; - const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( - `${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE} -{format_instructions} -${multiClassPrompt} -${fallbackPrompt}`, - ); - - const messages = [ - await systemPromptTemplate.format({ - categories: categories.map((cat) => cat.category).join(', '), - format_instructions: parser.getFormatInstructions(), - }), - inputPrompt, - ]; - const prompt = ChatPromptTemplate.fromMessages(messages); - const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); - - try { - const output = await chain.invoke(messages); - - categories.forEach((cat, idx) => { - if (output[cat.category]) returnData[idx].push(item); - }); - if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item); - } catch (error) { - if (this.continueOnFail()) { - returnData[0].push({ - json: { error: error.message }, - pairedItem: { item: itemIdx }, + categories.forEach((cat, idx) => { + if (output[cat.category]) returnData[idx].push(item); }); + if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item); + } catch (error) { + if (this.continueOnFail()) { + returnData[0].push({ + json: { error: error.message }, + pairedItem: { item: itemIndex }, + }); - continue; + continue; + } + + throw error; } - - throw error; } } diff --git a/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/constants.ts b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/constants.ts new file mode 100644 index 0000000000..e7332e393d --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/constants.ts @@ -0,0 +1,2 @@ +export const SYSTEM_PROMPT_TEMPLATE = + "Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json."; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/processItem.ts b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/processItem.ts new file mode 100644 index 0000000000..5c4da3de08 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/processItem.ts @@ -0,0 +1,57 @@ +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; +import { HumanMessage } from '@langchain/core/messages'; +import { ChatPromptTemplate, SystemMessagePromptTemplate } from '@langchain/core/prompts'; +import type { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers'; +import { NodeOperationError, type IExecuteFunctions, type INodeExecutionData } from 'n8n-workflow'; + +import { getTracingConfig } from '@utils/tracing'; + +import { SYSTEM_PROMPT_TEMPLATE } from './constants'; + +export async function processItem( + ctx: IExecuteFunctions, + itemIndex: number, + item: INodeExecutionData, + llm: BaseLanguageModel, + parser: StructuredOutputParser | OutputFixingParser, + categories: Array<{ category: string; description: string }>, + multiClassPrompt: string, + fallbackPrompt: string | undefined, +): Promise> { + const input = ctx.getNodeParameter('inputText', itemIndex) as string; + + if (input === undefined || input === null) { + throw new NodeOperationError( + ctx.getNode(), + `Text to classify for item ${itemIndex} is not defined`, + ); + } + + item.pairedItem = { item: itemIndex }; + + const inputPrompt = new HumanMessage(input); + + const systemPromptTemplateOpt = ctx.getNodeParameter( + 'options.systemPromptTemplate', + itemIndex, + SYSTEM_PROMPT_TEMPLATE, + ) as string; + const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( + `${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE} + {format_instructions} + ${multiClassPrompt} + ${fallbackPrompt}`, + ); + + const messages = [ + await systemPromptTemplate.format({ + categories: categories.map((cat) => cat.category).join(', '), + format_instructions: parser.getFormatInstructions(), + }), + inputPrompt, + ]; + const prompt = ChatPromptTemplate.fromMessages(messages); + const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(ctx)); + + return await chain.invoke(messages); +} diff --git a/packages/@n8n/nodes-langchain/utils/sharedFields.ts b/packages/@n8n/nodes-langchain/utils/sharedFields.ts index beb92b6f74..9e806c051c 100644 --- a/packages/@n8n/nodes-langchain/utils/sharedFields.ts +++ b/packages/@n8n/nodes-langchain/utils/sharedFields.ts @@ -1,4 +1,5 @@ -import { NodeConnectionTypes, type INodeProperties } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; +import type { IDisplayOptions, INodeProperties } from 'n8n-workflow'; export const metadataFilterField: INodeProperties = { displayName: 'Metadata Filter', @@ -42,6 +43,38 @@ export function getTemplateNoticeField(templateId: number): INodeProperties { }; } +export function getBatchingOptionFields( + displayOptions: IDisplayOptions | undefined, + defaultBatchSize: number = 5, +): INodeProperties { + return { + displayName: 'Batch Processing', + name: 'batching', + type: 'collection', + placeholder: 'Add Batch Processing Option', + description: 'Batch processing options for rate limiting', + default: {}, + options: [ + { + displayName: 'Batch Size', + name: 'batchSize', + default: defaultBatchSize, + type: 'number', + description: + 'How many items to process in parallel. This is useful for rate limiting, but might impact the log output ordering.', + }, + { + displayName: 'Delay Between Batches', + name: 'delayBetweenBatches', + default: 0, + type: 'number', + description: 'Delay in milliseconds between batches. This is useful for rate limiting.', + }, + ], + displayOptions, + }; +} + const connectionsString = { [NodeConnectionTypes.AiAgent]: { // Root AI view diff --git a/packages/frontend/editor-ui/src/stores/workflows.store.test.ts b/packages/frontend/editor-ui/src/stores/workflows.store.test.ts index 10952d54b9..d599d51f2a 100644 --- a/packages/frontend/editor-ui/src/stores/workflows.store.test.ts +++ b/packages/frontend/editor-ui/src/stores/workflows.store.test.ts @@ -11,7 +11,7 @@ import { useWorkflowsStore } from '@/stores/workflows.store'; import type { IExecutionResponse, INodeUi, IWorkflowDb, IWorkflowSettings } from '@/Interface'; import { useNodeTypesStore } from '@/stores/nodeTypes.store'; -import { SEND_AND_WAIT_OPERATION } from 'n8n-workflow'; +import { deepCopy, SEND_AND_WAIT_OPERATION } from 'n8n-workflow'; import type { IPinData, ExecutionSummary, @@ -658,6 +658,65 @@ describe('useWorkflowsStore', () => { TestNode1: [{ json: { test: false } }], }); }); + + it('should replace existing placeholder task data in new log view', () => { + settingsStore.settings = { + logsView: { + enabled: true, + }, + } as FrontendSettings; + const successEventWithExecutionIndex = deepCopy(successEvent); + successEventWithExecutionIndex.data.executionIndex = 1; + + const runWithExistingRunData = executionResponse; + runWithExistingRunData.data = { + resultData: { + runData: { + [successEventWithExecutionIndex.nodeName]: [ + { + hints: [], + startTime: 1727867966633, + executionIndex: successEventWithExecutionIndex.data.executionIndex, + executionTime: 1, + source: [], + executionStatus: 'running', + data: { + main: [ + [ + { + json: {}, + pairedItem: { + item: 0, + }, + }, + ], + ], + }, + }, + ], + }, + }, + }; + workflowsStore.setWorkflowExecutionData(runWithExistingRunData); + + workflowsStore.nodesByName[successEvent.nodeName] = mock({ + type: 'n8n-nodes-base.manualTrigger', + }); + + // ACT + workflowsStore.updateNodeExecutionData(successEventWithExecutionIndex); + + expect(workflowsStore.workflowExecutionData).toEqual({ + ...executionResponse, + data: { + resultData: { + runData: { + [successEvent.nodeName]: [successEventWithExecutionIndex.data], + }, + }, + }, + }); + }); }); describe('setNodeValue()', () => { diff --git a/packages/frontend/editor-ui/src/stores/workflows.store.ts b/packages/frontend/editor-ui/src/stores/workflows.store.ts index a6a83fbb4a..56f43ac62d 100644 --- a/packages/frontend/editor-ui/src/stores/workflows.store.ts +++ b/packages/frontend/editor-ui/src/stores/workflows.store.ts @@ -1541,10 +1541,16 @@ export const useWorkflowsStore = defineStore(STORES.WORKFLOWS, () => { openFormPopupWindow(testUrl); } } else { - const status = tasksData[tasksData.length - 1]?.executionStatus ?? 'unknown'; + // If we process items in paralell on subnodes we get several placeholder taskData items. + // We need to find and replace the item with the matching executionIndex and only append if we don't find anything matching. + const existingRunIndex = tasksData.findIndex( + (item) => item.executionIndex === data.executionIndex, + ); + const index = existingRunIndex > -1 ? existingRunIndex : tasksData.length - 1; + const status = tasksData[index]?.executionStatus ?? 'unknown'; if ('waiting' === status || (settingsStore.isNewLogsEnabled && 'running' === status)) { - tasksData.splice(tasksData.length - 1, 1, data); + tasksData.splice(index, 1, data); } else { tasksData.push(data); }