mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 18:12:04 +00:00
feat: Optimise langchain calls in batching mode (#15243)
This commit is contained in:
@@ -1,16 +1,5 @@
|
||||
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||
import { NodeConnectionTypes, parseErrorMetadata, sleep } from 'n8n-workflow';
|
||||
import {
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
} from '@langchain/core/prompts';
|
||||
import type { BaseRetriever } from '@langchain/core/retrievers';
|
||||
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
|
||||
import { createRetrievalChain } from 'langchain/chains/retrieval';
|
||||
import { NodeConnectionTypes, NodeOperationError, parseErrorMetadata } from 'n8n-workflow';
|
||||
import {
|
||||
type INodeProperties,
|
||||
type IExecuteFunctions,
|
||||
type INodeExecutionData,
|
||||
type INodeType,
|
||||
@@ -18,28 +7,10 @@ import {
|
||||
} from 'n8n-workflow';
|
||||
|
||||
import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions';
|
||||
import { getPromptInputByType, isChatInstance } from '@utils/helpers';
|
||||
import { getTemplateNoticeField } from '@utils/sharedFields';
|
||||
import { getTracingConfig } from '@utils/tracing';
|
||||
import { getBatchingOptionFields, getTemplateNoticeField } from '@utils/sharedFields';
|
||||
|
||||
const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
----------------
|
||||
Context: {context}`;
|
||||
|
||||
// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template.
|
||||
const LEGACY_INPUT_TEMPLATE_KEY = 'question';
|
||||
const INPUT_TEMPLATE_KEY = 'input';
|
||||
|
||||
const systemPromptOption: INodeProperties = {
|
||||
displayName: 'System Prompt Template',
|
||||
name: 'systemPromptTemplate',
|
||||
type: 'string',
|
||||
default: SYSTEM_PROMPT_TEMPLATE,
|
||||
typeOptions: {
|
||||
rows: 6,
|
||||
},
|
||||
};
|
||||
import { INPUT_TEMPLATE_KEY, LEGACY_INPUT_TEMPLATE_KEY, systemPromptOption } from './constants';
|
||||
import { processItem } from './processItem';
|
||||
|
||||
export class ChainRetrievalQa implements INodeType {
|
||||
description: INodeTypeDescription = {
|
||||
@@ -48,7 +19,7 @@ export class ChainRetrievalQa implements INodeType {
|
||||
icon: 'fa:link',
|
||||
iconColor: 'black',
|
||||
group: ['transform'],
|
||||
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5],
|
||||
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6],
|
||||
description: 'Answer questions about retrieved documents',
|
||||
defaults: {
|
||||
name: 'Question and Answer Chain',
|
||||
@@ -177,6 +148,11 @@ export class ChainRetrievalQa implements INodeType {
|
||||
},
|
||||
},
|
||||
},
|
||||
getBatchingOptionFields({
|
||||
show: {
|
||||
'@version': [{ _cnd: { gte: 1.6 } }],
|
||||
},
|
||||
}),
|
||||
],
|
||||
},
|
||||
],
|
||||
@@ -187,109 +163,78 @@ export class ChainRetrievalQa implements INodeType {
|
||||
|
||||
const items = this.getInputData();
|
||||
const returnData: INodeExecutionData[] = [];
|
||||
const batchSize = this.getNodeParameter('options.batching.batchSize', 0, 5) as number;
|
||||
const delayBetweenBatches = this.getNodeParameter(
|
||||
'options.batching.delayBetweenBatches',
|
||||
0,
|
||||
0,
|
||||
) as number;
|
||||
|
||||
// Run for each item
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const model = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiLanguageModel,
|
||||
0,
|
||||
)) as BaseLanguageModel;
|
||||
|
||||
const retriever = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiRetriever,
|
||||
0,
|
||||
)) as BaseRetriever;
|
||||
|
||||
let query;
|
||||
|
||||
if (this.getNode().typeVersion <= 1.2) {
|
||||
query = this.getNodeParameter('query', itemIndex) as string;
|
||||
} else {
|
||||
query = getPromptInputByType({
|
||||
ctx: this,
|
||||
i: itemIndex,
|
||||
inputKey: 'text',
|
||||
promptTypeKey: 'promptType',
|
||||
});
|
||||
}
|
||||
|
||||
if (query === undefined) {
|
||||
throw new NodeOperationError(this.getNode(), 'The ‘query‘ parameter is empty.');
|
||||
}
|
||||
|
||||
const options = this.getNodeParameter('options', itemIndex, {}) as {
|
||||
systemPromptTemplate?: string;
|
||||
};
|
||||
|
||||
let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE;
|
||||
|
||||
// Replace legacy input template key for versions 1.4 and below
|
||||
if (this.getNode().typeVersion < 1.5) {
|
||||
templateText = templateText.replace(
|
||||
`{${LEGACY_INPUT_TEMPLATE_KEY}}`,
|
||||
`{${INPUT_TEMPLATE_KEY}}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Create prompt template based on model type and user configuration
|
||||
let promptTemplate;
|
||||
if (isChatInstance(model)) {
|
||||
// For chat models, create a chat prompt template with system and human messages
|
||||
const messages = [
|
||||
SystemMessagePromptTemplate.fromTemplate(templateText),
|
||||
HumanMessagePromptTemplate.fromTemplate('{input}'),
|
||||
];
|
||||
promptTemplate = ChatPromptTemplate.fromMessages(messages);
|
||||
} else {
|
||||
// For non-chat models, create a text prompt template with Question/Answer format
|
||||
const questionSuffix =
|
||||
options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : '';
|
||||
|
||||
promptTemplate = new PromptTemplate({
|
||||
template: templateText + questionSuffix,
|
||||
inputVariables: ['context', 'input'],
|
||||
});
|
||||
}
|
||||
|
||||
// Create the document chain that combines the retrieved documents
|
||||
const combineDocsChain = await createStuffDocumentsChain({
|
||||
llm: model,
|
||||
prompt: promptTemplate,
|
||||
if (this.getNode().typeVersion >= 1.6 && batchSize >= 1) {
|
||||
// Run in batches
|
||||
for (let i = 0; i < items.length; i += batchSize) {
|
||||
const batch = items.slice(i, i + batchSize);
|
||||
const batchPromises = batch.map(async (_item, batchItemIndex) => {
|
||||
return await processItem(this, i + batchItemIndex);
|
||||
});
|
||||
|
||||
// Create the retrieval chain that handles the retrieval and then passes to the combine docs chain
|
||||
const retrievalChain = await createRetrievalChain({
|
||||
combineDocsChain,
|
||||
retriever,
|
||||
const batchResults = await Promise.allSettled(batchPromises);
|
||||
|
||||
batchResults.forEach((response, index) => {
|
||||
if (response.status === 'rejected') {
|
||||
const error = response.reason;
|
||||
if (this.continueOnFail()) {
|
||||
const metadata = parseErrorMetadata(error);
|
||||
returnData.push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: index },
|
||||
metadata,
|
||||
});
|
||||
return;
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
const output = response.value;
|
||||
const answer = output.answer as string;
|
||||
if (this.getNode().typeVersion >= 1.5) {
|
||||
returnData.push({ json: { response: answer } });
|
||||
} else {
|
||||
// Legacy format for versions 1.4 and below is { text: string }
|
||||
returnData.push({ json: { response: { text: answer } } });
|
||||
}
|
||||
});
|
||||
|
||||
// Execute the chain with tracing config
|
||||
const tracingConfig = getTracingConfig(this);
|
||||
const response = await retrievalChain
|
||||
.withConfig(tracingConfig)
|
||||
.invoke({ input: query }, { signal: this.getExecutionCancelSignal() });
|
||||
|
||||
// Get the answer from the response
|
||||
const answer: string = response.answer;
|
||||
if (this.getNode().typeVersion >= 1.5) {
|
||||
returnData.push({ json: { response: answer } });
|
||||
} else {
|
||||
// Legacy format for versions 1.4 and below is { text: string }
|
||||
returnData.push({ json: { response: { text: answer } } });
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
const metadata = parseErrorMetadata(error);
|
||||
returnData.push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: itemIndex },
|
||||
metadata,
|
||||
});
|
||||
continue;
|
||||
// Add delay between batches if not the last batch
|
||||
if (i + batchSize < items.length && delayBetweenBatches > 0) {
|
||||
await sleep(delayBetweenBatches);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Run for each item
|
||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||
try {
|
||||
const response = await processItem(this, itemIndex);
|
||||
const answer = response.answer as string;
|
||||
if (this.getNode().typeVersion >= 1.5) {
|
||||
returnData.push({ json: { response: answer } });
|
||||
} else {
|
||||
// Legacy format for versions 1.4 and below is { text: string }
|
||||
returnData.push({ json: { response: { text: answer } } });
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.continueOnFail()) {
|
||||
const metadata = parseErrorMetadata(error);
|
||||
returnData.push({
|
||||
json: { error: error.message },
|
||||
pairedItem: { item: itemIndex },
|
||||
metadata,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
throw error;
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
return [returnData];
|
||||
|
||||
Reference in New Issue
Block a user