feat: Optimize langchain calls in batching mode (#15011)

Co-authored-by: कारतोफ्फेलस्क्रिप्ट™ <aditya@netroy.in>
This commit is contained in:
Benjamin Schroth
2025-05-02 17:09:31 +02:00
committed by GitHub
parent a4290dcb78
commit f3e29d25ed
12 changed files with 632 additions and 205 deletions

View File

@@ -5,7 +5,7 @@ import type {
INodeType,
INodeTypeDescription,
} from 'n8n-workflow';
import { NodeApiError, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
import { NodeApiError, NodeConnectionTypes, NodeOperationError, sleep } from 'n8n-workflow';
import { getPromptInputByType } from '@utils/helpers';
import { getOptionalOutputParser } from '@utils/output_parsers/N8nOutputParser';
@@ -67,19 +67,28 @@ export class ChainLlm implements INodeType {
this.logger.debug('Executing Basic LLM Chain');
const items = this.getInputData();
const returnData: INodeExecutionData[] = [];
const { batchSize, delayBetweenBatches } = this.getNodeParameter('batching', 0, {
batchSize: 100,
delayBetweenBatches: 0,
}) as {
batchSize: number;
delayBetweenBatches: number;
};
// Get output parser if configured
const outputParser = await getOptionalOutputParser(this);
// Process items in batches
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
const batchPromises = batch.map(async (_item, batchItemIndex) => {
const itemIndex = i + batchItemIndex;
// Process each input item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
// Get the language model
const llm = (await this.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
0,
)) as BaseLanguageModel;
// Get output parser if configured
const outputParser = await getOptionalOutputParser(this);
// Get user prompt based on node version
let prompt: string;
@@ -106,44 +115,53 @@ export class ChainLlm implements INodeType {
[],
) as MessageTemplate[];
// Execute the chain
const responses = await executeChain({
return (await executeChain({
context: this,
itemIndex,
query: prompt,
llm,
outputParser,
messages,
});
})) as object[];
});
// If the node version is 1.6(and LLM is using `response_format: json_object`) or higher or an output parser is configured,
// we unwrap the response and return the object directly as JSON
const shouldUnwrapObjects = this.getNode().typeVersion >= 1.6 || !!outputParser;
// Process each response and add to return data
responses.forEach((response) => {
returnData.push({
json: formatResponse(response, shouldUnwrapObjects),
});
});
} catch (error) {
// Handle OpenAI specific rate limit errors
if (error instanceof NodeApiError && isOpenAiError(error.cause)) {
const openAiErrorCode: string | undefined = (error.cause as any).error?.code;
if (openAiErrorCode) {
const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode);
if (customMessage) {
error.message = customMessage;
const batchResults = await Promise.allSettled(batchPromises);
batchResults.forEach((promiseResult, batchItemIndex) => {
const itemIndex = i + batchItemIndex;
if (promiseResult.status === 'rejected') {
const error = promiseResult.reason as Error;
// Handle OpenAI specific rate limit errors
if (error instanceof NodeApiError && isOpenAiError(error.cause)) {
const openAiErrorCode: string | undefined = (error.cause as any).error?.code;
if (openAiErrorCode) {
const customMessage = getCustomOpenAiErrorMessage(openAiErrorCode);
if (customMessage) {
error.message = customMessage;
}
}
}
if (this.continueOnFail()) {
returnData.push({
json: { error: error.message },
pairedItem: { item: itemIndex },
});
return;
}
throw new NodeOperationError(this.getNode(), error);
}
// Continue on failure if configured
if (this.continueOnFail()) {
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
continue;
}
const responses = promiseResult.value;
responses.forEach((response: object) => {
returnData.push({
json: formatResponse(response, this.getNode().typeVersion >= 1.6 || !!outputParser),
});
});
});
throw error;
if (i + batchSize < items.length && delayBetweenBatches > 0) {
await sleep(delayBetweenBatches);
}
}

View File

@@ -270,4 +270,29 @@ export const nodeProperties: INodeProperties[] = [
},
},
},
{
displayName: 'Batch Processing',
name: 'batching',
type: 'collection',
placeholder: 'Add Batch Processing Option',
description: 'Batch processing options for rate limiting',
default: {},
options: [
{
displayName: 'Batch Size',
name: 'batchSize',
default: 100,
type: 'number',
description:
'How many items to process in parallel. This is useful for rate limiting, but will impact the agents log output.',
},
{
displayName: 'Delay Between Batches',
name: 'delayBetweenBatches',
default: 1000,
type: 'number',
description: 'Delay in milliseconds between batches. This is useful for rate limiting.',
},
],
},
];

View File

@@ -3,7 +3,7 @@
import { FakeChatModel } from '@langchain/core/utils/testing';
import { mock } from 'jest-mock-extended';
import type { IExecuteFunctions, INode } from 'n8n-workflow';
import { NodeConnectionTypes } from 'n8n-workflow';
import { NodeConnectionTypes, UnexpectedError } from 'n8n-workflow';
import * as helperModule from '@utils/helpers';
import * as outputParserModule from '@utils/output_parsers/N8nOutputParser';
@@ -12,6 +12,11 @@ import { ChainLlm } from '../ChainLlm.node';
import * as executeChainModule from '../methods/chainExecutor';
import * as responseFormatterModule from '../methods/responseFormatter';
jest.mock('n8n-workflow', () => ({
...jest.requireActual('n8n-workflow'),
sleep: jest.fn(),
}));
jest.mock('@utils/helpers', () => ({
getPromptInputByType: jest.fn(),
}));
@@ -25,12 +30,7 @@ jest.mock('../methods/chainExecutor', () => ({
}));
jest.mock('../methods/responseFormatter', () => ({
formatResponse: jest.fn().mockImplementation((response) => {
if (typeof response === 'string') {
return { text: response.trim() };
}
return response;
}),
formatResponse: jest.fn(),
}));
describe('ChainLlm Node', () => {
@@ -38,6 +38,8 @@ describe('ChainLlm Node', () => {
let mockExecuteFunction: jest.Mocked<IExecuteFunctions>;
beforeEach(() => {
jest.resetAllMocks();
node = new ChainLlm();
mockExecuteFunction = mock<IExecuteFunctions>();
@@ -63,7 +65,12 @@ describe('ChainLlm Node', () => {
const fakeLLM = new FakeChatModel({});
mockExecuteFunction.getInputConnectionData.mockResolvedValue(fakeLLM);
jest.clearAllMocks();
(responseFormatterModule.formatResponse as jest.Mock).mockImplementation((response) => {
if (typeof response === 'string') {
return { text: response.trim() };
}
return response;
});
});
describe('description', () => {
@@ -164,15 +171,14 @@ describe('ChainLlm Node', () => {
});
it('should continue on failure when configured', async () => {
mockExecuteFunction.continueOnFail.mockReturnValue(true);
(helperModule.getPromptInputByType as jest.Mock).mockReturnValue('Test prompt');
const error = new Error('Test error');
(executeChainModule.executeChain as jest.Mock).mockRejectedValue(error);
mockExecuteFunction.continueOnFail.mockReturnValue(true);
(executeChainModule.executeChain as jest.Mock).mockRejectedValueOnce(
new UnexpectedError('Test error'),
);
const result = await node.execute.call(mockExecuteFunction);
expect(result).toEqual([[{ json: { error: 'Test error' }, pairedItem: { item: 0 } }]]);
});