feat: Optimize langchain calls in batching mode (#15011)

Co-authored-by: कारतोफ्फेलस्क्रिप्ट™ <aditya@netroy.in>
This commit is contained in:
Benjamin Schroth
2025-05-02 17:09:31 +02:00
committed by GitHub
parent a4290dcb78
commit f3e29d25ed
12 changed files with 632 additions and 205 deletions

View File

@@ -8,7 +8,7 @@ import {
import type { BaseRetriever } from '@langchain/core/retrievers';
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
import { createRetrievalChain } from 'langchain/chains/retrieval';
import { NodeConnectionTypes, NodeOperationError, parseErrorMetadata } from 'n8n-workflow';
import { NodeConnectionTypes, NodeOperationError, parseErrorMetadata, sleep } from 'n8n-workflow';
import {
type INodeProperties,
type IExecuteFunctions,
@@ -177,6 +177,31 @@ export class ChainRetrievalQa implements INodeType {
},
},
},
{
displayName: 'Batch Processing',
name: 'batching',
type: 'collection',
description: 'Batch processing options for rate limiting',
default: {},
options: [
{
displayName: 'Batch Size',
name: 'batchSize',
default: 100,
type: 'number',
description:
'How many items to process in parallel. This is useful for rate limiting.',
},
{
displayName: 'Delay Between Batches',
name: 'delayBetweenBatches',
default: 0,
type: 'number',
description:
'Delay in milliseconds between batches. This is useful for rate limiting.',
},
],
},
],
},
],
@@ -186,11 +211,20 @@ export class ChainRetrievalQa implements INodeType {
this.logger.debug('Executing Retrieval QA Chain');
const items = this.getInputData();
const { batchSize, delayBetweenBatches } = this.getNodeParameter('options.batching', 0, {
batchSize: 100,
delayBetweenBatches: 0,
}) as {
batchSize: number;
delayBetweenBatches: number;
};
const returnData: INodeExecutionData[] = [];
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
const batchPromises = batch.map(async (_item, batchIndex) => {
const itemIndex = i + batchIndex;
const model = (await this.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
0,
@@ -266,32 +300,47 @@ export class ChainRetrievalQa implements INodeType {
// Execute the chain with tracing config
const tracingConfig = getTracingConfig(this);
const response = await retrievalChain
const result = await retrievalChain
.withConfig(tracingConfig)
.invoke({ input: query }, { signal: this.getExecutionCancelSignal() });
// Get the answer from the response
const answer: string = response.answer;
return result;
});
const batchResults = await Promise.allSettled(batchPromises);
batchResults.forEach((response, index) => {
if (response.status === 'rejected') {
const error = response.reason;
if (this.continueOnFail()) {
const metadata = parseErrorMetadata(error);
returnData.push({
json: { error: error.message },
pairedItem: { item: index },
metadata,
});
return;
} else {
throw error;
}
}
const output = response.value;
const answer: string = output.answer;
if (this.getNode().typeVersion >= 1.5) {
returnData.push({ json: { response: answer } });
} else {
// Legacy format for versions 1.4 and below is { text: string }
returnData.push({ json: { response: { text: answer } } });
}
} catch (error) {
if (this.continueOnFail()) {
const metadata = parseErrorMetadata(error);
returnData.push({
json: { error: error.message },
pairedItem: { item: itemIndex },
metadata,
});
continue;
}
});
throw error;
// Add delay between batches if not the last batch
if (i + batchSize < items.length && delayBetweenBatches > 0) {
await sleep(delayBetweenBatches);
}
}
return [returnData];
}
}

View File

@@ -8,6 +8,11 @@ import { NodeConnectionTypes, NodeOperationError, UnexpectedError } from 'n8n-wo
import { ChainRetrievalQa } from '../ChainRetrievalQa.node';
jest.mock('n8n-workflow', () => ({
...jest.requireActual('n8n-workflow'),
sleep: jest.fn(),
}));
const createExecuteFunctionsMock = (
parameters: IDataObject,
fakeLlm: BaseLanguageModel,
@@ -80,7 +85,12 @@ describe('ChainRetrievalQa', () => {
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {},
options: {
batching: {
batchSize: 1,
delayBetweenBatches: 100,
},
},
};
const result = await node.execute.call(
@@ -114,7 +124,12 @@ describe('ChainRetrievalQa', () => {
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {},
options: {
batching: {
batchSize: 1,
delayBetweenBatches: 100,
},
},
};
const result = await node.execute.call(
@@ -158,6 +173,10 @@ describe('ChainRetrievalQa', () => {
text: 'What is the capital of France?',
options: {
systemPromptTemplate: customSystemPrompt,
batching: {
batchSize: 1,
delayBetweenBatches: 100,
},
},
};
@@ -185,7 +204,12 @@ describe('ChainRetrievalQa', () => {
const params = {
promptType: 'define',
text: undefined, // undefined query
options: {},
options: {
batching: {
batchSize: 1,
delayBetweenBatches: 100,
},
},
};
await expect(
@@ -211,7 +235,12 @@ describe('ChainRetrievalQa', () => {
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {},
options: {
batching: {
batchSize: 1,
delayBetweenBatches: 100,
},
},
};
// Override continueOnFail to return true