mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 01:56:46 +00:00
feat: Add fallback mechanism for agent and basic chain llm (#16617)
This commit is contained in:
@@ -17,33 +17,25 @@ import { toolsAgentExecute } from '../agents/ToolsAgent/V2/execute';
|
||||
|
||||
// Function used in the inputs expression to figure out which inputs to
|
||||
// display based on the agent type
|
||||
function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeInputConfiguration> {
|
||||
function getInputs(
|
||||
hasOutputParser?: boolean,
|
||||
needsFallback?: boolean,
|
||||
): Array<NodeConnectionType | INodeInputConfiguration> {
|
||||
interface SpecialInput {
|
||||
type: NodeConnectionType;
|
||||
filter?: INodeInputFilter;
|
||||
displayName: string;
|
||||
required?: boolean;
|
||||
}
|
||||
|
||||
const getInputData = (
|
||||
inputs: SpecialInput[],
|
||||
): Array<NodeConnectionType | INodeInputConfiguration> => {
|
||||
const displayNames: { [key: string]: string } = {
|
||||
ai_languageModel: 'Model',
|
||||
ai_memory: 'Memory',
|
||||
ai_tool: 'Tool',
|
||||
ai_outputParser: 'Output Parser',
|
||||
};
|
||||
|
||||
return inputs.map(({ type, filter }) => {
|
||||
const isModelType = type === 'ai_languageModel';
|
||||
let displayName = type in displayNames ? displayNames[type] : undefined;
|
||||
if (isModelType) {
|
||||
displayName = 'Chat Model';
|
||||
}
|
||||
return inputs.map(({ type, filter, displayName, required }) => {
|
||||
const input: INodeInputConfiguration = {
|
||||
type,
|
||||
displayName,
|
||||
required: isModelType,
|
||||
required,
|
||||
maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes(
|
||||
type as NodeConnectionType,
|
||||
)
|
||||
@@ -62,33 +54,40 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
|
||||
let specialInputs: SpecialInput[] = [
|
||||
{
|
||||
type: 'ai_languageModel',
|
||||
displayName: 'Chat Model',
|
||||
required: true,
|
||||
filter: {
|
||||
nodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmChatAnthropic',
|
||||
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmChatAwsBedrock',
|
||||
'@n8n/n8n-nodes-langchain.lmChatMistralCloud',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOllama',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGroq',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGoogleVertex',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGoogleGemini',
|
||||
'@n8n/n8n-nodes-langchain.lmChatDeepSeek',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOpenRouter',
|
||||
'@n8n/n8n-nodes-langchain.lmChatXAiGrok',
|
||||
'@n8n/n8n-nodes-langchain.code',
|
||||
'@n8n/n8n-nodes-langchain.modelSelector',
|
||||
excludedNodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmCohere',
|
||||
'@n8n/n8n-nodes-langchain.lmOllama',
|
||||
'n8n/n8n-nodes-langchain.lmOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'ai_languageModel',
|
||||
displayName: 'Fallback Model',
|
||||
required: true,
|
||||
filter: {
|
||||
excludedNodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmCohere',
|
||||
'@n8n/n8n-nodes-langchain.lmOllama',
|
||||
'n8n/n8n-nodes-langchain.lmOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Memory',
|
||||
type: 'ai_memory',
|
||||
},
|
||||
{
|
||||
displayName: 'Tool',
|
||||
type: 'ai_tool',
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
displayName: 'Output Parser',
|
||||
type: 'ai_outputParser',
|
||||
},
|
||||
];
|
||||
@@ -96,6 +95,9 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
|
||||
if (hasOutputParser === false) {
|
||||
specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser');
|
||||
}
|
||||
if (needsFallback === false) {
|
||||
specialInputs = specialInputs.filter((input) => input.displayName !== 'Fallback Model');
|
||||
}
|
||||
return ['main', ...getInputData(specialInputs)];
|
||||
}
|
||||
|
||||
@@ -111,10 +113,10 @@ export class AgentV2 implements INodeType {
|
||||
color: '#404040',
|
||||
},
|
||||
inputs: `={{
|
||||
((hasOutputParser) => {
|
||||
((hasOutputParser, needsFallback) => {
|
||||
${getInputs.toString()};
|
||||
return getInputs(hasOutputParser)
|
||||
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true)
|
||||
return getInputs(hasOutputParser, needsFallback)
|
||||
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true, $parameter.needsFallback === undefined || $parameter.needsFallback === true)
|
||||
}}`,
|
||||
outputs: [NodeConnectionTypes.Main],
|
||||
properties: [
|
||||
@@ -160,6 +162,25 @@ export class AgentV2 implements INodeType {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Enable Fallback Model',
|
||||
name: 'needsFallback',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
noDataExpression: true,
|
||||
},
|
||||
{
|
||||
displayName:
|
||||
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
|
||||
name: 'fallbackNotice',
|
||||
type: 'notice',
|
||||
default: '',
|
||||
displayOptions: {
|
||||
show: {
|
||||
needsFallback: [true],
|
||||
},
|
||||
},
|
||||
},
|
||||
...toolsAgentProperties,
|
||||
],
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
|
||||
import omit from 'lodash/omit';
|
||||
@@ -40,7 +41,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const model = await getChatModel(this);
|
||||
const model = (await getChatModel(this)) as BaseLanguageModel;
|
||||
const memory = await getOptionalMemory(this);
|
||||
|
||||
const input = getPromptInputByType({
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import type { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
|
||||
import type { BaseChatMemory } from 'langchain/memory';
|
||||
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
|
||||
import omit from 'lodash/omit';
|
||||
import { jsonParse, NodeOperationError, sleep } from 'n8n-workflow';
|
||||
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
import {
|
||||
getOptionalOutputParser,
|
||||
type N8nOutputParser,
|
||||
} from '@utils/output_parsers/N8nOutputParser';
|
||||
|
||||
import {
|
||||
fixEmptyContentMessage,
|
||||
@@ -19,6 +26,41 @@ import {
|
||||
} from '../common';
|
||||
import { SYSTEM_MESSAGE } from '../prompt';
|
||||
|
||||
/**
|
||||
* Creates an agent executor with the given configuration
|
||||
*/
|
||||
function createAgentExecutor(
|
||||
model: BaseChatModel,
|
||||
tools: Array<DynamicStructuredTool | Tool>,
|
||||
prompt: ChatPromptTemplate,
|
||||
options: { maxIterations?: number; returnIntermediateSteps?: boolean },
|
||||
outputParser?: N8nOutputParser,
|
||||
memory?: BaseChatMemory,
|
||||
fallbackModel?: BaseChatModel | null,
|
||||
) {
|
||||
const modelWithFallback = fallbackModel ? model.withFallbacks([fallbackModel]) : model;
|
||||
const agent = createToolCallingAgent({
|
||||
llm: modelWithFallback,
|
||||
tools,
|
||||
prompt,
|
||||
streamRunnable: false,
|
||||
});
|
||||
|
||||
const runnableAgent = RunnableSequence.from([
|
||||
agent,
|
||||
getAgentStepsParser(outputParser, memory),
|
||||
fixEmptyContentMessage,
|
||||
]);
|
||||
|
||||
return AgentExecutor.fromAgentAndTools({
|
||||
agent: runnableAgent,
|
||||
memory,
|
||||
tools,
|
||||
returnIntermediateSteps: options.returnIntermediateSteps === true,
|
||||
maxIterations: options.maxIterations ?? 10,
|
||||
});
|
||||
}
|
||||
|
||||
/* -----------------------------------------------------------
|
||||
Main Executor Function
|
||||
----------------------------------------------------------- */
|
||||
@@ -42,8 +84,18 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
0,
|
||||
0,
|
||||
) as number;
|
||||
const needsFallback = this.getNodeParameter('needsFallback', 0, false) as boolean;
|
||||
const memory = await getOptionalMemory(this);
|
||||
const model = await getChatModel(this);
|
||||
const model = await getChatModel(this, 0);
|
||||
assert(model, 'Please connect a model to the Chat Model input');
|
||||
const fallbackModel = needsFallback ? await getChatModel(this, 1) : null;
|
||||
|
||||
if (needsFallback && !fallbackModel) {
|
||||
throw new NodeOperationError(
|
||||
this.getNode(),
|
||||
'Please connect a model to the Fallback Model input or disable the fallback option',
|
||||
);
|
||||
}
|
||||
|
||||
for (let i = 0; i < items.length; i += batchSize) {
|
||||
const batch = items.slice(i, i + batchSize);
|
||||
@@ -57,7 +109,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
promptTypeKey: 'promptType',
|
||||
});
|
||||
if (input === undefined) {
|
||||
throw new NodeOperationError(this.getNode(), 'The “text” parameter is empty.');
|
||||
throw new NodeOperationError(this.getNode(), 'The "text" parameter is empty.');
|
||||
}
|
||||
const outputParser = await getOptionalOutputParser(this, itemIndex);
|
||||
const tools = await getTools(this, outputParser);
|
||||
@@ -76,38 +128,26 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
});
|
||||
const prompt: ChatPromptTemplate = preparePrompt(messages);
|
||||
|
||||
// Create the base agent that calls tools.
|
||||
const agent = createToolCallingAgent({
|
||||
llm: model,
|
||||
// Create executors for primary and fallback models
|
||||
const executor = createAgentExecutor(
|
||||
model,
|
||||
tools,
|
||||
prompt,
|
||||
streamRunnable: false,
|
||||
});
|
||||
agent.streamRunnable = false;
|
||||
// Wrap the agent with parsers and fixes.
|
||||
const runnableAgent = RunnableSequence.from([
|
||||
agent,
|
||||
getAgentStepsParser(outputParser, memory),
|
||||
fixEmptyContentMessage,
|
||||
]);
|
||||
const executor = AgentExecutor.fromAgentAndTools({
|
||||
agent: runnableAgent,
|
||||
options,
|
||||
outputParser,
|
||||
memory,
|
||||
tools,
|
||||
returnIntermediateSteps: options.returnIntermediateSteps === true,
|
||||
maxIterations: options.maxIterations ?? 10,
|
||||
});
|
||||
|
||||
// Invoke the executor with the given input and system message.
|
||||
return await executor.invoke(
|
||||
{
|
||||
input,
|
||||
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
|
||||
formatting_instructions:
|
||||
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
|
||||
},
|
||||
{ signal: this.getExecutionCancelSignal() },
|
||||
fallbackModel,
|
||||
);
|
||||
// Invoke with fallback logic
|
||||
const invokeParams = {
|
||||
input,
|
||||
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
|
||||
formatting_instructions:
|
||||
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
|
||||
};
|
||||
const executeOptions = { signal: this.getExecutionCancelSignal() };
|
||||
|
||||
return await executor.invoke(invokeParams, executeOptions);
|
||||
});
|
||||
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
|
||||
@@ -263,8 +263,25 @@ export const getAgentStepsParser =
|
||||
* @param ctx - The execution context
|
||||
* @returns The validated chat model
|
||||
*/
|
||||
export async function getChatModel(ctx: IExecuteFunctions): Promise<BaseChatModel> {
|
||||
const model = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
|
||||
export async function getChatModel(
|
||||
ctx: IExecuteFunctions,
|
||||
index: number = 0,
|
||||
): Promise<BaseChatModel | undefined> {
|
||||
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
|
||||
|
||||
let model;
|
||||
|
||||
if (Array.isArray(connectedModels) && index !== undefined) {
|
||||
if (connectedModels.length <= index) {
|
||||
return undefined;
|
||||
}
|
||||
// We get the models in reversed order from the workflow so we need to reverse them to match the right index
|
||||
const reversedModels = [...connectedModels].reverse();
|
||||
model = reversedModels[index] as BaseChatModel;
|
||||
} else {
|
||||
model = connectedModels as BaseChatModel;
|
||||
}
|
||||
|
||||
if (!isChatInstance(model) || !model.bindTools) {
|
||||
throw new NodeOperationError(
|
||||
ctx.getNode(),
|
||||
|
||||
@@ -52,6 +52,7 @@ describe('toolsAgentExecute', () => {
|
||||
// Mock getNodeParameter to return default values
|
||||
mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => {
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options.batching.batchSize') return defaultValue;
|
||||
if (param === 'options.batching.delayBetweenBatches') return defaultValue;
|
||||
if (param === 'options')
|
||||
@@ -104,6 +105,7 @@ describe('toolsAgentExecute', () => {
|
||||
if (param === 'options.batching.batchSize') return 2;
|
||||
if (param === 'options.batching.delayBetweenBatches') return 100;
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
@@ -157,6 +159,7 @@ describe('toolsAgentExecute', () => {
|
||||
if (param === 'options.batching.batchSize') return 2;
|
||||
if (param === 'options.batching.delayBetweenBatches') return 0;
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
@@ -206,6 +209,7 @@ describe('toolsAgentExecute', () => {
|
||||
if (param === 'options.batching.batchSize') return 2;
|
||||
if (param === 'options.batching.delayBetweenBatches') return 0;
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
|
||||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
|
||||
import { FakeLLM, FakeStreamingChatModel } from '@langchain/core/utils/testing';
|
||||
import { Buffer } from 'buffer';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser';
|
||||
@@ -163,6 +164,72 @@ describe('getChatModel', () => {
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
|
||||
});
|
||||
|
||||
it('should return the first model when multiple models are connected and no index specified', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
const fakeChatModel2 = new FakeStreamingChatModel({});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
|
||||
|
||||
const model = await getChatModel(mockContext);
|
||||
expect(model).toEqual(fakeChatModel2); // Should return the last model (reversed array)
|
||||
});
|
||||
|
||||
it('should return the model at specified index when multiple models are connected', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
|
||||
const fakeChatModel2 = new FakeStreamingChatModel({});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
|
||||
|
||||
const model = await getChatModel(mockContext, 0);
|
||||
expect(model).toEqual(fakeChatModel2); // Should return the first model after reversal (index 0)
|
||||
});
|
||||
|
||||
it('should return the fallback model at index 1 when multiple models are connected', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
const fakeChatModel2 = new FakeStreamingChatModel({});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
|
||||
|
||||
const model = await getChatModel(mockContext, 1);
|
||||
expect(model).toEqual(fakeChatModel1); // Should return the second model after reversal (index 1)
|
||||
});
|
||||
|
||||
it('should return undefined when requested index is out of bounds', async () => {
|
||||
const fakeChatModel1 = mock<BaseChatModel>();
|
||||
fakeChatModel1.bindTools = jest.fn();
|
||||
fakeChatModel1.lc_namespace = ['chat_models'];
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1]);
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
|
||||
const result = await getChatModel(mockContext, 2);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when single model does not support tools', async () => {
|
||||
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue(fakeInvalidModel);
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
|
||||
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
|
||||
await expect(getChatModel(mockContext)).rejects.toThrow(
|
||||
'Tools Agent requires Chat Model which supports Tools calling',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when model at specified index does not support tools', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeInvalidModel]);
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
|
||||
await expect(getChatModel(mockContext, 0)).rejects.toThrow(NodeOperationError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getOptionalMemory', () => {
|
||||
|
||||
Reference in New Issue
Block a user