mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 10:02:05 +00:00
feat: Optimise langchain calls in batching mode (#15243)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers';
|
||||
import { NodeOperationError, NodeConnectionTypes } from 'n8n-workflow';
|
||||
import { NodeOperationError, NodeConnectionTypes, sleep } from 'n8n-workflow';
|
||||
import type {
|
||||
IDataObject,
|
||||
IExecuteFunctions,
|
||||
@@ -13,7 +11,9 @@ import type {
|
||||
} from 'n8n-workflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { getTracingConfig } from '@utils/tracing';
|
||||
import { getBatchingOptionFields } from '@utils/sharedFields';
|
||||
|
||||
import { processItem } from './processItem';
|
||||
|
||||
const SYSTEM_PROMPT_TEMPLATE =
|
||||
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json.";
|
||||
@@ -35,7 +35,7 @@ export class TextClassifier implements INodeType {
|
||||
icon: 'fa:tags',
|
||||
iconColor: 'black',
|
||||
group: ['transform'],
|
||||
version: 1,
|
||||
version: [1, 1.1],
|
||||
description: 'Classify your text into distinct categories',
|
||||
codex: {
|
||||
categories: ['AI'],
|
||||
@@ -158,6 +158,11 @@ export class TextClassifier implements INodeType {
|
||||
description:
|
||||
'Whether to enable auto-fixing (may trigger an additional LLM call if output is broken)',
|
||||
},
|
||||
getBatchingOptionFields({
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.1 } }],
|
||||
},
|
||||
}),
|
||||
],
|
||||
},
|
||||
],
|
||||
@@ -165,6 +170,12 @@ export class TextClassifier implements INodeType {
|
||||
|
||||
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
|
||||
const items = this.getInputData();
|
||||
const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number;
|
||||
const delayBetweenBatches = this.getNodeParameter(
|
||||
'options.batching.delayBetweenBatches',
|
||||
0,
|
||||
0,
|
||||
) as number;
|
||||
|
||||
const llm = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
@@ -223,68 +234,93 @@ export class TextClassifier implements INodeType {
|
||||
{ length: categories.length + (fallback === 'other' ? 1 : 0) },
|
||||
(_) => [],
|
||||
);
|
||||
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) {
|
||||
const item = items[itemIdx];
|
||||
item.pairedItem = { item: itemIdx };
|
||||
const input = this.getNodeParameter('inputText', itemIdx) as string;
|
||||
|
||||
if (input === undefined || input === null) {
|
||||
if (this.continueOnFail()) {
|
||||
returnData[0].push({
|
||||
json: { error: 'Text to classify is not defined' },
|
||||
pairedItem: { item: itemIdx },
|
||||
});
|
||||
continue;
|
||||
} else {
|
||||
throw new NodeOperationError(
|
||||
this.getNode(),
|
||||
`Text to classify for item ${itemIdx} is not defined`,
|
||||
if (this.getNode().typeVersion >= 1.1 && batchSize > 1) {
|
||||
for (let i = 0; i < items.length; i += batchSize) {
|
||||
const batch = items.slice(i, i + batchSize);
|
||||
const batchPromises = batch.map(async (_item, batchItemIndex) => {
|
||||
const itemIndex = i + batchItemIndex;
|
||||
const item = items[itemIndex];
|
||||
item.pairedItem = { item: itemIndex };
|
||||
|
||||
return await processItem(
|
||||
this,
|
||||
itemIndex,
|
||||
item,
|
||||
llm,
|
||||
parser,
|
||||
categories,
|
||||
multiClassPrompt,
|
||||
fallbackPrompt,
|
||||
);
|
||||
});
|
||||
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
|
||||
batchResults.forEach((response, batchItemIndex) => {
|
||||
const index = i + batchItemIndex;
|
||||
if (response.status === 'rejected') {
|
||||
const error = response.reason as Error;
|
||||
if (this.continueOnFail()) {
|
||||
returnData[0].push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: index },
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
throw new NodeOperationError(this.getNode(), error.message);
|
||||
}
|
||||
} else {
|
||||
const output = response.value;
|
||||
const item = items[index];
|
||||
|
||||
categories.forEach((cat, idx) => {
|
||||
if (output[cat.category]) returnData[idx].push(item);
|
||||
});
|
||||
|
||||
if (fallback === 'other' && output.fallback)
|
||||
returnData[returnData.length - 1].push(item);
|
||||
}
|
||||
});
|
||||
|
||||
// Add delay between batches if not the last batch
|
||||
if (i + batchSize < items.length && delayBetweenBatches > 0) {
|
||||
await sleep(delayBetweenBatches);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
const item = items[itemIndex];
|
||||
item.pairedItem = { item: itemIndex };
|
||||
|
||||
const inputPrompt = new HumanMessage(input);
|
||||
try {
|
||||
const output = await processItem(
|
||||
this,
|
||||
itemIndex,
|
||||
item,
|
||||
llm,
|
||||
parser,
|
||||
categories,
|
||||
multiClassPrompt,
|
||||
fallbackPrompt,
|
||||
);
|
||||
|
||||
const systemPromptTemplateOpt = this.getNodeParameter(
|
||||
'options.systemPromptTemplate',
|
||||
itemIdx,
|
||||
SYSTEM_PROMPT_TEMPLATE,
|
||||
) as string;
|
||||
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
|
||||
`${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE}
|
||||
{format_instructions}
|
||||
${multiClassPrompt}
|
||||
${fallbackPrompt}`,
|
||||
);
|
||||
|
||||
const messages = [
|
||||
await systemPromptTemplate.format({
|
||||
categories: categories.map((cat) => cat.category).join(', '),
|
||||
format_instructions: parser.getFormatInstructions(),
|
||||
}),
|
||||
inputPrompt,
|
||||
];
|
||||
const prompt = ChatPromptTemplate.fromMessages(messages);
|
||||
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this));
|
||||
|
||||
try {
|
||||
const output = await chain.invoke(messages);
|
||||
|
||||
categories.forEach((cat, idx) => {
|
||||
if (output[cat.category]) returnData[idx].push(item);
|
||||
});
|
||||
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
returnData[0].push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: itemIdx },
|
||||
categories.forEach((cat, idx) => {
|
||||
if (output[cat.category]) returnData[idx].push(item);
|
||||
});
|
||||
if (fallback === 'other' && output.fallback) returnData[returnData.length - 1].push(item);
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
returnData[0].push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: itemIndex },
|
||||
});
|
||||
|
||||
continue;
|
||||
continue;
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
export const SYSTEM_PROMPT_TEMPLATE =
|
||||
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json.";
|
||||
@@ -0,0 +1,57 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { ChatPromptTemplate, SystemMessagePromptTemplate } from '@langchain/core/prompts';
|
||||
import type { OutputFixingParser, StructuredOutputParser } from 'langchain/output_parsers';
|
||||
import { NodeOperationError, type IExecuteFunctions, type INodeExecutionData } from 'n8n-workflow';
|
||||
|
||||
import { getTracingConfig } from '@utils/tracing';
|
||||
|
||||
import { SYSTEM_PROMPT_TEMPLATE } from './constants';
|
||||
|
||||
export async function processItem(
|
||||
ctx: IExecuteFunctions,
|
||||
itemIndex: number,
|
||||
item: INodeExecutionData,
|
||||
llm: BaseLanguageModel,
|
||||
parser: StructuredOutputParser<any> | OutputFixingParser<any>,
|
||||
categories: Array<{ category: string; description: string }>,
|
||||
multiClassPrompt: string,
|
||||
fallbackPrompt: string | undefined,
|
||||
): Promise<Record<string, unknown>> {
|
||||
const input = ctx.getNodeParameter('inputText', itemIndex) as string;
|
||||
|
||||
if (input === undefined || input === null) {
|
||||
throw new NodeOperationError(
|
||||
ctx.getNode(),
|
||||
`Text to classify for item ${itemIndex} is not defined`,
|
||||
);
|
||||
}
|
||||
|
||||
item.pairedItem = { item: itemIndex };
|
||||
|
||||
const inputPrompt = new HumanMessage(input);
|
||||
|
||||
const systemPromptTemplateOpt = ctx.getNodeParameter(
|
||||
'options.systemPromptTemplate',
|
||||
itemIndex,
|
||||
SYSTEM_PROMPT_TEMPLATE,
|
||||
) as string;
|
||||
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate(
|
||||
`${systemPromptTemplateOpt ?? SYSTEM_PROMPT_TEMPLATE}
|
||||
{format_instructions}
|
||||
${multiClassPrompt}
|
||||
${fallbackPrompt}`,
|
||||
);
|
||||
|
||||
const messages = [
|
||||
await systemPromptTemplate.format({
|
||||
categories: categories.map((cat) => cat.category).join(', '),
|
||||
format_instructions: parser.getFormatInstructions(),
|
||||
}),
|
||||
inputPrompt,
|
||||
];
|
||||
const prompt = ChatPromptTemplate.fromMessages(messages);
|
||||
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(ctx));
|
||||
|
||||
return await chain.invoke(messages);
|
||||
}
|
||||
Reference in New Issue
Block a user