mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 18:12:04 +00:00
feat: Optimise langchain calls in batching mode (#15243)
This commit is contained in:
@@ -1,8 +1,3 @@
|
||||
import type { Document } from '@langchain/core/documents';
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import type { TextSplitter } from '@langchain/textsplitters';
|
||||
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters';
|
||||
import { loadSummarizationChain } from 'langchain/chains';
|
||||
import type {
|
||||
INodeTypeBaseDescription,
|
||||
IExecuteFunctions,
|
||||
@@ -12,14 +7,11 @@ import type {
|
||||
IDataObject,
|
||||
INodeInputConfiguration,
|
||||
} from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes, sleep } from 'n8n-workflow';
|
||||
|
||||
import { N8nBinaryLoader } from '@utils/N8nBinaryLoader';
|
||||
import { N8nJsonLoader } from '@utils/N8nJsonLoader';
|
||||
import { getTemplateNoticeField } from '@utils/sharedFields';
|
||||
import { getTracingConfig } from '@utils/tracing';
|
||||
import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields';
|
||||
|
||||
import { getChainPromptsArgs } from '../helpers';
|
||||
import { processItem } from './processItem';
|
||||
import { REFINE_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE } from '../prompt';
|
||||
|
||||
function getInputs(parameters: IDataObject) {
|
||||
@@ -63,7 +55,7 @@ export class ChainSummarizationV2 implements INodeType {
|
||||
constructor(baseDescription: INodeTypeBaseDescription) {
|
||||
this.description = {
|
||||
...baseDescription,
|
||||
version: [2],
|
||||
version: [2, 2.1],
|
||||
defaults: {
|
||||
name: 'Summarization Chain',
|
||||
color: '#909298',
|
||||
@@ -306,6 +298,11 @@ export class ChainSummarizationV2 implements INodeType {
|
||||
},
|
||||
],
|
||||
},
|
||||
getBatchingOptionFields({
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 2.1 } }],
|
||||
},
|
||||
}),
|
||||
],
|
||||
},
|
||||
],
|
||||
@@ -325,108 +322,64 @@ export class ChainSummarizationV2 implements INodeType {
|
||||
const items = this.getInputData();
|
||||
const returnData: INodeExecutionData[] = [];
|
||||
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const model = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
0,
|
||||
)) as BaseLanguageModel;
|
||||
const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number;
|
||||
const delayBetweenBatches = this.getNodeParameter(
|
||||
'options.batching.delayBetweenBatches',
|
||||
0,
|
||||
0,
|
||||
) as number;
|
||||
|
||||
const summarizationMethodAndPrompts = this.getNodeParameter(
|
||||
'options.summarizationMethodAndPrompts.values',
|
||||
itemIndex,
|
||||
{},
|
||||
) as {
|
||||
prompt?: string;
|
||||
refineQuestionPrompt?: string;
|
||||
refinePrompt?: string;
|
||||
summarizationMethod: 'map_reduce' | 'stuff' | 'refine';
|
||||
combineMapPrompt?: string;
|
||||
};
|
||||
if (this.getNode().typeVersion >= 2.1 && batchSize > 1) {
|
||||
// Batch processing
|
||||
for (let i = 0; i < items.length; i += batchSize) {
|
||||
const batch = items.slice(i, i + batchSize);
|
||||
const batchPromises = batch.map(async (item, batchItemIndex) => {
|
||||
const itemIndex = i + batchItemIndex;
|
||||
return await processItem(this, itemIndex, item, operationMode, chunkingMode);
|
||||
});
|
||||
|
||||
const chainArgs = getChainPromptsArgs(
|
||||
summarizationMethodAndPrompts.summarizationMethod ?? 'map_reduce',
|
||||
summarizationMethodAndPrompts,
|
||||
);
|
||||
|
||||
const chain = loadSummarizationChain(model, chainArgs);
|
||||
const item = items[itemIndex];
|
||||
|
||||
let processedDocuments: Document[];
|
||||
|
||||
// Use dedicated document loader input to load documents
|
||||
if (operationMode === 'documentLoader') {
|
||||
const documentInput = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiDocument,
|
||||
0,
|
||||
)) as N8nJsonLoader | Array<Document<Record<string, unknown>>>;
|
||||
|
||||
const isN8nLoader =
|
||||
documentInput instanceof N8nJsonLoader || documentInput instanceof N8nBinaryLoader;
|
||||
|
||||
processedDocuments = isN8nLoader
|
||||
? await documentInput.processItem(item, itemIndex)
|
||||
: documentInput;
|
||||
|
||||
const response = await chain.withConfig(getTracingConfig(this)).invoke({
|
||||
input_documents: processedDocuments,
|
||||
});
|
||||
|
||||
returnData.push({ json: { response } });
|
||||
}
|
||||
|
||||
// Take the input and use binary or json loader
|
||||
if (['nodeInputJson', 'nodeInputBinary'].includes(operationMode)) {
|
||||
let textSplitter: TextSplitter | undefined;
|
||||
|
||||
switch (chunkingMode) {
|
||||
// In simple mode we use recursive character splitter with default settings
|
||||
case 'simple':
|
||||
const chunkSize = this.getNodeParameter('chunkSize', itemIndex, 1000) as number;
|
||||
const chunkOverlap = this.getNodeParameter('chunkOverlap', itemIndex, 200) as number;
|
||||
|
||||
textSplitter = new RecursiveCharacterTextSplitter({ chunkOverlap, chunkSize });
|
||||
break;
|
||||
|
||||
// In advanced mode user can connect text splitter node so we just retrieve it
|
||||
case 'advanced':
|
||||
textSplitter = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiTextSplitter,
|
||||
0,
|
||||
)) as TextSplitter | undefined;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
let processor: N8nJsonLoader | N8nBinaryLoader;
|
||||
if (operationMode === 'nodeInputBinary') {
|
||||
const binaryDataKey = this.getNodeParameter(
|
||||
'options.binaryDataKey',
|
||||
itemIndex,
|
||||
'data',
|
||||
) as string;
|
||||
processor = new N8nBinaryLoader(this, 'options.', binaryDataKey, textSplitter);
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
batchResults.forEach((response, index) => {
|
||||
if (response.status === 'rejected') {
|
||||
const error = response.reason as Error;
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: i + index },
|
||||
});
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
} else {
|
||||
processor = new N8nJsonLoader(this, 'options.', textSplitter);
|
||||
const output = response.value;
|
||||
returnData.push({ json: { output } });
|
||||
}
|
||||
});
|
||||
|
||||
const processedItem = await processor.processItem(item, itemIndex);
|
||||
const response = await chain.invoke(
|
||||
{
|
||||
input_documents: processedItem,
|
||||
},
|
||||
{ signal: this.getExecutionCancelSignal() },
|
||||
// Add delay between batches if not the last batch
|
||||
if (i + batchSize < items.length && delayBetweenBatches > 0) {
|
||||
await sleep(delayBetweenBatches);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const response = await processItem(
|
||||
this,
|
||||
itemIndex,
|
||||
items[itemIndex],
|
||||
operationMode,
|
||||
chunkingMode,
|
||||
);
|
||||
returnData.push({ json: { response } });
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
|
||||
continue;
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
returnData.push({ json: { error: error.message }, pairedItem: { item: itemIndex } });
|
||||
continue;
|
||||
}
|
||||
|
||||
throw error;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user