mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-20 19:32:15 +00:00
feat: Optimise langchain calls in batching mode (#15243)
This commit is contained in:
@@ -1,23 +1,16 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import type {
|
||||
IExecuteFunctions,
|
||||
INodeExecutionData,
|
||||
INodeType,
|
||||
INodeTypeDescription,
|
||||
} from 'n8n-workflow';
|
||||
import { NodeApiError, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
||||
import { NodeApiError, NodeConnectionTypes, NodeOperationError, sleep } from 'n8n-workflow';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
|
||||
// Import from centralized module
|
||||
import {
|
||||
executeChain,
|
||||
formatResponse,
|
||||
getInputs,
|
||||
nodeProperties,
|
||||
type MessageTemplate,
|
||||
} from './methods';
|
||||
import { formatResponse, getInputs, nodeProperties } from './methods';
|
||||
import { processItem } from './methods/processItem';
|
||||
import {
|
||||
getCustomErrorMessage as getCustomOpenAiErrorMessage,
|
||||
isOpenAiError,
|
||||
@@ -34,7 +27,7 @@ export class ChainLlm implements INodeType {
|
||||
icon: 'fa:link',
|
||||
iconColor: 'black',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],
|
||||
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7],
|
||||
description: 'A simple chain to prompt a large language model',
|
||||
defaults: {
|
||||
name: 'Basic LLM Chain',
|
||||
@@ -67,83 +60,97 @@ export class ChainLlm implements INodeType {
|
||||
this.logger.debug('Executing Basic LLM Chain');
|
||||
const items = this.getInputData();
|
||||
const returnData: INodeExecutionData[] = [];
|
||||
const outputParser = await getOptionalOutputParser(this);
|
||||
// If the node version is 1.6(and LLM is using `response_format: json_object`) or higher or an output parser is configured,
|
||||
// we unwrap the response and return the object directly as JSON
|
||||
const shouldUnwrapObjects = this.getNode().typeVersion >= 1.6 || !!outputParser;
|
||||
|
||||
// Process each input item
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
// Get the language model
|
||||
const llm = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
0,
|
||||
)) as BaseLanguageModel;
|
||||
const batchSize = this.getNodeParameter('batching.batchSize', 0, 5) as number;
|
||||
const delayBetweenBatches = this.getNodeParameter(
|
||||
'batching.delayBetweenBatches',
|
||||
0,
|
||||
0,
|
||||
) as number;
|
||||
|
||||
// Get output parser if configured
|
||||
const outputParser = await getOptionalOutputParser(this);
|
||||
|
||||
// Get user prompt based on node version
|
||||
let prompt: string;
|
||||
|
||||
if (this.getNode().typeVersion <= 1.3) {
|
||||
prompt = this.getNodeParameter('prompt', itemIndex) as string;
|
||||
} else {
|
||||
prompt = getPromptInputByType({
|
||||
ctx: this,
|
||||
i: itemIndex,
|
||||
inputKey: 'text',
|
||||
promptTypeKey: 'promptType',
|
||||
});
|
||||
}
|
||||
|
||||
// Validate prompt
|
||||
if (prompt === undefined) {
|
||||
throw new NodeOperationError(this.getNode(), "The 'prompt' parameter is empty.");
|
||||
}
|
||||
|
||||
// Get chat messages if configured
|
||||
const messages = this.getNodeParameter(
|
||||
'messages.messageValues',
|
||||
itemIndex,
|
||||
[],
|
||||
) as MessageTemplate[];
|
||||
|
||||
// Execute the chain
|
||||
const responses = await executeChain({
|
||||
context: this,
|
||||
itemIndex,
|
||||
query: prompt,
|
||||
llm,
|
||||
outputParser,
|
||||
messages,
|
||||
if (this.getNode().typeVersion >= 1.7 && batchSize > 1) {
|
||||
// Process items in batches
|
||||
for (let i = 0; i < items.length; i += batchSize) {
|
||||
const batch = items.slice(i, i + batchSize);
|
||||
const batchPromises = batch.map(async (_item, batchItemIndex) => {
|
||||
return await processItem(this, i + batchItemIndex);
|
||||
});
|
||||
|
||||
// If the node version is 1.6(and LLM is using `response_format: json_object`) or higher or an output parser is configured,
|
||||
// we unwrap the response and return the object directly as JSON
|
||||
const shouldUnwrapObjects = this.getNode().typeVersion >= 1.6 || !!outputParser;
|
||||
// Process each response and add to return data
|
||||
responses.forEach((response) => {
|
||||
returnData.push({
|
||||
json: formatResponse(response, shouldUnwrapObjects),
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
|
||||
batchResults.forEach((promiseResult, batchItemIndex) => {
|
||||
const itemIndex = i + batchItemIndex;
|
||||
if (promiseResult.status === 'rejected') {
|
||||
const error = promiseResult.reason as Error;
|
||||
// Handle OpenAI specific rate limit errors
|
||||
if (error instanceof NodeApiError && isOpenAiError(error.cause)) {
|
||||
const openAiErrorCode: string | undefined = (error.cause as any).error?.code;
|
||||
if (openAiErrorCode) {
|
||||
const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode);
|
||||
if (customMessage) {
|
||||
error.message = customMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: itemIndex },
|
||||
});
|
||||
return;
|
||||
}
|
||||
throw new NodeOperationError(this.getNode(), error);
|
||||
}
|
||||
|
||||
const responses = promiseResult.value;
|
||||
responses.forEach((response: unknown) => {
|
||||
returnData.push({
|
||||
json: formatResponse(response, shouldUnwrapObjects),
|
||||
});
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
// Handle OpenAI specific rate limit errors
|
||||
if (error instanceof NodeApiError && isOpenAiError(error.cause)) {
|
||||
const openAiErrorCode: string | undefined = (error.cause as any).error?.code;
|
||||
if (openAiErrorCode) {
|
||||
const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode);
|
||||
if (customMessage) {
|
||||
error.message = customMessage;
|
||||
|
||||
if (i + batchSize < items.length && delayBetweenBatches > 0) {
|
||||
await sleep(delayBetweenBatches);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Process each input item
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const responses = await processItem(this, itemIndex);
|
||||
|
||||
// Process each response and add to return data
|
||||
responses.forEach((response) => {
|
||||
returnData.push({
|
||||
json: formatResponse(response, shouldUnwrapObjects),
|
||||
});
|
||||
});
|
||||
} catch (error) {
|
||||
// Handle OpenAI specific rate limit errors
|
||||
if (error instanceof NodeApiError && isOpenAiError(error.cause)) {
|
||||
const openAiErrorCode: string | undefined = (error.cause as any).error?.code;
|
||||
if (openAiErrorCode) {
|
||||
const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode);
|
||||
if (customMessage) {
|
||||
error.message = customMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Continue on failure if configured
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
|
||||
continue;
|
||||
}
|
||||
// Continue on failure if configured
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
|
||||
continue;
|
||||
}
|
||||
|
||||
throw error;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import type { IDataObject, INodeInputConfiguration, INodeProperties } from 'n8n-
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions';
|
||||
import { getTemplateNoticeField } from '@utils/sharedFields';
|
||||
import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields';
|
||||
|
||||
/**
|
||||
* Dynamic input configuration generation based on node parameters
|
||||
@@ -259,6 +259,11 @@ export const nodeProperties: INodeProperties[] = [
|
||||
},
|
||||
],
|
||||
},
|
||||
getBatchingOptionFields({
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.7 } }],
|
||||
},
|
||||
}),
|
||||
{
|
||||
displayName: `Connect an <a data-action='openSelectiveNodeCreator' data-action-parameter-connectiontype='${NodeConnectionTypes.AiOutputParser}'>output parser</a> on the canvas to specify the output format you require`,
|
||||
name: 'notice',
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
|
||||
|
||||
import { getPromptInputByType } from '@utils/helpers';
|
||||
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
|
||||
|
||||
import { executeChain } from './chainExecutor';
|
||||
import { type MessageTemplate } from './types';
|
||||
|
||||
export const processItem = async (ctx: IExecuteFunctions, itemIndex: number) => {
|
||||
const llm = (await ctx.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
0,
|
||||
)) as BaseLanguageModel;
|
||||
|
||||
// Get output parser if configured
|
||||
const outputParser = await getOptionalOutputParser(ctx);
|
||||
|
||||
// Get user prompt based on node version
|
||||
let prompt: string;
|
||||
|
||||
if (ctx.getNode().typeVersion <= 1.3) {
|
||||
prompt = ctx.getNodeParameter('prompt', itemIndex) as string;
|
||||
} else {
|
||||
prompt = getPromptInputByType({
|
||||
ctx,
|
||||
i: itemIndex,
|
||||
inputKey: 'text',
|
||||
promptTypeKey: 'promptType',
|
||||
});
|
||||
}
|
||||
|
||||
// Validate prompt
|
||||
if (prompt === undefined) {
|
||||
throw new NodeOperationError(ctx.getNode(), "The 'prompt' parameter is empty.");
|
||||
}
|
||||
|
||||
// Get chat messages if configured
|
||||
const messages = ctx.getNodeParameter(
|
||||
'messages.messageValues',
|
||||
itemIndex,
|
||||
[],
|
||||
) as MessageTemplate[];
|
||||
|
||||
// Execute the chain
|
||||
return await executeChain({
|
||||
context: ctx,
|
||||
itemIndex,
|
||||
query: prompt,
|
||||
llm,
|
||||
outputParser,
|
||||
messages,
|
||||
});
|
||||
};
|
||||
@@ -3,7 +3,7 @@
|
||||
import { FakeChatModel } from '@langchain/core/utils/testing';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import type { IExecuteFunctions, INode } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
import { NodeApiError, NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import * as helperModule from '@utils/helpers';
|
||||
import * as outputParserModule from '@utils/output_parsers/N8nOutputParser';
|
||||
@@ -191,6 +191,148 @@ describe('ChainLlm Node', () => {
|
||||
expect(result[0]).toHaveLength(2);
|
||||
});
|
||||
|
||||
describe('batching (version 1.7+)', () => {
|
||||
beforeEach(() => {
|
||||
mockExecuteFunction.getNode.mockReturnValue({
|
||||
name: 'Chain LLM',
|
||||
typeVersion: 1.7,
|
||||
parameters: {},
|
||||
} as INode);
|
||||
});
|
||||
|
||||
it('should process items in batches with default settings', async () => {
|
||||
mockExecuteFunction.getInputData.mockReturnValue([
|
||||
{ json: { item: 1 } },
|
||||
{ json: { item: 2 } },
|
||||
{ json: { item: 3 } },
|
||||
]);
|
||||
|
||||
mockExecuteFunction.getNodeParameter.mockImplementation(
|
||||
(param, _itemIndex, defaultValue) => {
|
||||
if (param === 'messages.messageValues') return [];
|
||||
return defaultValue;
|
||||
},
|
||||
);
|
||||
|
||||
(helperModule.getPromptInputByType as jest.Mock)
|
||||
.mockReturnValueOnce('Test prompt 1')
|
||||
.mockReturnValueOnce('Test prompt 2')
|
||||
.mockReturnValueOnce('Test prompt 3');
|
||||
|
||||
(executeChainModule.executeChain as jest.Mock)
|
||||
.mockResolvedValueOnce(['Response 1'])
|
||||
.mockResolvedValueOnce(['Response 2'])
|
||||
.mockResolvedValueOnce(['Response 3']);
|
||||
|
||||
const result = await node.execute.call(mockExecuteFunction);
|
||||
|
||||
expect(executeChainModule.executeChain).toHaveBeenCalledTimes(3);
|
||||
expect(result[0]).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('should process items in smaller batches', async () => {
|
||||
mockExecuteFunction.getInputData.mockReturnValue([
|
||||
{ json: { item: 1 } },
|
||||
{ json: { item: 2 } },
|
||||
{ json: { item: 3 } },
|
||||
{ json: { item: 4 } },
|
||||
]);
|
||||
|
||||
mockExecuteFunction.getNodeParameter.mockImplementation(
|
||||
(param, _itemIndex, defaultValue) => {
|
||||
if (param === 'batching.batchSize') return 2;
|
||||
if (param === 'batching.delayBetweenBatches') return 0;
|
||||
if (param === 'messages.messageValues') return [];
|
||||
return defaultValue;
|
||||
},
|
||||
);
|
||||
|
||||
(helperModule.getPromptInputByType as jest.Mock)
|
||||
.mockReturnValueOnce('Test prompt 1')
|
||||
.mockReturnValueOnce('Test prompt 2')
|
||||
.mockReturnValueOnce('Test prompt 3')
|
||||
.mockReturnValueOnce('Test prompt 4');
|
||||
|
||||
(executeChainModule.executeChain as jest.Mock)
|
||||
.mockResolvedValueOnce(['Response 1'])
|
||||
.mockResolvedValueOnce(['Response 2'])
|
||||
.mockResolvedValueOnce(['Response 3'])
|
||||
.mockResolvedValueOnce(['Response 4']);
|
||||
|
||||
const result = await node.execute.call(mockExecuteFunction);
|
||||
|
||||
expect(executeChainModule.executeChain).toHaveBeenCalledTimes(4);
|
||||
expect(result[0]).toHaveLength(4);
|
||||
});
|
||||
|
||||
it('should handle errors in batches with continueOnFail', async () => {
|
||||
mockExecuteFunction.getInputData.mockReturnValue([
|
||||
{ json: { item: 1 } },
|
||||
{ json: { item: 2 } },
|
||||
]);
|
||||
|
||||
mockExecuteFunction.getNodeParameter.mockImplementation(
|
||||
(param, _itemIndex, defaultValue) => {
|
||||
if (param === 'batching.batchSize') return 2;
|
||||
if (param === 'batching.delayBetweenBatches') return 0;
|
||||
if (param === 'messages.messageValues') return [];
|
||||
return defaultValue;
|
||||
},
|
||||
);
|
||||
|
||||
mockExecuteFunction.continueOnFail.mockReturnValue(true);
|
||||
|
||||
(helperModule.getPromptInputByType as jest.Mock)
|
||||
.mockReturnValueOnce('Test prompt 1')
|
||||
.mockReturnValueOnce('Test prompt 2');
|
||||
|
||||
(executeChainModule.executeChain as jest.Mock)
|
||||
.mockResolvedValueOnce(['Response 1'])
|
||||
.mockRejectedValueOnce(new Error('Test error'));
|
||||
|
||||
const result = await node.execute.call(mockExecuteFunction);
|
||||
|
||||
expect(result[0]).toHaveLength(2);
|
||||
expect(result[0][1].json).toEqual({ error: 'Test error' });
|
||||
});
|
||||
|
||||
it('should handle OpenAI rate limit errors in batches', async () => {
|
||||
mockExecuteFunction.getInputData.mockReturnValue([
|
||||
{ json: { item: 1 } },
|
||||
{ json: { item: 2 } },
|
||||
]);
|
||||
|
||||
mockExecuteFunction.getNodeParameter.mockImplementation(
|
||||
(param, _itemIndex, defaultValue) => {
|
||||
if (param === 'batching.batchSize') return 2;
|
||||
if (param === 'batching.delayBetweenBatches') return 0;
|
||||
if (param === 'messages.messageValues') return [];
|
||||
return defaultValue;
|
||||
},
|
||||
);
|
||||
|
||||
mockExecuteFunction.continueOnFail.mockReturnValue(true);
|
||||
|
||||
(helperModule.getPromptInputByType as jest.Mock)
|
||||
.mockReturnValueOnce('Test prompt 1')
|
||||
.mockReturnValueOnce('Test prompt 2');
|
||||
|
||||
const openAiError = new NodeApiError(mockExecuteFunction.getNode(), {
|
||||
message: 'Rate limit exceeded',
|
||||
cause: { error: { code: 'rate_limit_exceeded' } },
|
||||
});
|
||||
|
||||
(executeChainModule.executeChain as jest.Mock)
|
||||
.mockResolvedValueOnce(['Response 1'])
|
||||
.mockRejectedValueOnce(openAiError);
|
||||
|
||||
const result = await node.execute.call(mockExecuteFunction);
|
||||
|
||||
expect(result[0]).toHaveLength(2);
|
||||
expect(result[0][1].json).toEqual({ error: expect.stringContaining('Rate limit') });
|
||||
});
|
||||
});
|
||||
|
||||
it('should unwrap object responses when node version is 1.6 or higher', async () => {
|
||||
mockExecuteFunction.getNode.mockReturnValue({
|
||||
name: 'Chain LLM',
|
||||
|
||||
Reference in New Issue
Block a user