diff --git a/cypress/e2e/5-ndv.cy.ts b/cypress/e2e/5-ndv.cy.ts index c2c256a109..cae0921634 100644 --- a/cypress/e2e/5-ndv.cy.ts +++ b/cypress/e2e/5-ndv.cy.ts @@ -565,18 +565,20 @@ describe('NDV', () => { { title: 'Language Models', id: 'ai_languageModel', + index: 0, }, { title: 'Tools', id: 'ai_tool', + index: 0, }, ]; workflowPage.actions.addInitialNodeToCanvas('AI Agent', { keepNdvOpen: true }); connectionGroups.forEach((group) => { - cy.getByTestId(`add-subnode-${group.id}`).should('exist'); - cy.getByTestId(`add-subnode-${group.id}`).click(); + cy.getByTestId(`add-subnode-${group.id}-${group.index}`).should('exist'); + cy.getByTestId(`add-subnode-${group.id}-${group.index}`).click(); cy.getByTestId('nodes-list-header').contains(group.title).should('exist'); // Add HTTP Request tool @@ -585,16 +587,16 @@ describe('NDV', () => { getFloatingNodeByPosition('outputSub').click({ force: true }); if (group.id === 'ai_languageModel') { - cy.getByTestId(`add-subnode-${group.id}`).should('not.exist'); + cy.getByTestId(`add-subnode-${group.id}-${group.index}`).should('not.exist'); } else { - cy.getByTestId(`add-subnode-${group.id}`).should('exist'); + cy.getByTestId(`add-subnode-${group.id}-${group.index}`).should('exist'); // Expand the subgroup - cy.getByTestId('subnode-connection-group-ai_tool').click(); - cy.getByTestId(`add-subnode-${group.id}`).click(); + cy.getByTestId('subnode-connection-group-ai_tool-0').click(); + cy.getByTestId(`add-subnode-${group.id}-${group.index}`).click(); // Add HTTP Request tool nodeCreator.getters.getNthCreatorItem(2).click(); getFloatingNodeByPosition('outputSub').click({ force: true }); - cy.getByTestId('subnode-connection-group-ai_tool') + cy.getByTestId('subnode-connection-group-ai_tool-0') .findChildByTestId('floating-subnode') .should('have.length', 2); } diff --git a/packages/@n8n/nodes-langchain/nodes/ModelSelector/ModelSelector.node.ts b/packages/@n8n/nodes-langchain/nodes/ModelSelector/ModelSelector.node.ts new file mode 100644 index 0000000000..17d531438d --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/ModelSelector/ModelSelector.node.ts @@ -0,0 +1,204 @@ +/* eslint-disable n8n-nodes-base/node-param-description-wrong-for-dynamic-options */ +/* eslint-disable n8n-nodes-base/node-param-display-name-wrong-for-dynamic-options */ +import type { BaseCallbackHandler, CallbackHandlerMethods } from '@langchain/core/callbacks/base'; +import type { Callbacks } from '@langchain/core/callbacks/manager'; +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { + NodeConnectionTypes, + type INodeType, + type INodeTypeDescription, + type ISupplyDataFunctions, + type SupplyData, + type ILoadOptionsFunctions, + NodeOperationError, +} from 'n8n-workflow'; + +import { numberInputsProperty, configuredInputs } from './helpers'; +import { N8nLlmTracing } from '../llms/N8nLlmTracing'; +import { N8nNonEstimatingTracing } from '../llms/N8nNonEstimatingTracing'; + +interface ModeleSelectionRule { + modelIndex: number; + conditions: { + options: { + caseSensitive: boolean; + typeValidation: 'strict' | 'loose'; + leftValue: string; + version: 1 | 2; + }; + conditions: Array<{ + id: string; + leftValue: string; + rightValue: string; + operator: { + type: string; + operation: string; + name: string; + }; + }>; + combinator: 'and' | 'or'; + }; +} + +function getCallbacksArray( + callbacks: Callbacks | undefined, +): Array { + if (!callbacks) return []; + + if (Array.isArray(callbacks)) { + return callbacks; + } + + // If it's a CallbackManager, extract its handlers + return callbacks.handlers || []; +} + +export class ModelSelector implements INodeType { + description: INodeTypeDescription = { + displayName: 'Model Selector', + name: 'modelSelector', + icon: 'fa:map-signs', + iconColor: 'green', + defaults: { + name: 'Model Selector', + }, + version: 1, + group: ['transform'], + description: + 'Use this node to select one of the connected models to this node based on workflow data', + inputs: `={{ + ((parameters) => { + ${configuredInputs.toString()}; + return configuredInputs(parameters) + })($parameter) + }}`, + outputs: [NodeConnectionTypes.AiLanguageModel], + requiredInputs: 1, + properties: [ + numberInputsProperty, + { + displayName: 'Rules', + name: 'rules', + placeholder: 'Add Rule', + type: 'fixedCollection', + typeOptions: { + multipleValues: true, + sortable: true, + }, + description: 'Rules to map workflow data to specific models', + default: {}, + options: [ + { + displayName: 'Rule', + name: 'rule', + values: [ + { + displayName: 'Model', + name: 'modelIndex', + type: 'options', + description: 'Choose model input from the list', + default: 1, + required: true, + placeholder: 'Choose model input from the list', + typeOptions: { + loadOptionsMethod: 'getModels', + }, + }, + { + displayName: 'Conditions', + name: 'conditions', + placeholder: 'Add Condition', + type: 'filter', + default: {}, + typeOptions: { + filter: { + caseSensitive: true, + typeValidation: 'strict', + version: 2, + }, + }, + description: 'Conditions that must be met to select this model', + }, + ], + }, + ], + }, + ], + }; + + methods = { + loadOptions: { + async getModels(this: ILoadOptionsFunctions) { + const numberInputs = this.getCurrentNodeParameter('numberInputs') as number; + + return Array.from({ length: numberInputs ?? 2 }, (_, i) => ({ + value: i + 1, + name: `Model ${(i + 1).toString()}`, + })); + }, + }, + }; + + async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise { + const models = (await this.getInputConnectionData( + NodeConnectionTypes.AiLanguageModel, + itemIndex, + )) as unknown[]; + + if (!models || models.length === 0) { + throw new NodeOperationError(this.getNode(), 'No models connected', { + itemIndex, + description: 'No models found in input connections', + }); + } + models.reverse(); + + const rules = this.getNodeParameter('rules.rule', itemIndex, []) as ModeleSelectionRule[]; + + if (!rules || rules.length === 0) { + throw new NodeOperationError(this.getNode(), 'No rules defined', { + itemIndex, + description: 'At least one rule must be defined to select a model', + }); + } + + for (let i = 0; i < rules.length; i++) { + const rule = rules[i]; + const modelIndex = rule.modelIndex; + + if (modelIndex <= 0 || modelIndex > models.length) { + throw new NodeOperationError(this.getNode(), `Invalid model index ${modelIndex}`, { + itemIndex, + description: `Model index must be between 1 and ${models.length}`, + }); + } + + const conditionsMet = this.getNodeParameter(`rules.rule[${i}].conditions`, itemIndex, false, { + extractValue: true, + }) as boolean; + + if (conditionsMet) { + const selectedModel = models[modelIndex - 1] as BaseChatModel; + + const originalCallbacks = getCallbacksArray(selectedModel.callbacks); + + for (const currentCallback of originalCallbacks) { + if (currentCallback instanceof N8nLlmTracing) { + currentCallback.setParentRunIndex(this.getNextRunIndex()); + } + } + const modelSelectorTracing = new N8nNonEstimatingTracing(this); + selectedModel.callbacks = [...originalCallbacks, modelSelectorTracing]; + + return { + response: selectedModel, + }; + } + } + + throw new NodeOperationError(this.getNode(), 'No matching rule found', { + itemIndex, + description: 'None of the defined rules matched the workflow data', + }); + } +} diff --git a/packages/@n8n/nodes-langchain/nodes/ModelSelector/helpers.ts b/packages/@n8n/nodes-langchain/nodes/ModelSelector/helpers.ts new file mode 100644 index 0000000000..ab00a3743d --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/ModelSelector/helpers.ts @@ -0,0 +1,59 @@ +import type { INodeInputConfiguration, INodeParameters, INodeProperties } from 'n8n-workflow'; + +export const numberInputsProperty: INodeProperties = { + displayName: 'Number of Inputs', + name: 'numberInputs', + type: 'options', + noDataExpression: true, + default: 2, + options: [ + { + name: '2', + value: 2, + }, + { + name: '3', + value: 3, + }, + { + name: '4', + value: 4, + }, + { + name: '5', + value: 5, + }, + { + name: '6', + value: 6, + }, + { + name: '7', + value: 7, + }, + { + name: '8', + value: 8, + }, + { + name: '9', + value: 9, + }, + { + name: '10', + value: 10, + }, + ], + validateType: 'number', + description: + 'The number of data inputs you want to merge. The node waits for all connected inputs to be executed.', +}; + +export function configuredInputs(parameters: INodeParameters): INodeInputConfiguration[] { + return Array.from({ length: (parameters.numberInputs as number) || 2 }, (_, i) => ({ + type: 'ai_languageModel', + displayName: `Model ${(i + 1).toString()}`, + required: true, + maxConnections: 1, + })); +} diff --git a/packages/@n8n/nodes-langchain/nodes/ModelSelector/test/ModelSelector.node.test.ts b/packages/@n8n/nodes-langchain/nodes/ModelSelector/test/ModelSelector.node.test.ts new file mode 100644 index 0000000000..9e61476764 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/ModelSelector/test/ModelSelector.node.test.ts @@ -0,0 +1,298 @@ +/* eslint-disable @typescript-eslint/no-unsafe-return */ +/* eslint-disable @typescript-eslint/no-unsafe-assignment */ +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { mock } from 'jest-mock-extended'; +import type { ISupplyDataFunctions, INode, ILoadOptionsFunctions } from 'n8n-workflow'; +import { NodeOperationError, NodeConnectionTypes } from 'n8n-workflow'; + +import { ModelSelector } from '../ModelSelector.node'; + +// Mock the N8nLlmTracing module completely to avoid module resolution issues +jest.mock('../../llms/N8nLlmTracing', () => ({ + N8nLlmTracing: jest.fn().mockImplementation(() => ({ + handleLLMStart: jest.fn(), + handleLLMEnd: jest.fn(), + })), +})); + +describe('ModelSelector Node', () => { + let node: ModelSelector; + let mockSupplyDataFunction: jest.Mocked; + let mockLoadOptionsFunction: jest.Mocked; + + beforeEach(() => { + node = new ModelSelector(); + mockSupplyDataFunction = mock(); + mockLoadOptionsFunction = mock(); + + mockSupplyDataFunction.getNode.mockReturnValue({ + name: 'Model Selector', + typeVersion: 1, + parameters: {}, + } as INode); + + jest.clearAllMocks(); + }); + + describe('description', () => { + it('should have the expected properties', () => { + expect(node.description).toBeDefined(); + expect(node.description.name).toBe('modelSelector'); + expect(node.description.displayName).toBe('Model Selector'); + expect(node.description.version).toBe(1); + expect(node.description.group).toEqual(['transform']); + expect(node.description.outputs).toEqual([NodeConnectionTypes.AiLanguageModel]); + expect(node.description.requiredInputs).toBe(1); + }); + + it('should have the correct properties defined', () => { + expect(node.description.properties).toHaveLength(2); + expect(node.description.properties[0].name).toBe('numberInputs'); + expect(node.description.properties[1].name).toBe('rules'); + }); + }); + + describe('loadOptions methods', () => { + describe('getModels', () => { + it('should return correct number of models based on numberInputs parameter', async () => { + mockLoadOptionsFunction.getCurrentNodeParameter.mockReturnValue(3); + + const result = await node.methods.loadOptions.getModels.call(mockLoadOptionsFunction); + + expect(result).toEqual([ + { value: 1, name: 'Model 1' }, + { value: 2, name: 'Model 2' }, + { value: 3, name: 'Model 3' }, + ]); + }); + + it('should default to 2 models when numberInputs is undefined', async () => { + mockLoadOptionsFunction.getCurrentNodeParameter.mockReturnValue(undefined); + + const result = await node.methods.loadOptions.getModels.call(mockLoadOptionsFunction); + + expect(result).toEqual([ + { value: 1, name: 'Model 1' }, + { value: 2, name: 'Model 2' }, + ]); + }); + }); + }); + + describe('supplyData', () => { + const mockModel1: Partial = { + _llmType: () => 'fake-llm', + callbacks: [], + }; + const mockModel2: Partial = { + _llmType: () => 'fake-llm-2', + callbacks: undefined, + }; + const mockModel3: Partial = { + _llmType: () => 'fake-llm-3', + callbacks: [{ handleLLMStart: jest.fn() }], + }; + + beforeEach(() => { + // Note: models array gets reversed in supplyData, so [model1, model2, model3] becomes [model3, model2, model1] + mockSupplyDataFunction.getInputConnectionData.mockResolvedValue([ + mockModel1, + mockModel2, + mockModel3, + ]); + }); + + it('should throw error when no models are connected', async () => { + mockSupplyDataFunction.getInputConnectionData.mockResolvedValue([]); + + await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow( + NodeOperationError, + ); + }); + + it('should throw error when no rules are defined', async () => { + mockSupplyDataFunction.getNodeParameter.mockReturnValue([]); + + await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow( + NodeOperationError, + ); + }); + + it('should return the correct model when rule conditions are met', async () => { + const rules = [ + { + modelIndex: '2', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + const result = await node.supplyData.call(mockSupplyDataFunction, 0); + + // After reverse: [model3, model2, model1], so index 2 (1-based) = model2 + expect(result.response).toBe(mockModel2); + }); + + it('should add N8nLlmTracing callback to selected model', async () => { + const rules = [ + { + modelIndex: '1', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + const result = await node.supplyData.call(mockSupplyDataFunction, 0); + + // After reverse: [model3, model2, model1], so index 1 (1-based) = model3 + expect(result.response).toBe(mockModel3); + expect((result.response as BaseChatModel).callbacks).toHaveLength(2); // original + N8nLlmTracing + }); + + it('should handle models with undefined callbacks', async () => { + const rules = [ + { + modelIndex: '2', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + const result = await node.supplyData.call(mockSupplyDataFunction, 0); + + // After reverse: [model3, model2, model1], so index 2 (1-based) = model2 + expect(result.response).toBe(mockModel2); + // Should have 1 callback added (N8nLlmTracing) + expect(Array.isArray((result.response as BaseChatModel).callbacks)).toBe(true); + expect((result.response as BaseChatModel).callbacks).toHaveLength(2); + }); + + it('should evaluate multiple rules and return first matching model', async () => { + const rules = [ + { + modelIndex: '1', + conditions: {}, + }, + { + modelIndex: '3', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(false) // first rule conditions evaluation + .mockReturnValueOnce(true); // second rule conditions evaluation + + const result = await node.supplyData.call(mockSupplyDataFunction, 0); + + // After reverse: [model3, model2, model1], so index 3 (1-based) = model1 + expect(result.response).toBe(mockModel1); + }); + + it('should throw error when no rules match', async () => { + const rules = [ + { + modelIndex: '1', + conditions: {}, + }, + { + modelIndex: '2', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(false) // first rule conditions evaluation + .mockReturnValueOnce(false); // second rule conditions evaluation + + await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow( + NodeOperationError, + ); + }); + + it('should throw error when model index is invalid (too low)', async () => { + const rules = [ + { + modelIndex: '0', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow( + NodeOperationError, + ); + }); + + it('should throw error when model index is invalid (too high)', async () => { + const rules = [ + { + modelIndex: '5', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow( + NodeOperationError, + ); + }); + + it('should handle string model indices correctly', async () => { + const rules = [ + { + modelIndex: '3', + conditions: {}, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + const result = await node.supplyData.call(mockSupplyDataFunction, 0); + + // After reverse: [model3, model2, model1], so index 3 (1-based) = model1 + expect(result.response).toBe(mockModel1); + }); + + it('should call getNodeParameter with correct parameters for condition evaluation', async () => { + const rules = [ + { + modelIndex: '1', + conditions: { field: 'value' }, + }, + ]; + + mockSupplyDataFunction.getNodeParameter + .mockReturnValueOnce(rules) // rules.rule parameter + .mockReturnValueOnce(true); // conditions evaluation + + await node.supplyData.call(mockSupplyDataFunction, 0); + + expect(mockSupplyDataFunction.getNodeParameter).toHaveBeenCalledWith( + 'rules.rule[0].conditions', + 0, + false, + { extractValue: true }, + ); + }); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/ModelSelector/test/helpers.test.ts b/packages/@n8n/nodes-langchain/nodes/ModelSelector/test/helpers.test.ts new file mode 100644 index 0000000000..084801001c --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/ModelSelector/test/helpers.test.ts @@ -0,0 +1,68 @@ +import type { INodeParameters, INodePropertyOptions } from 'n8n-workflow'; + +// Import the function and property +import { numberInputsProperty, configuredInputs } from '../helpers'; + +// We need to extract the configuredInputs function for testing +// Since it's not exported, we'll test it indirectly through the node's inputs property + +describe('ModelSelector Configuration', () => { + describe('numberInputsProperty', () => { + it('should have correct configuration', () => { + expect(numberInputsProperty.displayName).toBe('Number of Inputs'); + expect(numberInputsProperty.name).toBe('numberInputs'); + expect(numberInputsProperty.type).toBe('options'); + expect(numberInputsProperty.default).toBe(2); + expect(numberInputsProperty.validateType).toBe('number'); + }); + + it('should have options from 2 to 10', () => { + const options = numberInputsProperty.options as INodePropertyOptions[]; + expect(options).toHaveLength(9); + expect(options[0]).toEqual({ name: '2', value: 2 }); + expect(options[8]).toEqual({ name: '10', value: 10 }); + }); + + it('should have all sequential values from 2 to 10', () => { + const expectedValues = [2, 3, 4, 5, 6, 7, 8, 9, 10]; + const options = numberInputsProperty.options as INodePropertyOptions[]; + const actualValues = options.map((option) => option.value); + expect(actualValues).toEqual(expectedValues); + }); + }); + + describe('configuredInputs function', () => { + it('should generate correct input configuration for default value', () => { + const parameters: INodeParameters = { numberInputs: 2 }; + const result = configuredInputs(parameters); + + expect(result).toEqual([ + { type: 'ai_languageModel', displayName: 'Model 1', required: true, maxConnections: 1 }, + { type: 'ai_languageModel', displayName: 'Model 2', required: true, maxConnections: 1 }, + ]); + }); + + it('should generate correct input configuration for custom value', () => { + const parameters: INodeParameters = { numberInputs: 5 }; + const result = configuredInputs(parameters); + + expect(result).toEqual([ + { type: 'ai_languageModel', displayName: 'Model 1', required: true, maxConnections: 1 }, + { type: 'ai_languageModel', displayName: 'Model 2', required: true, maxConnections: 1 }, + { type: 'ai_languageModel', displayName: 'Model 3', required: true, maxConnections: 1 }, + { type: 'ai_languageModel', displayName: 'Model 4', required: true, maxConnections: 1 }, + { type: 'ai_languageModel', displayName: 'Model 5', required: true, maxConnections: 1 }, + ]); + }); + + it('should handle undefined numberInputs parameter', () => { + const parameters: INodeParameters = {}; + const result = configuredInputs(parameters); + + expect(result).toEqual([ + { type: 'ai_languageModel', displayName: 'Model 1', required: true, maxConnections: 1 }, + { type: 'ai_languageModel', displayName: 'Model 2', required: true, maxConnections: 1 }, + ]); + }); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts index d4cb66b2fe..ac0f4fa5b6 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V1/AgentV1.node.ts @@ -102,6 +102,7 @@ function getInputs( '@n8n/n8n-nodes-langchain.lmChatDeepSeek', '@n8n/n8n-nodes-langchain.lmChatOpenRouter', '@n8n/n8n-nodes-langchain.lmChatXAiGrok', + '@n8n/n8n-nodes-langchain.modelSelector', ], }, }, diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts index 66adb25da1..912e06721f 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/V2/AgentV2.node.ts @@ -77,6 +77,7 @@ function getInputs(hasOutputParser?: boolean): Array = {}; + + options = { + // Default(OpenAI format) parser + errorDescriptionMapper: (error: NodeError) => error.description, + }; + + constructor( + private executionFunctions: ISupplyDataFunctions, + options?: { + errorDescriptionMapper?: (error: NodeError) => string; + }, + ) { + super(); + this.options = { ...this.options, ...options }; + } + + async handleLLMEnd(output: LLMResult, runId: string) { + // The fallback should never happen since handleLLMStart should always set the run details + // but just in case, we set the index to the length of the runsMap + const runDetails = this.runsMap[runId] ?? { index: Object.keys(this.runsMap).length }; + + output.generations = output.generations.map((gen) => + gen.map((g) => pick(g, ['text', 'generationInfo'])), + ); + + const tokenUsageEstimate = { + completionTokens: 0, + promptTokens: 0, + totalTokens: 0, + }; + const response: { + response: { generations: LLMResult['generations'] }; + tokenUsageEstimate?: typeof tokenUsageEstimate; + } = { + response: { generations: output.generations }, + }; + + response.tokenUsageEstimate = tokenUsageEstimate; + + const parsedMessages = + typeof runDetails.messages === 'string' + ? runDetails.messages + : runDetails.messages.map((message) => { + if (typeof message === 'string') return message; + if (typeof message?.toJSON === 'function') return message.toJSON(); + + return message; + }); + + const sourceNodeRunIndex = + this.#parentRunIndex !== undefined ? this.#parentRunIndex + runDetails.index : undefined; + + this.executionFunctions.addOutputData( + this.connectionType, + runDetails.index, + [[{ json: { ...response } }]], + undefined, + sourceNodeRunIndex, + ); + + logAiEvent(this.executionFunctions, 'ai-llm-generated-output', { + messages: parsedMessages, + options: runDetails.options, + response, + }); + } + + async handleLLMStart(llm: Serialized, prompts: string[], runId: string) { + const estimatedTokens = 0; + const sourceNodeRunIndex = + this.#parentRunIndex !== undefined + ? this.#parentRunIndex + this.executionFunctions.getNextRunIndex() + : undefined; + + const options = llm.type === 'constructor' ? llm.kwargs : llm; + const { index } = this.executionFunctions.addInputData( + this.connectionType, + [ + [ + { + json: { + messages: prompts, + estimatedTokens, + options, + }, + }, + ], + ], + sourceNodeRunIndex, + ); + + // Save the run details for later use when processing `handleLLMEnd` event + this.runsMap[runId] = { + index, + options, + messages: prompts, + }; + } + + async handleLLMError( + error: IDataObject | Error, + runId: string, + parentRunId?: string | undefined, + ) { + const runDetails = this.runsMap[runId] ?? { index: Object.keys(this.runsMap).length }; + + // Filter out non-x- headers to avoid leaking sensitive information in logs + if (typeof error === 'object' && error?.hasOwnProperty('headers')) { + const errorWithHeaders = error as { headers: Record }; + + Object.keys(errorWithHeaders.headers).forEach((key) => { + if (!key.startsWith('x-')) { + delete errorWithHeaders.headers[key]; + } + }); + } + + if (error instanceof NodeError) { + if (this.options.errorDescriptionMapper) { + error.description = this.options.errorDescriptionMapper(error); + } + + this.executionFunctions.addOutputData(this.connectionType, runDetails.index, error); + } else { + // If the error is not a NodeError, we wrap it in a NodeOperationError + this.executionFunctions.addOutputData( + this.connectionType, + runDetails.index, + new NodeOperationError(this.executionFunctions.getNode(), error as JsonObject, { + functionality: 'configuration-node', + }), + ); + } + + logAiEvent(this.executionFunctions, 'ai-llm-errored', { + error: Object.keys(error).length === 0 ? error.toString() : error, + runId, + parentRunId, + }); + } + + // Used to associate subsequent runs with the correct parent run in subnodes of subnodes + setParentRunIndex(runIndex: number) { + this.#parentRunIndex = runIndex; + } +} diff --git a/packages/@n8n/nodes-langchain/package.json b/packages/@n8n/nodes-langchain/package.json index 895d7f7ac4..7a22d77f0b 100644 --- a/packages/@n8n/nodes-langchain/package.json +++ b/packages/@n8n/nodes-langchain/package.json @@ -135,7 +135,8 @@ "dist/nodes/vector_store/VectorStoreZep/VectorStoreZep.node.js", "dist/nodes/vector_store/VectorStoreZepInsert/VectorStoreZepInsert.node.js", "dist/nodes/vector_store/VectorStoreZepLoad/VectorStoreZepLoad.node.js", - "dist/nodes/ToolExecutor/ToolExecutor.node.js" + "dist/nodes/ToolExecutor/ToolExecutor.node.js", + "dist/nodes/ModelSelector/ModelSelector.node.js" ] }, "devDependencies": { diff --git a/packages/core/src/execution-engine/node-execution-context/node-execution-context.ts b/packages/core/src/execution-engine/node-execution-context/node-execution-context.ts index 1c154d254a..5b247199a3 100644 --- a/packages/core/src/execution-engine/node-execution-context/node-execution-context.ts +++ b/packages/core/src/execution-engine/node-execution-context/node-execution-context.ts @@ -17,6 +17,7 @@ import type { IRunExecutionData, IWorkflowExecuteAdditionalData, NodeConnectionType, + NodeInputConnections, NodeParameterValueType, NodeTypeAndVersion, Workflow, @@ -153,6 +154,10 @@ export abstract class NodeExecutionContext implements Omit node.disabled !== true); } + getConnections(destination: INode, connectionType: NodeConnectionType): NodeInputConnections { + return this.workflow.connectionsByDestinationNode[destination.name]?.[connectionType] ?? []; + } + getNodeOutputs(): INodeOutputConfiguration[] { return this.nodeOutputs; } diff --git a/packages/core/src/execution-engine/node-execution-context/supply-data-context.ts b/packages/core/src/execution-engine/node-execution-context/supply-data-context.ts index 48d29ee3d6..9840bff171 100644 --- a/packages/core/src/execution-engine/node-execution-context/supply-data-context.ts +++ b/packages/core/src/execution-engine/node-execution-context/supply-data-context.ts @@ -176,6 +176,7 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData addInputData( connectionType: AINodeConnectionType, data: INodeExecutionData[][], + runIndex?: number, ): { index: number } { const nodeName = this.node.name; const currentNodeRunIndex = this.getNextRunIndex(); @@ -186,6 +187,8 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData connectionType, nodeName, currentNodeRunIndex, + undefined, + runIndex, ).catch((error) => { this.logger.warn( `There was a problem logging input data of node "${nodeName}": ${ @@ -204,6 +207,7 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData currentNodeRunIndex: number, data: INodeExecutionData[][] | ExecutionBaseError, metadata?: ITaskMetadata, + sourceNodeRunIndex?: number, ): void { const nodeName = this.node.name; this.addExecutionDataFunctions( @@ -213,6 +217,7 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData nodeName, currentNodeRunIndex, metadata, + sourceNodeRunIndex, ).catch((error) => { this.logger.warn( `There was a problem logging output data of node "${nodeName}": ${ @@ -230,17 +235,23 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData sourceNodeName: string, currentNodeRunIndex: number, metadata?: ITaskMetadata, + sourceNodeRunIndex?: number, ): Promise { const { additionalData, runExecutionData, - runIndex: sourceNodeRunIndex, + runIndex: currentRunIndex, node: { name: nodeName }, } = this; let taskData: ITaskData | undefined; const source: ISourceData[] = this.parentNode - ? [{ previousNode: this.parentNode.name, previousNodeRun: sourceNodeRunIndex }] + ? [ + { + previousNode: this.parentNode.name, + previousNodeRun: sourceNodeRunIndex ?? currentRunIndex, + }, + ] : []; if (type === 'input') { @@ -313,14 +324,13 @@ export class SupplyDataContext extends BaseExecuteContext implements ISupplyData runExecutionData.executionData!.metadata[sourceNodeName] = []; sourceTaskData = runExecutionData.executionData!.metadata[sourceNodeName]; } - - if (!sourceTaskData[sourceNodeRunIndex]) { - sourceTaskData[sourceNodeRunIndex] = { + if (!sourceTaskData[currentNodeRunIndex]) { + sourceTaskData[currentNodeRunIndex] = { subRun: [], }; } - sourceTaskData[sourceNodeRunIndex].subRun!.push({ + sourceTaskData[currentNodeRunIndex].subRun!.push({ node: nodeName, runIndex: currentNodeRunIndex, }); diff --git a/packages/core/src/execution-engine/node-execution-context/utils/__tests__/get-input-connection-data.test.ts b/packages/core/src/execution-engine/node-execution-context/utils/__tests__/get-input-connection-data.test.ts index 0047d23ae8..f464c3f102 100644 --- a/packages/core/src/execution-engine/node-execution-context/utils/__tests__/get-input-connection-data.test.ts +++ b/packages/core/src/execution-engine/node-execution-context/utils/__tests__/get-input-connection-data.test.ts @@ -68,6 +68,10 @@ describe('getInputConnectionData', () => { nodeTypes.getByNameAndVersion .calledWith(agentNode.type, expect.anything()) .mockReturnValue(agentNodeType); + + // Mock getConnections method used by validateInputConfiguration + // This will be overridden in individual tests as needed + jest.spyOn(executeContext, 'getConnections').mockReturnValue([]); }); describe.each([ @@ -88,7 +92,7 @@ describe('getInputConnectionData', () => { type: 'test.type', disabled: false, }); - const secondNode = mock({ name: 'Second Node', disabled: false }); + const secondNode = mock({ name: 'Second Node', type: 'test.type', disabled: false }); const supplyData = jest.fn().mockResolvedValue({ response }); const nodeType = mock({ supplyData }); @@ -121,6 +125,7 @@ describe('getInputConnectionData', () => { }, ]; workflow.getParentNodes.mockReturnValueOnce([]); + (executeContext.getConnections as jest.Mock).mockReturnValueOnce([]); const result = await executeContext.getInputConnectionData(connectionType, 0); expect(result).toBeUndefined(); @@ -136,6 +141,12 @@ describe('getInputConnectionData', () => { }, ]; workflow.getParentNodes.mockReturnValueOnce([node.name, secondNode.name]); + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: node.name, type: connectionType, index: 0 }], + [{ node: secondNode.name, type: connectionType, index: 0 }], + ]); await expect(executeContext.getInputConnectionData(connectionType, 0)).rejects.toThrow( `Only 1 ${connectionType} sub-nodes are/is allowed to be connected`, @@ -151,6 +162,7 @@ describe('getInputConnectionData', () => { }, ]; workflow.getParentNodes.mockReturnValueOnce([]); + jest.spyOn(executeContext, 'getConnections').mockReturnValueOnce([]); await expect(executeContext.getInputConnectionData(connectionType, 0)).rejects.toThrow( 'must be connected and enabled', @@ -173,6 +185,10 @@ describe('getInputConnectionData', () => { }); workflow.getParentNodes.mockReturnValueOnce([disabledNode.name]); workflow.getNode.calledWith(disabledNode.name).mockReturnValue(disabledNode); + // Mock connections that include the disabled node, but getConnectedNodes will filter it out + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([[{ node: disabledNode.name, type: connectionType, index: 0 }]]); await expect(executeContext.getInputConnectionData(connectionType, 0)).rejects.toThrow( 'must be connected and enabled', @@ -187,6 +203,9 @@ describe('getInputConnectionData', () => { required: true, }, ]; + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([[{ node: node.name, type: connectionType, index: 0 }]]); supplyData.mockRejectedValueOnce(new Error('supplyData error')); @@ -203,6 +222,9 @@ describe('getInputConnectionData', () => { required: true, }, ]; + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([[{ node: node.name, type: connectionType, index: 0 }]]); const configError = new NodeOperationError(node, 'Config Error in node', { functionality: 'configuration-node', @@ -223,6 +245,9 @@ describe('getInputConnectionData', () => { required: true, }, ]; + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([[{ node: node.name, type: connectionType, index: 0 }]]); const closeFunction = jest.fn(); supplyData.mockResolvedValueOnce({ response, closeFunction }); @@ -233,6 +258,127 @@ describe('getInputConnectionData', () => { // @ts-expect-error private property expect(executeContext.closeFunctions).toContain(closeFunction); }); + + it('should handle multiple input configurations of the same type with different max connections', async () => { + agentNodeType.description.inputs = [ + { + type: connectionType, + maxConnections: 2, + required: true, + }, + { + type: connectionType, + maxConnections: 1, + required: false, + }, + ]; + + const thirdNode = mock({ name: 'Third Node', type: 'test.type', disabled: false }); + + // Mock node types for all connected nodes + nodeTypes.getByNameAndVersion + .calledWith(secondNode.type, expect.anything()) + .mockReturnValue(nodeType); + nodeTypes.getByNameAndVersion + .calledWith(thirdNode.type, expect.anything()) + .mockReturnValue(nodeType); + + workflow.getParentNodes.mockReturnValueOnce([node.name, secondNode.name, thirdNode.name]); + workflow.getNode.calledWith(thirdNode.name).mockReturnValue(thirdNode); + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: node.name, type: connectionType, index: 0 }], + [{ node: secondNode.name, type: connectionType, index: 0 }], + [{ node: thirdNode.name, type: connectionType, index: 0 }], + ]); + + const result = await executeContext.getInputConnectionData(connectionType, 0); + expect(result).toEqual([response, response, response]); + expect(supplyData).toHaveBeenCalledTimes(3); + }); + + it('should throw when exceeding total max connections across multiple input configurations', async () => { + agentNodeType.description.inputs = [ + { + type: connectionType, + maxConnections: 1, + required: true, + }, + { + type: connectionType, + maxConnections: 1, + required: false, + }, + ]; + + const thirdNode = mock({ name: 'Third Node', type: 'test.type', disabled: false }); + + // Mock node types for all connected nodes + nodeTypes.getByNameAndVersion + .calledWith(secondNode.type, expect.anything()) + .mockReturnValue(nodeType); + nodeTypes.getByNameAndVersion + .calledWith(thirdNode.type, expect.anything()) + .mockReturnValue(nodeType); + + workflow.getParentNodes.mockReturnValueOnce([node.name, secondNode.name, thirdNode.name]); + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: node.name, type: connectionType, index: 0 }], + [{ node: secondNode.name, type: connectionType, index: 0 }], + [{ node: thirdNode.name, type: connectionType, index: 0 }], + ]); + + await expect(executeContext.getInputConnectionData(connectionType, 0)).rejects.toThrow( + `Only 2 ${connectionType} sub-nodes are/is allowed to be connected`, + ); + expect(supplyData).not.toHaveBeenCalled(); + }); + + it('should return array when multiple input configurations exist even with single connection', async () => { + agentNodeType.description.inputs = [ + { + type: connectionType, + maxConnections: 1, + required: true, + }, + { + type: connectionType, + maxConnections: 2, + required: false, + }, + ]; + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([[{ node: node.name, type: connectionType, index: 0 }]]); + + const result = await executeContext.getInputConnectionData(connectionType, 0); + expect(result).toEqual([response]); + expect(supplyData).toHaveBeenCalledTimes(1); + }); + + it('should return empty array when no connections and multiple optional inputs', async () => { + agentNodeType.description.inputs = [ + { + type: connectionType, + maxConnections: 1, + required: false, + }, + { + type: connectionType, + maxConnections: 1, + required: false, + }, + ]; + workflow.getParentNodes.mockReturnValueOnce([]); + jest.spyOn(executeContext, 'getConnections').mockReturnValueOnce([]); + + const result = await executeContext.getInputConnectionData(connectionType, 0); + expect(result).toEqual([]); + expect(supplyData).not.toHaveBeenCalled(); + }); }); describe(NodeConnectionTypes.AiTool, () => { @@ -270,6 +416,7 @@ describe('getInputConnectionData', () => { }, ]; workflow.getParentNodes.mockReturnValueOnce([]); + jest.spyOn(executeContext, 'getConnections').mockReturnValueOnce([]); const result = await executeContext.getInputConnectionData(NodeConnectionTypes.AiTool, 0); expect(result).toEqual([]); @@ -284,6 +431,7 @@ describe('getInputConnectionData', () => { }, ]; workflow.getParentNodes.mockReturnValueOnce([]); + jest.spyOn(executeContext, 'getConnections').mockReturnValueOnce([]); await expect( executeContext.getInputConnectionData(NodeConnectionTypes.AiTool, 0), @@ -309,6 +457,11 @@ describe('getInputConnectionData', () => { .calledWith(agentNode.name, NodeConnectionTypes.AiTool) .mockReturnValue([disabledToolNode.name]); workflow.getNode.calledWith(disabledToolNode.name).mockReturnValue(disabledToolNode); + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: disabledToolNode.name, type: NodeConnectionTypes.AiTool, index: 0 }], + ]); await expect( executeContext.getInputConnectionData(NodeConnectionTypes.AiTool, 0), @@ -331,6 +484,12 @@ describe('getInputConnectionData', () => { workflow.getParentNodes .calledWith(agentNode.name, NodeConnectionTypes.AiTool) .mockReturnValue([toolNode.name, secondToolNode.name]); + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: toolNode.name, type: NodeConnectionTypes.AiTool, index: 0 }], + [{ node: secondToolNode.name, type: NodeConnectionTypes.AiTool, index: 0 }], + ]); const result = await executeContext.getInputConnectionData(NodeConnectionTypes.AiTool, 0); expect(result).toEqual([mockTool, secondMockTool]); @@ -347,6 +506,11 @@ describe('getInputConnectionData', () => { required: true, }, ]; + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: toolNode.name, type: NodeConnectionTypes.AiTool, index: 0 }], + ]); await expect( executeContext.getInputConnectionData(NodeConnectionTypes.AiTool, 0), @@ -361,6 +525,11 @@ describe('getInputConnectionData', () => { required: true, }, ]; + jest + .spyOn(executeContext, 'getConnections') + .mockReturnValueOnce([ + [{ node: toolNode.name, type: NodeConnectionTypes.AiTool, index: 0 }], + ]); const result = await executeContext.getInputConnectionData(NodeConnectionTypes.AiTool, 0); expect(result).toEqual([mockTool]); diff --git a/packages/core/src/execution-engine/node-execution-context/utils/get-input-connection-data.ts b/packages/core/src/execution-engine/node-execution-context/utils/get-input-connection-data.ts index 89309a3d56..b62f4b5f15 100644 --- a/packages/core/src/execution-engine/node-execution-context/utils/get-input-connection-data.ts +++ b/packages/core/src/execution-engine/node-execution-context/utils/get-input-connection-data.ts @@ -15,12 +15,15 @@ import type { ISupplyDataFunctions, INodeType, INode, + INodeInputConfiguration, + NodeConnectionType, } from 'n8n-workflow'; import { NodeConnectionTypes, NodeOperationError, ExecutionBaseError, ApplicationError, + UserError, } from 'n8n-workflow'; import { createNodeAsTool } from './create-node-as-tool'; @@ -85,6 +88,42 @@ export function makeHandleToolInvocation( }; } +function validateInputConfiguration( + context: ExecuteContext | WebhookContext | SupplyDataContext, + connectionType: NodeConnectionType, + nodeInputs: INodeInputConfiguration[], + connectedNodes: INode[], +) { + const parentNode = context.getNode(); + + const connections = context.getConnections(parentNode, connectionType); + + // Validate missing required connections + for (let index = 0; index < nodeInputs.length; index++) { + const inputConfiguration = nodeInputs[index]; + + if (inputConfiguration.required) { + // For required inputs, we need at least one enabled connected node + if ( + connections.length === 0 || + connections.length <= index || + connections.at(index)?.length === 0 || + !connectedNodes.find((node) => + connections + .at(index) + ?.map((value) => value.node) + .includes(node.name), + ) + ) { + throw new NodeOperationError( + parentNode, + `A ${inputConfiguration?.displayName ?? connectionType} sub-node must be connected and enabled`, + ); + } + } + } +} + export async function getInputConnectionData( this: ExecuteContext | WebhookContext | SupplyDataContext, workflow: Workflow, @@ -101,32 +140,37 @@ export async function getInputConnectionData( abortSignal?: AbortSignal, ): Promise { const parentNode = this.getNode(); + const inputConfigurations = this.nodeInputs.filter((input) => input.type === connectionType); - const inputConfiguration = this.nodeInputs.find((input) => input.type === connectionType); - if (inputConfiguration === undefined) { - throw new ApplicationError('Node does not have input of type', { + if (inputConfigurations === undefined || inputConfigurations.length === 0) { + throw new UserError('Node does not have input of type', { extra: { nodeName: parentNode.name, connectionType }, }); } + const maxConnections = inputConfigurations.reduce( + (acc, currentItem) => + currentItem.maxConnections !== undefined ? acc + currentItem.maxConnections : acc, + 0, + ); + const connectedNodes = this.getConnectedNodes(connectionType); + validateInputConfiguration(this, connectionType, inputConfigurations, connectedNodes); + + // Nothing is connected or required if (connectedNodes.length === 0) { - if (inputConfiguration.required) { - throw new NodeOperationError( - parentNode, - `A ${inputConfiguration?.displayName ?? connectionType} sub-node must be connected and enabled`, - ); - } - return inputConfiguration.maxConnections === 1 ? undefined : []; + return maxConnections === 1 ? undefined : []; } + // Too many connections if ( - inputConfiguration.maxConnections !== undefined && - connectedNodes.length > inputConfiguration.maxConnections + maxConnections !== undefined && + maxConnections !== 0 && + connectedNodes.length > maxConnections ) { throw new NodeOperationError( parentNode, - `Only ${inputConfiguration.maxConnections} ${connectionType} sub-nodes are/is allowed to be connected`, + `Only ${maxConnections} ${connectionType} sub-nodes are/is allowed to be connected`, ); } @@ -214,7 +258,5 @@ export async function getInputConnectionData( } } - return inputConfiguration.maxConnections === 1 - ? (nodes || [])[0]?.response - : nodes.map((node) => node.response); + return maxConnections === 1 ? (nodes || [])[0]?.response : nodes.map((node) => node.response); } diff --git a/packages/core/src/execution-engine/partial-execution-utils/get-source-data-groups.ts b/packages/core/src/execution-engine/partial-execution-utils/get-source-data-groups.ts index d96aba3062..64bd685dfe 100644 --- a/packages/core/src/execution-engine/partial-execution-utils/get-source-data-groups.ts +++ b/packages/core/src/execution-engine/partial-execution-utils/get-source-data-groups.ts @@ -1,4 +1,4 @@ -import { type INode, type IPinData, type IRunData } from 'n8n-workflow'; +import { NodeConnectionTypes, type INode, type IPinData, type IRunData } from 'n8n-workflow'; import type { GraphConnection, DirectedGraph } from './directed-graph'; @@ -99,7 +99,7 @@ export function getSourceDataGroups( if (hasData) { sortedConnectionsWithData.push(connection); - } else { + } else if (connection.type === NodeConnectionTypes.Main) { sortedConnectionsWithoutData.push(connection); } } diff --git a/packages/frontend/editor-ui/src/components/NDVSubConnections.test.ts b/packages/frontend/editor-ui/src/components/NDVSubConnections.test.ts index d557c766eb..f9fe640deb 100644 --- a/packages/frontend/editor-ui/src/components/NDVSubConnections.test.ts +++ b/packages/frontend/editor-ui/src/components/NDVSubConnections.test.ts @@ -5,6 +5,7 @@ import { createTestingPinia } from '@pinia/testing'; import type { INodeUi } from '@/Interface'; import type { INodeTypeDescription, WorkflowParameters } from 'n8n-workflow'; import { NodeConnectionTypes, Workflow } from 'n8n-workflow'; +import { nextTick } from 'vue'; const nodeType: INodeTypeDescription = { displayName: 'OpenAI', @@ -57,6 +58,8 @@ const workflow: WorkflowParameters = { }; const getNodeType = vi.fn(); +let mockWorkflowData = workflow; +let mockGetNodeByName = vi.fn(() => node); vi.mock('@/stores/nodeTypes.store', () => ({ useNodeTypesStore: vi.fn(() => ({ @@ -66,8 +69,8 @@ vi.mock('@/stores/nodeTypes.store', () => ({ vi.mock('@/stores/workflows.store', () => ({ useWorkflowsStore: vi.fn(() => ({ - getCurrentWorkflow: vi.fn(() => new Workflow(workflow)), - getNodeByName: vi.fn(() => node), + getCurrentWorkflow: vi.fn(() => new Workflow(mockWorkflowData)), + getNodeByName: mockGetNodeByName, })), })); @@ -88,17 +91,17 @@ describe('NDVSubConnections', () => { vi.advanceTimersByTime(1000); // Event debounce time await waitFor(() => {}); - expect(getByTestId('subnode-connection-group-ai_tool')).toBeVisible(); + expect(getByTestId('subnode-connection-group-ai_tool-0')).toBeVisible(); expect(html()).toEqual( `
-
+
Tools
- +
@@ -123,4 +126,101 @@ describe('NDVSubConnections', () => { await waitFor(() => {}); expect(component.html()).toEqual(''); }); + + it('should render multiple connections of the same type separately', async () => { + // Mock a ModelSelector-like node with multiple ai_languageModel connections + const multiConnectionNodeType: INodeTypeDescription = { + displayName: 'Model Selector', + name: 'modelSelector', + version: [1], + inputs: [ + { type: NodeConnectionTypes.Main }, + { + type: NodeConnectionTypes.AiLanguageModel, + displayName: 'Model 1', + required: true, + maxConnections: 1, + }, + { + type: NodeConnectionTypes.AiLanguageModel, + displayName: 'Model 2', + required: true, + maxConnections: 1, + }, + { + type: NodeConnectionTypes.AiLanguageModel, + displayName: 'Model 3', + required: true, + maxConnections: 1, + }, + ], + outputs: [NodeConnectionTypes.AiLanguageModel], + properties: [], + defaults: { color: '', name: '' }, + group: [], + description: '', + }; + + const multiConnectionNode: INodeUi = { + ...node, + name: 'ModelSelector', + type: 'modelSelector', + }; + + // Mock connected nodes + const mockWorkflow = { + ...workflow, + nodes: [multiConnectionNode], + connectionsByDestinationNode: { + ModelSelector: { + [NodeConnectionTypes.AiLanguageModel]: [ + null, // Main input (index 0) - no ai_languageModel connection + [{ node: 'OpenAI1', type: NodeConnectionTypes.AiLanguageModel, index: 0 }], // Model 1 (index 1) + [{ node: 'Claude', type: NodeConnectionTypes.AiLanguageModel, index: 0 }], // Model 2 (index 2) + [], // Model 3 (index 3) - no connection + ], + }, + }, + }; + + // Mock additional nodes + const openAI1Node: INodeUi = { + ...node, + name: 'OpenAI1', + type: '@n8n/n8n-nodes-langchain.openAi', + }; + const claudeNode: INodeUi = { + ...node, + name: 'Claude', + type: '@n8n/n8n-nodes-langchain.claude', + }; + + getNodeType.mockReturnValue(multiConnectionNodeType); + + // Update mock data for this test + mockWorkflowData = mockWorkflow; + mockGetNodeByName = vi.fn((name: string) => { + if (name === 'ModelSelector') return multiConnectionNode; + if (name === 'OpenAI1') return openAI1Node; + if (name === 'Claude') return claudeNode; + return null; + }); + + const { getByTestId } = render(NDVSubConnections, { + props: { + rootNode: multiConnectionNode, + }, + }); + vi.advanceTimersByTime(1); + + await nextTick(); + + expect(getByTestId('subnode-connection-group-ai_languageModel-0')).toBeVisible(); // Model 1 + expect(getByTestId('subnode-connection-group-ai_languageModel-1')).toBeVisible(); // Model 2 + expect(getByTestId('subnode-connection-group-ai_languageModel-2')).toBeVisible(); // Model 3 + + expect(getByTestId('add-subnode-ai_languageModel-0')).toBeVisible(); + expect(getByTestId('add-subnode-ai_languageModel-1')).toBeVisible(); + expect(getByTestId('add-subnode-ai_languageModel-2')).toBeVisible(); + }); }); diff --git a/packages/frontend/editor-ui/src/components/NDVSubConnections.vue b/packages/frontend/editor-ui/src/components/NDVSubConnections.vue index 35eeb20d9c..ce5696bc1c 100644 --- a/packages/frontend/editor-ui/src/components/NDVSubConnections.vue +++ b/packages/frontend/editor-ui/src/components/NDVSubConnections.vue @@ -40,7 +40,7 @@ interface NodeConfig { const possibleConnections = ref([]); -const expandedGroups = ref([]); +const expandedGroups = ref([]); const shouldShowNodeInputIssues = ref(false); const nodeType = computed(() => @@ -61,41 +61,79 @@ const nodeInputIssues = computed(() => { return issues?.input ?? {}; }); -const connectedNodes = computed>(() => { +const connectedNodes = computed>(() => { + const typeIndexCounters: Record = {}; + return possibleConnections.value.reduce( (acc, connection) => { - const nodes = getINodesFromNames( - workflow.value.getParentNodes(props.rootNode.name, connection.type), - ); - return { ...acc, [connection.type]: nodes }; + // Track index per connection type + const typeIndex = typeIndexCounters[connection.type] ?? 0; + typeIndexCounters[connection.type] = typeIndex + 1; + + // Get input-index-specific connections using the per-type index + const nodeConnections = + workflow.value.connectionsByDestinationNode[props.rootNode.name]?.[connection.type] ?? []; + const inputConnections = nodeConnections[typeIndex] ?? []; + const nodeNames = inputConnections.map((conn) => conn.node); + const nodes = getINodesFromNames(nodeNames); + + // Use a unique key that combines connection type and per-type index + const connectionKey = `${connection.type}-${typeIndex}`; + return { ...acc, [connectionKey]: nodes }; }, - {} as Record, + {} as Record, ); }); -function getConnectionConfig(connectionType: NodeConnectionType) { - return possibleConnections.value.find((c) => c.type === connectionType); +function getConnectionKey(connection: INodeInputConfiguration, globalIndex: number): string { + // Calculate the per-type index for this connection + let typeIndex = 0; + for (let i = 0; i < globalIndex; i++) { + if (possibleConnections.value[i].type === connection.type) { + typeIndex++; + } + } + return `${connection.type}-${typeIndex}`; } -function isMultiConnection(connectionType: NodeConnectionType) { - const connectionConfig = getConnectionConfig(connectionType); +function getConnectionConfig(connectionKey: string) { + const [type, indexStr] = connectionKey.split('-'); + const typeIndex = parseInt(indexStr, 10); + + // Find the connection config by type and type-specific index + let currentTypeIndex = 0; + for (const connection of possibleConnections.value) { + if (connection.type === type) { + if (currentTypeIndex === typeIndex) { + return connection; + } + currentTypeIndex++; + } + } + return undefined; +} + +function isMultiConnection(connectionKey: string) { + const connectionConfig = getConnectionConfig(connectionKey); return connectionConfig?.maxConnections !== 1; } -function shouldShowConnectionTooltip(connectionType: NodeConnectionType) { - return isMultiConnection(connectionType) && !expandedGroups.value.includes(connectionType); +function shouldShowConnectionTooltip(connectionKey: string) { + const [type] = connectionKey.split('-'); + return isMultiConnection(connectionKey) && !expandedGroups.value.includes(type); } -function expandConnectionGroup(connectionType: NodeConnectionType, isExpanded: boolean) { +function expandConnectionGroup(connectionKey: string, isExpanded: boolean) { + const [type] = connectionKey.split('-'); // If the connection is a single connection, we don't need to expand the group - if (!isMultiConnection(connectionType)) { + if (!isMultiConnection(connectionKey)) { return; } if (isExpanded) { - expandedGroups.value = [...expandedGroups.value, connectionType]; + expandedGroups.value = [...expandedGroups.value, type]; } else { - expandedGroups.value = expandedGroups.value.filter((g) => g !== connectionType); + expandedGroups.value = expandedGroups.value.filter((g) => g !== type); } } @@ -116,10 +154,9 @@ function getINodesFromNames(names: string[]): NodeConfig[] { .filter((n): n is NodeConfig => n !== null); } -function hasInputIssues(connectionType: NodeConnectionType) { - return ( - shouldShowNodeInputIssues.value && (nodeInputIssues.value[connectionType] ?? []).length > 0 - ); +function hasInputIssues(connectionKey: string) { + const [type] = connectionKey.split('-'); + return shouldShowNodeInputIssues.value && (nodeInputIssues.value[type] ?? []).length > 0; } function isNodeInputConfiguration( @@ -144,27 +181,29 @@ function getPossibleSubInputConnections(): INodeInputConfiguration[] { return nonMainInputs; } -function onNodeClick(nodeName: string, connectionType: NodeConnectionType) { - if (isMultiConnection(connectionType) && !expandedGroups.value.includes(connectionType)) { - expandConnectionGroup(connectionType, true); +function onNodeClick(nodeName: string, connectionKey: string) { + const [type] = connectionKey.split('-'); + if (isMultiConnection(connectionKey) && !expandedGroups.value.includes(type)) { + expandConnectionGroup(connectionKey, true); return; } emit('switchSelectedNode', nodeName); } -function onPlusClick(connectionType: NodeConnectionType) { - const connectionNodes = connectedNodes.value[connectionType]; +function onPlusClick(connectionKey: string) { + const [type] = connectionKey.split('-'); + const connectionNodes = connectedNodes.value[connectionKey]; if ( - isMultiConnection(connectionType) && - !expandedGroups.value.includes(connectionType) && + isMultiConnection(connectionKey) && + !expandedGroups.value.includes(type) && connectionNodes.length >= 1 ) { - expandConnectionGroup(connectionType, true); + expandConnectionGroup(connectionKey, true); return; } - emit('openConnectionNodeCreator', props.rootNode.name, connectionType); + emit('openConnectionNodeCreator', props.rootNode.name, type as NodeConnectionType); } function showNodeInputsIssues() { @@ -200,39 +239,41 @@ defineExpose({ :style="`--possible-connections: ${possibleConnections.length}`" >
- +