mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-22 12:19:09 +00:00
feat: Add fallback mechanism for agent and basic chain llm (#16617)
This commit is contained in:
@@ -88,15 +88,24 @@ async function executeSimpleChain({
|
||||
llm,
|
||||
query,
|
||||
prompt,
|
||||
fallbackLlm,
|
||||
}: {
|
||||
context: IExecuteFunctions;
|
||||
llm: BaseLanguageModel;
|
||||
query: string;
|
||||
prompt: ChatPromptTemplate | PromptTemplate;
|
||||
fallbackLlm?: BaseLanguageModel | null;
|
||||
}) {
|
||||
const outputParser = getOutputParserForLLM(llm);
|
||||
let model;
|
||||
|
||||
const chain = prompt.pipe(llm).pipe(outputParser).withConfig(getTracingConfig(context));
|
||||
if (fallbackLlm) {
|
||||
model = llm.withFallbacks([fallbackLlm]);
|
||||
} else {
|
||||
model = llm;
|
||||
}
|
||||
|
||||
const chain = prompt.pipe(model).pipe(outputParser).withConfig(getTracingConfig(context));
|
||||
|
||||
// Execute the chain
|
||||
const response = await chain.invoke({
|
||||
@@ -118,6 +127,7 @@ export async function executeChain({
|
||||
llm,
|
||||
outputParser,
|
||||
messages,
|
||||
fallbackLlm,
|
||||
}: ChainExecutionParams): Promise<unknown[]> {
|
||||
// If no output parsers provided, use a simple chain with basic prompt template
|
||||
if (!outputParser) {
|
||||
@@ -134,6 +144,7 @@ export async function executeChain({
|
||||
llm,
|
||||
query,
|
||||
prompt: promptTemplate,
|
||||
fallbackLlm,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,17 @@ export function getInputs(parameters: IDataObject) {
|
||||
},
|
||||
];
|
||||
|
||||
const needsFallback = parameters?.needsFallback;
|
||||
|
||||
if (needsFallback === undefined || needsFallback === true) {
|
||||
inputs.push({
|
||||
displayName: 'Fallback Model',
|
||||
maxConnections: 1,
|
||||
type: 'ai_languageModel',
|
||||
required: true,
|
||||
});
|
||||
}
|
||||
|
||||
// If `hasOutputParser` is undefined it must be version 1.3 or earlier so we
|
||||
// always add the output parser input
|
||||
const hasOutputParser = parameters?.hasOutputParser;
|
||||
@@ -119,6 +130,18 @@ export const nodeProperties: INodeProperties[] = [
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Enable Fallback Model',
|
||||
name: 'needsFallback',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
noDataExpression: true,
|
||||
displayOptions: {
|
||||
hide: {
|
||||
'@version': [1, 1.1, 1.3],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Chat Messages (if Using a Chat Model)',
|
||||
name: 'messages',
|
||||
@@ -275,4 +298,16 @@ export const nodeProperties: INodeProperties[] = [
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName:
|
||||
'Connect an additional language model on the canvas to use it as a fallback if the main model fails',
|
||||
name: 'fallbackNotice',
|
||||
type: 'notice',
|
||||
default: '',
|
||||
displayOptions: {
|
||||
show: {
|
||||
needsFallback: [true],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
||||
import assert from 'node:assert';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
@@ -7,11 +8,40 @@ import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
import { executeChain } from './chainExecutor';
|
||||
import { type MessageTemplate } from './types';
|
||||
|
||||
async function getChatModel(
|
||||
ctx: IExecuteFunctions,
|
||||
index: number = 0,
|
||||
): Promise<BaseLanguageModel | undefined> {
|
||||
const connectedModels = await ctx.getInputConnectionData(NodeConnectionTypes.AiLanguageModel, 0);
|
||||
|
||||
let model;
|
||||
|
||||
if (Array.isArray(connectedModels) && index !== undefined) {
|
||||
if (connectedModels.length <= index) {
|
||||
return undefined;
|
||||
}
|
||||
// We get the models in reversed order from the workflow so we need to reverse them again to match the right index
|
||||
const reversedModels = [...connectedModels].reverse();
|
||||
model = reversedModels[index] as BaseLanguageModel;
|
||||
} else {
|
||||
model = connectedModels as BaseLanguageModel;
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => {
|
||||
const llm = (await ctx.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
0,
|
||||
)) as BaseLanguageModel;
|
||||
const needsFallback = ctx.getNodeParameter('needsFallback', 0, false) as boolean;
|
||||
const llm = await getChatModel(ctx, 0);
|
||||
assert(llm, 'Please connect a model to the Chat Model input');
|
||||
|
||||
const fallbackLlm = needsFallback ? await getChatModel(ctx, 1) : null;
|
||||
if (needsFallback && !fallbackLlm) {
|
||||
throw new NodeOperationError(
|
||||
ctx.getNode(),
|
||||
'Please connect a model to the Fallback Model input or disable the fallback option',
|
||||
);
|
||||
}
|
||||
|
||||
// Get output parser if configured
|
||||
const outputParser = await getOptionalOutputParser(ctx, itemIndex);
|
||||
@@ -50,5 +80,6 @@ export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) =>
|
||||
llm,
|
||||
outputParser,
|
||||
messages,
|
||||
fallbackLlm,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -38,4 +38,5 @@ export interface ChainExecutionParams {
|
||||
llm: BaseLanguageModel;
|
||||
outputParser?: N8nOutputParser;
|
||||
messages?: MessageTemplate[];
|
||||
fallbackLlm?: BaseLanguageModel | null;
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ jest.mock('../methods/responseFormatter', () => ({
|
||||
describe('ChainLlm Node', () => {
|
||||
let node: ChainLlm;
|
||||
let mockExecuteFunction: jest.Mocked<IExecuteFunctions>;
|
||||
let needsFallback: boolean;
|
||||
|
||||
beforeEach(() => {
|
||||
node = new ChainLlm();
|
||||
@@ -48,6 +49,8 @@ describe('ChainLlm Node', () => {
|
||||
error: jest.fn(),
|
||||
};
|
||||
|
||||
needsFallback = false;
|
||||
|
||||
mockExecuteFunction.getInputData.mockReturnValue([{ json: {} }]);
|
||||
mockExecuteFunction.getNode.mockReturnValue({
|
||||
name: 'Chain LLM',
|
||||
@@ -57,6 +60,7 @@ describe('ChainLlm Node', () => {
|
||||
|
||||
mockExecuteFunction.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'messages.messageValues') return [];
|
||||
if (param === 'needsFallback') return needsFallback;
|
||||
return defaultValue;
|
||||
});
|
||||
|
||||
@@ -96,6 +100,7 @@ describe('ChainLlm Node', () => {
|
||||
context: mockExecuteFunction,
|
||||
itemIndex: 0,
|
||||
query: 'Test prompt',
|
||||
fallbackLlm: null,
|
||||
llm: expect.any(FakeChatModel),
|
||||
outputParser: undefined,
|
||||
messages: [],
|
||||
@@ -151,6 +156,7 @@ describe('ChainLlm Node', () => {
|
||||
context: mockExecuteFunction,
|
||||
itemIndex: 0,
|
||||
query: 'Old version prompt',
|
||||
fallbackLlm: null,
|
||||
llm: expect.any(Object),
|
||||
outputParser: undefined,
|
||||
messages: expect.any(Array),
|
||||
@@ -505,6 +511,35 @@ describe('ChainLlm Node', () => {
|
||||
expect(responseFormatterModule.formatResponse).toHaveBeenCalledWith(markdownResponse, true);
|
||||
});
|
||||
|
||||
it('should use fallback llm if enabled', async () => {
|
||||
needsFallback = true;
|
||||
(helperModule.getPromptInputByType as jest.Mock).mockReturnValue('Test prompt');
|
||||
|
||||
(outputParserModule.getOptionalOutputParser as jest.Mock).mockResolvedValue(undefined);
|
||||
|
||||
(executeChainModule.executeChain as jest.Mock).mockResolvedValue(['Test response']);
|
||||
|
||||
const fakeLLM = new FakeChatModel({});
|
||||
const fakeFallbackLLM = new FakeChatModel({});
|
||||
mockExecuteFunction.getInputConnectionData.mockResolvedValue([fakeLLM, fakeFallbackLLM]);
|
||||
|
||||
const result = await node.execute.call(mockExecuteFunction);
|
||||
|
||||
expect(executeChainModule.executeChain).toHaveBeenCalledWith({
|
||||
context: mockExecuteFunction,
|
||||
itemIndex: 0,
|
||||
query: 'Test prompt',
|
||||
fallbackLlm: expect.any(FakeChatModel),
|
||||
llm: expect.any(FakeChatModel),
|
||||
outputParser: undefined,
|
||||
messages: [],
|
||||
});
|
||||
|
||||
expect(mockExecuteFunction.logger.debug).toHaveBeenCalledWith('Executing Basic LLM Chain');
|
||||
|
||||
expect(result).toEqual([[{ json: expect.any(Object) }]]);
|
||||
});
|
||||
|
||||
it('should pass correct itemIndex to getOptionalOutputParser', async () => {
|
||||
// Clear any previous calls to the mock
|
||||
(outputParserModule.getOptionalOutputParser as jest.Mock).mockClear();
|
||||
@@ -568,6 +603,7 @@ describe('ChainLlm Node', () => {
|
||||
itemIndex: 0,
|
||||
query: 'Test prompt 1',
|
||||
llm: expect.any(Object),
|
||||
fallbackLlm: null,
|
||||
outputParser: mockParser1,
|
||||
messages: [],
|
||||
});
|
||||
@@ -576,6 +612,7 @@ describe('ChainLlm Node', () => {
|
||||
itemIndex: 1,
|
||||
query: 'Test prompt 2',
|
||||
llm: expect.any(Object),
|
||||
fallbackLlm: null,
|
||||
outputParser: mockParser2,
|
||||
messages: [],
|
||||
});
|
||||
@@ -584,6 +621,7 @@ describe('ChainLlm Node', () => {
|
||||
itemIndex: 2,
|
||||
query: 'Test prompt 3',
|
||||
llm: expect.any(Object),
|
||||
fallbackLlm: null,
|
||||
outputParser: mockParser3,
|
||||
messages: [],
|
||||
});
|
||||
|
||||
@@ -7,26 +7,44 @@ describe('config', () => {
|
||||
it('should return basic inputs for all parameters', () => {
|
||||
const inputs = getInputs({});
|
||||
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs).toHaveLength(4);
|
||||
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
|
||||
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
});
|
||||
|
||||
it('should exclude the OutputParser when hasOutputParser is false', () => {
|
||||
const inputs = getInputs({ hasOutputParser: false });
|
||||
|
||||
expect(inputs).toHaveLength(2);
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
|
||||
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
});
|
||||
|
||||
it('should include the OutputParser when hasOutputParser is true', () => {
|
||||
const inputs = getInputs({ hasOutputParser: true });
|
||||
|
||||
expect(inputs).toHaveLength(4);
|
||||
expect(inputs[3].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
});
|
||||
|
||||
it('should exclude the FallbackInput when needsFallback is false', () => {
|
||||
const inputs = getInputs({ hasOutputParser: true, needsFallback: false });
|
||||
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs[0].type).toBe(NodeConnectionTypes.Main);
|
||||
expect(inputs[1].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiOutputParser);
|
||||
});
|
||||
|
||||
it('should include the FallbackInput when needsFallback is true', () => {
|
||||
const inputs = getInputs({ hasOutputParser: false, needsFallback: true });
|
||||
|
||||
expect(inputs).toHaveLength(3);
|
||||
expect(inputs[2].type).toBe(NodeConnectionTypes.AiLanguageModel);
|
||||
});
|
||||
});
|
||||
|
||||
describe('nodeProperties', () => {
|
||||
|
||||
Reference in New Issue
Block a user