mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 01:56:46 +00:00
fix(AI Agent Node): Fix tool calling when tools run in a loop (#15250)
Co-authored-by: JP van Oosten <jp@n8n.io> Co-authored-by: कारतोफ्फेलस्क्रिप्ट™ <aditya@netroy.in>
This commit is contained in:
committed by
GitHub
parent
52f27a76ac
commit
cd1d6c9dfc
@@ -39,6 +39,7 @@ export class ToolWorkflowV2 implements INodeType {
|
||||
const description = this.getNodeParameter('description', itemIndex) as string;
|
||||
|
||||
const tool = await workflowToolService.createTool({
|
||||
ctx: this,
|
||||
name,
|
||||
description,
|
||||
itemIndex,
|
||||
|
||||
@@ -13,6 +13,10 @@ import { WorkflowToolService } from './utils/WorkflowToolService';
|
||||
|
||||
// Mock ISupplyDataFunctions interface
|
||||
function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDataFunctions {
|
||||
let runIndex = 0;
|
||||
const getNextRunIndex = jest.fn(() => {
|
||||
return runIndex++;
|
||||
});
|
||||
return {
|
||||
runIndex: 0,
|
||||
getNodeParameter: jest.fn(),
|
||||
@@ -26,6 +30,7 @@ function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDa
|
||||
getInputData: jest.fn(),
|
||||
getMode: jest.fn(),
|
||||
getRestApiUrl: jest.fn(),
|
||||
getNextRunIndex,
|
||||
getTimezone: jest.fn(),
|
||||
getWorkflow: jest.fn(),
|
||||
getWorkflowStaticData: jest.fn(),
|
||||
@@ -56,6 +61,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
|
||||
describe('createTool', () => {
|
||||
it('should create a basic dynamic tool when schema is not used', async () => {
|
||||
const toolParams = {
|
||||
ctx: context,
|
||||
name: 'TestTool',
|
||||
description: 'Test Description',
|
||||
itemIndex: 0,
|
||||
@@ -70,6 +76,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
|
||||
|
||||
it('should create a tool that can handle successful execution', async () => {
|
||||
const toolParams = {
|
||||
ctx: context,
|
||||
name: 'TestTool',
|
||||
description: 'Test Description',
|
||||
itemIndex: 0,
|
||||
@@ -112,6 +119,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
|
||||
|
||||
it('should handle errors during tool execution', async () => {
|
||||
const toolParams = {
|
||||
ctx: context,
|
||||
name: 'TestTool',
|
||||
description: 'Test Description',
|
||||
itemIndex: 0,
|
||||
|
||||
@@ -60,17 +60,21 @@ export class WorkflowToolService {
|
||||
|
||||
// Creates the tool based on the provided parameters
|
||||
async createTool({
|
||||
ctx,
|
||||
name,
|
||||
description,
|
||||
itemIndex,
|
||||
}: {
|
||||
ctx: ISupplyDataFunctions;
|
||||
name: string;
|
||||
description: string;
|
||||
itemIndex: number;
|
||||
}): Promise<DynamicTool | DynamicStructuredTool> {
|
||||
let runIndex = 0;
|
||||
// Handler for the tool execution, will be called when the tool is executed
|
||||
// This function will execute the sub-workflow and return the response
|
||||
// We get the runIndex from the context to handle multiple executions
|
||||
// of the same tool when the tool is used in a loop or in a parallel execution.
|
||||
let runIndex: number = ctx.getNextRunIndex();
|
||||
const toolHandler = async (
|
||||
query: string | IDataObject,
|
||||
runManager?: CallbackManagerForToolRun,
|
||||
|
||||
@@ -15,6 +15,7 @@ import type {
|
||||
ICredentialDataDecryptedObject,
|
||||
NodeConnectionType,
|
||||
} from 'n8n-workflow';
|
||||
import type { IRunData } from 'n8n-workflow';
|
||||
import { ApplicationError, NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { describeCommonTests } from './shared-tests';
|
||||
@@ -186,4 +187,36 @@ describe('SupplyDataContext', () => {
|
||||
expect(clone).not.toBe(supplyDataContext);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getNextRunIndex', () => {
|
||||
it('should return 0 as the default latest run index', () => {
|
||||
const latestRunIndex = supplyDataContext.getNextRunIndex();
|
||||
expect(latestRunIndex).toBe(0);
|
||||
});
|
||||
|
||||
it('should return the length of the run execution data for the node', () => {
|
||||
const runData = mock<IRunData>();
|
||||
const runExecutionData = mock<IRunExecutionData>({
|
||||
resultData: { runData: { [node.name]: [runData, runData] } },
|
||||
});
|
||||
const supplyDataContext = new SupplyDataContext(
|
||||
workflow,
|
||||
node,
|
||||
additionalData,
|
||||
mode,
|
||||
runExecutionData,
|
||||
runIndex,
|
||||
connectionInputData,
|
||||
inputData,
|
||||
connectionType,
|
||||
executeData,
|
||||
[closeFn],
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
const latestRunIndex = supplyDataContext.getNextRunIndex();
|
||||
|
||||
expect(latestRunIndex).toBe(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -167,16 +167,18 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData
|
||||
return super.getInputItems(inputIndex, connectionType) ?? [];
|
||||
}
|
||||
|
||||
getNextRunIndex(): number {
|
||||
const nodeName = this.node.name;
|
||||
return this.runExecutionData.resultData.runData[nodeName]?.length ?? 0;
|
||||
}
|
||||
|
||||
/** @deprecated create a context object with inputData for every runIndex */
|
||||
addInputData(
|
||||
connectionType: AINodeConnectionType,
|
||||
data: INodeExecutionData[][],
|
||||
): { index: number } {
|
||||
const nodeName = this.node.name;
|
||||
let currentNodeRunIndex = 0;
|
||||
if (this.runExecutionData.resultData.runData.hasOwnProperty(nodeName)) {
|
||||
currentNodeRunIndex = this.runExecutionData.resultData.runData[nodeName].length;
|
||||
}
|
||||
const currentNodeRunIndex = this.getNextRunIndex();
|
||||
|
||||
this.addExecutionDataFunctions(
|
||||
'input',
|
||||
|
||||
@@ -11,7 +11,9 @@ import type {
|
||||
INodeType,
|
||||
INodeTypes,
|
||||
IExecuteFunctions,
|
||||
IRunData,
|
||||
} from 'n8n-workflow';
|
||||
import type { ITaskData } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
||||
|
||||
import { ExecuteContext } from '../../execute-context';
|
||||
@@ -377,11 +379,19 @@ describe('makeHandleToolInvocation', () => {
|
||||
execute,
|
||||
});
|
||||
const contextFactory = jest.fn();
|
||||
const taskData = mock<ITaskData>();
|
||||
|
||||
let runExecutionData = mock<IRunExecutionData>({
|
||||
resultData: {
|
||||
runData: mock<IRunData>(),
|
||||
},
|
||||
});
|
||||
const toolArgs = { key: 'value' };
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return stringified results when execution is successful', async () => {
|
||||
const mockContext = mock<IExecuteFunctions>();
|
||||
contextFactory.mockReturnValue(mockContext);
|
||||
@@ -393,6 +403,7 @@ describe('makeHandleToolInvocation', () => {
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
const result = await handleToolInvocation(toolArgs);
|
||||
|
||||
@@ -405,7 +416,6 @@ describe('makeHandleToolInvocation', () => {
|
||||
it('should handle binary data and return a warning message', async () => {
|
||||
const mockContext = mock<IExecuteFunctions>();
|
||||
contextFactory.mockReturnValue(mockContext);
|
||||
|
||||
const mockResult = [[{ json: {}, binary: { file: 'data' } }]];
|
||||
execute.mockResolvedValueOnce(mockResult);
|
||||
|
||||
@@ -413,6 +423,7 @@ describe('makeHandleToolInvocation', () => {
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
const result = await handleToolInvocation(toolArgs);
|
||||
|
||||
@@ -439,7 +450,6 @@ describe('makeHandleToolInvocation', () => {
|
||||
},
|
||||
});
|
||||
contextFactory.mockReturnValue(mockContext);
|
||||
|
||||
const mockResult = [[{ json: { a: 3 }, binary: { file: 'data' } }]];
|
||||
execute.mockResolvedValueOnce(mockResult);
|
||||
|
||||
@@ -447,6 +457,7 @@ describe('makeHandleToolInvocation', () => {
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
const result = await handleToolInvocation(toolArgs);
|
||||
|
||||
@@ -466,7 +477,6 @@ describe('makeHandleToolInvocation', () => {
|
||||
it('should handle execution errors and return an error message', async () => {
|
||||
const mockContext = mock<IExecuteFunctions>();
|
||||
contextFactory.mockReturnValue(mockContext);
|
||||
|
||||
const error = new Error('Execution failed');
|
||||
execute.mockRejectedValueOnce(error);
|
||||
|
||||
@@ -474,6 +484,7 @@ describe('makeHandleToolInvocation', () => {
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
const result = await handleToolInvocation(toolArgs);
|
||||
|
||||
@@ -489,14 +500,42 @@ describe('makeHandleToolInvocation', () => {
|
||||
const mockContext = mock<IExecuteFunctions>();
|
||||
contextFactory.mockReturnValue(mockContext);
|
||||
|
||||
const handleToolInvocation = makeHandleToolInvocation(
|
||||
let handleToolInvocation = makeHandleToolInvocation(
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
await handleToolInvocation(toolArgs);
|
||||
|
||||
runExecutionData = mock<IRunExecutionData>({
|
||||
resultData: {
|
||||
runData: {
|
||||
[connectedNode.name]: [taskData],
|
||||
},
|
||||
},
|
||||
});
|
||||
handleToolInvocation = makeHandleToolInvocation(
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
await handleToolInvocation(toolArgs);
|
||||
await handleToolInvocation(toolArgs);
|
||||
|
||||
runExecutionData = mock<IRunExecutionData>({
|
||||
resultData: {
|
||||
runData: {
|
||||
[connectedNode.name]: [taskData, taskData],
|
||||
},
|
||||
},
|
||||
});
|
||||
handleToolInvocation = makeHandleToolInvocation(
|
||||
contextFactory,
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
);
|
||||
await handleToolInvocation(toolArgs);
|
||||
|
||||
expect(contextFactory).toHaveBeenCalledWith(0);
|
||||
|
||||
@@ -28,19 +28,28 @@ import type { ExecuteContext, WebhookContext } from '../../node-execution-contex
|
||||
// eslint-disable-next-line import/no-cycle
|
||||
import { SupplyDataContext } from '../../node-execution-context/supply-data-context';
|
||||
|
||||
function getNextRunIndex(runExecutionData: IRunExecutionData, nodeName: string) {
|
||||
return runExecutionData.resultData.runData[nodeName]?.length ?? 0;
|
||||
}
|
||||
|
||||
export function makeHandleToolInvocation(
|
||||
contextFactory: (runIndex: number) => ISupplyDataFunctions,
|
||||
node: INode,
|
||||
nodeType: INodeType,
|
||||
runExecutionData: IRunExecutionData,
|
||||
) {
|
||||
/**
|
||||
* This keeps track of how many times this specific AI tool node has been invoked.
|
||||
* It is incremented on every invocation of the tool to keep the output of each invocation separate from each other.
|
||||
*/
|
||||
let toolRunIndex = 0;
|
||||
// We get the runIndex from the context to handle multiple executions
|
||||
// of the same tool when the tool is used in a loop or in a parallel execution.
|
||||
let runIndex = getNextRunIndex(runExecutionData, node.name);
|
||||
|
||||
return async (toolArgs: IDataObject) => {
|
||||
const runIndex = toolRunIndex++;
|
||||
const context = contextFactory(runIndex);
|
||||
// Increment the runIndex for the next invocation
|
||||
const localRunIndex = runIndex++;
|
||||
const context = contextFactory(localRunIndex);
|
||||
context.addInputData(NodeConnectionTypes.AiTool, [[{ json: toolArgs }]]);
|
||||
|
||||
try {
|
||||
@@ -64,13 +73,13 @@ export function makeHandleToolInvocation(
|
||||
}
|
||||
|
||||
// Add output data to the context
|
||||
context.addOutputData(NodeConnectionTypes.AiTool, runIndex, [[{ json: { response } }]]);
|
||||
context.addOutputData(NodeConnectionTypes.AiTool, localRunIndex, [[{ json: { response } }]]);
|
||||
|
||||
// Return the stringified results
|
||||
return JSON.stringify(response);
|
||||
} catch (error) {
|
||||
const nodeError = new NodeOperationError(node, error as Error);
|
||||
context.addOutputData(NodeConnectionTypes.AiTool, runIndex, nodeError);
|
||||
context.addOutputData(NodeConnectionTypes.AiTool, localRunIndex, nodeError);
|
||||
return 'Error during node execution: ' + (nodeError.description ?? nodeError.message);
|
||||
}
|
||||
};
|
||||
@@ -153,6 +162,7 @@ export async function getInputConnectionData(
|
||||
(i) => contextFactory(i, {}),
|
||||
connectedNode,
|
||||
connectedNodeType,
|
||||
runExecutionData,
|
||||
),
|
||||
});
|
||||
nodes.push(supplyData);
|
||||
|
||||
@@ -992,6 +992,7 @@ export type ISupplyDataFunctions = ExecuteFunctions.GetNodeParameterFn &
|
||||
| 'sendMessageToUI'
|
||||
| 'helpers'
|
||||
> & {
|
||||
getNextRunIndex(): number;
|
||||
continueOnFail(): boolean;
|
||||
evaluateExpression(expression: string, itemIndex: number): NodeParameterValueType;
|
||||
getWorkflowDataProxy(itemIndex: number): IWorkflowDataProxyData;
|
||||
|
||||
Reference in New Issue
Block a user