diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts index 912e06721f..1b302a82f2 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts @@ -17,33 +17,25 @@ import { toolsAgentExecute } from '../agents/ToolsAgent/V2/execute'; // Function used in the inputs expression to figure out which inputs to // display based on the agent type -function getInputs(hasOutputParser?: boolean): Array { +function getInputs( + hasOutputParser?: boolean, + needsFallback?: boolean, +): Array { interface SpecialInput { type: NodeConnectionType; filter?: INodeInputFilter; + displayName: string; required?: boolean; } const getInputData = ( inputs: SpecialInput[], ): Array => { - const displayNames: { [key: string]: string } = { - ai_languageModel: 'Model', - ai_memory: 'Memory', - ai_tool: 'Tool', - ai_outputParser: 'Output Parser', - }; - - return inputs.map(({ type, filter }) => { - const isModelType = type === 'ai_languageModel'; - let displayName = type in displayNames ? displayNames[type] : undefined; - if (isModelType) { - displayName = 'Chat Model'; - } + return inputs.map(({ type, filter, displayName, required }) => { const input: INodeInputConfiguration = { type, displayName, - required: isModelType, + required, maxConnections: ['ai_languageModel', 'ai_memory', 'ai_outputParser'].includes( type as NodeConnectionType, ) @@ -62,33 +54,40 @@ function getInputs(hasOutputParser?: boolean): Array input.type !== 'ai_outputParser'); } + if (needsFallback === false) { + specialInputs = specialInputs.filter((input) => input.displayName !== 'Fallback Model'); + } return ['main', ...getInputData(specialInputs)]; } @@ -111,10 +113,10 @@ export class AgentV2 implements INodeType { color: '#404040', }, inputs: `={{ - ((hasOutputParser) => { + ((hasOutputParser, needsFallback) => { ${getInputs.toString()}; - return getInputs(hasOutputParser) - })($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true) + return getInputs(hasOutputParser, needsFallback) + })($parameter.hasOutputParser === undefined || $parameter.hasOutputParser === true, $parameter.needsFallback === undefined || $parameter.needsFallback === true) }}`, outputs: [NodeConnectionTypes.Main], properties: [ @@ -160,6 +162,25 @@ export class AgentV2 implements INodeType { }, }, }, + { + displayName: 'Enable Fallback Model', + name: 'needsFallback', + type: 'boolean', + default: false, + noDataExpression: true, + }, + { + displayName: + 'Connect an additional language model on the canvas to use it as a fallback if the main model fails', + name: 'fallbackNotice', + type: 'notice', + default: '', + displayOptions: { + show: { + needsFallback: [true], + }, + }, + }, ...toolsAgentProperties, ], }; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts index 2cefa4ed4b..55b9bc1466 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V1/execute.ts @@ -1,3 +1,4 @@ +import type { BaseLanguageModel } from '@langchain/core/language_models/base'; import { RunnableSequence } from '@langchain/core/runnables'; import { AgentExecutor, createToolCallingAgent } from 'langchain/agents'; import omit from 'lodash/omit'; @@ -40,7 +41,7 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise, + prompt: ChatPromptTemplate, + options: { maxIterations?: number; returnIntermediateSteps?: boolean }, + outputParser?: N8nOutputParser, + memory?: BaseChatMemory, + fallbackModel?: BaseChatModel | null, +) { + const modelWithFallback = fallbackModel ? model.withFallbacks([fallbackModel]) : model; + const agent = createToolCallingAgent({ + llm: modelWithFallback, + tools, + prompt, + streamRunnable: false, + }); + + const runnableAgent = RunnableSequence.from([ + agent, + getAgentStepsParser(outputParser, memory), + fixEmptyContentMessage, + ]); + + return AgentExecutor.fromAgentAndTools({ + agent: runnableAgent, + memory, + tools, + returnIntermediateSteps: options.returnIntermediateSteps === true, + maxIterations: options.maxIterations ?? 10, + }); +} + /* ----------------------------------------------------------- Main Executor Function ----------------------------------------------------------- */ @@ -42,8 +84,18 @@ export async function toolsAgentExecute(this: IExecuteFunctions): Promise { - const model = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0); +export async function getChatModel( + ctx: IExecuteFunctions, + index: number = 0, +): Promise { + const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0); + + let model; + + if (Array.isArray(connectedModels) && index !== undefined) { + if (connectedModels.length <= index) { + return undefined; + } + // We get the models in reversed order from the workflow so we need to reverse them to match the right index + const reversedModels = [...connectedModels].reverse(); + model = reversedModels[index] as BaseChatModel; + } else { + model = connectedModels as BaseChatModel; + } + if (!isChatInstance(model) || !model.bindTools) { throw new NodeOperationError( ctx.getNode(), diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts index 5dbbd2d939..31532ec24f 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/ToolsAgentV2.test.ts @@ -52,6 +52,7 @@ describe('toolsAgentExecute', () => { // Mock getNodeParameter to return default values mockContext.getNodeParameter.mockImplementation((param, _i, defaultValue) => { if (param === 'text') return 'test input'; + if (param === 'needsFallback') return false; if (param === 'options.batching.batchSize') return defaultValue; if (param === 'options.batching.delayBetweenBatches') return defaultValue; if (param === 'options') @@ -104,6 +105,7 @@ describe('toolsAgentExecute', () => { if (param === 'options.batching.batchSize') return 2; if (param === 'options.batching.delayBetweenBatches') return 100; if (param === 'text') return 'test input'; + if (param === 'needsFallback') return false; if (param === 'options') return { systemMessage: 'You are a helpful assistant', @@ -157,6 +159,7 @@ describe('toolsAgentExecute', () => { if (param === 'options.batching.batchSize') return 2; if (param === 'options.batching.delayBetweenBatches') return 0; if (param === 'text') return 'test input'; + if (param === 'needsFallback') return false; if (param === 'options') return { systemMessage: 'You are a helpful assistant', @@ -206,6 +209,7 @@ describe('toolsAgentExecute', () => { if (param === 'options.batching.batchSize') return 2; if (param === 'options.batching.delayBetweenBatches') return 0; if (param === 'text') return 'test input'; + if (param === 'needsFallback') return false; if (param === 'options') return { systemMessage: 'You are a helpful assistant', diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts index d9768dcb43..a89fd3bb2c 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/test/ToolsAgent/commons.test.ts @@ -2,6 +2,7 @@ import type { BaseChatMemory } from '@langchain/community/memory/chat_memory'; import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { HumanMessage } from '@langchain/core/messages'; import type { BaseMessagePromptTemplateLike } from '@langchain/core/prompts'; +import { FakeLLM, FakeStreamingChatModel } from '@langchain/core/utils/testing'; import { Buffer } from 'buffer'; import { mock } from 'jest-mock-extended'; import type { ToolsAgentAction } from 'langchain/dist/agents/tool_calling/output_parser'; @@ -163,6 +164,72 @@ describe('getChatModel', () => { mockContext.getNode.mockReturnValue(mock()); await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError); }); + + it('should return the first model when multiple models are connected and no index specified', async () => { + const fakeChatModel1 = new FakeStreamingChatModel({}); + const fakeChatModel2 = new FakeStreamingChatModel({}); + + mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]); + + const model = await getChatModel(mockContext); + expect(model).toEqual(fakeChatModel2); // Should return the last model (reversed array) + }); + + it('should return the model at specified index when multiple models are connected', async () => { + const fakeChatModel1 = new FakeStreamingChatModel({}); + + const fakeChatModel2 = new FakeStreamingChatModel({}); + + mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]); + + const model = await getChatModel(mockContext, 0); + expect(model).toEqual(fakeChatModel2); // Should return the first model after reversal (index 0) + }); + + it('should return the fallback model at index 1 when multiple models are connected', async () => { + const fakeChatModel1 = new FakeStreamingChatModel({}); + const fakeChatModel2 = new FakeStreamingChatModel({}); + + mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeChatModel2]); + + const model = await getChatModel(mockContext, 1); + expect(model).toEqual(fakeChatModel1); // Should return the second model after reversal (index 1) + }); + + it('should return undefined when requested index is out of bounds', async () => { + const fakeChatModel1 = mock(); + fakeChatModel1.bindTools = jest.fn(); + fakeChatModel1.lc_namespace = ['chat_models']; + + mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1]); + mockContext.getNode.mockReturnValue(mock()); + + const result = await getChatModel(mockContext, 2); + + expect(result).toBeUndefined(); + }); + + it('should throw error when single model does not support tools', async () => { + const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls + + mockContext.getInputConnectionData.mockResolvedValue(fakeInvalidModel); + mockContext.getNode.mockReturnValue(mock()); + + await expect(getChatModel(mockContext)).rejects.toThrow(NodeOperationError); + await expect(getChatModel(mockContext)).rejects.toThrow( + 'Tools Agent requires Chat Model which supports Tools calling', + ); + }); + + it('should throw error when model at specified index does not support tools', async () => { + const fakeChatModel1 = new FakeStreamingChatModel({}); + const fakeInvalidModel = new FakeLLM({}); // doesn't support tool calls + + mockContext.getInputConnectionData.mockResolvedValue([fakeChatModel1, fakeInvalidModel]); + mockContext.getNode.mockReturnValue(mock()); + + await expect(getChatModel(mockContext, 0)).rejects.toThrow(NodeOperationError); + }); }); describe('getOptionalMemory', () => { diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts index b16c1a60d4..1b1472251f 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts @@ -88,15 +88,24 @@ async function executeSimpleChain({ llm, query, prompt, + fallbackLlm, }: { context: IExecuteFunctions; llm: BaseLanguageModel; query: string; prompt: ChatPromptTemplate | PromptTemplate; + fallbackLlm?: BaseLanguageModel | null; }) { const outputParser = getOutputParserForLLM(llm); + let model; - const chain = prompt.pipe(llm).pipe(outputParser).withConfig(getTracingConfig(context)); + if (fallbackLlm) { + model = llm.withFallbacks([fallbackLlm]); + } else { + model = llm; + } + + const chain = prompt.pipe(model).pipe(outputParser).withConfig(getTracingConfig(context)); // Execute the chain const response = await chain.invoke({ @@ -118,6 +127,7 @@ export async function executeChain({ llm, outputParser, messages, + fallbackLlm, }: ChainExecutionParams): Promise { // If no output parsers provided, use a simple chain with basic prompt template if (!outputParser) { @@ -134,6 +144,7 @@ export async function executeChain({ llm, query, prompt: promptTemplate, + fallbackLlm, }); } diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts index cf9d08495a..2245cfd2a8 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/config.ts @@ -23,6 +23,17 @@ export function getInputs(parameters: IDataObject) { }, ]; + const needsFallback = parameters?.needsFallback; + + if (needsFallback === undefined || needsFallback === true) { + inputs.push({ + displayName: 'Fallback Model', + maxConnections: 1, + type: 'ai_languageModel', + required: true, + }); + } + // If `hasOutputParser` is undefined it must be version 1.3 or earlier so we // always add the output parser input const hasOutputParser = parameters?.hasOutputParser; @@ -119,6 +130,18 @@ export const nodeProperties: INodeProperties[] = [ }, }, }, + { + displayName: 'Enable Fallback Model', + name: 'needsFallback', + type: 'boolean', + default: false, + noDataExpression: true, + displayOptions: { + hide: { + '@version': [1, 1.1, 1.3], + }, + }, + }, { displayName: 'Chat Messages (if Using a Chat Model)', name: 'messages', @@ -275,4 +298,16 @@ export const nodeProperties: INodeProperties[] = [ }, }, }, + { + displayName: + 'Connect an additional language model on the canvas to use it as a fallback if the main model fails', + name: 'fallbackNotice', + type: 'notice', + default: '', + displayOptions: { + show: { + needsFallback: [true], + }, + }, + }, ]; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts index fdd284085c..a5c90fa78f 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/processItem.ts @@ -1,5 +1,6 @@ import type { BaseLanguageModel } from '@langchain/core/language_models/base'; import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow'; +import assert from 'node:assert'; import { getPromptInputByType } from '@utils/helpers'; import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser'; @@ -7,11 +8,40 @@ import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser'; import { executeChain } from './chainExecutor'; import { type MessageTemplate } from './types'; +async function getChatModel( + ctx: IExecuteFunctions, + index: number = 0, +): Promise { + const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0); + + let model; + + if (Array.isArray(connectedModels) && index !== undefined) { + if (connectedModels.length <= index) { + return undefined; + } + // We get the models in reversed order from the workflow so we need to reverse them again to match the right index + const reversedModels = [...connectedModels].reverse(); + model = reversedModels[index] as BaseLanguageModel; + } else { + model = connectedModels as BaseLanguageModel; + } + + return model; +} + export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => { - const llm = (await ctx.getInputConnectionData( - NodeConnectionTypes.AiLanguageModel, - 0, - )) as BaseLanguageModel; + const needsFallback = ctx.getNodeParameter('needsFallback', 0, false) as boolean; + const llm = await getChatModel(ctx, 0); + assert(llm, 'Please connect a model to the Chat Model input'); + + const fallbackLlm = needsFallback ? await getChatModel(ctx, 1) : null; + if (needsFallback && !fallbackLlm) { + throw new NodeOperationError( + ctx.getNode(), + 'Please connect a model to the Fallback Model input or disable the fallback option', + ); + } // Get output parser if configured const outputParser = await getOptionalOutputParser(ctx, itemIndex); @@ -50,5 +80,6 @@ export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => llm, outputParser, messages, + fallbackLlm, }); }; diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/types.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/types.ts index ce76a3cf6b..b61a8db094 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/types.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/types.ts @@ -38,4 +38,5 @@ export interface ChainExecutionParams { llm: BaseLanguageModel; outputParser?: N8nOutputParser; messages?: MessageTemplate[]; + fallbackLlm?: BaseLanguageModel | null; } diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts index aef85db4a1..5c01548446 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/ChainLlm.node.test.ts @@ -36,6 +36,7 @@ jest.mock('../methods/responseFormatter', () => ({ describe('ChainLlm Node', () => { let node: ChainLlm; let mockExecuteFunction: jest.Mocked; + let needsFallback: boolean; beforeEach(() => { node = new ChainLlm(); @@ -48,6 +49,8 @@ describe('ChainLlm Node', () => { error: jest.fn(), }; + needsFallback = false; + mockExecuteFunction.getInputData.mockReturnValue([{ json: {} }]); mockExecuteFunction.getNode.mockReturnValue({ name: 'Chain LLM', @@ -57,6 +60,7 @@ describe('ChainLlm Node', () => { mockExecuteFunction.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => { if (param === 'messages.messageValues') return []; + if (param === 'needsFallback') return needsFallback; return defaultValue; }); @@ -96,6 +100,7 @@ describe('ChainLlm Node', () => { context: mockExecuteFunction, itemIndex: 0, query: 'Test prompt', + fallbackLlm: null, llm: expect.any(FakeChatModel), outputParser: undefined, messages: [], @@ -151,6 +156,7 @@ describe('ChainLlm Node', () => { context: mockExecuteFunction, itemIndex: 0, query: 'Old version prompt', + fallbackLlm: null, llm: expect.any(Object), outputParser: undefined, messages: expect.any(Array), @@ -505,6 +511,35 @@ describe('ChainLlm Node', () => { expect(responseFormatterModule.formatResponse).toHaveBeenCalledWith(markdownResponse, true); }); + it('should use fallback llm if enabled', async () => { + needsFallback = true; + (helperModule.getPromptInputByType as jest.Mock).mockReturnValue('Test prompt'); + + (outputParserModule.getOptionalOutputParser as jest.Mock).mockResolvedValue(undefined); + + (executeChainModule.executeChain as jest.Mock).mockResolvedValue(['Test response']); + + const fakeLLM = new FakeChatModel({}); + const fakeFallbackLLM = new FakeChatModel({}); + mockExecuteFunction.getInputConnectionData.mockResolvedValue([fakeLLM, fakeFallbackLLM]); + + const result = await node.execute.call(mockExecuteFunction); + + expect(executeChainModule.executeChain).toHaveBeenCalledWith({ + context: mockExecuteFunction, + itemIndex: 0, + query: 'Test prompt', + fallbackLlm: expect.any(FakeChatModel), + llm: expect.any(FakeChatModel), + outputParser: undefined, + messages: [], + }); + + expect(mockExecuteFunction.logger.debug).toHaveBeenCalledWith('Executing Basic LLM Chain'); + + expect(result).toEqual([[{ json: expect.any(Object) }]]); + }); + it('should pass correct itemIndex to getOptionalOutputParser', async () => { // Clear any previous calls to the mock (outputParserModule.getOptionalOutputParser as jest.Mock).mockClear(); @@ -568,6 +603,7 @@ describe('ChainLlm Node', () => { itemIndex: 0, query: 'Test prompt 1', llm: expect.any(Object), + fallbackLlm: null, outputParser: mockParser1, messages: [], }); @@ -576,6 +612,7 @@ describe('ChainLlm Node', () => { itemIndex: 1, query: 'Test prompt 2', llm: expect.any(Object), + fallbackLlm: null, outputParser: mockParser2, messages: [], }); @@ -584,6 +621,7 @@ describe('ChainLlm Node', () => { itemIndex: 2, query: 'Test prompt 3', llm: expect.any(Object), + fallbackLlm: null, outputParser: mockParser3, messages: [], }); diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/config.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/config.test.ts index 2ec49822d7..67bc1a72b4 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/config.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/config.test.ts @@ -7,26 +7,44 @@ describe('config', () => { it('should return basic inputs for all parameters', () => { const inputs = getInputs({}); - expect(inputs).toHaveLength(3); + expect(inputs).toHaveLength(4); expect(inputs[0].type).toBe(NodeConnectionTypes.Main); expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel); - expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser); + expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel); + expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser); }); it('should exclude the OutputParser when hasOutputParser is false', () => { const inputs = getInputs({ hasOutputParser: false }); - expect(inputs).toHaveLength(2); + expect(inputs).toHaveLength(3); expect(inputs[0].type).toBe(NodeConnectionTypes.Main); expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel); + expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel); }); it('should include the OutputParser when hasOutputParser is true', () => { const inputs = getInputs({ hasOutputParser: true }); + expect(inputs).toHaveLength(4); + expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser); + }); + + it('should exclude the FallbackInput when needsFallback is false', () => { + const inputs = getInputs({ hasOutputParser: true, needsFallback: false }); + expect(inputs).toHaveLength(3); + expect(inputs[0].type).toBe(NodeConnectionTypes.Main); + expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel); expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser); }); + + it('should include the FallbackInput when needsFallback is true', () => { + const inputs = getInputs({ hasOutputParser: false, needsFallback: true }); + + expect(inputs).toHaveLength(3); + expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel); + }); }); describe('nodeProperties', () => { diff --git a/packages/frontend/editor-ui/src/components/NDVSubConnections.vue b/packages/frontend/editor-ui/src/components/NDVSubConnections.vue index ce5696bc1c..383ee83fd9 100644 --- a/packages/frontend/editor-ui/src/components/NDVSubConnections.vue +++ b/packages/frontend/editor-ui/src/components/NDVSubConnections.vue @@ -29,7 +29,11 @@ const i18n = useI18n(); const { debounce } = useDebounce(); const emit = defineEmits<{ switchSelectedNode: [nodeName: string]; - openConnectionNodeCreator: [nodeName: string, connectionType: NodeConnectionType]; + openConnectionNodeCreator: [ + nodeName: string, + connectionType: NodeConnectionType, + connectionIndex: number, + ]; }>(); interface NodeConfig { @@ -38,6 +42,12 @@ interface NodeConfig { issues: string[]; } +interface ConnectionContext { + connectionType: NodeConnectionType; + typeIndex: number; + key: string; +} + const possibleConnections = ref([]); const expandedGroups = ref([]); @@ -85,55 +95,60 @@ const connectedNodes = computed>(() => { ); }); -function getConnectionKey(connection: INodeInputConfiguration, globalIndex: number): string { - // Calculate the per-type index for this connection +function getConnectionContext( + connection: INodeInputConfiguration, + globalIndex: number, +): ConnectionContext { let typeIndex = 0; for (let i = 0; i < globalIndex; i++) { if (possibleConnections.value[i].type === connection.type) { typeIndex++; } } - return `${connection.type}-${typeIndex}`; + return { + connectionType: connection.type, + typeIndex, + key: `${connection.type}-${typeIndex}`, + }; } -function getConnectionConfig(connectionKey: string) { - const [type, indexStr] = connectionKey.split('-'); - const typeIndex = parseInt(indexStr, 10); - - // Find the connection config by type and type-specific index - let currentTypeIndex = 0; - for (const connection of possibleConnections.value) { - if (connection.type === type) { - if (currentTypeIndex === typeIndex) { - return connection; - } - currentTypeIndex++; - } - } - return undefined; +function getConnectionKey(connection: INodeInputConfiguration, globalIndex: number): string { + return getConnectionContext(connection, globalIndex).key; } -function isMultiConnection(connectionKey: string) { - const connectionConfig = getConnectionConfig(connectionKey); +function getConnectionConfig(connectionType: NodeConnectionType, typeIndex: number) { + return possibleConnections.value + .filter((connection) => connection.type === connectionType) + .at(typeIndex); +} + +function isMultiConnection(connectionContext: ConnectionContext) { + const connectionConfig = getConnectionConfig( + connectionContext.connectionType, + connectionContext.typeIndex, + ); return connectionConfig?.maxConnections !== 1; } -function shouldShowConnectionTooltip(connectionKey: string) { - const [type] = connectionKey.split('-'); - return isMultiConnection(connectionKey) && !expandedGroups.value.includes(type); +function shouldShowConnectionTooltip(connectionContext: ConnectionContext) { + return ( + isMultiConnection(connectionContext) && + !expandedGroups.value.includes(connectionContext.connectionType) + ); } -function expandConnectionGroup(connectionKey: string, isExpanded: boolean) { - const [type] = connectionKey.split('-'); +function expandConnectionGroup(connectionContext: ConnectionContext, isExpanded: boolean) { // If the connection is a single connection, we don't need to expand the group - if (!isMultiConnection(connectionKey)) { + if (!isMultiConnection(connectionContext)) { return; } if (isExpanded) { - expandedGroups.value = [...expandedGroups.value, type]; + expandedGroups.value = [...expandedGroups.value, connectionContext.connectionType]; } else { - expandedGroups.value = expandedGroups.value.filter((g) => g !== type); + expandedGroups.value = expandedGroups.value.filter( + (g) => g !== connectionContext.connectionType, + ); } } @@ -154,9 +169,11 @@ function getINodesFromNames(names: string[]): NodeConfig[] { .filter((n): n is NodeConfig => n !== null); } -function hasInputIssues(connectionKey: string) { - const [type] = connectionKey.split('-'); - return shouldShowNodeInputIssues.value && (nodeInputIssues.value[type] ?? []).length > 0; +function hasInputIssues(connectionContext: ConnectionContext) { + return ( + shouldShowNodeInputIssues.value && + (nodeInputIssues.value[connectionContext.connectionType] ?? []).length > 0 + ); } function isNodeInputConfiguration( @@ -181,29 +198,35 @@ function getPossibleSubInputConnections(): INodeInputConfiguration[] { return nonMainInputs; } -function onNodeClick(nodeName: string, connectionKey: string) { - const [type] = connectionKey.split('-'); - if (isMultiConnection(connectionKey) && !expandedGroups.value.includes(type)) { - expandConnectionGroup(connectionKey, true); +function onNodeClick(nodeName: string, connectionContext: ConnectionContext) { + if ( + isMultiConnection(connectionContext) && + !expandedGroups.value.includes(connectionContext.connectionType) + ) { + expandConnectionGroup(connectionContext, true); return; } emit('switchSelectedNode', nodeName); } -function onPlusClick(connectionKey: string) { - const [type] = connectionKey.split('-'); - const connectionNodes = connectedNodes.value[connectionKey]; +function onPlusClick(connectionContext: ConnectionContext) { + const connectionNodes = connectedNodes.value[connectionContext.key]; if ( - isMultiConnection(connectionKey) && - !expandedGroups.value.includes(type) && + isMultiConnection(connectionContext) && + !expandedGroups.value.includes(connectionContext.connectionType) && connectionNodes.length >= 1 ) { - expandConnectionGroup(connectionKey, true); + expandConnectionGroup(connectionContext, true); return; } - emit('openConnectionNodeCreator', props.rootNode.name, type as NodeConnectionType); + emit( + 'openConnectionNodeCreator', + props.rootNode.name, + connectionContext.connectionType, + connectionContext.typeIndex, + ); } function showNodeInputsIssues() { @@ -247,12 +270,12 @@ defineExpose({