feat: Add fallback mechanism for agent and basic chain llm (#16617)

This commit is contained in:
Benjamin Schroth
2025-06-26 16:14:03 +02:00
committed by GitHub
parent 0b7bca29f8
commit 6408d5a1b0
20 changed files with 476 additions and 140 deletions

View File

@@ -17,33 +17,25 @@ import { toolsAgentExecute } from '../agents/ToolsAgent/V2/execute';
// Function used in the inputs expression to figure out which inputs to
// display based on the agent type
function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeInputConfiguration> {
function getInputs(
hasOutputParser?: boolean,
needsFallback?: boolean,
): Array<NodeConnectionType | INodeInputConfiguration> {
interface SpecialInput {
type: NodeConnectionType;
filter?: INodeInputFilter;
displayName: string;
required?: boolean;
}
const getInputData = (
inputs: SpecialInput[],
): Array<NodeConnectionType | INodeInputConfiguration> => {
const displayNames: { [key: string]: string } = {
ai_languageModel: 'Model',
ai_memory: 'Memory',
ai_tool: 'Tool',
ai_outputParser: 'Output Parser',
};
return inputs.map(({ type, filter }) => {
const isModelType = type === 'ai_languageModel';
let displayName = type in displayNames ? displayNames[type] : undefined;
if (isModelType) {
displayName = 'Chat Model';
}
return inputs.map(({ type, filter, displayName, required }) => {
const input: INodeInputConfiguration = {
type,
displayName,
required: isModelType,
required,
maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes(
type as NodeConnectionType,
)
@@ -62,33 +54,40 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
let specialInputs: SpecialInput[] = [
{
type: 'ai_languageModel',
displayName: 'Chat Model',
required: true,
filter: {
nodes: [
'@n8n/n8n-nodes-langchain.lmChatAnthropic',
'@n8n/n8n-nodes-langchain.lmChatAzureOpenAi',
'@n8n/n8n-nodes-langchain.lmChatAwsBedrock',
'@n8n/n8n-nodes-langchain.lmChatMistralCloud',
'@n8n/n8n-nodes-langchain.lmChatOllama',
'@n8n/n8n-nodes-langchain.lmChatOpenAi',
'@n8n/n8n-nodes-langchain.lmChatGroq',
'@n8n/n8n-nodes-langchain.lmChatGoogleVertex',
'@n8n/n8n-nodes-langchain.lmChatGoogleGemini',
'@n8n/n8n-nodes-langchain.lmChatDeepSeek',
'@n8n/n8n-nodes-langchain.lmChatOpenRouter',
'@n8n/n8n-nodes-langchain.lmChatXAiGrok',
'@n8n/n8n-nodes-langchain.code',
'@n8n/n8n-nodes-langchain.modelSelector',
excludedNodes: [
'@n8n/n8n-nodes-langchain.lmCohere',
'@n8n/n8n-nodes-langchain.lmOllama',
'n8n/n8n-nodes-langchain.lmOpenAi',
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
],
},
},
{
type: 'ai_languageModel',
displayName: 'Fallback Model',
required: true,
filter: {
excludedNodes: [
'@n8n/n8n-nodes-langchain.lmCohere',
'@n8n/n8n-nodes-langchain.lmOllama',
'n8n/n8n-nodes-langchain.lmOpenAi',
'@n8n/n8n-nodes-langchain.lmOpenHuggingFaceInference',
],
},
},
{
displayName: 'Memory',
type: 'ai_memory',
},
{
displayName: 'Tool',
type: 'ai_tool',
required: true,
},
{
displayName: 'Output Parser',
type: 'ai_outputParser',
},
];
@@ -96,6 +95,9 @@ function getInputs(hasOutputParser?: boolean): Array<NodeConnectionType | INodeI
if (hasOutputParser === false) {
specialInputs = specialInputs.filter((input) => input.type !== 'ai_outputParser');
}
if (needsFallback === false) {
specialInputs = specialInputs.filter((input) => input.displayName !== 'Fallback Model');
}
return ['main', ...getInputData(specialInputs)];
}
@@ -111,10 +113,10 @@ export class AgentV2 implements INodeType {
color: '#404040',
},
inputs: `={{
((hasOutputParser) => {
((hasOutputParser, needsFallback) => {
${getInputs.toString()};
return getInputs(hasOutputParser)
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true)
return getInputs(hasOutputParser, needsFallback)
})($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true, $parameter.needsFallback === undefined || $parameter.needsFallback === true)
}}`,
outputs: [NodeConnectionTypes.Main],
properties: [
@@ -160,6 +162,25 @@ export class AgentV2 implements INodeType {
},
},
},
{
displayName: 'Enable Fallback Model',
name: 'needsFallback',
type: 'boolean',
default: false,
noDataExpression: true,
},
{
displayName:
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
name: 'fallbackNotice',
type: 'notice',
default: '',
displayOptions: {
show: {
needsFallback: [true],
},
},
},
...toolsAgentProperties,
],
};

View File

@@ -1,3 +1,4 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { RunnableSequence } from '@langchain/core/runnables';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import omit from 'lodash/omit';
@@ -40,7 +41,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const model = await getChatModel(this);
const model = (await getChatModel(this)) as BaseLanguageModel;
const memory = await getOptionalMemory(this);
const input = getPromptInputByType({

View File

@@ -1,12 +1,19 @@
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { ChatPromptTemplate } from '@langchain/core/prompts';
import { RunnableSequence } from '@langchain/core/runnables';
import { AgentExecutor, createToolCallingAgent } from 'langchain/agents';
import type { BaseChatMemory } from 'langchain/memory';
import type { DynamicStructuredTool, Tool } from 'langchain/tools';
import omit from 'lodash/omit';
import { jsonParse, NodeOperationError, sleep } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import assert from 'node:assert';
import { getPromptInputByType } from '@utils/helpers';
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
import {
getOptionalOutputParser,
type N8nOutputParser,
} from '@utils/output_parsers/N8nOutputParser';
import {
fixEmptyContentMessage,
@@ -19,6 +26,41 @@ import {
} from '../common';
import { SYSTEM_MESSAGE } from '../prompt';
/**
* Creates an agent executor with the given configuration
*/
function createAgentExecutor(
model: BaseChatModel,
tools: Array<DynamicStructuredTool | Tool>,
prompt: ChatPromptTemplate,
options: { maxIterations?: number; returnIntermediateSteps?: boolean },
outputParser?: N8nOutputParser,
memory?: BaseChatMemory,
fallbackModel?: BaseChatModel | null,
) {
const modelWithFallback = fallbackModel ? model.withFallbacks([fallbackModel]) : model;
const agent = createToolCallingAgent({
llm: modelWithFallback,
tools,
prompt,
streamRunnable: false,
});
const runnableAgent = RunnableSequence.from([
agent,
getAgentStepsParser(outputParser, memory),
fixEmptyContentMessage,
]);
return AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
memory,
tools,
returnIntermediateSteps: options.returnIntermediateSteps === true,
maxIterations: options.maxIterations ?? 10,
});
}
/* -----------------------------------------------------------
Main Executor Function
----------------------------------------------------------- */
@@ -42,8 +84,18 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
0,
0,
) as number;
const needsFallback = this.getNodeParameter('needsFallback', 0, false) as boolean;
const memory = await getOptionalMemory(this);
const model = await getChatModel(this);
const model = await getChatModel(this, 0);
assert(model, 'Please connect a model to the Chat Model input');
const fallbackModel = needsFallback ? await getChatModel(this, 1) : null;
if (needsFallback && !fallbackModel) {
throw new NodeOperationError(
this.getNode(),
'Please connect a model to the Fallback Model input or disable the fallback option',
);
}
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
@@ -57,7 +109,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
promptTypeKey: 'promptType',
});
if (input === undefined) {
throw new NodeOperationError(this.getNode(), 'The text parameter is empty.');
throw new NodeOperationError(this.getNode(), 'The "text" parameter is empty.');
}
const outputParser = await getOptionalOutputParser(this, itemIndex);
const tools = await getTools(this, outputParser);
@@ -76,38 +128,26 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise<INodeE
});
const prompt: ChatPromptTemplate = preparePrompt(messages);
// Create the base agent that calls tools.
const agent = createToolCallingAgent({
llm: model,
// Create executors for primary and fallback models
const executor = createAgentExecutor(
model,
tools,
prompt,
streamRunnable: false,
});
agent.streamRunnable = false;
// Wrap the agent with parsers and fixes.
const runnableAgent = RunnableSequence.from([
agent,
getAgentStepsParser(outputParser, memory),
fixEmptyContentMessage,
]);
const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
options,
outputParser,
memory,
tools,
returnIntermediateSteps: options.returnIntermediateSteps === true,
maxIterations: options.maxIterations ?? 10,
});
// Invoke the executor with the given input and system message.
return await executor.invoke(
{
input,
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
formatting_instructions:
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
},
{ signal: this.getExecutionCancelSignal() },
fallbackModel,
);
// Invoke with fallback logic
const invokeParams = {
input,
system_message: options.systemMessage ?? SYSTEM_MESSAGE,
formatting_instructions:
'IMPORTANT: For your response to user, you MUST use the `format_final_json_response` tool with your complete answer formatted according to the required schema. Do not attempt to format the JSON manually - always use this tool. Your response will be rejected if it is not properly formatted through this tool. Only use this tool once you are ready to provide your final answer.',
};
const executeOptions = { signal: this.getExecutionCancelSignal() };
return await executor.invoke(invokeParams, executeOptions);
});
const batchResults = await Promise.allSettled(batchPromises);

View File

@@ -263,8 +263,25 @@ export const getAgentStepsParser =
* @param ctx - The execution context
* @returns The validated chat model
*/
export async function getChatModel(ctx: IExecuteFunctions): Promise<BaseChatModel> {
const model = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
export async function getChatModel(
ctx: IExecuteFunctions,
index: number = 0,
): Promise<BaseChatModel | undefined> {
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
let model;
if (Array.isArray(connectedModels) && index !== undefined) {
if (connectedModels.length <= index) {
return undefined;
}
// We get the models in reversed order from the workflow so we need to reverse them to match the right index
const reversedModels = [...connectedModels].reverse();
model = reversedModels[index] as BaseChatModel;
} else {
model = connectedModels as BaseChatModel;
}
if (!isChatInstance(model) || !model.bindTools) {
throw new NodeOperationError(
ctx.getNode(),

View File

@@ -52,6 +52,7 @@ describe('toolsAgentExecute', () => {
// Mock getNodeParameter to return default values
mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => {
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options.batching.batchSize') return defaultValue;
if (param === 'options.batching.delayBetweenBatches') return defaultValue;
if (param === 'options')
@@ -104,6 +105,7 @@ describe('toolsAgentExecute', () => {
if (param === 'options.batching.batchSize') return 2;
if (param === 'options.batching.delayBetweenBatches') return 100;
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',
@@ -157,6 +159,7 @@ describe('toolsAgentExecute', () => {
if (param === 'options.batching.batchSize') return 2;
if (param === 'options.batching.delayBetweenBatches') return 0;
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',
@@ -206,6 +209,7 @@ describe('toolsAgentExecute', () => {
if (param === 'options.batching.batchSize') return 2;
if (param === 'options.batching.delayBetweenBatches') return 0;
if (param === 'text') return 'test input';
if (param === 'needsFallback') return false;
if (param === 'options')
return {
systemMessage: 'You are a helpful assistant',

View File

@@ -2,6 +2,7 @@ import type { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { HumanMessage } from '@langchain/core/messages';
import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts';
import { FakeLLM, FakeStreamingChatModel } from '@langchain/core/utils/testing';
import { Buffer } from 'buffer';
import { mock } from 'jest-mock-extended';
import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser';
@@ -163,6 +164,72 @@ describe('getChatModel', () => {
mockContext.getNode.mockReturnValue(mock());
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
});
it('should return the first model when multiple models are connected and no index specified', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeChatModel2 = new FakeStreamingChatModel({});
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
const model = await getChatModel(mockContext);
expect(model).toEqual(fakeChatModel2); // Should return the last model (reversed array)
});
it('should return the model at specified index when multiple models are connected', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeChatModel2 = new FakeStreamingChatModel({});
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
const model = await getChatModel(mockContext, 0);
expect(model).toEqual(fakeChatModel2); // Should return the first model after reversal (index 0)
});
it('should return the fallback model at index 1 when multiple models are connected', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeChatModel2 = new FakeStreamingChatModel({});
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]);
const model = await getChatModel(mockContext, 1);
expect(model).toEqual(fakeChatModel1); // Should return the second model after reversal (index 1)
});
it('should return undefined when requested index is out of bounds', async () => {
const fakeChatModel1 = mock<BaseChatModel>();
fakeChatModel1.bindTools = jest.fn();
fakeChatModel1.lc_namespace = ['chat_models'];
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1]);
mockContext.getNode.mockReturnValue(mock());
const result = await getChatModel(mockContext, 2);
expect(result).toBeUndefined();
});
it('should throw error when single model does not support tools', async () => {
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
mockContext.getInputConnectionData.mockResolvedValue(fakeInvalidModel);
mockContext.getNode.mockReturnValue(mock());
await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError);
await expect(getChatModel(mockContext)).rejects.toThrow(
'Tools Agent requires Chat Model which supports Tools calling',
);
});
it('should throw error when model at specified index does not support tools', async () => {
const fakeChatModel1 = new FakeStreamingChatModel({});
const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls
mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeInvalidModel]);
mockContext.getNode.mockReturnValue(mock());
await expect(getChatModel(mockContext, 0)).rejects.toThrow(NodeOperationError);
});
});
describe('getOptionalMemory', () => {

View File

@@ -88,15 +88,24 @@ async function executeSimpleChain({
llm,
query,
prompt,
fallbackLlm,
}: {
context: IExecuteFunctions;
llm: BaseLanguageModel;
query: string;
prompt: ChatPromptTemplate | PromptTemplate;
fallbackLlm?: BaseLanguageModel | null;
}) {
const outputParser = getOutputParserForLLM(llm);
let model;
const chain = prompt.pipe(llm).pipe(outputParser).withConfig(getTracingConfig(context));
if (fallbackLlm) {
model = llm.withFallbacks([fallbackLlm]);
} else {
model = llm;
}
const chain = prompt.pipe(model).pipe(outputParser).withConfig(getTracingConfig(context));
// Execute the chain
const response = await chain.invoke({
@@ -118,6 +127,7 @@ export async function executeChain({
llm,
outputParser,
messages,
fallbackLlm,
}: ChainExecutionParams): Promise<unknown[]> {
// If no output parsers provided, use a simple chain with basic prompt template
if (!outputParser) {
@@ -134,6 +144,7 @@ export async function executeChain({
llm,
query,
prompt: promptTemplate,
fallbackLlm,
});
}

View File

@@ -23,6 +23,17 @@ export function getInputs(parameters: IDataObject) {
},
];
const needsFallback = parameters?.needsFallback;
if (needsFallback === undefined || needsFallback === true) {
inputs.push({
displayName: 'Fallback Model',
maxConnections: 1,
type: 'ai_languageModel',
required: true,
});
}
// If `hasOutputParser` is undefined it must be version 1.3 or earlier so we
// always add the output parser input
const hasOutputParser = parameters?.hasOutputParser;
@@ -119,6 +130,18 @@ export const nodeProperties: INodeProperties[] = [
},
},
},
{
displayName: 'Enable Fallback Model',
name: 'needsFallback',
type: 'boolean',
default: false,
noDataExpression: true,
displayOptions: {
hide: {
'@version': [1, 1.1, 1.3],
},
},
},
{
displayName: 'Chat Messages (if Using a Chat Model)',
name: 'messages',
@@ -275,4 +298,16 @@ export const nodeProperties: INodeProperties[] = [
},
},
},
{
displayName:
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
name: 'fallbackNotice',
type: 'notice',
default: '',
displayOptions: {
show: {
needsFallback: [true],
},
},
},
];

View File

@@ -1,5 +1,6 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
import assert from 'node:assert';
import { getPromptInputByType } from '@utils/helpers';
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
@@ -7,11 +8,40 @@ import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
import { executeChain } from './chainExecutor';
import { type MessageTemplate } from './types';
async function getChatModel(
ctx: IExecuteFunctions,
index: number = 0,
): Promise<BaseLanguageModel | undefined> {
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
let model;
if (Array.isArray(connectedModels) && index !== undefined) {
if (connectedModels.length <= index) {
return undefined;
}
// We get the models in reversed order from the workflow so we need to reverse them again to match the right index
const reversedModels = [...connectedModels].reverse();
model = reversedModels[index] as BaseLanguageModel;
} else {
model = connectedModels as BaseLanguageModel;
}
return model;
}
export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => {
const llm = (await ctx.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
0,
)) as BaseLanguageModel;
const needsFallback = ctx.getNodeParameter('needsFallback', 0, false) as boolean;
const llm = await getChatModel(ctx, 0);
assert(llm, 'Please connect a model to the Chat Model input');
const fallbackLlm = needsFallback ? await getChatModel(ctx, 1) : null;
if (needsFallback && !fallbackLlm) {
throw new NodeOperationError(
ctx.getNode(),
'Please connect a model to the Fallback Model input or disable the fallback option',
);
}
// Get output parser if configured
const outputParser = await getOptionalOutputParser(ctx, itemIndex);
@@ -50,5 +80,6 @@ export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) =>
llm,
outputParser,
messages,
fallbackLlm,
});
};

View File

@@ -38,4 +38,5 @@ export interface ChainExecutionParams {
llm: BaseLanguageModel;
outputParser?: N8nOutputParser;
messages?: MessageTemplate[];
fallbackLlm?: BaseLanguageModel | null;
}

View File

@@ -36,6 +36,7 @@ jest.mock('../methods/responseFormatter', () => ({
describe('ChainLlm Node', () => {
let node: ChainLlm;
let mockExecuteFunction: jest.Mocked<IExecuteFunctions>;
let needsFallback: boolean;
beforeEach(() => {
node = new ChainLlm();
@@ -48,6 +49,8 @@ describe('ChainLlm Node', () => {
error: jest.fn(),
};
needsFallback = false;
mockExecuteFunction.getInputData.mockReturnValue([{ json: {} }]);
mockExecuteFunction.getNode.mockReturnValue({
name: 'Chain LLM',
@@ -57,6 +60,7 @@ describe('ChainLlm Node', () => {
mockExecuteFunction.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
if (param === 'messages.messageValues') return [];
if (param === 'needsFallback') return needsFallback;
return defaultValue;
});
@@ -96,6 +100,7 @@ describe('ChainLlm Node', () => {
context: mockExecuteFunction,
itemIndex: 0,
query: 'Test prompt',
fallbackLlm: null,
llm: expect.any(FakeChatModel),
outputParser: undefined,
messages: [],
@@ -151,6 +156,7 @@ describe('ChainLlm Node', () => {
context: mockExecuteFunction,
itemIndex: 0,
query: 'Old version prompt',
fallbackLlm: null,
llm: expect.any(Object),
outputParser: undefined,
messages: expect.any(Array),
@@ -505,6 +511,35 @@ describe('ChainLlm Node', () => {
expect(responseFormatterModule.formatResponse).toHaveBeenCalledWith(markdownResponse, true);
});
it('should use fallback llm if enabled', async () => {
needsFallback = true;
(helperModule.getPromptInputByType as jest.Mock).mockReturnValue('Test prompt');
(outputParserModule.getOptionalOutputParser as jest.Mock).mockResolvedValue(undefined);
(executeChainModule.executeChain as jest.Mock).mockResolvedValue(['Test response']);
const fakeLLM = new FakeChatModel({});
const fakeFallbackLLM = new FakeChatModel({});
mockExecuteFunction.getInputConnectionData.mockResolvedValue([fakeLLM, fakeFallbackLLM]);
const result = await node.execute.call(mockExecuteFunction);
expect(executeChainModule.executeChain).toHaveBeenCalledWith({
context: mockExecuteFunction,
itemIndex: 0,
query: 'Test prompt',
fallbackLlm: expect.any(FakeChatModel),
llm: expect.any(FakeChatModel),
outputParser: undefined,
messages: [],
});
expect(mockExecuteFunction.logger.debug).toHaveBeenCalledWith('Executing Basic LLM Chain');
expect(result).toEqual([[{ json: expect.any(Object) }]]);
});
it('should pass correct itemIndex to getOptionalOutputParser', async () => {
// Clear any previous calls to the mock
(outputParserModule.getOptionalOutputParser as jest.Mock).mockClear();
@@ -568,6 +603,7 @@ describe('ChainLlm Node', () => {
itemIndex: 0,
query: 'Test prompt 1',
llm: expect.any(Object),
fallbackLlm: null,
outputParser: mockParser1,
messages: [],
});
@@ -576,6 +612,7 @@ describe('ChainLlm Node', () => {
itemIndex: 1,
query: 'Test prompt 2',
llm: expect.any(Object),
fallbackLlm: null,
outputParser: mockParser2,
messages: [],
});
@@ -584,6 +621,7 @@ describe('ChainLlm Node', () => {
itemIndex: 2,
query: 'Test prompt 3',
llm: expect.any(Object),
fallbackLlm: null,
outputParser: mockParser3,
messages: [],
});

View File

@@ -7,26 +7,44 @@ describe('config', () => {
it('should return basic inputs for all parameters', () => {
const inputs = getInputs({});
expect(inputs).toHaveLength(3);
expect(inputs).toHaveLength(4);
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser);
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser);
});
it('should exclude the OutputParser when hasOutputParser is false', () => {
const inputs = getInputs({ hasOutputParser: false });
expect(inputs).toHaveLength(2);
expect(inputs).toHaveLength(3);
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
});
it('should include the OutputParser when hasOutputParser is true', () => {
const inputs = getInputs({ hasOutputParser: true });
expect(inputs).toHaveLength(4);
expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser);
});
it('should exclude the FallbackInput when needsFallback is false', () => {
const inputs = getInputs({ hasOutputParser: true, needsFallback: false });
expect(inputs).toHaveLength(3);
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser);
});
it('should include the FallbackInput when needsFallback is true', () => {
const inputs = getInputs({ hasOutputParser: false, needsFallback: true });
expect(inputs).toHaveLength(3);
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
});
});
describe('nodeProperties', () => {

View File

@@ -29,7 +29,11 @@ const i18n = useI18n();
const { debounce } = useDebounce();
const emit = defineEmits<{
switchSelectedNode: [nodeName: string];
openConnectionNodeCreator: [nodeName: string, connectionType: NodeConnectionType];
openConnectionNodeCreator: [
nodeName: string,
connectionType: NodeConnectionType,
connectionIndex: number,
];
}>();
interface NodeConfig {
@@ -38,6 +42,12 @@ interface NodeConfig {
issues: string[];
}
interface ConnectionContext {
connectionType: NodeConnectionType;
typeIndex: number;
key: string;
}
const possibleConnections = ref<INodeInputConfiguration[]>([]);
const expandedGroups = ref<string[]>([]);
@@ -85,55 +95,60 @@ const connectedNodes = computed<Record<string, NodeConfig[]>>(() => {
);
});
function getConnectionKey(connection: INodeInputConfiguration, globalIndex: number): string {
// Calculate the per-type index for this connection
function getConnectionContext(
connection: INodeInputConfiguration,
globalIndex: number,
): ConnectionContext {
let typeIndex = 0;
for (let i = 0; i < globalIndex; i++) {
if (possibleConnections.value[i].type === connection.type) {
typeIndex++;
}
}
return `${connection.type}-${typeIndex}`;
return {
connectionType: connection.type,
typeIndex,
key: `${connection.type}-${typeIndex}`,
};
}
function getConnectionConfig(connectionKey: string) {
const [type, indexStr] = connectionKey.split('-');
const typeIndex = parseInt(indexStr, 10);
// Find the connection config by type and type-specific index
let currentTypeIndex = 0;
for (const connection of possibleConnections.value) {
if (connection.type === type) {
if (currentTypeIndex === typeIndex) {
return connection;
}
currentTypeIndex++;
}
}
return undefined;
function getConnectionKey(connection: INodeInputConfiguration, globalIndex: number): string {
return getConnectionContext(connection, globalIndex).key;
}
function isMultiConnection(connectionKey: string) {
const connectionConfig = getConnectionConfig(connectionKey);
function getConnectionConfig(connectionType: NodeConnectionType, typeIndex: number) {
return possibleConnections.value
.filter((connection) => connection.type === connectionType)
.at(typeIndex);
}
function isMultiConnection(connectionContext: ConnectionContext) {
const connectionConfig = getConnectionConfig(
connectionContext.connectionType,
connectionContext.typeIndex,
);
return connectionConfig?.maxConnections !== 1;
}
function shouldShowConnectionTooltip(connectionKey: string) {
const [type] = connectionKey.split('-');
return isMultiConnection(connectionKey) && !expandedGroups.value.includes(type);
function shouldShowConnectionTooltip(connectionContext: ConnectionContext) {
return (
isMultiConnection(connectionContext) &&
!expandedGroups.value.includes(connectionContext.connectionType)
);
}
function expandConnectionGroup(connectionKey: string, isExpanded: boolean) {
const [type] = connectionKey.split('-');
function expandConnectionGroup(connectionContext: ConnectionContext, isExpanded: boolean) {
// If the connection is a single connection, we don't need to expand the group
if (!isMultiConnection(connectionKey)) {
if (!isMultiConnection(connectionContext)) {
return;
}
if (isExpanded) {
expandedGroups.value = [...expandedGroups.value, type];
expandedGroups.value = [...expandedGroups.value, connectionContext.connectionType];
} else {
expandedGroups.value = expandedGroups.value.filter((g) => g !== type);
expandedGroups.value = expandedGroups.value.filter(
(g) => g !== connectionContext.connectionType,
);
}
}
@@ -154,9 +169,11 @@ function getINodesFromNames(names: string[]): NodeConfig[] {
.filter((n): n is NodeConfig => n !== null);
}
function hasInputIssues(connectionKey: string) {
const [type] = connectionKey.split('-');
return shouldShowNodeInputIssues.value && (nodeInputIssues.value[type] ?? []).length > 0;
function hasInputIssues(connectionContext: ConnectionContext) {
return (
shouldShowNodeInputIssues.value &&
(nodeInputIssues.value[connectionContext.connectionType] ?? []).length > 0
);
}
function isNodeInputConfiguration(
@@ -181,29 +198,35 @@ function getPossibleSubInputConnections(): INodeInputConfiguration[] {
return nonMainInputs;
}
function onNodeClick(nodeName: string, connectionKey: string) {
const [type] = connectionKey.split('-');
if (isMultiConnection(connectionKey) && !expandedGroups.value.includes(type)) {
expandConnectionGroup(connectionKey, true);
function onNodeClick(nodeName: string, connectionContext: ConnectionContext) {
if (
isMultiConnection(connectionContext) &&
!expandedGroups.value.includes(connectionContext.connectionType)
) {
expandConnectionGroup(connectionContext, true);
return;
}
emit('switchSelectedNode', nodeName);
}
function onPlusClick(connectionKey: string) {
const [type] = connectionKey.split('-');
const connectionNodes = connectedNodes.value[connectionKey];
function onPlusClick(connectionContext: ConnectionContext) {
const connectionNodes = connectedNodes.value[connectionContext.key];
if (
isMultiConnection(connectionKey) &&
!expandedGroups.value.includes(type) &&
isMultiConnection(connectionContext) &&
!expandedGroups.value.includes(connectionContext.connectionType) &&
connectionNodes.length >= 1
) {
expandConnectionGroup(connectionKey, true);
expandConnectionGroup(connectionContext, true);
return;
}
emit('openConnectionNodeCreator', props.rootNode.name, type as NodeConnectionType);
emit(
'openConnectionNodeCreator',
props.rootNode.name,
connectionContext.connectionType,
connectionContext.typeIndex,
);
}
function showNodeInputsIssues() {
@@ -247,12 +270,12 @@ defineExpose({
<span
:class="{
[$style.connectionLabel]: true,
[$style.hasIssues]: hasInputIssues(getConnectionKey(connection, index)),
[$style.hasIssues]: hasInputIssues(getConnectionContext(connection, index)),
}"
v-text="`${connection.displayName}${connection.required ? ' *' : ''}`"
/>
<OnClickOutside
@trigger="expandConnectionGroup(getConnectionKey(connection, index), false)"
@trigger="expandConnectionGroup(getConnectionContext(connection, index), false)"
>
<div
ref="connectedNodesWrapper"
@@ -261,7 +284,7 @@ defineExpose({
[$style.connectedNodesWrapperExpanded]: expandedGroups.includes(connection.type),
}"
:style="`--nodes-length: ${connectedNodes[getConnectionKey(connection, index)].length}`"
@click="expandConnectionGroup(getConnectionKey(connection, index), true)"
@click="expandConnectionGroup(getConnectionContext(connection, index), true)"
>
<div
v-if="
@@ -271,9 +294,9 @@ defineExpose({
"
:class="{
[$style.plusButton]: true,
[$style.hasIssues]: hasInputIssues(getConnectionKey(connection, index)),
[$style.hasIssues]: hasInputIssues(getConnectionContext(connection, index)),
}"
@click="onPlusClick(getConnectionKey(connection, index))"
@click="onPlusClick(getConnectionContext(connection, index))"
>
<n8n-tooltip
placement="top"
@@ -281,13 +304,13 @@ defineExpose({
:offset="10"
:show-after="300"
:disabled="
shouldShowConnectionTooltip(getConnectionKey(connection, index)) &&
shouldShowConnectionTooltip(getConnectionContext(connection, index)) &&
connectedNodes[getConnectionKey(connection, index)].length >= 1
"
>
<template #content>
Add {{ connection.displayName }}
<template v-if="hasInputIssues(getConnectionKey(connection, index))">
<template v-if="hasInputIssues(getConnectionContext(connection, index))">
<TitledList
:title="`${i18n.baseText('node.issues')}:`"
:items="nodeInputIssues[connection.type]"
@@ -324,7 +347,7 @@ defineExpose({
:teleported="true"
:offset="10"
:show-after="300"
:disabled="shouldShowConnectionTooltip(getConnectionKey(connection, index))"
:disabled="shouldShowConnectionTooltip(getConnectionContext(connection, index))"
>
<template #content>
{{ node.node.name }}
@@ -338,7 +361,7 @@ defineExpose({
<div
:class="$style.connectedNode"
@click="onNodeClick(node.node.name, getConnectionKey(connection, index))"
@click="onNodeClick(node.node.name, getConnectionContext(connection, index))"
>
<NodeIcon
:node-type="node.nodeType"

View File

@@ -377,6 +377,9 @@ export const useViewStacks = defineStore('nodeCreatorViewStacks', () => {
if (displayNode && filter?.nodes?.length) {
return filter.nodes.includes(i.key);
}
if (displayNode && filter?.excludedNodes?.length) {
return !filter.excludedNodes.includes(i.key);
}
return displayNode;
},

View File

@@ -46,7 +46,11 @@ const emit = defineEmits<{
saveKeyboardShortcut: [event: KeyboardEvent];
valueChanged: [parameterData: IUpdateInformation];
switchSelectedNode: [nodeTypeName: string];
openConnectionNodeCreator: [nodeTypeName: string, connectionType: NodeConnectionType];
openConnectionNodeCreator: [
nodeTypeName: string,
connectionType: NodeConnectionType,
connectionIndex?: number,
];
redrawNode: [nodeName: string];
stopExecution: [];
}>();
@@ -500,8 +504,12 @@ const onSwitchSelectedNode = (nodeTypeName: string) => {
emit('switchSelectedNode', nodeTypeName);
};
const onOpenConnectionNodeCreator = (nodeTypeName: string, connectionType: NodeConnectionType) => {
emit('openConnectionNodeCreator', nodeTypeName, connectionType);
const onOpenConnectionNodeCreator = (
nodeTypeName: string,
connectionType: NodeConnectionType,
connectionIndex: number = 0,
) => {
emit('openConnectionNodeCreator', nodeTypeName, connectionType, connectionIndex);
};
const close = async () => {

View File

@@ -77,7 +77,11 @@ const emit = defineEmits<{
redrawRequired: [];
valueChanged: [value: IUpdateInformation];
switchSelectedNode: [nodeName: string];
openConnectionNodeCreator: [nodeName: string, connectionType: NodeConnectionType];
openConnectionNodeCreator: [
nodeName: string,
connectionType: NodeConnectionType,
connectionIndex?: number,
];
activate: [];
execute: [];
expand: [];
@@ -382,8 +386,12 @@ const onSwitchSelectedNode = (node: string) => {
emit('switchSelectedNode', node);
};
const onOpenConnectionNodeCreator = (nodeName: string, connectionType: NodeConnectionType) => {
emit('openConnectionNodeCreator', nodeName, connectionType);
const onOpenConnectionNodeCreator = (
nodeName: string,
connectionType: NodeConnectionType,
connectionIndex: number = 0,
) => {
emit('openConnectionNodeCreator', nodeName, connectionType, connectionIndex);
};
const populateHiddenIssuesSet = () => {

View File

@@ -1531,7 +1531,10 @@ export function useCanvasOperations() {
if (inputType !== targetConnection.type) return false;
const filter = typeof input === 'object' && 'filter' in input ? input.filter : undefined;
if (filter?.nodes.length && !filter.nodes.includes(sourceNode.type)) {
if (
(filter?.nodes?.length && !filter.nodes?.includes(sourceNode.type)) ||
(filter?.excludedNodes?.length && filter.excludedNodes?.includes(sourceNode.type))
) {
toast.showToast({
title: i18n.baseText('nodeView.showError.nodeNodeCompatible.title'),
message: i18n.baseText('nodeView.showError.nodeNodeCompatible.message', {

View File

@@ -94,10 +94,12 @@ export const useNodeCreatorStore = defineStore(STORES.NODE_CREATOR, () => {
connectionType,
node,
creatorView,
connectionIndex = 0,
}: {
connectionType: NodeConnectionType;
node: string;
creatorView?: NodeFilterType;
connectionIndex?: number;
}) {
const nodeName = node ?? ndvStore.activeNodeName;
const nodeData = nodeName ? workflowsStore.getNodeByName(nodeName) : null;
@@ -118,7 +120,7 @@ export const useNodeCreatorStore = defineStore(STORES.NODE_CREATOR, () => {
sourceHandle: createCanvasConnectionHandleString({
mode: 'inputs',
type: connectionType,
index: 0,
index: connectionIndex,
}),
},
eventSource: NODE_CREATOR_OPEN_SOURCES.NOTICE_ERROR_MESSAGE,

View File

@@ -1189,11 +1189,15 @@ async function onSwitchActiveNode(nodeName: string) {
selectNodes([node.id]);
}
async function onOpenSelectiveNodeCreator(node: string, connectionType: NodeConnectionType) {
nodeCreatorStore.openSelectiveNodeCreator({ node, connectionType });
async function onOpenSelectiveNodeCreator(
node: string,
connectionType: NodeConnectionType,
connectionIndex: number = 0,
) {
nodeCreatorStore.openSelectiveNodeCreator({ node, connectionType, connectionIndex });
}
async function onOpenNodeCreatorForTriggerNodes(source: NodeCreatorOpenSource) {
function onOpenNodeCreatorForTriggerNodes(source: NodeCreatorOpenSource) {
nodeCreatorStore.openNodeCreatorForTriggerNodes(source);
}

View File

@@ -1908,7 +1908,8 @@ export interface INodeInputFilter {
// TODO: Later add more filter options like categories, subcatogries,
// regex, allow to exclude certain nodes, ... ?
// Potentially change totally after alpha/beta. Is not a breaking change after all.
nodes: string[]; // Allowed nodes
nodes?: string[]; // Allowed nodes
excludedNodes?: string[];
}
export interface INodeInputConfiguration {