feat: Add model selector node (#16371)

This commit is contained in:
Benjamin Schroth
2025-06-20 15:30:33 +02:00
committed by GitHub
parent a9688b101f
commit 79650ea55a
20 changed files with 1321 additions and 113 deletions

View File

@@ -0,0 +1,204 @@
/* eslint-disable n8n-nodes-base/node-param-description-wrong-for-dynamic-options */
/* eslint-disable n8n-nodes-base/node-param-display-name-wrong-for-dynamic-options */
import type { BaseCallbackHandler, CallbackHandlerMethods } from '@langchain/core/callbacks/base';
import type { Callbacks } from '@langchain/core/callbacks/manager';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import {
NodeConnectionTypes,
type INodeType,
type INodeTypeDescription,
type ISupplyDataFunctions,
type SupplyData,
type ILoadOptionsFunctions,
NodeOperationError,
} from 'n8n-workflow';
import { numberInputsProperty, configuredInputs } from './helpers';
import { N8nLlmTracing } from '../llms/N8nLlmTracing';
import { N8nNonEstimatingTracing } from '../llms/N8nNonEstimatingTracing';
interface ModeleSelectionRule {
modelIndex: number;
conditions: {
options: {
caseSensitive: boolean;
typeValidation: 'strict' | 'loose';
leftValue: string;
version: 1 | 2;
};
conditions: Array<{
id: string;
leftValue: string;
rightValue: string;
operator: {
type: string;
operation: string;
name: string;
};
}>;
combinator: 'and' | 'or';
};
}
function getCallbacksArray(
callbacks: Callbacks | undefined,
): Array<BaseCallbackHandler | CallbackHandlerMethods> {
if (!callbacks) return [];
if (Array.isArray(callbacks)) {
return callbacks;
}
// If it's a CallbackManager, extract its handlers
return callbacks.handlers || [];
}
export class ModelSelector implements INodeType {
description: INodeTypeDescription = {
displayName: 'Model Selector',
name: 'modelSelector',
icon: 'fa:map-signs',
iconColor: 'green',
defaults: {
name: 'Model Selector',
},
version: 1,
group: ['transform'],
description:
'Use this node to select one of the connected models to this node based on workflow data',
inputs: `={{
((parameters) => {
${configuredInputs.toString()};
return configuredInputs(parameters)
})($parameter)
}}`,
outputs: [NodeConnectionTypes.AiLanguageModel],
requiredInputs: 1,
properties: [
numberInputsProperty,
{
displayName: 'Rules',
name: 'rules',
placeholder: 'Add Rule',
type: 'fixedCollection',
typeOptions: {
multipleValues: true,
sortable: true,
},
description: 'Rules to map workflow data to specific models',
default: {},
options: [
{
displayName: 'Rule',
name: 'rule',
values: [
{
displayName: 'Model',
name: 'modelIndex',
type: 'options',
description: 'Choose model input from the list',
default: 1,
required: true,
placeholder: 'Choose model input from the list',
typeOptions: {
loadOptionsMethod: 'getModels',
},
},
{
displayName: 'Conditions',
name: 'conditions',
placeholder: 'Add Condition',
type: 'filter',
default: {},
typeOptions: {
filter: {
caseSensitive: true,
typeValidation: 'strict',
version: 2,
},
},
description: 'Conditions that must be met to select this model',
},
],
},
],
},
],
};
methods = {
loadOptions: {
async getModels(this: ILoadOptionsFunctions) {
const numberInputs = this.getCurrentNodeParameter('numberInputs') as number;
return Array.from({ length: numberInputs ?? 2 }, (_, i) => ({
value: i + 1,
name: `Model ${(i + 1).toString()}`,
}));
},
},
};
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
const models = (await this.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
itemIndex,
)) as unknown[];
if (!models || models.length === 0) {
throw new NodeOperationError(this.getNode(), 'No models connected', {
itemIndex,
description: 'No models found in input connections',
});
}
models.reverse();
const rules = this.getNodeParameter('rules.rule', itemIndex, []) as ModeleSelectionRule[];
if (!rules || rules.length === 0) {
throw new NodeOperationError(this.getNode(), 'No rules defined', {
itemIndex,
description: 'At least one rule must be defined to select a model',
});
}
for (let i = 0; i < rules.length; i++) {
const rule = rules[i];
const modelIndex = rule.modelIndex;
if (modelIndex <= 0 || modelIndex > models.length) {
throw new NodeOperationError(this.getNode(), `Invalid model index ${modelIndex}`, {
itemIndex,
description: `Model index must be between 1 and ${models.length}`,
});
}
const conditionsMet = this.getNodeParameter(`rules.rule[${i}].conditions`, itemIndex, false, {
extractValue: true,
}) as boolean;
if (conditionsMet) {
const selectedModel = models[modelIndex - 1] as BaseChatModel;
const originalCallbacks = getCallbacksArray(selectedModel.callbacks);
for (const currentCallback of originalCallbacks) {
if (currentCallback instanceof N8nLlmTracing) {
currentCallback.setParentRunIndex(this.getNextRunIndex());
}
}
const modelSelectorTracing = new N8nNonEstimatingTracing(this);
selectedModel.callbacks = [...originalCallbacks, modelSelectorTracing];
return {
response: selectedModel,
};
}
}
throw new NodeOperationError(this.getNode(), 'No matching rule found', {
itemIndex,
description: 'None of the defined rules matched the workflow data',
});
}
}

View File

@@ -0,0 +1,59 @@
import type { INodeInputConfiguration, INodeParameters, INodeProperties } from 'n8n-workflow';
export const numberInputsProperty: INodeProperties = {
displayName: 'Number of Inputs',
name: 'numberInputs',
type: 'options',
noDataExpression: true,
default: 2,
options: [
{
name: '2',
value: 2,
},
{
name: '3',
value: 3,
},
{
name: '4',
value: 4,
},
{
name: '5',
value: 5,
},
{
name: '6',
value: 6,
},
{
name: '7',
value: 7,
},
{
name: '8',
value: 8,
},
{
name: '9',
value: 9,
},
{
name: '10',
value: 10,
},
],
validateType: 'number',
description:
'The number of data inputs you want to merge. The node waits for all connected inputs to be executed.',
};
export function configuredInputs(parameters: INodeParameters): INodeInputConfiguration[] {
return Array.from({ length: (parameters.numberInputs as number) || 2 }, (_, i) => ({
type: 'ai_languageModel',
displayName: `Model ${(i + 1).toString()}`,
required: true,
maxConnections: 1,
}));
}

View File

@@ -0,0 +1,298 @@
/* eslint-disable @typescript-eslint/no-unsafe-return */
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { mock } from 'jest-mock-extended';
import type { ISupplyDataFunctions, INode, ILoadOptionsFunctions } from 'n8n-workflow';
import { NodeOperationError, NodeConnectionTypes } from 'n8n-workflow';
import { ModelSelector } from '../ModelSelector.node';
// Mock the N8nLlmTracing module completely to avoid module resolution issues
jest.mock('../../llms/N8nLlmTracing', () => ({
N8nLlmTracing: jest.fn().mockImplementation(() => ({
handleLLMStart: jest.fn(),
handleLLMEnd: jest.fn(),
})),
}));
describe('ModelSelector Node', () => {
let node: ModelSelector;
let mockSupplyDataFunction: jest.Mocked<ISupplyDataFunctions>;
let mockLoadOptionsFunction: jest.Mocked<ILoadOptionsFunctions>;
beforeEach(() => {
node = new ModelSelector();
mockSupplyDataFunction = mock<ISupplyDataFunctions>();
mockLoadOptionsFunction = mock<ILoadOptionsFunctions>();
mockSupplyDataFunction.getNode.mockReturnValue({
name: 'Model Selector',
typeVersion: 1,
parameters: {},
} as INode);
jest.clearAllMocks();
});
describe('description', () => {
it('should have the expected properties', () => {
expect(node.description).toBeDefined();
expect(node.description.name).toBe('modelSelector');
expect(node.description.displayName).toBe('Model Selector');
expect(node.description.version).toBe(1);
expect(node.description.group).toEqual(['transform']);
expect(node.description.outputs).toEqual([NodeConnectionTypes.AiLanguageModel]);
expect(node.description.requiredInputs).toBe(1);
});
it('should have the correct properties defined', () => {
expect(node.description.properties).toHaveLength(2);
expect(node.description.properties[0].name).toBe('numberInputs');
expect(node.description.properties[1].name).toBe('rules');
});
});
describe('loadOptions methods', () => {
describe('getModels', () => {
it('should return correct number of models based on numberInputs parameter', async () => {
mockLoadOptionsFunction.getCurrentNodeParameter.mockReturnValue(3);
const result = await node.methods.loadOptions.getModels.call(mockLoadOptionsFunction);
expect(result).toEqual([
{ value: 1, name: 'Model 1' },
{ value: 2, name: 'Model 2' },
{ value: 3, name: 'Model 3' },
]);
});
it('should default to 2 models when numberInputs is undefined', async () => {
mockLoadOptionsFunction.getCurrentNodeParameter.mockReturnValue(undefined);
const result = await node.methods.loadOptions.getModels.call(mockLoadOptionsFunction);
expect(result).toEqual([
{ value: 1, name: 'Model 1' },
{ value: 2, name: 'Model 2' },
]);
});
});
});
describe('supplyData', () => {
const mockModel1: Partial<BaseChatModel> = {
_llmType: () => 'fake-llm',
callbacks: [],
};
const mockModel2: Partial<BaseChatModel> = {
_llmType: () => 'fake-llm-2',
callbacks: undefined,
};
const mockModel3: Partial<BaseChatModel> = {
_llmType: () => 'fake-llm-3',
callbacks: [{ handleLLMStart: jest.fn() }],
};
beforeEach(() => {
// Note: models array gets reversed in supplyData, so [model1, model2, model3] becomes [model3, model2, model1]
mockSupplyDataFunction.getInputConnectionData.mockResolvedValue([
mockModel1,
mockModel2,
mockModel3,
]);
});
it('should throw error when no models are connected', async () => {
mockSupplyDataFunction.getInputConnectionData.mockResolvedValue([]);
await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow(
NodeOperationError,
);
});
it('should throw error when no rules are defined', async () => {
mockSupplyDataFunction.getNodeParameter.mockReturnValue([]);
await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow(
NodeOperationError,
);
});
it('should return the correct model when rule conditions are met', async () => {
const rules = [
{
modelIndex: '2',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
const result = await node.supplyData.call(mockSupplyDataFunction, 0);
// After reverse: [model3, model2, model1], so index 2 (1-based) = model2
expect(result.response).toBe(mockModel2);
});
it('should add N8nLlmTracing callback to selected model', async () => {
const rules = [
{
modelIndex: '1',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
const result = await node.supplyData.call(mockSupplyDataFunction, 0);
// After reverse: [model3, model2, model1], so index 1 (1-based) = model3
expect(result.response).toBe(mockModel3);
expect((result.response as BaseChatModel).callbacks).toHaveLength(2); // original + N8nLlmTracing
});
it('should handle models with undefined callbacks', async () => {
const rules = [
{
modelIndex: '2',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
const result = await node.supplyData.call(mockSupplyDataFunction, 0);
// After reverse: [model3, model2, model1], so index 2 (1-based) = model2
expect(result.response).toBe(mockModel2);
// Should have 1 callback added (N8nLlmTracing)
expect(Array.isArray((result.response as BaseChatModel).callbacks)).toBe(true);
expect((result.response as BaseChatModel).callbacks).toHaveLength(2);
});
it('should evaluate multiple rules and return first matching model', async () => {
const rules = [
{
modelIndex: '1',
conditions: {},
},
{
modelIndex: '3',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(false) // first rule conditions evaluation
.mockReturnValueOnce(true); // second rule conditions evaluation
const result = await node.supplyData.call(mockSupplyDataFunction, 0);
// After reverse: [model3, model2, model1], so index 3 (1-based) = model1
expect(result.response).toBe(mockModel1);
});
it('should throw error when no rules match', async () => {
const rules = [
{
modelIndex: '1',
conditions: {},
},
{
modelIndex: '2',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(false) // first rule conditions evaluation
.mockReturnValueOnce(false); // second rule conditions evaluation
await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow(
NodeOperationError,
);
});
it('should throw error when model index is invalid (too low)', async () => {
const rules = [
{
modelIndex: '0',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow(
NodeOperationError,
);
});
it('should throw error when model index is invalid (too high)', async () => {
const rules = [
{
modelIndex: '5',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
await expect(node.supplyData.call(mockSupplyDataFunction, 0)).rejects.toThrow(
NodeOperationError,
);
});
it('should handle string model indices correctly', async () => {
const rules = [
{
modelIndex: '3',
conditions: {},
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
const result = await node.supplyData.call(mockSupplyDataFunction, 0);
// After reverse: [model3, model2, model1], so index 3 (1-based) = model1
expect(result.response).toBe(mockModel1);
});
it('should call getNodeParameter with correct parameters for condition evaluation', async () => {
const rules = [
{
modelIndex: '1',
conditions: { field: 'value' },
},
];
mockSupplyDataFunction.getNodeParameter
.mockReturnValueOnce(rules) // rules.rule parameter
.mockReturnValueOnce(true); // conditions evaluation
await node.supplyData.call(mockSupplyDataFunction, 0);
expect(mockSupplyDataFunction.getNodeParameter).toHaveBeenCalledWith(
'rules.rule[0].conditions',
0,
false,
{ extractValue: true },
);
});
});
});

View File

@@ -0,0 +1,68 @@
import type { INodeParameters, INodePropertyOptions } from 'n8n-workflow';
// Import the function and property
import { numberInputsProperty, configuredInputs } from '../helpers';
// We need to extract the configuredInputs function for testing
// Since it's not exported, we'll test it indirectly through the node's inputs property
describe('ModelSelector Configuration', () => {
describe('numberInputsProperty', () => {
it('should have correct configuration', () => {
expect(numberInputsProperty.displayName).toBe('Number of Inputs');
expect(numberInputsProperty.name).toBe('numberInputs');
expect(numberInputsProperty.type).toBe('options');
expect(numberInputsProperty.default).toBe(2);
expect(numberInputsProperty.validateType).toBe('number');
});
it('should have options from 2 to 10', () => {
const options = numberInputsProperty.options as INodePropertyOptions[];
expect(options).toHaveLength(9);
expect(options[0]).toEqual({ name: '2', value: 2 });
expect(options[8]).toEqual({ name: '10', value: 10 });
});
it('should have all sequential values from 2 to 10', () => {
const expectedValues = [2, 3, 4, 5, 6, 7, 8, 9, 10];
const options = numberInputsProperty.options as INodePropertyOptions[];
const actualValues = options.map((option) => option.value);
expect(actualValues).toEqual(expectedValues);
});
});
describe('configuredInputs function', () => {
it('should generate correct input configuration for default value', () => {
const parameters: INodeParameters = { numberInputs: 2 };
const result = configuredInputs(parameters);
expect(result).toEqual([
{ type: 'ai_languageModel', displayName: 'Model 1', required: true, maxConnections: 1 },
{ type: 'ai_languageModel', displayName: 'Model 2', required: true, maxConnections: 1 },
]);
});
it('should generate correct input configuration for custom value', () => {
const parameters: INodeParameters = { numberInputs: 5 };
const result = configuredInputs(parameters);
expect(result).toEqual([
{ type: 'ai_languageModel', displayName: 'Model 1', required: true, maxConnections: 1 },
{ type: 'ai_languageModel', displayName: 'Model 2', required: true, maxConnections: 1 },
{ type: 'ai_languageModel', displayName: 'Model 3', required: true, maxConnections: 1 },
{ type: 'ai_languageModel', displayName: 'Model 4', required: true, maxConnections: 1 },
{ type: 'ai_languageModel', displayName: 'Model 5', required: true, maxConnections: 1 },
]);
});
it('should handle undefined numberInputs parameter', () => {
const parameters: INodeParameters = {};
const result = configuredInputs(parameters);
expect(result).toEqual([
{ type: 'ai_languageModel', displayName: 'Model 1', required: true, maxConnections: 1 },
{ type: 'ai_languageModel', displayName: 'Model 2', required: true, maxConnections: 1 },
]);
});
});
});