fix(AI Agent Node): Fix tool calling when tools run in a loop (#15250)

Co-authored-by: JP van Oosten <jp@n8n.io>
Co-authored-by: कारतोफ्फेलस्क्रिप्ट™ <aditya@netroy.in>
This commit is contained in:
Yiorgis Gozadinos
2025-05-13 14:40:07 +02:00
committed by GitHub
parent 52f27a76ac
commit cd1d6c9dfc
8 changed files with 113 additions and 15 deletions

View File

@@ -39,6 +39,7 @@ export class ToolWorkflowV2 implements INodeType {
const description = this.getNodeParameter('description', itemIndex) as string; const description = this.getNodeParameter('description', itemIndex) as string;
const tool = await workflowToolService.createTool({ const tool = await workflowToolService.createTool({
ctx: this,
name, name,
description, description,
itemIndex, itemIndex,

View File

@@ -13,6 +13,10 @@ import { WorkflowToolService } from './utils/WorkflowToolService';
// Mock ISupplyDataFunctions interface // Mock ISupplyDataFunctions interface
function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDataFunctions { function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDataFunctions {
let runIndex = 0;
const getNextRunIndex = jest.fn(() => {
return runIndex++;
});
return { return {
runIndex: 0, runIndex: 0,
getNodeParameter: jest.fn(), getNodeParameter: jest.fn(),
@@ -26,6 +30,7 @@ function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDa
getInputData: jest.fn(), getInputData: jest.fn(),
getMode: jest.fn(), getMode: jest.fn(),
getRestApiUrl: jest.fn(), getRestApiUrl: jest.fn(),
getNextRunIndex,
getTimezone: jest.fn(), getTimezone: jest.fn(),
getWorkflow: jest.fn(), getWorkflow: jest.fn(),
getWorkflowStaticData: jest.fn(), getWorkflowStaticData: jest.fn(),
@@ -56,6 +61,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
describe('createTool', () => { describe('createTool', () => {
it('should create a basic dynamic tool when schema is not used', async () => { it('should create a basic dynamic tool when schema is not used', async () => {
const toolParams = { const toolParams = {
ctx: context,
name: 'TestTool', name: 'TestTool',
description: 'Test Description', description: 'Test Description',
itemIndex: 0, itemIndex: 0,
@@ -70,6 +76,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
it('should create a tool that can handle successful execution', async () => { it('should create a tool that can handle successful execution', async () => {
const toolParams = { const toolParams = {
ctx: context,
name: 'TestTool', name: 'TestTool',
description: 'Test Description', description: 'Test Description',
itemIndex: 0, itemIndex: 0,
@@ -112,6 +119,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
it('should handle errors during tool execution', async () => { it('should handle errors during tool execution', async () => {
const toolParams = { const toolParams = {
ctx: context,
name: 'TestTool', name: 'TestTool',
description: 'Test Description', description: 'Test Description',
itemIndex: 0, itemIndex: 0,

View File

@@ -60,17 +60,21 @@ export class WorkflowToolService {
// Creates the tool based on the provided parameters // Creates the tool based on the provided parameters
async createTool({ async createTool({
ctx,
name, name,
description, description,
itemIndex, itemIndex,
}: { }: {
ctx: ISupplyDataFunctions;
name: string; name: string;
description: string; description: string;
itemIndex: number; itemIndex: number;
}): Promise<DynamicTool | DynamicStructuredTool> { }): Promise<DynamicTool | DynamicStructuredTool> {
let runIndex = 0;
// Handler for the tool execution, will be called when the tool is executed // Handler for the tool execution, will be called when the tool is executed
// This function will execute the sub-workflow and return the response // This function will execute the sub-workflow and return the response
// We get the runIndex from the context to handle multiple executions
// of the same tool when the tool is used in a loop or in a parallel execution.
let runIndex: number = ctx.getNextRunIndex();
const toolHandler = async ( const toolHandler = async (
query: string | IDataObject, query: string | IDataObject,
runManager?: CallbackManagerForToolRun, runManager?: CallbackManagerForToolRun,

View File

@@ -15,6 +15,7 @@ import type {
ICredentialDataDecryptedObject, ICredentialDataDecryptedObject,
NodeConnectionType, NodeConnectionType,
} from 'n8n-workflow'; } from 'n8n-workflow';
import type { IRunData } from 'n8n-workflow';
import { ApplicationError, NodeConnectionTypes } from 'n8n-workflow'; import { ApplicationError, NodeConnectionTypes } from 'n8n-workflow';
import { describeCommonTests } from './shared-tests'; import { describeCommonTests } from './shared-tests';
@@ -186,4 +187,36 @@ describe('SupplyDataContext', () => {
expect(clone).not.toBe(supplyDataContext); expect(clone).not.toBe(supplyDataContext);
}); });
}); });
describe('getNextRunIndex', () => {
it('should return 0 as the default latest run index', () => {
const latestRunIndex = supplyDataContext.getNextRunIndex();
expect(latestRunIndex).toBe(0);
});
it('should return the length of the run execution data for the node', () => {
const runData = mock<IRunData>();
const runExecutionData = mock<IRunExecutionData>({
resultData: { runData: { [node.name]: [runData, runData] } },
});
const supplyDataContext = new SupplyDataContext(
workflow,
node,
additionalData,
mode,
runExecutionData,
runIndex,
connectionInputData,
inputData,
connectionType,
executeData,
[closeFn],
abortSignal,
);
const latestRunIndex = supplyDataContext.getNextRunIndex();
expect(latestRunIndex).toBe(2);
});
});
}); });

View File

@@ -167,16 +167,18 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData
return super.getInputItems(inputIndex, connectionType) ?? []; return super.getInputItems(inputIndex, connectionType) ?? [];
} }
getNextRunIndex(): number {
const nodeName = this.node.name;
return this.runExecutionData.resultData.runData[nodeName]?.length ?? 0;
}
/** @deprecated create a context object with inputData for every runIndex */ /** @deprecated create a context object with inputData for every runIndex */
addInputData( addInputData(
connectionType: AINodeConnectionType, connectionType: AINodeConnectionType,
data: INodeExecutionData[][], data: INodeExecutionData[][],
): { index: number } { ): { index: number } {
const nodeName = this.node.name; const nodeName = this.node.name;
let currentNodeRunIndex = 0; const currentNodeRunIndex = this.getNextRunIndex();
if (this.runExecutionData.resultData.runData.hasOwnProperty(nodeName)) {
currentNodeRunIndex = this.runExecutionData.resultData.runData[nodeName].length;
}
this.addExecutionDataFunctions( this.addExecutionDataFunctions(
'input', 'input',

View File

@@ -11,7 +11,9 @@ import type {
INodeType, INodeType,
INodeTypes, INodeTypes,
IExecuteFunctions, IExecuteFunctions,
IRunData,
} from 'n8n-workflow'; } from 'n8n-workflow';
import type { ITaskData } from 'n8n-workflow';
import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
import { ExecuteContext } from '../../execute-context'; import { ExecuteContext } from '../../execute-context';
@@ -377,11 +379,19 @@ describe('makeHandleToolInvocation', () => {
execute, execute,
}); });
const contextFactory = jest.fn(); const contextFactory = jest.fn();
const taskData = mock<ITaskData>();
let runExecutionData = mock<IRunExecutionData>({
resultData: {
runData: mock<IRunData>(),
},
});
const toolArgs = { key: 'value' }; const toolArgs = { key: 'value' };
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
}); });
it('should return stringified results when execution is successful', async () => { it('should return stringified results when execution is successful', async () => {
const mockContext = mock<IExecuteFunctions>(); const mockContext = mock<IExecuteFunctions>();
contextFactory.mockReturnValue(mockContext); contextFactory.mockReturnValue(mockContext);
@@ -393,6 +403,7 @@ describe('makeHandleToolInvocation', () => {
contextFactory, contextFactory,
connectedNode, connectedNode,
connectedNodeType, connectedNodeType,
runExecutionData,
); );
const result = await handleToolInvocation(toolArgs); const result = await handleToolInvocation(toolArgs);
@@ -405,7 +416,6 @@ describe('makeHandleToolInvocation', () => {
it('should handle binary data and return a warning message', async () => { it('should handle binary data and return a warning message', async () => {
const mockContext = mock<IExecuteFunctions>(); const mockContext = mock<IExecuteFunctions>();
contextFactory.mockReturnValue(mockContext); contextFactory.mockReturnValue(mockContext);
const mockResult = [[{ json: {}, binary: { file: 'data' } }]]; const mockResult = [[{ json: {}, binary: { file: 'data' } }]];
execute.mockResolvedValueOnce(mockResult); execute.mockResolvedValueOnce(mockResult);
@@ -413,6 +423,7 @@ describe('makeHandleToolInvocation', () => {
contextFactory, contextFactory,
connectedNode, connectedNode,
connectedNodeType, connectedNodeType,
runExecutionData,
); );
const result = await handleToolInvocation(toolArgs); const result = await handleToolInvocation(toolArgs);
@@ -439,7 +450,6 @@ describe('makeHandleToolInvocation', () => {
}, },
}); });
contextFactory.mockReturnValue(mockContext); contextFactory.mockReturnValue(mockContext);
const mockResult = [[{ json: { a: 3 }, binary: { file: 'data' } }]]; const mockResult = [[{ json: { a: 3 }, binary: { file: 'data' } }]];
execute.mockResolvedValueOnce(mockResult); execute.mockResolvedValueOnce(mockResult);
@@ -447,6 +457,7 @@ describe('makeHandleToolInvocation', () => {
contextFactory, contextFactory,
connectedNode, connectedNode,
connectedNodeType, connectedNodeType,
runExecutionData,
); );
const result = await handleToolInvocation(toolArgs); const result = await handleToolInvocation(toolArgs);
@@ -466,7 +477,6 @@ describe('makeHandleToolInvocation', () => {
it('should handle execution errors and return an error message', async () => { it('should handle execution errors and return an error message', async () => {
const mockContext = mock<IExecuteFunctions>(); const mockContext = mock<IExecuteFunctions>();
contextFactory.mockReturnValue(mockContext); contextFactory.mockReturnValue(mockContext);
const error = new Error('Execution failed'); const error = new Error('Execution failed');
execute.mockRejectedValueOnce(error); execute.mockRejectedValueOnce(error);
@@ -474,6 +484,7 @@ describe('makeHandleToolInvocation', () => {
contextFactory, contextFactory,
connectedNode, connectedNode,
connectedNodeType, connectedNodeType,
runExecutionData,
); );
const result = await handleToolInvocation(toolArgs); const result = await handleToolInvocation(toolArgs);
@@ -489,14 +500,42 @@ describe('makeHandleToolInvocation', () => {
const mockContext = mock<IExecuteFunctions>(); const mockContext = mock<IExecuteFunctions>();
contextFactory.mockReturnValue(mockContext); contextFactory.mockReturnValue(mockContext);
const handleToolInvocation = makeHandleToolInvocation( let handleToolInvocation = makeHandleToolInvocation(
contextFactory, contextFactory,
connectedNode, connectedNode,
connectedNodeType, connectedNodeType,
runExecutionData,
); );
await handleToolInvocation(toolArgs);
runExecutionData = mock<IRunExecutionData>({
resultData: {
runData: {
[connectedNode.name]: [taskData],
},
},
});
handleToolInvocation = makeHandleToolInvocation(
contextFactory,
connectedNode,
connectedNodeType,
runExecutionData,
);
await handleToolInvocation(toolArgs); await handleToolInvocation(toolArgs);
await handleToolInvocation(toolArgs);
runExecutionData = mock<IRunExecutionData>({
resultData: {
runData: {
[connectedNode.name]: [taskData, taskData],
},
},
});
handleToolInvocation = makeHandleToolInvocation(
contextFactory,
connectedNode,
connectedNodeType,
runExecutionData,
);
await handleToolInvocation(toolArgs); await handleToolInvocation(toolArgs);
expect(contextFactory).toHaveBeenCalledWith(0); expect(contextFactory).toHaveBeenCalledWith(0);

View File

@@ -28,19 +28,28 @@ import type { ExecuteContext, WebhookContext } from '../../node-execution-contex
// eslint-disable-next-line import/no-cycle // eslint-disable-next-line import/no-cycle
import { SupplyDataContext } from '../../node-execution-context/supply-data-context'; import { SupplyDataContext } from '../../node-execution-context/supply-data-context';
function getNextRunIndex(runExecutionData: IRunExecutionData, nodeName: string) {
return runExecutionData.resultData.runData[nodeName]?.length ?? 0;
}
export function makeHandleToolInvocation( export function makeHandleToolInvocation(
contextFactory: (runIndex: number) => ISupplyDataFunctions, contextFactory: (runIndex: number) => ISupplyDataFunctions,
node: INode, node: INode,
nodeType: INodeType, nodeType: INodeType,
runExecutionData: IRunExecutionData,
) { ) {
/** /**
* This keeps track of how many times this specific AI tool node has been invoked. * This keeps track of how many times this specific AI tool node has been invoked.
* It is incremented on every invocation of the tool to keep the output of each invocation separate from each other. * It is incremented on every invocation of the tool to keep the output of each invocation separate from each other.
*/ */
let toolRunIndex = 0; // We get the runIndex from the context to handle multiple executions
// of the same tool when the tool is used in a loop or in a parallel execution.
let runIndex = getNextRunIndex(runExecutionData, node.name);
return async (toolArgs: IDataObject) => { return async (toolArgs: IDataObject) => {
const runIndex = toolRunIndex++; // Increment the runIndex for the next invocation
const context = contextFactory(runIndex); const localRunIndex = runIndex++;
const context = contextFactory(localRunIndex);
context.addInputData(NodeConnectionTypes.AiTool, [[{ json: toolArgs }]]); context.addInputData(NodeConnectionTypes.AiTool, [[{ json: toolArgs }]]);
try { try {
@@ -64,13 +73,13 @@ export function makeHandleToolInvocation(
} }
// Add output data to the context // Add output data to the context
context.addOutputData(NodeConnectionTypes.AiTool, runIndex, [[{ json: { response } }]]); context.addOutputData(NodeConnectionTypes.AiTool, localRunIndex, [[{ json: { response } }]]);
// Return the stringified results // Return the stringified results
return JSON.stringify(response); return JSON.stringify(response);
} catch (error) { } catch (error) {
const nodeError = new NodeOperationError(node, error as Error); const nodeError = new NodeOperationError(node, error as Error);
context.addOutputData(NodeConnectionTypes.AiTool, runIndex, nodeError); context.addOutputData(NodeConnectionTypes.AiTool, localRunIndex, nodeError);
return 'Error during node execution: ' + (nodeError.description ?? nodeError.message); return 'Error during node execution: ' + (nodeError.description ?? nodeError.message);
} }
}; };
@@ -153,6 +162,7 @@ export async function getInputConnectionData(
(i) => contextFactory(i, {}), (i) => contextFactory(i, {}),
connectedNode, connectedNode,
connectedNodeType, connectedNodeType,
runExecutionData,
), ),
}); });
nodes.push(supplyData); nodes.push(supplyData);

View File

@@ -992,6 +992,7 @@ export type ISupplyDataFunctions = ExecuteFunctions.GetNodeParameterFn &
| 'sendMessageToUI' | 'sendMessageToUI'
| 'helpers' | 'helpers'
> & { > & {
getNextRunIndex(): number;
continueOnFail(): boolean; continueOnFail(): boolean;
evaluateExpression(expression: string, itemIndex: number): NodeParameterValueType; evaluateExpression(expression: string, itemIndex: number): NodeParameterValueType;
getWorkflowDataProxy(itemIndex: number): IWorkflowDataProxyData; getWorkflowDataProxy(itemIndex: number): IWorkflowDataProxyData;