feat: Optimise langchain calls in batching mode (#15243)

This commit is contained in:
Benjamin Schroth
2025-05-13 13:58:38 +02:00
committed by GitHub
parent 8591c2e0d1
commit ff156930c5
35 changed files with 2946 additions and 1171 deletions

View File

@@ -1,8 +1,3 @@
import type { Document } from '@langchain/core/documents';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { TextSplitter } from '@langchain/textsplitters';
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters';
import { loadSummarizationChain } from 'langchain/chains';
import type {
INodeTypeBaseDescription,
IExecuteFunctions,
@@ -12,14 +7,11 @@ import type {
IDataObject,
INodeInputConfiguration,
} from 'n8n-workflow';
import { NodeConnectionTypes } from 'n8n-workflow';
import { NodeConnectionTypes, sleep } from 'n8n-workflow';
import { N8nBinaryLoader } from '@utils/N8nBinaryLoader';
import { N8nJsonLoader } from '@utils/N8nJsonLoader';
import { getTemplateNoticeField } from '@utils/sharedFields';
import { getTracingConfig } from '@utils/tracing';
import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields';
import { getChainPromptsArgs } from '../helpers';
import { processItem } from './processItem';
import { REFINE_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE } from '../prompt';
function getInputs(parameters: IDataObject) {
@@ -63,7 +55,7 @@ export class ChainSummarizationV2 implements INodeType {
constructor(baseDescription: INodeTypeBaseDescription) {
this.description = {
...baseDescription,
version: [2],
version: [2, 2.1],
defaults: {
name: 'Summarization Chain',
color: '#909298',
@@ -306,6 +298,11 @@ export class ChainSummarizationV2 implements INodeType {
},
],
},
getBatchingOptionFields({
show: {
'@version': [{ _cnd: { gte: 2.1 } }],
},
}),
],
},
],
@@ -325,108 +322,64 @@ export class ChainSummarizationV2 implements INodeType {
const items = this.getInputData();
const returnData: INodeExecutionData[] = [];
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const model = (await this.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
0,
)) as BaseLanguageModel;
const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number;
const delayBetweenBatches = this.getNodeParameter(
'options.batching.delayBetweenBatches',
0,
0,
) as number;
const summarizationMethodAndPrompts = this.getNodeParameter(
'options.summarizationMethodAndPrompts.values',
itemIndex,
{},
) as {
prompt?: string;
refineQuestionPrompt?: string;
refinePrompt?: string;
summarizationMethod: 'map_reduce' | 'stuff' | 'refine';
combineMapPrompt?: string;
};
if (this.getNode().typeVersion >= 2.1 && batchSize > 1) {
// Batch processing
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
const batchPromises = batch.map(async (item, batchItemIndex) => {
const itemIndex = i + batchItemIndex;
return await processItem(this, itemIndex, item, operationMode, chunkingMode);
});
const chainArgs = getChainPromptsArgs(
summarizationMethodAndPrompts.summarizationMethod ?? 'map_reduce',
summarizationMethodAndPrompts,
);
const chain = loadSummarizationChain(model, chainArgs);
const item = items[itemIndex];
let processedDocuments: Document[];
// Use dedicated document loader input to load documents
if (operationMode === 'documentLoader') {
const documentInput = (await this.getInputConnectionData(
NodeConnectionTypes.AiDocument,
0,
)) as N8nJsonLoader | Array<Document<Record<string, unknown>>>;
const isN8nLoader =
documentInput instanceof N8nJsonLoader || documentInput instanceof N8nBinaryLoader;
processedDocuments = isN8nLoader
? await documentInput.processItem(item, itemIndex)
: documentInput;
const response = await chain.withConfig(getTracingConfig(this)).invoke({
input_documents: processedDocuments,
});
returnData.push({ json: { response } });
}
// Take the input and use binary or json loader
if (['nodeInputJson', 'nodeInputBinary'].includes(operationMode)) {
let textSplitter: TextSplitter | undefined;
switch (chunkingMode) {
// In simple mode we use recursive character splitter with default settings
case 'simple':
const chunkSize = this.getNodeParameter('chunkSize', itemIndex, 1000) as number;
const chunkOverlap = this.getNodeParameter('chunkOverlap', itemIndex, 200) as number;
textSplitter = new RecursiveCharacterTextSplitter({ chunkOverlap, chunkSize });
break;
// In advanced mode user can connect text splitter node so we just retrieve it
case 'advanced':
textSplitter = (await this.getInputConnectionData(
NodeConnectionTypes.AiTextSplitter,
0,
)) as TextSplitter | undefined;
break;
default:
break;
}
let processor: N8nJsonLoader | N8nBinaryLoader;
if (operationMode === 'nodeInputBinary') {
const binaryDataKey = this.getNodeParameter(
'options.binaryDataKey',
itemIndex,
'data',
) as string;
processor = new N8nBinaryLoader(this, 'options.', binaryDataKey, textSplitter);
const batchResults = await Promise.allSettled(batchPromises);
batchResults.forEach((response, index) => {
if (response.status === 'rejected') {
const error = response.reason as Error;
if (this.continueOnFail()) {
returnData.push({
json: { error: error.message },
pairedItem: { item: i + index },
});
} else {
throw error;
}
} else {
processor = new N8nJsonLoader(this, 'options.', textSplitter);
const output = response.value;
returnData.push({ json: { output } });
}
});
const processedItem = await processor.processItem(item, itemIndex);
const response = await chain.invoke(
{
input_documents: processedItem,
},
{ signal: this.getExecutionCancelSignal() },
// Add delay between batches if not the last batch
if (i + batchSize < items.length && delayBetweenBatches > 0) {
await sleep(delayBetweenBatches);
}
}
} else {
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const response = await processItem(
this,
itemIndex,
items[itemIndex],
operationMode,
chunkingMode,
);
returnData.push({ json: { response } });
}
} catch (error) {
if (this.continueOnFail()) {
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
continue;
}
} catch (error) {
if (this.continueOnFail()) {
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
continue;
}
throw error;
throw error;
}
}
}