feat: Optimise langchain calls in batching mode (#15243)

This commit is contained in:
Benjamin Schroth
2025-05-13 13:58:38 +02:00
committed by GitHub
parent 8591c2e0d1
commit ff156930c5
35 changed files with 2946 additions and 1171 deletions

View File

@@ -1,8 +1,6 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { HumanMessage } from '@langchain/core/messages';
import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers';
import { NodeOperationError, NodeConnectionTypes } from 'n8n-workflow';
import { NodeOperationError, NodeConnectionTypes, sleep } from 'n8n-workflow';
import type {
IDataObject,
IExecuteFunctions,
@@ -13,7 +11,9 @@ import type {
} from 'n8n-workflow';
import { z } from 'zod';
import { getTracingConfig } from '@utils/tracing';
import { getBatchingOptionFields } from '@utils/sharedFields';
import { processItem } from './processItem';
const SYSTEM_PROMPT_TEMPLATE =
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json.";
@@ -35,7 +35,7 @@ export class TextClassifier implements INodeType {
icon: 'fa:tags',
iconColor: 'black',
group: ['transform'],
version: 1,
version: [1, 1.1],
description: 'Classify your text into distinct categories',
codex: {
categories: ['AI'],
@@ -158,6 +158,11 @@ export class TextClassifier implements INodeType {
description:
'Whether to enable auto-fixing (may trigger an additional LLM call if output is broken)',
},
getBatchingOptionFields({
show: {
'@version': [{ _cnd: { gte: 1.1 } }],
},
}),
],
},
],
@@ -165,6 +170,12 @@ export class TextClassifier implements INodeType {
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
const items = this.getInputData();
const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number;
const delayBetweenBatches = this.getNodeParameter(
'options.batching.delayBetweenBatches',
0,
0,
) as number;
const llm = (await this.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
@@ -223,68 +234,93 @@ export class TextClassifier implements INodeType {
{ length: categories.length + (fallback === 'other' ? 1 : 0) },
(_) => [],
);
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) {
const item = items[itemIdx];
item.pairedItem = { item: itemIdx };
const input = this.getNodeParameter('inputText', itemIdx) as string;
if (input === undefined || input === null) {
if (this.continueOnFail()) {
returnData[0].push({
json: { error: 'Text to classify is not defined' },
pairedItem: { item: itemIdx },
});
continue;
} else {
throw new NodeOperationError(
this.getNode(),
`Text to classify for item ${itemIdx} is not defined`,
if (this.getNode().typeVersion >= 1.1 && batchSize > 1) {
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
const batchPromises = batch.map(async (_item, batchItemIndex) => {
const itemIndex = i + batchItemIndex;
const item = items[itemIndex];
item.pairedItem = { item: itemIndex };
return await processItem(
this,
itemIndex,
item,
llm,
parser,
categories,
multiClassPrompt,
fallbackPrompt,
);
});
const batchResults = await Promise.allSettled(batchPromises);
batchResults.forEach((response, batchItemIndex) => {
const index = i + batchItemIndex;
if (response.status === 'rejected') {
const error = response.reason as Error;
if (this.continueOnFail()) {
returnData[0].push({
json: { error: error.message },
pairedItem: { item: index },
});
return;
} else {
throw new NodeOperationError(this.getNode(), error.message);
}
} else {
const output = response.value;
const item = items[index];
categories.forEach((cat, idx) => {
if (output[cat.category]) returnData[idx].push(item);
});
if (fallback === 'other' && output.fallback)
returnData[returnData.length - 1].push(item);
}
});
// Add delay between batches if not the last batch
if (i + batchSize < items.length && delayBetweenBatches > 0) {
await sleep(delayBetweenBatches);
}
}
} else {
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
const item = items[itemIndex];
item.pairedItem = { item: itemIndex };
const inputPrompt = new HumanMessage(input);
try {
const output = await processItem(
this,
itemIndex,
item,
llm,
parser,
categories,
multiClassPrompt,
fallbackPrompt,
);
const systemPromptTemplateOpt = this.getNodeParameter(
'options.systemPromptTemplate',
itemIdx,
SYSTEM_PROMPT_TEMPLATE,
) as string;
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
`${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE}
{format_instructions}
${multiClassPrompt}
${fallbackPrompt}`,
);
const messages = [
await systemPromptTemplate.format({
categories: categories.map((cat) => cat.category).join(', '),
format_instructions: parser.getFormatInstructions(),
}),
inputPrompt,
];
const prompt = ChatPromptTemplate.fromMessages(messages);
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
try {
const output = await chain.invoke(messages);
categories.forEach((cat, idx) => {
if (output[cat.category]) returnData[idx].push(item);
});
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
} catch (error) {
if (this.continueOnFail()) {
returnData[0].push({
json: { error: error.message },
pairedItem: { item: itemIdx },
categories.forEach((cat, idx) => {
if (output[cat.category]) returnData[idx].push(item);
});
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
} catch (error) {
if (this.continueOnFail()) {
returnData[0].push({
json: { error: error.message },
pairedItem: { item: itemIndex },
});
continue;
continue;
}
throw error;
}
throw error;
}
}