From 969552aeae96ca6c27537b32c1618f91c2c1075a Mon Sep 17 00:00:00 2001 From: Benjamin Schroth <68321970+schrothbn@users.noreply.github.com> Date: Fri, 6 Jun 2025 13:16:03 +0200 Subject: [PATCH] feat(Vector Store Retriever Node): Add reranker support to retriever for QA chain (#16051) --- .../RetrieverVectorStore.node.ts | 22 ++- .../test/RetrieverVectorStore.node.test.ts | 133 ++++++++++++++++++ .../createVectorStoreNode.test.ts.snap | 3 +- .../createVectorStoreNode.ts | 4 +- .../__tests__/retrieveOperation.test.ts | 48 +++++++ .../operations/retrieveOperation.ts | 25 +++- 6 files changed, 226 insertions(+), 9 deletions(-) create mode 100644 packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/test/RetrieverVectorStore.node.test.ts diff --git a/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.ts b/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.ts index 7323eacb2d..3be75b289c 100644 --- a/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.ts @@ -1,5 +1,7 @@ /* eslint-disable n8n-nodes-base/node-dirname-against-convention */ -import type { VectorStore } from '@langchain/core/vectorstores'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; +import { VectorStore } from '@langchain/core/vectorstores'; +import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'; import { NodeConnectionTypes, type INodeType, @@ -65,9 +67,23 @@ export class RetrieverVectorStore implements INodeType { const vectorStore = (await this.getInputConnectionData( NodeConnectionTypes.AiVectorStore, itemIndex, - )) as VectorStore; + )) as + | VectorStore + | { + reranker: BaseDocumentCompressor; + vectorStore: VectorStore; + }; - const retriever = vectorStore.asRetriever(topK); + let retriever = null; + + if (vectorStore instanceof VectorStore) { + retriever = vectorStore.asRetriever(topK); + } else { + retriever = new ContextualCompressionRetriever({ + baseCompressor: vectorStore.reranker, + baseRetriever: vectorStore.vectorStore.asRetriever(topK), + }); + } return { response: logWrapper(retriever, this), diff --git a/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/test/RetrieverVectorStore.node.test.ts b/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/test/RetrieverVectorStore.node.test.ts new file mode 100644 index 0000000000..ae6f4e5c30 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/retrievers/RetrieverVectorStore/test/RetrieverVectorStore.node.test.ts @@ -0,0 +1,133 @@ +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; +import { VectorStore } from '@langchain/core/vectorstores'; +import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression'; +import type { ISupplyDataFunctions } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; + +import { RetrieverVectorStore } from '../RetrieverVectorStore.node'; + +const mockLogger = { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), +}; + +describe('RetrieverVectorStore', () => { + let retrieverNode: RetrieverVectorStore; + let mockContext: jest.Mocked; + + beforeEach(() => { + retrieverNode = new RetrieverVectorStore(); + mockContext = { + logger: mockLogger, + getNodeParameter: jest.fn(), + getInputConnectionData: jest.fn(), + } as unknown as jest.Mocked; + jest.clearAllMocks(); + }); + + describe('supplyData', () => { + it('should create a retriever from a basic VectorStore', async () => { + const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore; + mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' }); + + mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => { + if (param === 'topK') return 4; + return defaultValue; + }); + + mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore); + + const result = await retrieverNode.supplyData.call(mockContext, 0); + + expect(mockContext.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiVectorStore, + 0, + ); + expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4); + expect(result).toHaveProperty('response', { test: 'retriever' }); + }); + + it('should create a retriever with custom topK parameter', async () => { + const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore; + mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' }); + + mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => { + if (param === 'topK') return 10; + return defaultValue; + }); + mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore); + + const result = await retrieverNode.supplyData.call(mockContext, 0); + + expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(10); + expect(result).toHaveProperty('response', { test: 'retriever' }); + }); + + it('should create a ContextualCompressionRetriever when input contains reranker and vectorStore', async () => { + const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore; + mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' }); + + const mockReranker = {} as BaseDocumentCompressor; + + const inputWithReranker = { + reranker: mockReranker, + vectorStore: mockVectorStore, + }; + + mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => { + if (param === 'topK') return 4; + return defaultValue; + }); + mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker); + + const result = await retrieverNode.supplyData.call(mockContext, 0); + + expect(mockContext.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiVectorStore, + 0, + ); + expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4); + expect(result.response).toBeInstanceOf(ContextualCompressionRetriever); + }); + + it('should create a ContextualCompressionRetriever with custom topK when using reranker', async () => { + const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore; + mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' }); + + const mockReranker = {} as BaseDocumentCompressor; + + const inputWithReranker = { + reranker: mockReranker, + vectorStore: mockVectorStore, + }; + + mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => { + if (param === 'topK') return 8; + return defaultValue; + }); + mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker); + + const result = await retrieverNode.supplyData.call(mockContext, 0); + + expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(8); + expect(result.response).toBeInstanceOf(ContextualCompressionRetriever); + }); + + it('should use default topK value when parameter is not provided', async () => { + const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore; + mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' }); + + mockContext.getNodeParameter.mockImplementation((_param, _itemIndex, defaultValue) => { + return defaultValue; + }); + mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore); + + await retrieverNode.supplyData.call(mockContext, 0); + + expect(mockContext.getNodeParameter).toHaveBeenCalledWith('topK', 0, 4); + expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4); + }); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap index a63873a4d6..7ab56553a3 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap @@ -44,7 +44,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] = const useReranker = parameters?.useReranker; const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}] - if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) { + if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) { inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1}) } @@ -246,6 +246,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] = "show": { "mode": [ "load", + "retrieve", "retrieve-as-tool", ], }, diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts index 833749b954..4761d8cb8b 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts @@ -72,7 +72,7 @@ export const createVectorStoreNode = ( const useReranker = parameters?.useReranker; const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}] - if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) { + if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) { inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1}) } @@ -215,7 +215,7 @@ export const createVectorStoreNode = ( description: 'Whether or not to rerank results', displayOptions: { show: { - mode: ['load', 'retrieve-as-tool'], + mode: ['load', 'retrieve', 'retrieve-as-tool'], }, }, }, diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveOperation.test.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveOperation.test.ts index 65564885ce..cd6cac1c3a 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveOperation.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveOperation.test.ts @@ -1,8 +1,10 @@ import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { VectorStore } from '@langchain/core/vectorstores'; import type { MockProxy } from 'jest-mock-extended'; import { mock } from 'jest-mock-extended'; import type { ISupplyDataFunctions } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; import { logWrapper } from '@utils/logWrapper'; @@ -22,15 +24,19 @@ describe('handleRetrieveOperation', () => { let mockContext: MockProxy; let mockEmbeddings: MockProxy; let mockVectorStore: MockProxy; + let mockReranker: MockProxy; let mockArgs: VectorStoreNodeConstructorArgs; beforeEach(() => { mockContext = mock(); + mockContext.getNodeParameter.mockReturnValue(false); // Default useReranker to false mockEmbeddings = mock(); mockVectorStore = mock(); + mockReranker = mock(); + mockArgs = { meta: { displayName: 'Test Vector Store', @@ -88,4 +94,46 @@ describe('handleRetrieveOperation', () => { // Call the closeFunction - should not throw error even with no release method await expect(result.closeFunction!()).resolves.not.toThrow(); }); + + it('should retrieve vector store without reranker when useReranker is false', async () => { + mockContext.getNodeParameter.mockReturnValue(false); + + const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0); + + expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false); + + expect(mockArgs.getVectorStoreClient).toHaveBeenCalledWith( + mockContext, + { testFilter: 'value' }, + mockEmbeddings, + 0, + ); + + // Result should contain vector store and close function + expect(result).toHaveProperty('response', mockVectorStore); + expect(result).toHaveProperty('closeFunction'); + + // Should not try to get reranker input connection + expect(mockContext.getInputConnectionData).not.toHaveBeenCalled(); + }); + + it('should retrieve vector store with reranker when useReranker is true', async () => { + mockContext.getNodeParameter.mockReturnValue(true); + mockContext.getInputConnectionData.mockResolvedValue(mockReranker); + + const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0); + + expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false); + + expect(mockContext.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiReranker, + 0, + ); + + expect(result.response).toEqual({ + reranker: mockReranker, + vectorStore: mockVectorStore, + }); + expect(result).toHaveProperty('closeFunction'); + }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveOperation.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveOperation.ts index 847ea9d980..1052cbcbaf 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveOperation.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveOperation.ts @@ -1,6 +1,7 @@ import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { VectorStore } from '@langchain/core/vectorstores'; -import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow'; +import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow'; import { getMetadataFiltersValues } from '@utils/helpers'; import { logWrapper } from '@utils/logWrapper'; @@ -19,13 +20,31 @@ export async function handleRetrieveOperation { // Get metadata filters const filter = getMetadataFiltersValues(context, itemIndex); + const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean; // Get the vector store client const vectorStore = await args.getVectorStoreClient(context, filter, embeddings, itemIndex); + let response: VectorStore | { reranker: BaseDocumentCompressor; vectorStore: VectorStore } = + vectorStore; + + if (useReranker) { + const reranker = (await context.getInputConnectionData( + NodeConnectionTypes.AiReranker, + 0, + )) as BaseDocumentCompressor; + + // Return reranker and vector store with log wrapper + response = { + reranker, + vectorStore: logWrapper(vectorStore, context), + }; + } else { + // Return the vector store with logging wrapper + response = logWrapper(vectorStore, context); + } - // Return the vector store with logging wrapper and cleanup function return { - response: logWrapper(vectorStore, context), + response, closeFunction: async () => { // Release the vector store client if a release method was provided args.releaseVectorStoreClient?.(vectorStore);