mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-18 02:21:13 +00:00
refactor(Question and Answer Chain Node): Use new LangChain's syntax (#13868)
This commit is contained in:
@@ -6,15 +6,15 @@ import {
|
|||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
} from '@langchain/core/prompts';
|
} from '@langchain/core/prompts';
|
||||||
import type { BaseRetriever } from '@langchain/core/retrievers';
|
import type { BaseRetriever } from '@langchain/core/retrievers';
|
||||||
import { RetrievalQAChain } from 'langchain/chains';
|
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
|
||||||
|
import { createRetrievalChain } from 'langchain/chains/retrieval';
|
||||||
|
import { NodeConnectionType, NodeOperationError, parseErrorMetadata } from 'n8n-workflow';
|
||||||
import {
|
import {
|
||||||
NodeConnectionType,
|
type INodeProperties,
|
||||||
type IExecuteFunctions,
|
type IExecuteFunctions,
|
||||||
type INodeExecutionData,
|
type INodeExecutionData,
|
||||||
type INodeType,
|
type INodeType,
|
||||||
type INodeTypeDescription,
|
type INodeTypeDescription,
|
||||||
NodeOperationError,
|
|
||||||
parseErrorMetadata,
|
|
||||||
} from 'n8n-workflow';
|
} from 'n8n-workflow';
|
||||||
|
|
||||||
import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions';
|
import { promptTypeOptions, textFromPreviousNode } from '@utils/descriptions';
|
||||||
@@ -22,10 +22,24 @@ import { getPromptInputByType, isChatInstance } from '@utils/helpers';
|
|||||||
import { getTemplateNoticeField } from '@utils/sharedFields';
|
import { getTemplateNoticeField } from '@utils/sharedFields';
|
||||||
import { getTracingConfig } from '@utils/tracing';
|
import { getTracingConfig } from '@utils/tracing';
|
||||||
|
|
||||||
const SYSTEM_PROMPT_TEMPLATE = `Use the following pieces of context to answer the users question.
|
const SYSTEM_PROMPT_TEMPLATE = `You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
|
||||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||||
----------------
|
----------------
|
||||||
{context}`;
|
Context: {context}`;
|
||||||
|
|
||||||
|
// Due to the refactoring in version 1.5, the variable name {question} needed to be changed to {input} in the prompt template.
|
||||||
|
const LEGACY_INPUT_TEMPLATE_KEY = 'question';
|
||||||
|
const INPUT_TEMPLATE_KEY = 'input';
|
||||||
|
|
||||||
|
const systemPromptOption: INodeProperties = {
|
||||||
|
displayName: 'System Prompt Template',
|
||||||
|
name: 'systemPromptTemplate',
|
||||||
|
type: 'string',
|
||||||
|
default: SYSTEM_PROMPT_TEMPLATE,
|
||||||
|
typeOptions: {
|
||||||
|
rows: 6,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
export class ChainRetrievalQa implements INodeType {
|
export class ChainRetrievalQa implements INodeType {
|
||||||
description: INodeTypeDescription = {
|
description: INodeTypeDescription = {
|
||||||
@@ -34,7 +48,7 @@ export class ChainRetrievalQa implements INodeType {
|
|||||||
icon: 'fa:link',
|
icon: 'fa:link',
|
||||||
iconColor: 'black',
|
iconColor: 'black',
|
||||||
group: ['transform'],
|
group: ['transform'],
|
||||||
version: [1, 1.1, 1.2, 1.3, 1.4],
|
version: [1, 1.1, 1.2, 1.3, 1.4, 1.5],
|
||||||
description: 'Answer questions about retrieved documents',
|
description: 'Answer questions about retrieved documents',
|
||||||
defaults: {
|
defaults: {
|
||||||
name: 'Question and Answer Chain',
|
name: 'Question and Answer Chain',
|
||||||
@@ -146,14 +160,21 @@ export class ChainRetrievalQa implements INodeType {
|
|||||||
placeholder: 'Add Option',
|
placeholder: 'Add Option',
|
||||||
options: [
|
options: [
|
||||||
{
|
{
|
||||||
displayName: 'System Prompt Template',
|
...systemPromptOption,
|
||||||
name: 'systemPromptTemplate',
|
description: `Template string used for the system prompt. This should include the variable \`{context}\` for the provided context. For text completion models, you should also include the variable \`{${LEGACY_INPUT_TEMPLATE_KEY}}\` for the user’s query.`,
|
||||||
type: 'string',
|
displayOptions: {
|
||||||
default: SYSTEM_PROMPT_TEMPLATE,
|
show: {
|
||||||
description:
|
'@version': [{ _cnd: { lt: 1.5 } }],
|
||||||
'Template string used for the system prompt. This should include the variable `{context}` for the provided context. For text completion models, you should also include the variable `{question}` for the user’s query.',
|
},
|
||||||
typeOptions: {
|
},
|
||||||
rows: 6,
|
},
|
||||||
|
{
|
||||||
|
...systemPromptOption,
|
||||||
|
description: `Template string used for the system prompt. This should include the variable \`{context}\` for the provided context. For text completion models, you should also include the variable \`{${INPUT_TEMPLATE_KEY}}\` for the user’s query.`,
|
||||||
|
displayOptions: {
|
||||||
|
show: {
|
||||||
|
'@version': [{ _cnd: { gte: 1.5 } }],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -166,6 +187,7 @@ export class ChainRetrievalQa implements INodeType {
|
|||||||
|
|
||||||
const items = this.getInputData();
|
const items = this.getInputData();
|
||||||
const returnData: INodeExecutionData[] = [];
|
const returnData: INodeExecutionData[] = [];
|
||||||
|
|
||||||
// Run for each item
|
// Run for each item
|
||||||
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
for (let itemIndex = 0; itemIndex < items.length; itemIndex++) {
|
||||||
try {
|
try {
|
||||||
@@ -200,35 +222,62 @@ export class ChainRetrievalQa implements INodeType {
|
|||||||
systemPromptTemplate?: string;
|
systemPromptTemplate?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const chainParameters = {} as {
|
let templateText = options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE;
|
||||||
prompt?: PromptTemplate | ChatPromptTemplate;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (options.systemPromptTemplate !== undefined) {
|
// Replace legacy input template key for versions 1.4 and below
|
||||||
if (isChatInstance(model)) {
|
if (this.getNode().typeVersion < 1.5) {
|
||||||
const messages = [
|
templateText = templateText.replace(
|
||||||
SystemMessagePromptTemplate.fromTemplate(options.systemPromptTemplate),
|
`{${LEGACY_INPUT_TEMPLATE_KEY}}`,
|
||||||
HumanMessagePromptTemplate.fromTemplate('{question}'),
|
`{${INPUT_TEMPLATE_KEY}}`,
|
||||||
];
|
);
|
||||||
const chatPromptTemplate = ChatPromptTemplate.fromMessages(messages);
|
|
||||||
|
|
||||||
chainParameters.prompt = chatPromptTemplate;
|
|
||||||
} else {
|
|
||||||
const completionPromptTemplate = new PromptTemplate({
|
|
||||||
template: options.systemPromptTemplate,
|
|
||||||
inputVariables: ['context', 'question'],
|
|
||||||
});
|
|
||||||
|
|
||||||
chainParameters.prompt = completionPromptTemplate;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const chain = RetrievalQAChain.fromLLM(model, retriever, chainParameters);
|
// Create prompt template based on model type and user configuration
|
||||||
|
let promptTemplate;
|
||||||
|
if (isChatInstance(model)) {
|
||||||
|
// For chat models, create a chat prompt template with system and human messages
|
||||||
|
const messages = [
|
||||||
|
SystemMessagePromptTemplate.fromTemplate(templateText),
|
||||||
|
HumanMessagePromptTemplate.fromTemplate('{input}'),
|
||||||
|
];
|
||||||
|
promptTemplate = ChatPromptTemplate.fromMessages(messages);
|
||||||
|
} else {
|
||||||
|
// For non-chat models, create a text prompt template with Question/Answer format
|
||||||
|
const questionSuffix =
|
||||||
|
options.systemPromptTemplate === undefined ? '\n\nQuestion: {input}\nAnswer:' : '';
|
||||||
|
|
||||||
const response = await chain
|
promptTemplate = new PromptTemplate({
|
||||||
.withConfig(getTracingConfig(this))
|
template: templateText + questionSuffix,
|
||||||
.invoke({ query }, { signal: this.getExecutionCancelSignal() });
|
inputVariables: ['context', 'input'],
|
||||||
returnData.push({ json: { response } });
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the document chain that combines the retrieved documents
|
||||||
|
const combineDocsChain = await createStuffDocumentsChain({
|
||||||
|
llm: model,
|
||||||
|
prompt: promptTemplate,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create the retrieval chain that handles the retrieval and then passes to the combine docs chain
|
||||||
|
const retrievalChain = await createRetrievalChain({
|
||||||
|
combineDocsChain,
|
||||||
|
retriever,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Execute the chain with tracing config
|
||||||
|
const tracingConfig = getTracingConfig(this);
|
||||||
|
const response = await retrievalChain
|
||||||
|
.withConfig(tracingConfig)
|
||||||
|
.invoke({ input: query }, { signal: this.getExecutionCancelSignal() });
|
||||||
|
|
||||||
|
// Get the answer from the response
|
||||||
|
const answer: string = response.answer;
|
||||||
|
if (this.getNode().typeVersion >= 1.5) {
|
||||||
|
returnData.push({ json: { response: answer } });
|
||||||
|
} else {
|
||||||
|
// Legacy format for versions 1.4 and below is { text: string }
|
||||||
|
returnData.push({ json: { response: { text: answer } } });
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (this.continueOnFail()) {
|
if (this.continueOnFail()) {
|
||||||
const metadata = parseErrorMetadata(error);
|
const metadata = parseErrorMetadata(error);
|
||||||
|
|||||||
@@ -0,0 +1,229 @@
|
|||||||
|
import { Document } from '@langchain/core/documents';
|
||||||
|
import type { BaseLanguageModel } from '@langchain/core/language_models/base';
|
||||||
|
import type { BaseRetriever } from '@langchain/core/retrievers';
|
||||||
|
import { FakeChatModel, FakeLLM, FakeRetriever } from '@langchain/core/utils/testing';
|
||||||
|
import get from 'lodash/get';
|
||||||
|
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
|
||||||
|
import { NodeConnectionType, NodeOperationError, UnexpectedError } from 'n8n-workflow';
|
||||||
|
|
||||||
|
import { ChainRetrievalQa } from '../ChainRetrievalQa.node';
|
||||||
|
|
||||||
|
const createExecuteFunctionsMock = (
|
||||||
|
parameters: IDataObject,
|
||||||
|
fakeLlm: BaseLanguageModel,
|
||||||
|
fakeRetriever: BaseRetriever,
|
||||||
|
version: number,
|
||||||
|
) => {
|
||||||
|
return {
|
||||||
|
getExecutionCancelSignal() {
|
||||||
|
return new AbortController().signal;
|
||||||
|
},
|
||||||
|
getNodeParameter(parameter: string) {
|
||||||
|
return get(parameters, parameter);
|
||||||
|
},
|
||||||
|
getNode() {
|
||||||
|
return {
|
||||||
|
typeVersion: version,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
getInputConnectionData(type: NodeConnectionType) {
|
||||||
|
if (type === NodeConnectionType.AiLanguageModel) {
|
||||||
|
return fakeLlm;
|
||||||
|
}
|
||||||
|
if (type === NodeConnectionType.AiRetriever) {
|
||||||
|
return fakeRetriever;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
getInputData() {
|
||||||
|
return [{ json: {} }];
|
||||||
|
},
|
||||||
|
getWorkflow() {
|
||||||
|
return {
|
||||||
|
name: 'Test Workflow',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
getExecutionId() {
|
||||||
|
return 'test_execution_id';
|
||||||
|
},
|
||||||
|
continueOnFail() {
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
logger: { debug: jest.fn() },
|
||||||
|
} as unknown as IExecuteFunctions;
|
||||||
|
};
|
||||||
|
|
||||||
|
describe('ChainRetrievalQa', () => {
|
||||||
|
let node: ChainRetrievalQa;
|
||||||
|
const testDocs = [
|
||||||
|
new Document({
|
||||||
|
pageContent: 'The capital of France is Paris. It is known for the Eiffel Tower.',
|
||||||
|
}),
|
||||||
|
new Document({
|
||||||
|
pageContent:
|
||||||
|
'Paris is the largest city in France with a population of over 2 million people.',
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
|
||||||
|
const fakeRetriever = new FakeRetriever({ output: testDocs });
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
node = new ChainRetrievalQa();
|
||||||
|
});
|
||||||
|
|
||||||
|
it.each([1.3, 1.4, 1.5])(
|
||||||
|
'should process a query using a chat model (version %s)',
|
||||||
|
async (version) => {
|
||||||
|
// Mock a chat model that returns a predefined answer
|
||||||
|
const mockChatModel = new FakeChatModel({});
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
promptType: 'define',
|
||||||
|
text: 'What is the capital of France?',
|
||||||
|
options: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await node.execute.call(
|
||||||
|
createExecuteFunctionsMock(params, mockChatModel, fakeRetriever, version),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check that the result contains the expected response (FakeChatModel returns the query as response)
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result[0]).toHaveLength(1);
|
||||||
|
expect(result[0][0].json.response).toBeDefined();
|
||||||
|
|
||||||
|
let responseText = result[0][0].json.response;
|
||||||
|
if (version < 1.5 && typeof responseText === 'object') {
|
||||||
|
responseText = (responseText as { text: string }).text;
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(responseText).toContain('You are an assistant for question-answering tasks'); // system prompt
|
||||||
|
expect(responseText).toContain('The capital of France is Paris.'); // context
|
||||||
|
expect(responseText).toContain('What is the capital of France?'); // query
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
it.each([1.3, 1.4, 1.5])(
|
||||||
|
'should process a query using a text completion model (version %s)',
|
||||||
|
async (version) => {
|
||||||
|
// Mock a text completion model that returns a predefined answer
|
||||||
|
const mockTextModel = new FakeLLM({ response: 'Paris is the capital of France.' });
|
||||||
|
|
||||||
|
const modelCallSpy = jest.spyOn(mockTextModel, '_call');
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
promptType: 'define',
|
||||||
|
text: 'What is the capital of France?',
|
||||||
|
options: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await node.execute.call(
|
||||||
|
createExecuteFunctionsMock(params, mockTextModel, fakeRetriever, version),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check model was called with the correct query
|
||||||
|
expect(modelCallSpy).toHaveBeenCalled();
|
||||||
|
expect(modelCallSpy.mock.calls[0][0]).toEqual(
|
||||||
|
expect.stringContaining('Question: What is the capital of France?'),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check that the result contains the expected response
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result[0]).toHaveLength(1);
|
||||||
|
|
||||||
|
if (version < 1.5) {
|
||||||
|
expect((result[0][0].json.response as { text: string }).text).toContain(
|
||||||
|
'Paris is the capital of France.',
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
expect(result[0][0].json).toEqual({
|
||||||
|
response: 'Paris is the capital of France.',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
it.each([1.3, 1.4, 1.5])(
|
||||||
|
'should use a custom system prompt if provided (version %s)',
|
||||||
|
async (version) => {
|
||||||
|
const customSystemPrompt = `You are a geography expert. Use the following context to answer the question.
|
||||||
|
----------------
|
||||||
|
Context: {context}`;
|
||||||
|
|
||||||
|
// The chat model will return a response indicating it received the custom prompt
|
||||||
|
const mockChatModel = new FakeChatModel({});
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
promptType: 'define',
|
||||||
|
text: 'What is the capital of France?',
|
||||||
|
options: {
|
||||||
|
systemPromptTemplate: customSystemPrompt,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await node.execute.call(
|
||||||
|
createExecuteFunctionsMock(params, mockChatModel, fakeRetriever, version),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result[0]).toHaveLength(1);
|
||||||
|
if (version < 1.5) {
|
||||||
|
expect((result[0][0].json.response as { text: string }).text).toContain(
|
||||||
|
'You are a geography expert.',
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
expect(result[0][0].json.response).toContain('You are a geography expert.');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
it.each([1.3, 1.4, 1.5])(
|
||||||
|
'should throw an error if the query is undefined (version %s)',
|
||||||
|
async (version) => {
|
||||||
|
const mockChatModel = new FakeChatModel({});
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
promptType: 'define',
|
||||||
|
text: undefined, // undefined query
|
||||||
|
options: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
node.execute.call(
|
||||||
|
createExecuteFunctionsMock(params, mockChatModel, fakeRetriever, version),
|
||||||
|
),
|
||||||
|
).rejects.toThrow(NodeOperationError);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
it.each([1.3, 1.4, 1.5])(
|
||||||
|
'should add error to json if continueOnFail is true (version %s)',
|
||||||
|
async (version) => {
|
||||||
|
// Create a model that will throw an error
|
||||||
|
class ErrorLLM extends FakeLLM {
|
||||||
|
async _call(): Promise<string> {
|
||||||
|
throw new UnexpectedError('Model error');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const errorModel = new ErrorLLM({});
|
||||||
|
|
||||||
|
const params = {
|
||||||
|
promptType: 'define',
|
||||||
|
text: 'What is the capital of France?',
|
||||||
|
options: {},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Override continueOnFail to return true
|
||||||
|
const execMock = createExecuteFunctionsMock(params, errorModel, fakeRetriever, version);
|
||||||
|
execMock.continueOnFail = () => true;
|
||||||
|
|
||||||
|
const result = await node.execute.call(execMock);
|
||||||
|
|
||||||
|
expect(result).toHaveLength(1);
|
||||||
|
expect(result[0]).toHaveLength(1);
|
||||||
|
expect(result[0][0].json).toHaveProperty('error');
|
||||||
|
expect(result[0][0].json.error).toContain('Model error');
|
||||||
|
},
|
||||||
|
);
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user