feat: Add fallback mechanism for agent and basic chain llm (#16617)

This commit is contained in:
Benjamin Schroth
2025-06-26 16:14:03 +02:00
committed by GitHub
parent 0b7bca29f8
commit 6408d5a1b0
20 changed files with 476 additions and 140 deletions

View File

@@ -17,33 +17,25 @@ import { toolsAgentExecute } from '../agents/ToolsAgent/V2/execute';
// Function used in the inputs expression to figure out which inputs to
// display based on the agent type
function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeInputConfiguration> {
function getInputs(
hasOutputParser?: boolean,
needsFallback?: boolean,
): Array<NodeConnectionType | INodeInputConfiguration> {
interface SpecialInput {
type: NodeConnectionType;
filter?: INodeInputFilter;
displayName: string;
required?: boolean;
}
const getInputData = (
inputs: SpecialInput[],
): Array<NodeConnectionType | INodeInputConfiguration> => {
const displayNames: { [key: string]: string } = {
ai_languageModel: 'Model',
ai_memory: 'Memory',
ai_tool: 'Tool',
ai_outputParser: 'Output Parser',
};
return inputs.map(({ type, filter }) => {
const isModelType = type === 'ai_languageModel';
let displayName = type in displayNames ? displayNames[type] : undefined;
if (isModelType) {
displayName = 'Chat Model';
}
return inputs.map(({ type, filter, displayName, required }) => {
const input: INodeInputConfiguration = {
type,
displayName,
required: isModelType,
required,
maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes(
type as NodeConnectionType,
)
@@ -62,33 +54,40 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
let specialInputs: SpecialInput[] = [
{
type: 'ai_languageModel',
displayName: 'Chat Model',
required: true,
filter: {
nodes: [
'@n8n/n8n-nodes-langchain.lmChatAnthropic',
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
'@n8n/n8n-nodes-langchain.lmChatAwsBedrock',
'@n8n/n8n-nodes-langchain.lmChatMistralCloud',
'@n8n/n8n-nodes-langchain.lmChatOllama',
'@n8n/n8n-nodes-langchain.lmChatOpenAi',
'@n8n/n8n-nodes-langchain.lmChatGroq',
'@n8n/n8n-nodes-langchain.lmChatGoogleVertex',
'@n8n/n8n-nodes-langchain.lmChatGoogleGemini',
'@n8n/n8n-nodes-langchain.lmChatDeepSeek',
'@n8n/n8n-nodes-langchain.lmChatOpenRouter',
'@n8n/n8n-nodes-langchain.lmChatXAiGrok',
'@n8n/n8n-nodes-langchain.code',
'@n8n/n8n-nodes-langchain.modelSelector',
excludedNodes: [
'@n8n/n8n-nodes-langchain.lmCohere',
'@n8n/n8n-nodes-langchain.lmOllama',
'n8n/n8n-nodes-langchain.lmOpenAi',
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
],
},
},
{
type: 'ai_languageModel',
displayName: 'Fallback Model',
required: true,
filter: {
excludedNodes: [
'@n8n/n8n-nodes-langchain.lmCohere',
'@n8n/n8n-nodes-langchain.lmOllama',
'n8n/n8n-nodes-langchain.lmOpenAi',
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
],
},
},
{
displayName: 'Memory',
type: 'ai_memory',
},
{
displayName: 'Tool',
type: 'ai_tool',
required: true,
},
{
displayName: 'Output Parser',
type: 'ai_outputParser',
},
];
@@ -96,6 +95,9 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
if (hasOutputParser === false) {
specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser');
}
if (needsFallback === false) {
specialInputs = specialInputs.filter((input) => input.displayName !== 'Fallback Model');
}
return ['main', ...getInputData(specialInputs)];
}
@@ -111,10 +113,10 @@ export class AgentV2 implements INodeType {
color: '#404040',
},
inputs: `={{
((hasOutputParser) => {
((hasOutputParser, needsFallback) => {
${getInputs.toString()};
return getInputs(hasOutputParser)
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true)
return getInputs(hasOutputParser, needsFallback)
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true, $parameter.needsFallback === undefined || $parameter.needsFallback === true)
}}`,
outputs: [NodeConnectionTypes.Main],
properties: [
@@ -160,6 +162,25 @@ export class AgentV2 implements INodeType {
},
},
},
{
displayName: 'Enable Fallback Model',
name: 'needsFallback',
type: 'boolean',
default: false,
noDataExpression: true,
},
{
displayName:
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
name: 'fallbackNotice',
type: 'notice',
default: '',
displayOptions: {
show: {
needsFallback: [true],
},
},
},
...toolsAgentProperties,
],
};

View File

@@ -1,3 +1,4 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { RunnableSequence } from '@langchain/core/runnables';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import omit from 'lodash/omit';
@@ -40,7 +41,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const model = await getChatModel(this);
const model = (await getChatModel(this)) as BaseLanguageModel;
const memory = await getOptionalMemory(this);
const input = getPromptInputByType({

View File

@@ -1,12 +1,19 @@
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { ChatPromptTemplate } from '@langchain/core/prompts';
import { RunnableSequence } from '@langchain/core/runnables';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import type { BaseChatMemory } from 'langchain/memory';
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
import omit from 'lodash/omit';
import { jsonParse, NodeOperationError, sleep } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import assert from 'node:assert';
import { getPromptInputByType } from '@utils/helpers';
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
import {
getOptionalOutputParser,
type N8nOutputParser,
} from '@utils/output_parsers/N8nOutputParser';
import {
fixEmptyContentMessage,
@@ -19,6 +26,41 @@ import {
} from '../common';
import { SYSTEM_MESSAGE } from '../prompt';
/**
* Creates an agent executor with the given configuration
*/
function createAgentExecutor(
model: BaseChatModel,
tools: Array<DynamicStructuredTool | Tool>,
prompt: ChatPromptTemplate,
options: { maxIterations?: number; returnIntermediateSteps?: boolean },
outputParser?: N8nOutputParser,
memory?: BaseChatMemory,
fallbackModel?: BaseChatModel | null,
) {
const modelWithFallback = fallbackModel ? model.withFallbacks([fallbackModel]) : model;
const agent = createToolCallingAgent({
llm: modelWithFallback,
tools,
prompt,
streamRunnable: false,
});
const runnableAgent = RunnableSequence.from([
agent,
getAgentStepsParser(outputParser, memory),
fixEmptyContentMessage,
]);
return AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
memory,
tools,
returnIntermediateSteps: options.returnIntermediateSteps === true,
maxIterations: options.maxIterations ?? 10,
});
}
/* -----------------------------------------------------------
Main Executor Function
----------------------------------------------------------- */
@@ -42,8 +84,18 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
0,
0,
) as number;
const needsFallback = this.getNodeParameter('needsFallback', 0, false) as boolean;
const memory = await getOptionalMemory(this);
const model = await getChatModel(this);
const model = await getChatModel(this, 0);
assert(model, 'Please connect a model to the Chat Model input');
const fallbackModel = needsFallback ? await getChatModel(this, 1) : null;
if (needsFallback && !fallbackModel) {
throw new NodeOperationError(
this.getNode(),
'Please connect a model to the Fallback Model input or disable the fallback option',
);
}
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
@@ -57,7 +109,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
promptTypeKey: 'promptType',
});
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');
throw new NodeOperationError(this.getNode(), 'The "text" parameter is empty.');
}
const outputParser = await getOptionalOutputParser(this, itemIndex);
const tools = await getTools(this, outputParser);
@@ -76,38 +128,26 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
});
const prompt: ChatPromptTemplate = preparePrompt(messages);
// Create the base agent that calls tools.
const agent = createToolCallingAgent({
llm: model,
// Create executors for primary and fallback models
const executor = createAgentExecutor(
model,
tools,
prompt,
streamRunnable: false,
});
agent.streamRunnable = false;
// Wrap the agent with parsers and fixes.
const runnableAgent = RunnableSequence.from([
agent,
getAgentStepsParser(outputParser, memory),
fixEmptyContentMessage,
]);
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
options,
outputParser,
memory,
tools,
returnIntermediateSteps: options.returnIntermediateSteps === true,
maxIterations: options.maxIterations ?? 10,
});
// Invoke the executor with the given input and system message.
return await executor.invoke(
{
input,
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
formatting_instructions:
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
},
{ signal: this.getExecutionCancelSignal() },
fallbackModel,
);
// Invoke with fallback logic
const invokeParams = {
input,
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
formatting_instructions:
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
};
const executeOptions = { signal: this.getExecutionCancelSignal() };
return await executor.invoke(invokeParams, executeOptions);
});
const batchResults = await Promise.allSettled(batchPromises);

View File

@@ -263,8 +263,25 @@ export const getAgentStepsParser =
* @param ctx - The execution context
* @returns The validated chat model
*/
export async function getChatModel(ctx: IExecuteFunctions): Promise<BaseChatModel> {
const model = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
export async function getChatModel(
ctx: IExecuteFunctions,
index: number = 0,
): Promise<BaseChatModel | undefined> {
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
let model;
if (Array.isArray(connectedModels) && index !== undefined) {
if (connectedModels.length <= index) {
return undefined;
}
// We get the models in reversed order from the workflow so we need to reverse them to match the right index
const reversedModels = [...connectedModels].reverse();
model = reversedModels[index] as BaseChatModel;
} else {
model = connectedModels as BaseChatModel;
}
if (!isChatInstance(model) || !model.bindTools) {
throw new NodeOperationError(
ctx.getNode(),

View File

@@ -52,6 +52,7 @@ describe('toolsAgentExecute', () => {
// Mock getNodeParameter to return default values
mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => {
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options.batching.batchSize') return defaultValue;
if (param === 'options.batching.delayBetweenBatches') return defaultValue;
if (param === 'options')
@@ -104,6 +105,7 @@ describe('toolsAgentExecute', () => {
if (param === 'options.batching.batchSize') return 2;
if (param === 'options.batching.delayBetweenBatches') return 100;
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',
@@ -157,6 +159,7 @@ describe('toolsAgentExecute', () => {
if (param === 'options.batching.batchSize') return 2;
if (param === 'options.batching.delayBetweenBatches') return 0;
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',
@@ -206,6 +209,7 @@ describe('toolsAgentExecute', () => {
if (param === 'options.batching.batchSize') return 2;
if (param === 'options.batching.delayBetweenBatches') return 0;
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',

View File

@@ -2,6 +2,7 @@ import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { HumanMessage } from '@langchain/core/messages';
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
import { FakeLLM, FakeStreamingChatModel } from '@langchain/core/utils/testing';
import { Buffer } from 'buffer';
import { mock } from 'jest-mock-extended';
import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser';
@@ -163,6 +164,72 @@ describe('getChatModel', () => {
mockContext.getNode.mockReturnValue(mock());
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
});
it('should return the first model when multiple models are connected and no index specified', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeChatModel2 = new FakeStreamingChatModel({});
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
const model = await getChatModel(mockContext);
expect(model).toEqual(fakeChatModel2); // Should return the last model (reversed array)
});
it('should return the model at specified index when multiple models are connected', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeChatModel2 = new FakeStreamingChatModel({});
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
const model = await getChatModel(mockContext, 0);
expect(model).toEqual(fakeChatModel2); // Should return the first model after reversal (index 0)
});
it('should return the fallback model at index 1 when multiple models are connected', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeChatModel2 = new FakeStreamingChatModel({});
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
const model = await getChatModel(mockContext, 1);
expect(model).toEqual(fakeChatModel1); // Should return the second model after reversal (index 1)
});
it('should return undefined when requested index is out of bounds', async () => {
const fakeChatModel1 = mock<BaseChatModel>();
fakeChatModel1.bindTools = jest.fn();
fakeChatModel1.lc_namespace = ['chat_models'];
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1]);
mockContext.getNode.mockReturnValue(mock());
const result = await getChatModel(mockContext, 2);
expect(result).toBeUndefined();
});
it('should throw error when single model does not support tools', async () => {
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
mockContext.getInputConnectionData.mockResolvedValue(fakeInvalidModel);
mockContext.getNode.mockReturnValue(mock());
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
await expect(getChatModel(mockContext)).rejects.toThrow(
'Tools Agent requires Chat Model which supports Tools calling',
);
});
it('should throw error when model at specified index does not support tools', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeInvalidModel]);
mockContext.getNode.mockReturnValue(mock());
await expect(getChatModel(mockContext, 0)).rejects.toThrow(NodeOperationError);
});
});
describe('getOptionalMemory', () => {