feat: Optimise langchain calls in batching mode (#15243)

This commit is contained in:
Benjamin Schroth
2025-05-13 13:58:38 +02:00
committed by GitHub
parent 8591c2e0d1
commit ff156930c5
35 changed files with 2946 additions and 1171 deletions

View File

@@ -1,16 +1,5 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import { NodeConnectionTypes, parseErrorMetadata, sleep } from 'n8n-workflow';
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
} from '@langchain/core/prompts';
import type { BaseRetriever } from '@langchain/core/retrievers';
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
import { createRetrievalChain } from 'langchain/chains/retrieval';
import { NodeConnectionTypes, NodeOperationError, parseErrorMetadata } from 'n8n-workflow';
import {
type INodeProperties,
type IExecuteFunctions,
type INodeExecutionData,
type INodeType,
@@ -18,28 +7,10 @@ import {
} from 'n8n-workflow';
import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions';
import { getPromptInputByType, isChatInstance } from '@utils/helpers';
import { getTemplateNoticeField } from '@utils/sharedFields';
import { getTracingConfig } from '@utils/tracing';
import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields';
const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
Context: {context}`;
// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template.
const LEGACY_INPUT_TEMPLATE_KEY = 'question';
const INPUT_TEMPLATE_KEY = 'input';
const systemPromptOption: INodeProperties = {
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
typeOptions: {
rows: 6,
},
};
import { INPUT_TEMPLATE_KEY, LEGACY_INPUT_TEMPLATE_KEY, systemPromptOption } from './constants';
import { processItem } from './processItem';
export class ChainRetrievalQa implements INodeType {
description: INodeTypeDescription = {
@@ -48,7 +19,7 @@ export class ChainRetrievalQa implements INodeType {
icon: 'fa:link',
iconColor: 'black',
group: ['transform'],
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5],
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],
description: 'Answer questions about retrieved documents',
defaults: {
name: 'Question and Answer Chain',
@@ -177,6 +148,11 @@ export class ChainRetrievalQa implements INodeType {
},
},
},
getBatchingOptionFields({
show: {
'@version': [{ _cnd: { gte: 1.6 } }],
},
}),
],
},
],
@@ -187,109 +163,78 @@ export class ChainRetrievalQa implements INodeType {
const items = this.getInputData();
const returnData: INodeExecutionData[] = [];
const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number;
const delayBetweenBatches = this.getNodeParameter(
'options.batching.delayBetweenBatches',
0,
0,
) as number;
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const model = (await this.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
0,
)) as BaseLanguageModel;
const retriever = (await this.getInputConnectionData(
NodeConnectionTypes.AiRetriever,
0,
)) as BaseRetriever;
let query;
if (this.getNode().typeVersion <= 1.2) {
query = this.getNodeParameter('query', itemIndex) as string;
} else {
query = getPromptInputByType({
ctx: this,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (query === undefined) {
throw new NodeOperationError(this.getNode(), 'The query parameter is empty.');
}
const options = this.getNodeParameter('options', itemIndex, {}) as {
systemPromptTemplate?: string;
};
let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE;
// Replace legacy input template key for versions 1.4 and below
if (this.getNode().typeVersion < 1.5) {
templateText = templateText.replace(
`{${LEGACY_INPUT_TEMPLATE_KEY}}`,
`{${INPUT_TEMPLATE_KEY}}`,
);
}
// Create prompt template based on model type and user configuration
let promptTemplate;
if (isChatInstance(model)) {
// For chat models, create a chat prompt template with system and human messages
const messages = [
SystemMessagePromptTemplate.fromTemplate(templateText),
HumanMessagePromptTemplate.fromTemplate('{input}'),
];
promptTemplate = ChatPromptTemplate.fromMessages(messages);
} else {
// For non-chat models, create a text prompt template with Question/Answer format
const questionSuffix =
options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : '';
promptTemplate = new PromptTemplate({
template: templateText + questionSuffix,
inputVariables: ['context', 'input'],
});
}
// Create the document chain that combines the retrieved documents
const combineDocsChain = await createStuffDocumentsChain({
llm: model,
prompt: promptTemplate,
if (this.getNode().typeVersion >= 1.6 && batchSize >= 1) {
// Run in batches
for (let i = 0; i < items.length; i += batchSize) {
const batch = items.slice(i, i + batchSize);
const batchPromises = batch.map(async (_item, batchItemIndex) => {
return await processItem(this, i + batchItemIndex);
});
// Create the retrieval chain that handles the retrieval and then passes to the combine docs chain
const retrievalChain = await createRetrievalChain({
combineDocsChain,
retriever,
const batchResults = await Promise.allSettled(batchPromises);
batchResults.forEach((response, index) => {
if (response.status === 'rejected') {
const error = response.reason;
if (this.continueOnFail()) {
const metadata = parseErrorMetadata(error);
returnData.push({
json: { error: error.message },
pairedItem: { item: index },
metadata,
});
return;
} else {
throw error;
}
}
const output = response.value;
const answer = output.answer as string;
if (this.getNode().typeVersion >= 1.5) {
returnData.push({ json: { response: answer } });
} else {
// Legacy format for versions 1.4 and below is { text: string }
returnData.push({ json: { response: { text: answer } } });
}
});
// Execute the chain with tracing config
const tracingConfig = getTracingConfig(this);
const response = await retrievalChain
.withConfig(tracingConfig)
.invoke({ input: query }, { signal: this.getExecutionCancelSignal() });
// Get the answer from the response
const answer: string = response.answer;
if (this.getNode().typeVersion >= 1.5) {
returnData.push({ json: { response: answer } });
} else {
// Legacy format for versions 1.4 and below is { text: string }
returnData.push({ json: { response: { text: answer } } });
}
} catch (error) {
if (this.continueOnFail()) {
const metadata = parseErrorMetadata(error);
returnData.push({
json: { error: error.message },
pairedItem: { item: itemIndex },
metadata,
});
continue;
// Add delay between batches if not the last batch
if (i + batchSize < items.length && delayBetweenBatches > 0) {
await sleep(delayBetweenBatches);
}
}
} else {
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
const response = await processItem(this, itemIndex);
const answer = response.answer as string;
if (this.getNode().typeVersion >= 1.5) {
returnData.push({ json: { response: answer } });
} else {
// Legacy format for versions 1.4 and below is { text: string }
returnData.push({ json: { response: { text: answer } } });
}
} catch (error) {
if (this.continueOnFail()) {
const metadata = parseErrorMetadata(error);
returnData.push({
json: { error: error.message },
pairedItem: { item: itemIndex },
metadata,
});
continue;
}
throw error;
throw error;
}
}
}
return [returnData];

View File

@@ -0,0 +1,20 @@
import type { INodeProperties } from 'n8n-workflow';
export const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
Context: {context}`;
// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template.
export const LEGACY_INPUT_TEMPLATE_KEY = 'question';
export const INPUT_TEMPLATE_KEY = 'input';
export const systemPromptOption: INodeProperties = {
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
typeOptions: {
rows: 6,
},
};

View File

@@ -0,0 +1,100 @@
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
} from '@langchain/core/prompts';
import type { BaseRetriever } from '@langchain/core/retrievers';
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
import { createRetrievalChain } from 'langchain/chains/retrieval';
import { type IExecuteFunctions, NodeConnectionTypes, NodeOperationError } from 'n8n-workflow';
import { getPromptInputByType, isChatInstance } from '@utils/helpers';
import { getTracingConfig } from '@utils/tracing';
import { INPUT_TEMPLATE_KEY, LEGACY_INPUT_TEMPLATE_KEY, SYSTEM_PROMPT_TEMPLATE } from './constants';
export const processItem = async (
ctx: IExecuteFunctions,
itemIndex: number,
): Promise<Record<string, unknown>> => {
const model = (await ctx.getInputConnectionData(
NodeConnectionTypes.AiLanguageModel,
0,
)) as BaseLanguageModel;
const retriever = (await ctx.getInputConnectionData(
NodeConnectionTypes.AiRetriever,
0,
)) as BaseRetriever;
let query;
if (ctx.getNode().typeVersion <= 1.2) {
query = ctx.getNodeParameter('query', itemIndex) as string;
} else {
query = getPromptInputByType({
ctx,
i: itemIndex,
inputKey: 'text',
promptTypeKey: 'promptType',
});
}
if (query === undefined) {
throw new NodeOperationError(ctx.getNode(), 'The query parameter is empty.');
}
const options = ctx.getNodeParameter('options', itemIndex, {}) as {
systemPromptTemplate?: string;
};
let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE;
// Replace legacy input template key for versions 1.4 and below
if (ctx.getNode().typeVersion < 1.5) {
templateText = templateText.replace(
`{${LEGACY_INPUT_TEMPLATE_KEY}}`,
`{${INPUT_TEMPLATE_KEY}}`,
);
}
// Create prompt template based on model type and user configuration
let promptTemplate;
if (isChatInstance(model)) {
// For chat models, create a chat prompt template with system and human messages
const messages = [
SystemMessagePromptTemplate.fromTemplate(templateText),
HumanMessagePromptTemplate.fromTemplate('{input}'),
];
promptTemplate = ChatPromptTemplate.fromMessages(messages);
} else {
// For non-chat models, create a text prompt template with Question/Answer format
const questionSuffix =
options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : '';
promptTemplate = new PromptTemplate({
template: templateText + questionSuffix,
inputVariables: ['context', 'input'],
});
}
// Create the document chain that combines the retrieved documents
const combineDocsChain = await createStuffDocumentsChain({
llm: model,
prompt: promptTemplate,
});
// Create the retrieval chain that handles the retrieval and then passes to the combine docs chain
const retrievalChain = await createRetrievalChain({
combineDocsChain,
retriever,
});
// Execute the chain with tracing config
const tracingConfig = getTracingConfig(ctx);
return await retrievalChain
.withConfig(tracingConfig)
.invoke({ input: query }, { signal: ctx.getExecutionCancelSignal() });
};

View File

@@ -71,7 +71,7 @@ describe('ChainRetrievalQa', () => {
node = new ChainRetrievalQa();
});
it.each([1.3, 1.4, 1.5])(
it.each([1.3, 1.4, 1.5, 1.6])(
'should process a query using a chat model (version %s)',
async (version) => {
// Mock a chat model that returns a predefined answer
@@ -103,7 +103,7 @@ describe('ChainRetrievalQa', () => {
},
);
it.each([1.3, 1.4, 1.5])(
it.each([1.3, 1.4, 1.5, 1.6])(
'should process a query using a text completion model (version %s)',
async (version) => {
// Mock a text completion model that returns a predefined answer
@@ -143,7 +143,7 @@ describe('ChainRetrievalQa', () => {
},
);
it.each([1.3, 1.4, 1.5])(
it.each([1.3, 1.4, 1.5, 1.6])(
'should use a custom system prompt if provided (version %s)',
async (version) => {
const customSystemPrompt = `You are a geography expert. Use the following context to answer the question.
@@ -177,7 +177,7 @@ describe('ChainRetrievalQa', () => {
},
);
it.each([1.3, 1.4, 1.5])(
it.each([1.3, 1.4, 1.5, 1.6])(
'should throw an error if the query is undefined (version %s)',
async (version) => {
const mockChatModel = new FakeChatModel({});
@@ -196,7 +196,7 @@ describe('ChainRetrievalQa', () => {
},
);
it.each([1.3, 1.4, 1.5])(
it.each([1.3, 1.4, 1.5, 1.6])(
'should add error to json if continueOnFail is true (version %s)',
async (version) => {
// Create a model that will throw an error
@@ -226,4 +226,118 @@ describe('ChainRetrievalQa', () => {
expect(result[0][0].json.error).toContain('Model error');
},
);
it('should process items in batches', async () => {
const mockChatModel = new FakeLLM({ response: 'Paris is the capital of France.' });
const items = [
{ json: { input: 'What is the capital of France?' } },
{ json: { input: 'What is the capital of France?' } },
{ json: { input: 'What is the capital of France?' } },
];
const execMock = createExecuteFunctionsMock(
{
promptType: 'define',
text: '={{ $json.input }}',
options: {
batching: {
batchSize: 2,
delayBetweenBatches: 0,
},
},
},
mockChatModel,
fakeRetriever,
1.6,
);
execMock.getInputData = () => items;
const result = await node.execute.call(execMock);
expect(result).toHaveLength(1);
expect(result[0]).toHaveLength(3);
result[0].forEach((item) => {
expect(item.json.response).toBeDefined();
});
expect(result[0][0].json.response).toContain('Paris is the capital of France.');
expect(result[0][1].json.response).toContain('Paris is the capital of France.');
expect(result[0][2].json.response).toContain('Paris is the capital of France.');
});
it('should handle errors in batches with continueOnFail', async () => {
class ErrorLLM extends FakeLLM {
async _call(): Promise<string> {
throw new UnexpectedError('Model error');
}
}
const errorModel = new ErrorLLM({});
const items = [
{ json: { input: 'What is the capital of France?' } },
{ json: { input: 'What is the population of Paris?' } },
];
const execMock = createExecuteFunctionsMock(
{
promptType: 'define',
text: '={{ $json.input }}',
options: {
batching: {
batchSize: 2,
delayBetweenBatches: 0,
},
},
},
errorModel,
fakeRetriever,
1.6,
);
execMock.getInputData = () => items;
execMock.continueOnFail = () => true;
const result = await node.execute.call(execMock);
expect(result).toHaveLength(1);
expect(result[0]).toHaveLength(2);
result[0].forEach((item) => {
expect(item.json.error).toContain('Model error');
});
});
it('should respect delay between batches', async () => {
const mockChatModel = new FakeChatModel({});
const items = [
{ json: { input: 'What is the capital of France?' } },
{ json: { input: 'What is the population of Paris?' } },
{ json: { input: 'What is France known for?' } },
];
const execMock = createExecuteFunctionsMock(
{
promptType: 'define',
text: '={{ $json.input }}',
options: {
batching: {
batchSize: 2,
delayBetweenBatches: 100,
},
},
},
mockChatModel,
fakeRetriever,
1.6,
);
execMock.getInputData = () => items;
const startTime = Date.now();
await node.execute.call(execMock);
const endTime = Date.now();
// Should take at least 100ms due to delay between batches
expect(endTime - startTime).toBeGreaterThanOrEqual(100);
});
});