refactor(Question and Answer Chain Node): Use new LangChain's syntax (#13868)

This commit is contained in:
Eugene
2025-03-14 13:17:11 +03:00
committed by GitHub
parent 3518c14f7f
commit 311553926a
2 changed files with 318 additions and 40 deletions

View File

@@ -6,15 +6,15 @@ import {
PromptTemplate,
} from '@langchain/core/prompts';
import type { BaseRetriever } from '@langchain/core/retrievers';
import { RetrievalQAChain } from 'langchain/chains';
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
import { createRetrievalChain } from 'langchain/chains/retrieval';
import { NodeConnectionType, NodeOperationError, parseErrorMetadata } from 'n8n-workflow';
import {
NodeConnectionType,
type INodeProperties,
type IExecuteFunctions,
type INodeExecutionData,
type INodeType,
type INodeTypeDescription,
NodeOperationError,
parseErrorMetadata,
} from 'n8n-workflow';
import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions';
@@ -22,10 +22,24 @@ import { getPromptInputByType, isChatInstance } from '@utils/helpers';
import { getTemplateNoticeField } from '@utils/sharedFields';
import { getTracingConfig } from '@utils/tracing';
const SYSTEM_PROMPT_TEMPLATE = `Use the following pieces of context to answer the users question.
const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
{context}`;
Context: {context}`;
// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template.
const LEGACY_INPUT_TEMPLATE_KEY = 'question';
const INPUT_TEMPLATE_KEY = 'input';
const systemPromptOption: INodeProperties = {
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
typeOptions: {
rows: 6,
},
};
export class ChainRetrievalQa implements INodeType {
description: INodeTypeDescription = {
@@ -34,7 +48,7 @@ export class ChainRetrievalQa implements INodeType {
icon: 'fa:link',
iconColor: 'black',
group: ['transform'],
version: [1, 1.1, 1.2, 1.3, 1.4],
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5],
description: 'Answer questions about retrieved documents',
defaults: {
name: 'Question and Answer Chain',
@@ -146,14 +160,21 @@ export class ChainRetrievalQa implements INodeType {
placeholder: 'Add Option',
options: [
{
displayName: 'System Prompt Template',
name: 'systemPromptTemplate',
type: 'string',
default: SYSTEM_PROMPT_TEMPLATE,
description:
'Template string used for the system prompt. This should include the variable `{context}` for the provided context. For text completion models, you should also include the variable `{question}` for the users query.',
typeOptions: {
rows: 6,
...systemPromptOption,
description: `Template string used for the system prompt. This should include the variable \`{context}\` for the provided context. For text completion models, you should also include the variable \`{${LEGACY_INPUT_TEMPLATE_KEY}}\` for the users query.`,
displayOptions: {
show: {
'@version': [{ _cnd: { lt: 1.5 } }],
},
},
},
{
...systemPromptOption,
description: `Template string used for the system prompt. This should include the variable \`{context}\` for the provided context. For text completion models, you should also include the variable \`{${INPUT_TEMPLATE_KEY}}\` for the users query.`,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.5 } }],
},
},
},
],
@@ -166,6 +187,7 @@ export class ChainRetrievalQa implements INodeType {
const items = this.getInputData();
const returnData: INodeExecutionData[] = [];
// Run for each item
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
try {
@@ -200,35 +222,62 @@ export class ChainRetrievalQa implements INodeType {
systemPromptTemplate?: string;
};
const chainParameters = {} as {
prompt?: PromptTemplate | ChatPromptTemplate;
};
let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE;
if (options.systemPromptTemplate !== undefined) {
if (isChatInstance(model)) {
const messages = [
SystemMessagePromptTemplate.fromTemplate(options.systemPromptTemplate),
HumanMessagePromptTemplate.fromTemplate('{question}'),
];
const chatPromptTemplate = ChatPromptTemplate.fromMessages(messages);
chainParameters.prompt = chatPromptTemplate;
} else {
const completionPromptTemplate = new PromptTemplate({
template: options.systemPromptTemplate,
inputVariables: ['context', 'question'],
});
chainParameters.prompt = completionPromptTemplate;
}
// Replace legacy input template key for versions 1.4 and below
if (this.getNode().typeVersion < 1.5) {
templateText = templateText.replace(
`{${LEGACY_INPUT_TEMPLATE_KEY}}`,
`{${INPUT_TEMPLATE_KEY}}`,
);
}
const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters);
// Create prompt template based on model type and user configuration
let promptTemplate;
if (isChatInstance(model)) {
// For chat models, create a chat prompt template with system and human messages
const messages = [
SystemMessagePromptTemplate.fromTemplate(templateText),
HumanMessagePromptTemplate.fromTemplate('{input}'),
];
promptTemplate = ChatPromptTemplate.fromMessages(messages);
} else {
// For non-chat models, create a text prompt template with Question/Answer format
const questionSuffix =
options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : '';
const response = await chain
.withConfig(getTracingConfig(this))
.invoke({ query }, { signal: this.getExecutionCancelSignal() });
returnData.push({ json: { response } });
promptTemplate = new PromptTemplate({
template: templateText + questionSuffix,
inputVariables: ['context', 'input'],
});
}
// Create the document chain that combines the retrieved documents
const combineDocsChain = await createStuffDocumentsChain({
llm: model,
prompt: promptTemplate,
});
// Create the retrieval chain that handles the retrieval and then passes to the combine docs chain
const retrievalChain = await createRetrievalChain({
combineDocsChain,
retriever,
});
// Execute the chain with tracing config
const tracingConfig = getTracingConfig(this);
const response = await retrievalChain
.withConfig(tracingConfig)
.invoke({ input: query }, { signal: this.getExecutionCancelSignal() });
// Get the answer from the response
const answer: string = response.answer;
if (this.getNode().typeVersion >= 1.5) {
returnData.push({ json: { response: answer } });
} else {
// Legacy format for versions 1.4 and below is { text: string }
returnData.push({ json: { response: { text: answer } } });
}
} catch (error) {
if (this.continueOnFail()) {
const metadata = parseErrorMetadata(error);

View File

@@ -0,0 +1,229 @@
import { Document } from '@langchain/core/documents';
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { BaseRetriever } from '@langchain/core/retrievers';
import { FakeChatModel, FakeLLM, FakeRetriever } from '@langchain/core/utils/testing';
import get from 'lodash/get';
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
import { NodeConnectionType, NodeOperationError, UnexpectedError } from 'n8n-workflow';
import { ChainRetrievalQa } from '../ChainRetrievalQa.node';
const createExecuteFunctionsMock = (
parameters: IDataObject,
fakeLlm: BaseLanguageModel,
fakeRetriever: BaseRetriever,
version: number,
) => {
return {
getExecutionCancelSignal() {
return new AbortController().signal;
},
getNodeParameter(parameter: string) {
return get(parameters, parameter);
},
getNode() {
return {
typeVersion: version,
};
},
getInputConnectionData(type: NodeConnectionType) {
if (type === NodeConnectionType.AiLanguageModel) {
return fakeLlm;
}
if (type === NodeConnectionType.AiRetriever) {
return fakeRetriever;
}
return null;
},
getInputData() {
return [{ json: {} }];
},
getWorkflow() {
return {
name: 'Test Workflow',
};
},
getExecutionId() {
return 'test_execution_id';
},
continueOnFail() {
return false;
},
logger: { debug: jest.fn() },
} as unknown as IExecuteFunctions;
};
describe('ChainRetrievalQa', () => {
let node: ChainRetrievalQa;
const testDocs = [
new Document({
pageContent: 'The capital of France is Paris. It is known for the Eiffel Tower.',
}),
new Document({
pageContent:
'Paris is the largest city in France with a population of over 2 million people.',
}),
];
const fakeRetriever = new FakeRetriever({ output: testDocs });
beforeEach(() => {
node = new ChainRetrievalQa();
});
it.each([1.3, 1.4, 1.5])(
'should process a query using a chat model (version %s)',
async (version) => {
// Mock a chat model that returns a predefined answer
const mockChatModel = new FakeChatModel({});
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {},
};
const result = await node.execute.call(
createExecuteFunctionsMock(params, mockChatModel, fakeRetriever, version),
);
// Check that the result contains the expected response (FakeChatModel returns the query as response)
expect(result).toHaveLength(1);
expect(result[0]).toHaveLength(1);
expect(result[0][0].json.response).toBeDefined();
let responseText = result[0][0].json.response;
if (version < 1.5 && typeof responseText === 'object') {
responseText = (responseText as { text: string }).text;
}
expect(responseText).toContain('You are an assistant for question-answering tasks'); // system prompt
expect(responseText).toContain('The capital of France is Paris.'); // context
expect(responseText).toContain('What is the capital of France?'); // query
},
);
it.each([1.3, 1.4, 1.5])(
'should process a query using a text completion model (version %s)',
async (version) => {
// Mock a text completion model that returns a predefined answer
const mockTextModel = new FakeLLM({ response: 'Paris is the capital of France.' });
const modelCallSpy = jest.spyOn(mockTextModel, '_call');
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {},
};
const result = await node.execute.call(
createExecuteFunctionsMock(params, mockTextModel, fakeRetriever, version),
);
// Check model was called with the correct query
expect(modelCallSpy).toHaveBeenCalled();
expect(modelCallSpy.mock.calls[0][0]).toEqual(
expect.stringContaining('Question: What is the capital of France?'),
);
// Check that the result contains the expected response
expect(result).toHaveLength(1);
expect(result[0]).toHaveLength(1);
if (version < 1.5) {
expect((result[0][0].json.response as { text: string }).text).toContain(
'Paris is the capital of France.',
);
} else {
expect(result[0][0].json).toEqual({
response: 'Paris is the capital of France.',
});
}
},
);
it.each([1.3, 1.4, 1.5])(
'should use a custom system prompt if provided (version %s)',
async (version) => {
const customSystemPrompt = `You are a geography expert. Use the following context to answer the question.
----------------
Context: {context}`;
// The chat model will return a response indicating it received the custom prompt
const mockChatModel = new FakeChatModel({});
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {
systemPromptTemplate: customSystemPrompt,
},
};
const result = await node.execute.call(
createExecuteFunctionsMock(params, mockChatModel, fakeRetriever, version),
);
expect(result).toHaveLength(1);
expect(result[0]).toHaveLength(1);
if (version < 1.5) {
expect((result[0][0].json.response as { text: string }).text).toContain(
'You are a geography expert.',
);
} else {
expect(result[0][0].json.response).toContain('You are a geography expert.');
}
},
);
it.each([1.3, 1.4, 1.5])(
'should throw an error if the query is undefined (version %s)',
async (version) => {
const mockChatModel = new FakeChatModel({});
const params = {
promptType: 'define',
text: undefined, // undefined query
options: {},
};
await expect(
node.execute.call(
createExecuteFunctionsMock(params, mockChatModel, fakeRetriever, version),
),
).rejects.toThrow(NodeOperationError);
},
);
it.each([1.3, 1.4, 1.5])(
'should add error to json if continueOnFail is true (version %s)',
async (version) => {
// Create a model that will throw an error
class ErrorLLM extends FakeLLM {
async _call(): Promise<string> {
throw new UnexpectedError('Model error');
}
}
const errorModel = new ErrorLLM({});
const params = {
promptType: 'define',
text: 'What is the capital of France?',
options: {},
};
// Override continueOnFail to return true
const execMock = createExecuteFunctionsMock(params, errorModel, fakeRetriever, version);
execMock.continueOnFail = () => true;
const result = await node.execute.call(execMock);
expect(result).toHaveLength(1);
expect(result[0]).toHaveLength(1);
expect(result[0][0].json).toHaveProperty('error');
expect(result[0][0].json.error).toContain('Model error');
},
);
});