mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 10:02:05 +00:00
fix(AI Agent Node): Fix tool calling when tools run in a loop (#15250)
Co-authored-by: JP van Oosten <jp@n8n.io> Co-authored-by: कारतोफ्फेलस्क्रिप्ट™ <aditya@netroy.in>
This commit is contained in:
committed by
GitHub
parent
52f27a76ac
commit
cd1d6c9dfc
@@ -39,6 +39,7 @@ export class ToolWorkflowV2 implements INodeType {
|
|||||||
const description = this.getNodeParameter('description', itemIndex) as string;
|
const description = this.getNodeParameter('description', itemIndex) as string;
|
||||||
|
|
||||||
const tool = await workflowToolService.createTool({
|
const tool = await workflowToolService.createTool({
|
||||||
|
ctx: this,
|
||||||
name,
|
name,
|
||||||
description,
|
description,
|
||||||
itemIndex,
|
itemIndex,
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ import { WorkflowToolService } from './utils/WorkflowToolService';
|
|||||||
|
|
||||||
// Mock ISupplyDataFunctions interface
|
// Mock ISupplyDataFunctions interface
|
||||||
function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDataFunctions {
|
function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDataFunctions {
|
||||||
|
let runIndex = 0;
|
||||||
|
const getNextRunIndex = jest.fn(() => {
|
||||||
|
return runIndex++;
|
||||||
|
});
|
||||||
return {
|
return {
|
||||||
runIndex: 0,
|
runIndex: 0,
|
||||||
getNodeParameter: jest.fn(),
|
getNodeParameter: jest.fn(),
|
||||||
@@ -26,6 +30,7 @@ function createMockContext(overrides?: Partial<ISupplyDataFunctions>): ISupplyDa
|
|||||||
getInputData: jest.fn(),
|
getInputData: jest.fn(),
|
||||||
getMode: jest.fn(),
|
getMode: jest.fn(),
|
||||||
getRestApiUrl: jest.fn(),
|
getRestApiUrl: jest.fn(),
|
||||||
|
getNextRunIndex,
|
||||||
getTimezone: jest.fn(),
|
getTimezone: jest.fn(),
|
||||||
getWorkflow: jest.fn(),
|
getWorkflow: jest.fn(),
|
||||||
getWorkflowStaticData: jest.fn(),
|
getWorkflowStaticData: jest.fn(),
|
||||||
@@ -56,6 +61,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
|
|||||||
describe('createTool', () => {
|
describe('createTool', () => {
|
||||||
it('should create a basic dynamic tool when schema is not used', async () => {
|
it('should create a basic dynamic tool when schema is not used', async () => {
|
||||||
const toolParams = {
|
const toolParams = {
|
||||||
|
ctx: context,
|
||||||
name: 'TestTool',
|
name: 'TestTool',
|
||||||
description: 'Test Description',
|
description: 'Test Description',
|
||||||
itemIndex: 0,
|
itemIndex: 0,
|
||||||
@@ -70,6 +76,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
|
|||||||
|
|
||||||
it('should create a tool that can handle successful execution', async () => {
|
it('should create a tool that can handle successful execution', async () => {
|
||||||
const toolParams = {
|
const toolParams = {
|
||||||
|
ctx: context,
|
||||||
name: 'TestTool',
|
name: 'TestTool',
|
||||||
description: 'Test Description',
|
description: 'Test Description',
|
||||||
itemIndex: 0,
|
itemIndex: 0,
|
||||||
@@ -112,6 +119,7 @@ describe('WorkflowTool::WorkflowToolService', () => {
|
|||||||
|
|
||||||
it('should handle errors during tool execution', async () => {
|
it('should handle errors during tool execution', async () => {
|
||||||
const toolParams = {
|
const toolParams = {
|
||||||
|
ctx: context,
|
||||||
name: 'TestTool',
|
name: 'TestTool',
|
||||||
description: 'Test Description',
|
description: 'Test Description',
|
||||||
itemIndex: 0,
|
itemIndex: 0,
|
||||||
|
|||||||
@@ -60,17 +60,21 @@ export class WorkflowToolService {
|
|||||||
|
|
||||||
// Creates the tool based on the provided parameters
|
// Creates the tool based on the provided parameters
|
||||||
async createTool({
|
async createTool({
|
||||||
|
ctx,
|
||||||
name,
|
name,
|
||||||
description,
|
description,
|
||||||
itemIndex,
|
itemIndex,
|
||||||
}: {
|
}: {
|
||||||
|
ctx: ISupplyDataFunctions;
|
||||||
name: string;
|
name: string;
|
||||||
description: string;
|
description: string;
|
||||||
itemIndex: number;
|
itemIndex: number;
|
||||||
}): Promise<DynamicTool | DynamicStructuredTool> {
|
}): Promise<DynamicTool | DynamicStructuredTool> {
|
||||||
let runIndex = 0;
|
|
||||||
// Handler for the tool execution, will be called when the tool is executed
|
// Handler for the tool execution, will be called when the tool is executed
|
||||||
// This function will execute the sub-workflow and return the response
|
// This function will execute the sub-workflow and return the response
|
||||||
|
// We get the runIndex from the context to handle multiple executions
|
||||||
|
// of the same tool when the tool is used in a loop or in a parallel execution.
|
||||||
|
let runIndex: number = ctx.getNextRunIndex();
|
||||||
const toolHandler = async (
|
const toolHandler = async (
|
||||||
query: string | IDataObject,
|
query: string | IDataObject,
|
||||||
runManager?: CallbackManagerForToolRun,
|
runManager?: CallbackManagerForToolRun,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import type {
|
|||||||
ICredentialDataDecryptedObject,
|
ICredentialDataDecryptedObject,
|
||||||
NodeConnectionType,
|
NodeConnectionType,
|
||||||
} from 'n8n-workflow';
|
} from 'n8n-workflow';
|
||||||
|
import type { IRunData } from 'n8n-workflow';
|
||||||
import { ApplicationError, NodeConnectionTypes } from 'n8n-workflow';
|
import { ApplicationError, NodeConnectionTypes } from 'n8n-workflow';
|
||||||
|
|
||||||
import { describeCommonTests } from './shared-tests';
|
import { describeCommonTests } from './shared-tests';
|
||||||
@@ -186,4 +187,36 @@ describe('SupplyDataContext', () => {
|
|||||||
expect(clone).not.toBe(supplyDataContext);
|
expect(clone).not.toBe(supplyDataContext);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('getNextRunIndex', () => {
|
||||||
|
it('should return 0 as the default latest run index', () => {
|
||||||
|
const latestRunIndex = supplyDataContext.getNextRunIndex();
|
||||||
|
expect(latestRunIndex).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return the length of the run execution data for the node', () => {
|
||||||
|
const runData = mock<IRunData>();
|
||||||
|
const runExecutionData = mock<IRunExecutionData>({
|
||||||
|
resultData: { runData: { [node.name]: [runData, runData] } },
|
||||||
|
});
|
||||||
|
const supplyDataContext = new SupplyDataContext(
|
||||||
|
workflow,
|
||||||
|
node,
|
||||||
|
additionalData,
|
||||||
|
mode,
|
||||||
|
runExecutionData,
|
||||||
|
runIndex,
|
||||||
|
connectionInputData,
|
||||||
|
inputData,
|
||||||
|
connectionType,
|
||||||
|
executeData,
|
||||||
|
[closeFn],
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
|
|
||||||
|
const latestRunIndex = supplyDataContext.getNextRunIndex();
|
||||||
|
|
||||||
|
expect(latestRunIndex).toBe(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -167,16 +167,18 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData
|
|||||||
return super.getInputItems(inputIndex, connectionType) ?? [];
|
return super.getInputItems(inputIndex, connectionType) ?? [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getNextRunIndex(): number {
|
||||||
|
const nodeName = this.node.name;
|
||||||
|
return this.runExecutionData.resultData.runData[nodeName]?.length ?? 0;
|
||||||
|
}
|
||||||
|
|
||||||
/** @deprecated create a context object with inputData for every runIndex */
|
/** @deprecated create a context object with inputData for every runIndex */
|
||||||
addInputData(
|
addInputData(
|
||||||
connectionType: AINodeConnectionType,
|
connectionType: AINodeConnectionType,
|
||||||
data: INodeExecutionData[][],
|
data: INodeExecutionData[][],
|
||||||
): { index: number } {
|
): { index: number } {
|
||||||
const nodeName = this.node.name;
|
const nodeName = this.node.name;
|
||||||
let currentNodeRunIndex = 0;
|
const currentNodeRunIndex = this.getNextRunIndex();
|
||||||
if (this.runExecutionData.resultData.runData.hasOwnProperty(nodeName)) {
|
|
||||||
currentNodeRunIndex = this.runExecutionData.resultData.runData[nodeName].length;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.addExecutionDataFunctions(
|
this.addExecutionDataFunctions(
|
||||||
'input',
|
'input',
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ import type {
|
|||||||
INodeType,
|
INodeType,
|
||||||
INodeTypes,
|
INodeTypes,
|
||||||
IExecuteFunctions,
|
IExecuteFunctions,
|
||||||
|
IRunData,
|
||||||
} from 'n8n-workflow';
|
} from 'n8n-workflow';
|
||||||
|
import type { ITaskData } from 'n8n-workflow';
|
||||||
import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
import { NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
||||||
|
|
||||||
import { ExecuteContext } from '../../execute-context';
|
import { ExecuteContext } from '../../execute-context';
|
||||||
@@ -377,11 +379,19 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
execute,
|
execute,
|
||||||
});
|
});
|
||||||
const contextFactory = jest.fn();
|
const contextFactory = jest.fn();
|
||||||
|
const taskData = mock<ITaskData>();
|
||||||
|
|
||||||
|
let runExecutionData = mock<IRunExecutionData>({
|
||||||
|
resultData: {
|
||||||
|
runData: mock<IRunData>(),
|
||||||
|
},
|
||||||
|
});
|
||||||
const toolArgs = { key: 'value' };
|
const toolArgs = { key: 'value' };
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return stringified results when execution is successful', async () => {
|
it('should return stringified results when execution is successful', async () => {
|
||||||
const mockContext = mock<IExecuteFunctions>();
|
const mockContext = mock<IExecuteFunctions>();
|
||||||
contextFactory.mockReturnValue(mockContext);
|
contextFactory.mockReturnValue(mockContext);
|
||||||
@@ -393,6 +403,7 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
contextFactory,
|
contextFactory,
|
||||||
connectedNode,
|
connectedNode,
|
||||||
connectedNodeType,
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
);
|
);
|
||||||
const result = await handleToolInvocation(toolArgs);
|
const result = await handleToolInvocation(toolArgs);
|
||||||
|
|
||||||
@@ -405,7 +416,6 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
it('should handle binary data and return a warning message', async () => {
|
it('should handle binary data and return a warning message', async () => {
|
||||||
const mockContext = mock<IExecuteFunctions>();
|
const mockContext = mock<IExecuteFunctions>();
|
||||||
contextFactory.mockReturnValue(mockContext);
|
contextFactory.mockReturnValue(mockContext);
|
||||||
|
|
||||||
const mockResult = [[{ json: {}, binary: { file: 'data' } }]];
|
const mockResult = [[{ json: {}, binary: { file: 'data' } }]];
|
||||||
execute.mockResolvedValueOnce(mockResult);
|
execute.mockResolvedValueOnce(mockResult);
|
||||||
|
|
||||||
@@ -413,6 +423,7 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
contextFactory,
|
contextFactory,
|
||||||
connectedNode,
|
connectedNode,
|
||||||
connectedNodeType,
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
);
|
);
|
||||||
const result = await handleToolInvocation(toolArgs);
|
const result = await handleToolInvocation(toolArgs);
|
||||||
|
|
||||||
@@ -439,7 +450,6 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
contextFactory.mockReturnValue(mockContext);
|
contextFactory.mockReturnValue(mockContext);
|
||||||
|
|
||||||
const mockResult = [[{ json: { a: 3 }, binary: { file: 'data' } }]];
|
const mockResult = [[{ json: { a: 3 }, binary: { file: 'data' } }]];
|
||||||
execute.mockResolvedValueOnce(mockResult);
|
execute.mockResolvedValueOnce(mockResult);
|
||||||
|
|
||||||
@@ -447,6 +457,7 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
contextFactory,
|
contextFactory,
|
||||||
connectedNode,
|
connectedNode,
|
||||||
connectedNodeType,
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
);
|
);
|
||||||
const result = await handleToolInvocation(toolArgs);
|
const result = await handleToolInvocation(toolArgs);
|
||||||
|
|
||||||
@@ -466,7 +477,6 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
it('should handle execution errors and return an error message', async () => {
|
it('should handle execution errors and return an error message', async () => {
|
||||||
const mockContext = mock<IExecuteFunctions>();
|
const mockContext = mock<IExecuteFunctions>();
|
||||||
contextFactory.mockReturnValue(mockContext);
|
contextFactory.mockReturnValue(mockContext);
|
||||||
|
|
||||||
const error = new Error('Execution failed');
|
const error = new Error('Execution failed');
|
||||||
execute.mockRejectedValueOnce(error);
|
execute.mockRejectedValueOnce(error);
|
||||||
|
|
||||||
@@ -474,6 +484,7 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
contextFactory,
|
contextFactory,
|
||||||
connectedNode,
|
connectedNode,
|
||||||
connectedNodeType,
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
);
|
);
|
||||||
const result = await handleToolInvocation(toolArgs);
|
const result = await handleToolInvocation(toolArgs);
|
||||||
|
|
||||||
@@ -489,14 +500,42 @@ describe('makeHandleToolInvocation', () => {
|
|||||||
const mockContext = mock<IExecuteFunctions>();
|
const mockContext = mock<IExecuteFunctions>();
|
||||||
contextFactory.mockReturnValue(mockContext);
|
contextFactory.mockReturnValue(mockContext);
|
||||||
|
|
||||||
const handleToolInvocation = makeHandleToolInvocation(
|
let handleToolInvocation = makeHandleToolInvocation(
|
||||||
contextFactory,
|
contextFactory,
|
||||||
connectedNode,
|
connectedNode,
|
||||||
connectedNodeType,
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
);
|
);
|
||||||
|
await handleToolInvocation(toolArgs);
|
||||||
|
|
||||||
|
runExecutionData = mock<IRunExecutionData>({
|
||||||
|
resultData: {
|
||||||
|
runData: {
|
||||||
|
[connectedNode.name]: [taskData],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
handleToolInvocation = makeHandleToolInvocation(
|
||||||
|
contextFactory,
|
||||||
|
connectedNode,
|
||||||
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
|
);
|
||||||
await handleToolInvocation(toolArgs);
|
await handleToolInvocation(toolArgs);
|
||||||
await handleToolInvocation(toolArgs);
|
|
||||||
|
runExecutionData = mock<IRunExecutionData>({
|
||||||
|
resultData: {
|
||||||
|
runData: {
|
||||||
|
[connectedNode.name]: [taskData, taskData],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
handleToolInvocation = makeHandleToolInvocation(
|
||||||
|
contextFactory,
|
||||||
|
connectedNode,
|
||||||
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
|
);
|
||||||
await handleToolInvocation(toolArgs);
|
await handleToolInvocation(toolArgs);
|
||||||
|
|
||||||
expect(contextFactory).toHaveBeenCalledWith(0);
|
expect(contextFactory).toHaveBeenCalledWith(0);
|
||||||
|
|||||||
@@ -28,19 +28,28 @@ import type { ExecuteContext, WebhookContext } from '../../node-execution-contex
|
|||||||
// eslint-disable-next-line import/no-cycle
|
// eslint-disable-next-line import/no-cycle
|
||||||
import { SupplyDataContext } from '../../node-execution-context/supply-data-context';
|
import { SupplyDataContext } from '../../node-execution-context/supply-data-context';
|
||||||
|
|
||||||
|
function getNextRunIndex(runExecutionData: IRunExecutionData, nodeName: string) {
|
||||||
|
return runExecutionData.resultData.runData[nodeName]?.length ?? 0;
|
||||||
|
}
|
||||||
|
|
||||||
export function makeHandleToolInvocation(
|
export function makeHandleToolInvocation(
|
||||||
contextFactory: (runIndex: number) => ISupplyDataFunctions,
|
contextFactory: (runIndex: number) => ISupplyDataFunctions,
|
||||||
node: INode,
|
node: INode,
|
||||||
nodeType: INodeType,
|
nodeType: INodeType,
|
||||||
|
runExecutionData: IRunExecutionData,
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* This keeps track of how many times this specific AI tool node has been invoked.
|
* This keeps track of how many times this specific AI tool node has been invoked.
|
||||||
* It is incremented on every invocation of the tool to keep the output of each invocation separate from each other.
|
* It is incremented on every invocation of the tool to keep the output of each invocation separate from each other.
|
||||||
*/
|
*/
|
||||||
let toolRunIndex = 0;
|
// We get the runIndex from the context to handle multiple executions
|
||||||
|
// of the same tool when the tool is used in a loop or in a parallel execution.
|
||||||
|
let runIndex = getNextRunIndex(runExecutionData, node.name);
|
||||||
|
|
||||||
return async (toolArgs: IDataObject) => {
|
return async (toolArgs: IDataObject) => {
|
||||||
const runIndex = toolRunIndex++;
|
// Increment the runIndex for the next invocation
|
||||||
const context = contextFactory(runIndex);
|
const localRunIndex = runIndex++;
|
||||||
|
const context = contextFactory(localRunIndex);
|
||||||
context.addInputData(NodeConnectionTypes.AiTool, [[{ json: toolArgs }]]);
|
context.addInputData(NodeConnectionTypes.AiTool, [[{ json: toolArgs }]]);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -64,13 +73,13 @@ export function makeHandleToolInvocation(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add output data to the context
|
// Add output data to the context
|
||||||
context.addOutputData(NodeConnectionTypes.AiTool, runIndex, [[{ json: { response } }]]);
|
context.addOutputData(NodeConnectionTypes.AiTool, localRunIndex, [[{ json: { response } }]]);
|
||||||
|
|
||||||
// Return the stringified results
|
// Return the stringified results
|
||||||
return JSON.stringify(response);
|
return JSON.stringify(response);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const nodeError = new NodeOperationError(node, error as Error);
|
const nodeError = new NodeOperationError(node, error as Error);
|
||||||
context.addOutputData(NodeConnectionTypes.AiTool, runIndex, nodeError);
|
context.addOutputData(NodeConnectionTypes.AiTool, localRunIndex, nodeError);
|
||||||
return 'Error during node execution: ' + (nodeError.description ?? nodeError.message);
|
return 'Error during node execution: ' + (nodeError.description ?? nodeError.message);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -153,6 +162,7 @@ export async function getInputConnectionData(
|
|||||||
(i) => contextFactory(i, {}),
|
(i) => contextFactory(i, {}),
|
||||||
connectedNode,
|
connectedNode,
|
||||||
connectedNodeType,
|
connectedNodeType,
|
||||||
|
runExecutionData,
|
||||||
),
|
),
|
||||||
});
|
});
|
||||||
nodes.push(supplyData);
|
nodes.push(supplyData);
|
||||||
|
|||||||
@@ -992,6 +992,7 @@ export type ISupplyDataFunctions = ExecuteFunctions.GetNodeParameterFn &
|
|||||||
| 'sendMessageToUI'
|
| 'sendMessageToUI'
|
||||||
| 'helpers'
|
| 'helpers'
|
||||||
> & {
|
> & {
|
||||||
|
getNextRunIndex(): number;
|
||||||
continueOnFail(): boolean;
|
continueOnFail(): boolean;
|
||||||
evaluateExpression(expression: string, itemIndex: number): NodeParameterValueType;
|
evaluateExpression(expression: string, itemIndex: number): NodeParameterValueType;
|
||||||
getWorkflowDataProxy(itemIndex: number): IWorkflowDataProxyData;
|
getWorkflowDataProxy(itemIndex: number): IWorkflowDataProxyData;
|
||||||
|
|||||||
Reference in New Issue
Block a user