mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat: Add fallback mechanism for agent and basic chain llm (#16617)
This commit is contained in:
@@ -17,33 +17,25 @@ import { toolsAgentExecute } from '../agents/ToolsAgent/V2/execute';
|
||||
|
||||
// Function used in the inputs expression to figure out which inputs to
|
||||
// display based on the agent type
|
||||
function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeInputConfiguration> {
|
||||
function getInputs(
|
||||
hasOutputParser?: boolean,
|
||||
needsFallback?: boolean,
|
||||
): Array<NodeConnectionType | INodeInputConfiguration> {
|
||||
interface SpecialInput {
|
||||
type: NodeConnectionType;
|
||||
filter?: INodeInputFilter;
|
||||
displayName: string;
|
||||
required?: boolean;
|
||||
}
|
||||
|
||||
const getInputData = (
|
||||
inputs: SpecialInput[],
|
||||
): Array<NodeConnectionType | INodeInputConfiguration> => {
|
||||
const displayNames: { [key: string]: string } = {
|
||||
ai_languageModel: 'Model',
|
||||
ai_memory: 'Memory',
|
||||
ai_tool: 'Tool',
|
||||
ai_outputParser: 'Output Parser',
|
||||
};
|
||||
|
||||
return inputs.map(({ type, filter }) => {
|
||||
const isModelType = type === 'ai_languageModel';
|
||||
let displayName = type in displayNames ? displayNames[type] : undefined;
|
||||
if (isModelType) {
|
||||
displayName = 'Chat Model';
|
||||
}
|
||||
return inputs.map(({ type, filter, displayName, required }) => {
|
||||
const input: INodeInputConfiguration = {
|
||||
type,
|
||||
displayName,
|
||||
required: isModelType,
|
||||
required,
|
||||
maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes(
|
||||
type as NodeConnectionType,
|
||||
)
|
||||
@@ -62,33 +54,40 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
|
||||
let specialInputs: SpecialInput[] = [
|
||||
{
|
||||
type: 'ai_languageModel',
|
||||
displayName: 'Chat Model',
|
||||
required: true,
|
||||
filter: {
|
||||
nodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmChatAnthropic',
|
||||
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmChatAwsBedrock',
|
||||
'@n8n/n8n-nodes-langchain.lmChatMistralCloud',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOllama',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGroq',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGoogleVertex',
|
||||
'@n8n/n8n-nodes-langchain.lmChatGoogleGemini',
|
||||
'@n8n/n8n-nodes-langchain.lmChatDeepSeek',
|
||||
'@n8n/n8n-nodes-langchain.lmChatOpenRouter',
|
||||
'@n8n/n8n-nodes-langchain.lmChatXAiGrok',
|
||||
'@n8n/n8n-nodes-langchain.code',
|
||||
'@n8n/n8n-nodes-langchain.modelSelector',
|
||||
excludedNodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmCohere',
|
||||
'@n8n/n8n-nodes-langchain.lmOllama',
|
||||
'n8n/n8n-nodes-langchain.lmOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'ai_languageModel',
|
||||
displayName: 'Fallback Model',
|
||||
required: true,
|
||||
filter: {
|
||||
excludedNodes: [
|
||||
'@n8n/n8n-nodes-langchain.lmCohere',
|
||||
'@n8n/n8n-nodes-langchain.lmOllama',
|
||||
'n8n/n8n-nodes-langchain.lmOpenAi',
|
||||
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Memory',
|
||||
type: 'ai_memory',
|
||||
},
|
||||
{
|
||||
displayName: 'Tool',
|
||||
type: 'ai_tool',
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
displayName: 'Output Parser',
|
||||
type: 'ai_outputParser',
|
||||
},
|
||||
];
|
||||
@@ -96,6 +95,9 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
|
||||
if (hasOutputParser === false) {
|
||||
specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser');
|
||||
}
|
||||
if (needsFallback === false) {
|
||||
specialInputs = specialInputs.filter((input) => input.displayName !== 'Fallback Model');
|
||||
}
|
||||
return ['main', ...getInputData(specialInputs)];
|
||||
}
|
||||
|
||||
@@ -111,10 +113,10 @@ export class AgentV2 implements INodeType {
|
||||
color: '#404040',
|
||||
},
|
||||
inputs: `={{
|
||||
((hasOutputParser) => {
|
||||
((hasOutputParser, needsFallback) => {
|
||||
${getInputs.toString()};
|
||||
return getInputs(hasOutputParser)
|
||||
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true)
|
||||
return getInputs(hasOutputParser, needsFallback)
|
||||
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true, $parameter.needsFallback === undefined || $parameter.needsFallback === true)
|
||||
}}`,
|
||||
outputs: [NodeConnectionTypes.Main],
|
||||
properties: [
|
||||
@@ -160,6 +162,25 @@ export class AgentV2 implements INodeType {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Enable Fallback Model',
|
||||
name: 'needsFallback',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
noDataExpression: true,
|
||||
},
|
||||
{
|
||||
displayName:
|
||||
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
|
||||
name: 'fallbackNotice',
|
||||
type: 'notice',
|
||||
default: '',
|
||||
displayOptions: {
|
||||
show: {
|
||||
needsFallback: [true],
|
||||
},
|
||||
},
|
||||
},
|
||||
...toolsAgentProperties,
|
||||
],
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
|
||||
import omit from 'lodash/omit';
|
||||
@@ -40,7 +41,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const model = await getChatModel(this);
|
||||
const model = (await getChatModel(this)) as BaseLanguageModel;
|
||||
const memory = await getOptionalMemory(this);
|
||||
|
||||
const input = getPromptInputByType({
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import type { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { RunnableSequence } from '@langchain/core/runnables';
|
||||
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
|
||||
import type { BaseChatMemory } from 'langchain/memory';
|
||||
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
|
||||
import omit from 'lodash/omit';
|
||||
import { jsonParse, NodeOperationError, sleep } from 'n8n-workflow';
|
||||
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
import {
|
||||
getOptionalOutputParser,
|
||||
type N8nOutputParser,
|
||||
} from '@utils/output_parsers/N8nOutputParser';
|
||||
|
||||
import {
|
||||
fixEmptyContentMessage,
|
||||
@@ -19,6 +26,41 @@ import {
|
||||
} from '../common';
|
||||
import { SYSTEM_MESSAGE } from '../prompt';
|
||||
|
||||
/**
|
||||
* Creates an agent executor with the given configuration
|
||||
*/
|
||||
function createAgentExecutor(
|
||||
model: BaseChatModel,
|
||||
tools: Array<DynamicStructuredTool | Tool>,
|
||||
prompt: ChatPromptTemplate,
|
||||
options: { maxIterations?: number; returnIntermediateSteps?: boolean },
|
||||
outputParser?: N8nOutputParser,
|
||||
memory?: BaseChatMemory,
|
||||
fallbackModel?: BaseChatModel | null,
|
||||
) {
|
||||
const modelWithFallback = fallbackModel ? model.withFallbacks([fallbackModel]) : model;
|
||||
const agent = createToolCallingAgent({
|
||||
llm: modelWithFallback,
|
||||
tools,
|
||||
prompt,
|
||||
streamRunnable: false,
|
||||
});
|
||||
|
||||
const runnableAgent = RunnableSequence.from([
|
||||
agent,
|
||||
getAgentStepsParser(outputParser, memory),
|
||||
fixEmptyContentMessage,
|
||||
]);
|
||||
|
||||
return AgentExecutor.fromAgentAndTools({
|
||||
agent: runnableAgent,
|
||||
memory,
|
||||
tools,
|
||||
returnIntermediateSteps: options.returnIntermediateSteps === true,
|
||||
maxIterations: options.maxIterations ?? 10,
|
||||
});
|
||||
}
|
||||
|
||||
/* -----------------------------------------------------------
|
||||
Main Executor Function
|
||||
----------------------------------------------------------- */
|
||||
@@ -42,8 +84,18 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
0,
|
||||
0,
|
||||
) as number;
|
||||
const needsFallback = this.getNodeParameter('needsFallback', 0, false) as boolean;
|
||||
const memory = await getOptionalMemory(this);
|
||||
const model = await getChatModel(this);
|
||||
const model = await getChatModel(this, 0);
|
||||
assert(model, 'Please connect a model to the Chat Model input');
|
||||
const fallbackModel = needsFallback ? await getChatModel(this, 1) : null;
|
||||
|
||||
if (needsFallback && !fallbackModel) {
|
||||
throw new NodeOperationError(
|
||||
this.getNode(),
|
||||
'Please connect a model to the Fallback Model input or disable the fallback option',
|
||||
);
|
||||
}
|
||||
|
||||
for (let i = 0; i < items.length; i += batchSize) {
|
||||
const batch = items.slice(i, i + batchSize);
|
||||
@@ -57,7 +109,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
promptTypeKey: 'promptType',
|
||||
});
|
||||
if (input === undefined) {
|
||||
throw new NodeOperationError(this.getNode(), 'The “text” parameter is empty.');
|
||||
throw new NodeOperationError(this.getNode(), 'The "text" parameter is empty.');
|
||||
}
|
||||
const outputParser = await getOptionalOutputParser(this, itemIndex);
|
||||
const tools = await getTools(this, outputParser);
|
||||
@@ -76,38 +128,26 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
|
||||
});
|
||||
const prompt: ChatPromptTemplate = preparePrompt(messages);
|
||||
|
||||
// Create the base agent that calls tools.
|
||||
const agent = createToolCallingAgent({
|
||||
llm: model,
|
||||
// Create executors for primary and fallback models
|
||||
const executor = createAgentExecutor(
|
||||
model,
|
||||
tools,
|
||||
prompt,
|
||||
streamRunnable: false,
|
||||
});
|
||||
agent.streamRunnable = false;
|
||||
// Wrap the agent with parsers and fixes.
|
||||
const runnableAgent = RunnableSequence.from([
|
||||
agent,
|
||||
getAgentStepsParser(outputParser, memory),
|
||||
fixEmptyContentMessage,
|
||||
]);
|
||||
const executor = AgentExecutor.fromAgentAndTools({
|
||||
agent: runnableAgent,
|
||||
options,
|
||||
outputParser,
|
||||
memory,
|
||||
tools,
|
||||
returnIntermediateSteps: options.returnIntermediateSteps === true,
|
||||
maxIterations: options.maxIterations ?? 10,
|
||||
});
|
||||
|
||||
// Invoke the executor with the given input and system message.
|
||||
return await executor.invoke(
|
||||
{
|
||||
input,
|
||||
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
|
||||
formatting_instructions:
|
||||
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
|
||||
},
|
||||
{ signal: this.getExecutionCancelSignal() },
|
||||
fallbackModel,
|
||||
);
|
||||
// Invoke with fallback logic
|
||||
const invokeParams = {
|
||||
input,
|
||||
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
|
||||
formatting_instructions:
|
||||
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
|
||||
};
|
||||
const executeOptions = { signal: this.getExecutionCancelSignal() };
|
||||
|
||||
return await executor.invoke(invokeParams, executeOptions);
|
||||
});
|
||||
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
|
||||
@@ -263,8 +263,25 @@ export const getAgentStepsParser =
|
||||
* @param ctx - The execution context
|
||||
* @returns The validated chat model
|
||||
*/
|
||||
export async function getChatModel(ctx: IExecuteFunctions): Promise<BaseChatModel> {
|
||||
const model = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
|
||||
export async function getChatModel(
|
||||
ctx: IExecuteFunctions,
|
||||
index: number = 0,
|
||||
): Promise<BaseChatModel | undefined> {
|
||||
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
|
||||
|
||||
let model;
|
||||
|
||||
if (Array.isArray(connectedModels) && index !== undefined) {
|
||||
if (connectedModels.length <= index) {
|
||||
return undefined;
|
||||
}
|
||||
// We get the models in reversed order from the workflow so we need to reverse them to match the right index
|
||||
const reversedModels = [...connectedModels].reverse();
|
||||
model = reversedModels[index] as BaseChatModel;
|
||||
} else {
|
||||
model = connectedModels as BaseChatModel;
|
||||
}
|
||||
|
||||
if (!isChatInstance(model) || !model.bindTools) {
|
||||
throw new NodeOperationError(
|
||||
ctx.getNode(),
|
||||
|
||||
@@ -52,6 +52,7 @@ describe('toolsAgentExecute', () => {
|
||||
// Mock getNodeParameter to return default values
|
||||
mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => {
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options.batching.batchSize') return defaultValue;
|
||||
if (param === 'options.batching.delayBetweenBatches') return defaultValue;
|
||||
if (param === 'options')
|
||||
@@ -104,6 +105,7 @@ describe('toolsAgentExecute', () => {
|
||||
if (param === 'options.batching.batchSize') return 2;
|
||||
if (param === 'options.batching.delayBetweenBatches') return 100;
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
@@ -157,6 +159,7 @@ describe('toolsAgentExecute', () => {
|
||||
if (param === 'options.batching.batchSize') return 2;
|
||||
if (param === 'options.batching.delayBetweenBatches') return 0;
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
@@ -206,6 +209,7 @@ describe('toolsAgentExecute', () => {
|
||||
if (param === 'options.batching.batchSize') return 2;
|
||||
if (param === 'options.batching.delayBetweenBatches') return 0;
|
||||
if (param === 'text') return 'test input';
|
||||
if (param === 'needsFallback') return false;
|
||||
if (param === 'options')
|
||||
return {
|
||||
systemMessage: 'You are a helpful assistant',
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
|
||||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
|
||||
import { FakeLLM, FakeStreamingChatModel } from '@langchain/core/utils/testing';
|
||||
import { Buffer } from 'buffer';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser';
|
||||
@@ -163,6 +164,72 @@ describe('getChatModel', () => {
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
|
||||
});
|
||||
|
||||
it('should return the first model when multiple models are connected and no index specified', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
const fakeChatModel2 = new FakeStreamingChatModel({});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
|
||||
|
||||
const model = await getChatModel(mockContext);
|
||||
expect(model).toEqual(fakeChatModel2); // Should return the last model (reversed array)
|
||||
});
|
||||
|
||||
it('should return the model at specified index when multiple models are connected', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
|
||||
const fakeChatModel2 = new FakeStreamingChatModel({});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
|
||||
|
||||
const model = await getChatModel(mockContext, 0);
|
||||
expect(model).toEqual(fakeChatModel2); // Should return the first model after reversal (index 0)
|
||||
});
|
||||
|
||||
it('should return the fallback model at index 1 when multiple models are connected', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
const fakeChatModel2 = new FakeStreamingChatModel({});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
|
||||
|
||||
const model = await getChatModel(mockContext, 1);
|
||||
expect(model).toEqual(fakeChatModel1); // Should return the second model after reversal (index 1)
|
||||
});
|
||||
|
||||
it('should return undefined when requested index is out of bounds', async () => {
|
||||
const fakeChatModel1 = mock<BaseChatModel>();
|
||||
fakeChatModel1.bindTools = jest.fn();
|
||||
fakeChatModel1.lc_namespace = ['chat_models'];
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1]);
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
|
||||
const result = await getChatModel(mockContext, 2);
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when single model does not support tools', async () => {
|
||||
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue(fakeInvalidModel);
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
|
||||
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
|
||||
await expect(getChatModel(mockContext)).rejects.toThrow(
|
||||
'Tools Agent requires Chat Model which supports Tools calling',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when model at specified index does not support tools', async () => {
|
||||
const fakeChatModel1 = new FakeStreamingChatModel({});
|
||||
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeInvalidModel]);
|
||||
mockContext.getNode.mockReturnValue(mock());
|
||||
|
||||
await expect(getChatModel(mockContext, 0)).rejects.toThrow(NodeOperationError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getOptionalMemory', () => {
|
||||
|
||||
@@ -88,15 +88,24 @@ async function executeSimpleChain({
|
||||
llm,
|
||||
query,
|
||||
prompt,
|
||||
fallbackLlm,
|
||||
}: {
|
||||
context: IExecuteFunctions;
|
||||
llm: BaseLanguageModel;
|
||||
query: string;
|
||||
prompt: ChatPromptTemplate | PromptTemplate;
|
||||
fallbackLlm?: BaseLanguageModel | null;
|
||||
}) {
|
||||
const outputParser = getOutputParserForLLM(llm);
|
||||
let model;
|
||||
|
||||
const chain = prompt.pipe(llm).pipe(outputParser).withConfig(getTracingConfig(context));
|
||||
if (fallbackLlm) {
|
||||
model = llm.withFallbacks([fallbackLlm]);
|
||||
} else {
|
||||
model = llm;
|
||||
}
|
||||
|
||||
const chain = prompt.pipe(model).pipe(outputParser).withConfig(getTracingConfig(context));
|
||||
|
||||
// Execute the chain
|
||||
const response = await chain.invoke({
|
||||
@@ -118,6 +127,7 @@ export async function executeChain({
|
||||
llm,
|
||||
outputParser,
|
||||
messages,
|
||||
fallbackLlm,
|
||||
}: ChainExecutionParams): Promise<unknown[]> {
|
||||
// If no output parsers provided, use a simple chain with basic prompt template
|
||||
if (!outputParser) {
|
||||
@@ -134,6 +144,7 @@ export async function executeChain({
|
||||
llm,
|
||||
query,
|
||||
prompt: promptTemplate,
|
||||
fallbackLlm,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,17 @@ export function getInputs(parameters: IDataObject) {
|
||||
},
|
||||
];
|
||||
|
||||
const needsFallback = parameters?.needsFallback;
|
||||
|
||||
if (needsFallback === undefined || needsFallback === true) {
|
||||
inputs.push({
|
||||
displayName: 'Fallback Model',
|
||||
maxConnections: 1,
|
||||
type: 'ai_languageModel',
|
||||
required: true,
|
||||
});
|
||||
}
|
||||
|
||||
// If `hasOutputParser` is undefined it must be version 1.3 or earlier so we
|
||||
// always add the output parser input
|
||||
const hasOutputParser = parameters?.hasOutputParser;
|
||||
@@ -119,6 +130,18 @@ export const nodeProperties: INodeProperties[] = [
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Enable Fallback Model',
|
||||
name: 'needsFallback',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
noDataExpression: true,
|
||||
displayOptions: {
|
||||
hide: {
|
||||
'@version': [1, 1.1, 1.3],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Chat Messages (if Using a Chat Model)',
|
||||
name: 'messages',
|
||||
@@ -275,4 +298,16 @@ export const nodeProperties: INodeProperties[] = [
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName:
|
||||
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
|
||||
name: 'fallbackNotice',
|
||||
type: 'notice',
|
||||
default: '',
|
||||
displayOptions: {
|
||||
show: {
|
||||
needsFallback: [true],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
@@ -7,11 +8,40 @@ import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
import { executeChain } from './chainExecutor';
|
||||
import { type MessageTemplate } from './types';
|
||||
|
||||
async function getChatModel(
|
||||
ctx: IExecuteFunctions,
|
||||
index: number = 0,
|
||||
): Promise<BaseLanguageModel | undefined> {
|
||||
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
|
||||
|
||||
let model;
|
||||
|
||||
if (Array.isArray(connectedModels) && index !== undefined) {
|
||||
if (connectedModels.length <= index) {
|
||||
return undefined;
|
||||
}
|
||||
// We get the models in reversed order from the workflow so we need to reverse them again to match the right index
|
||||
const reversedModels = [...connectedModels].reverse();
|
||||
model = reversedModels[index] as BaseLanguageModel;
|
||||
} else {
|
||||
model = connectedModels as BaseLanguageModel;
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => {
|
||||
const llm = (await ctx.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
0,
|
||||
)) as BaseLanguageModel;
|
||||
const needsFallback = ctx.getNodeParameter('needsFallback', 0, false) as boolean;
|
||||
const llm = await getChatModel(ctx, 0);
|
||||
assert(llm, 'Please connect a model to the Chat Model input');
|
||||
|
||||
const fallbackLlm = needsFallback ? await getChatModel(ctx, 1) : null;
|
||||
if (needsFallback && !fallbackLlm) {
|
||||
throw new NodeOperationError(
|
||||
ctx.getNode(),
|
||||
'Please connect a model to the Fallback Model input or disable the fallback option',
|
||||
);
|
||||
}
|
||||
|
||||
// Get output parser if configured
|
||||
const outputParser = await getOptionalOutputParser(ctx, itemIndex);
|
||||
@@ -50,5 +80,6 @@ export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) =>
|
||||
llm,
|
||||
outputParser,
|
||||
messages,
|
||||
fallbackLlm,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -38,4 +38,5 @@ export interface ChainExecutionParams {
|
||||
llm: BaseLanguageModel;
|
||||
outputParser?: N8nOutputParser;
|
||||
messages?: MessageTemplate[];
|
||||
fallbackLlm?: BaseLanguageModel | null;
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ jest.mock('../methods/responseFormatter', () => ({
|
||||
describe('ChainLlm Node', () => {
|
||||
let node: ChainLlm;
|
||||
let mockExecuteFunction: jest.Mocked<IExecuteFunctions>;
|
||||
let needsFallback: boolean;
|
||||
|
||||
beforeEach(() => {
|
||||
node = new ChainLlm();
|
||||
@@ -48,6 +49,8 @@ describe('ChainLlm Node', () => {
|
||||
error: jest.fn(),
|
||||
};
|
||||
|
||||
needsFallback = false;
|
||||
|
||||
mockExecuteFunction.getInputData.mockReturnValue([{ json: {} }]);
|
||||
mockExecuteFunction.getNode.mockReturnValue({
|
||||
name: 'Chain LLM',
|
||||
@@ -57,6 +60,7 @@ describe('ChainLlm Node', () => {
|
||||
|
||||
mockExecuteFunction.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'messages.messageValues') return [];
|
||||
if (param === 'needsFallback') return needsFallback;
|
||||
return defaultValue;
|
||||
});
|
||||
|
||||
@@ -96,6 +100,7 @@ describe('ChainLlm Node', () => {
|
||||
context: mockExecuteFunction,
|
||||
itemIndex: 0,
|
||||
query: 'Test prompt',
|
||||
fallbackLlm: null,
|
||||
llm: expect.any(FakeChatModel),
|
||||
outputParser: undefined,
|
||||
messages: [],
|
||||
@@ -151,6 +156,7 @@ describe('ChainLlm Node', () => {
|
||||
context: mockExecuteFunction,
|
||||
itemIndex: 0,
|
||||
query: 'Old version prompt',
|
||||
fallbackLlm: null,
|
||||
llm: expect.any(Object),
|
||||
outputParser: undefined,
|
||||
messages: expect.any(Array),
|
||||
@@ -505,6 +511,35 @@ describe('ChainLlm Node', () => {
|
||||
expect(responseFormatterModule.formatResponse).toHaveBeenCalledWith(markdownResponse, true);
|
||||
});
|
||||
|
||||
it('should use fallback llm if enabled', async () => {
|
||||
needsFallback = true;
|
||||
(helperModule.getPromptInputByType as jest.Mock).mockReturnValue('Test prompt');
|
||||
|
||||
(outputParserModule.getOptionalOutputParser as jest.Mock).mockResolvedValue(undefined);
|
||||
|
||||
(executeChainModule.executeChain as jest.Mock).mockResolvedValue(['Test response']);
|
||||
|
||||
const fakeLLM = new FakeChatModel({});
|
||||
const fakeFallbackLLM = new FakeChatModel({});
|
||||
mockExecuteFunction.getInputConnectionData.mockResolvedValue([fakeLLM, fakeFallbackLLM]);
|
||||
|
||||
const result = await node.execute.call(mockExecuteFunction);
|
||||
|
||||
expect(executeChainModule.executeChain).toHaveBeenCalledWith({
|
||||
context: mockExecuteFunction,
|
||||
itemIndex: 0,
|
||||
query: 'Test prompt',
|
||||
fallbackLlm: expect.any(FakeChatModel),
|
||||
llm: expect.any(FakeChatModel),
|
||||
outputParser: undefined,
|
||||
messages: [],
|
||||
});
|
||||
|
||||
expect(mockExecuteFunction.logger.debug).toHaveBeenCalledWith('Executing Basic LLM Chain');
|
||||
|
||||
expect(result).toEqual([[{ json: expect.any(Object) }]]);
|
||||
});
|
||||
|
||||
it('should pass correct itemIndex to getOptionalOutputParser', async () => {
|
||||
// Clear any previous calls to the mock
|
||||
(outputParserModule.getOptionalOutputParser as jest.Mock).mockClear();
|
||||
@@ -568,6 +603,7 @@ describe('ChainLlm Node', () => {
|
||||
itemIndex: 0,
|
||||
query: 'Test prompt 1',
|
||||
llm: expect.any(Object),
|
||||
fallbackLlm: null,
|
||||
outputParser: mockParser1,
|
||||
messages: [],
|
||||
});
|
||||
@@ -576,6 +612,7 @@ describe('ChainLlm Node', () => {
|
||||
itemIndex: 1,
|
||||
query: 'Test prompt 2',
|
||||
llm: expect.any(Object),
|
||||
fallbackLlm: null,
|
||||
outputParser: mockParser2,
|
||||
messages: [],
|
||||
});
|
||||
@@ -584,6 +621,7 @@ describe('ChainLlm Node', () => {
|
||||
itemIndex: 2,
|
||||
query: 'Test prompt 3',
|
||||
llm: expect.any(Object),
|
||||
fallbackLlm: null,
|
||||
outputParser: mockParser3,
|
||||
messages: [],
|
||||
});
|
||||
|
||||
@@ -7,26 +7,44 @@ describe('config', () => {
|
||||
it('should return basic inputs for all parameters', () => {
|
||||
const inputs = getInputs({});
|
||||
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs).toHaveLength(4);
|
||||
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
|
||||
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
});
|
||||
|
||||
it('should exclude the OutputParser when hasOutputParser is false', () => {
|
||||
const inputs = getInputs({ hasOutputParser: false });
|
||||
|
||||
expect(inputs).toHaveLength(2);
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
|
||||
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
});
|
||||
|
||||
it('should include the OutputParser when hasOutputParser is true', () => {
|
||||
const inputs = getInputs({ hasOutputParser: true });
|
||||
|
||||
expect(inputs).toHaveLength(4);
|
||||
expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
});
|
||||
|
||||
it('should exclude the FallbackInput when needsFallback is false', () => {
|
||||
const inputs = getInputs({ hasOutputParser: true, needsFallback: false });
|
||||
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
|
||||
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
});
|
||||
|
||||
it('should include the FallbackInput when needsFallback is true', () => {
|
||||
const inputs = getInputs({ hasOutputParser: false, needsFallback: true });
|
||||
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
});
|
||||
});
|
||||
|
||||
describe('nodeProperties', () => {
|
||||
|
||||
Reference in New Issue
Block a user