feat(Vector Store Retriever Node): Add reranker support to retriever for QA chain (#16051)

This commit is contained in:
Benjamin Schroth
2025-06-06 13:16:03 +02:00
committed by GitHub
parent ac1a1dfbc2
commit 969552aeae
6 changed files with 226 additions and 9 deletions

View File

@@ -1,5 +1,7 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import type { VectorStore } from '@langchain/core/vectorstores';
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
import { VectorStore } from '@langchain/core/vectorstores';
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
import {
NodeConnectionTypes,
type INodeType,
@@ -65,9 +67,23 @@ export class RetrieverVectorStore implements INodeType {
const vectorStore = (await this.getInputConnectionData(
NodeConnectionTypes.AiVectorStore,
itemIndex,
)) as VectorStore;
)) as
| VectorStore
| {
reranker: BaseDocumentCompressor;
vectorStore: VectorStore;
};
const retriever = vectorStore.asRetriever(topK);
let retriever = null;
if (vectorStore instanceof VectorStore) {
retriever = vectorStore.asRetriever(topK);
} else {
retriever = new ContextualCompressionRetriever({
baseCompressor: vectorStore.reranker,
baseRetriever: vectorStore.vectorStore.asRetriever(topK),
});
}
return {
response: logWrapper(retriever, this),

View File

@@ -0,0 +1,133 @@
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
import { VectorStore } from '@langchain/core/vectorstores';
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
import type { ISupplyDataFunctions } from 'n8n-workflow';
import { NodeConnectionTypes } from 'n8n-workflow';
import { RetrieverVectorStore } from '../RetrieverVectorStore.node';
const mockLogger = {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
};
describe('RetrieverVectorStore', () => {
let retrieverNode: RetrieverVectorStore;
let mockContext: jest.Mocked<ISupplyDataFunctions>;
beforeEach(() => {
retrieverNode = new RetrieverVectorStore();
mockContext = {
logger: mockLogger,
getNodeParameter: jest.fn(),
getInputConnectionData: jest.fn(),
} as unknown as jest.Mocked<ISupplyDataFunctions>;
jest.clearAllMocks();
});
describe('supplyData', () => {
it('should create a retriever from a basic VectorStore', async () => {
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
if (param === 'topK') return 4;
return defaultValue;
});
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
const result = await retrieverNode.supplyData.call(mockContext, 0);
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
NodeConnectionTypes.AiVectorStore,
0,
);
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
expect(result).toHaveProperty('response', { test: 'retriever' });
});
it('should create a retriever with custom topK parameter', async () => {
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
if (param === 'topK') return 10;
return defaultValue;
});
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
const result = await retrieverNode.supplyData.call(mockContext, 0);
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(10);
expect(result).toHaveProperty('response', { test: 'retriever' });
});
it('should create a ContextualCompressionRetriever when input contains reranker and vectorStore', async () => {
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
const mockReranker = {} as BaseDocumentCompressor;
const inputWithReranker = {
reranker: mockReranker,
vectorStore: mockVectorStore,
};
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
if (param === 'topK') return 4;
return defaultValue;
});
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
const result = await retrieverNode.supplyData.call(mockContext, 0);
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
NodeConnectionTypes.AiVectorStore,
0,
);
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
});
it('should create a ContextualCompressionRetriever with custom topK when using reranker', async () => {
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
const mockReranker = {} as BaseDocumentCompressor;
const inputWithReranker = {
reranker: mockReranker,
vectorStore: mockVectorStore,
};
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
if (param === 'topK') return 8;
return defaultValue;
});
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
const result = await retrieverNode.supplyData.call(mockContext, 0);
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(8);
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
});
it('should use default topK value when parameter is not provided', async () => {
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
mockContext.getNodeParameter.mockImplementation((_param, _itemIndex, defaultValue) => {
return defaultValue;
});
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
await retrieverNode.supplyData.call(mockContext, 0);
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('topK', 0, 4);
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
});
});
});

View File

@@ -44,7 +44,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
const useReranker = parameters?.useReranker;
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
}
@@ -246,6 +246,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
"show": {
"mode": [
"load",
"retrieve",
"retrieve-as-tool",
],
},

View File

@@ -72,7 +72,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
const useReranker = parameters?.useReranker;
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
}
@@ -215,7 +215,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
description: 'Whether or not to rerank results',
displayOptions: {
show: {
mode: ['load', 'retrieve-as-tool'],
mode: ['load', 'retrieve', 'retrieve-as-tool'],
},
},
},

View File

@@ -1,8 +1,10 @@
import type { Embeddings } from '@langchain/core/embeddings';
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
import type { VectorStore } from '@langchain/core/vectorstores';
import type { MockProxy } from 'jest-mock-extended';
import { mock } from 'jest-mock-extended';
import type { ISupplyDataFunctions } from 'n8n-workflow';
import { NodeConnectionTypes } from 'n8n-workflow';
import { logWrapper } from '@utils/logWrapper';
@@ -22,15 +24,19 @@ describe('handleRetrieveOperation', () => {
let mockContext: MockProxy<ISupplyDataFunctions>;
let mockEmbeddings: MockProxy<Embeddings>;
let mockVectorStore: MockProxy<VectorStore>;
let mockReranker: MockProxy<BaseDocumentCompressor>;
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
beforeEach(() => {
mockContext = mock<ISupplyDataFunctions>();
mockContext.getNodeParameter.mockReturnValue(false); // Default useReranker to false
mockEmbeddings = mock<Embeddings>();
mockVectorStore = mock<VectorStore>();
mockReranker = mock<BaseDocumentCompressor>();
mockArgs = {
meta: {
displayName: 'Test Vector Store',
@@ -88,4 +94,46 @@ describe('handleRetrieveOperation', () => {
// Call the closeFunction - should not throw error even with no release method
await expect(result.closeFunction!()).resolves.not.toThrow();
});
it('should retrieve vector store without reranker when useReranker is false', async () => {
mockContext.getNodeParameter.mockReturnValue(false);
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
expect(mockArgs.getVectorStoreClient).toHaveBeenCalledWith(
mockContext,
{ testFilter: 'value' },
mockEmbeddings,
0,
);
// Result should contain vector store and close function
expect(result).toHaveProperty('response', mockVectorStore);
expect(result).toHaveProperty('closeFunction');
// Should not try to get reranker input connection
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
});
it('should retrieve vector store with reranker when useReranker is true', async () => {
mockContext.getNodeParameter.mockReturnValue(true);
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
NodeConnectionTypes.AiReranker,
0,
);
expect(result.response).toEqual({
reranker: mockReranker,
vectorStore: mockVectorStore,
});
expect(result).toHaveProperty('closeFunction');
});
});

View File

@@ -1,6 +1,7 @@
import type { Embeddings } from '@langchain/core/embeddings';
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
import type { VectorStore } from '@langchain/core/vectorstores';
import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow';
import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow';
import { getMetadataFiltersValues } from '@utils/helpers';
import { logWrapper } from '@utils/logWrapper';
@@ -19,13 +20,31 @@ export async function handleRetrieveOperation<T extends VectorStore = VectorStor
): Promise<SupplyData> {
// Get metadata filters
const filter = getMetadataFiltersValues(context, itemIndex);
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
// Get the vector store client
const vectorStore = await args.getVectorStoreClient(context, filter, embeddings, itemIndex);
let response: VectorStore | { reranker: BaseDocumentCompressor; vectorStore: VectorStore } =
vectorStore;
if (useReranker) {
const reranker = (await context.getInputConnectionData(
NodeConnectionTypes.AiReranker,
0,
)) as BaseDocumentCompressor;
// Return reranker and vector store with log wrapper
response = {
reranker,
vectorStore: logWrapper(vectorStore, context),
};
} else {
// Return the vector store with logging wrapper
response = logWrapper(vectorStore, context);
}
// Return the vector store with logging wrapper and cleanup function
return {
response: logWrapper(vectorStore, context),
response,
closeFunction: async () => {
// Release the vector store client if a release method was provided
args.releaseVectorStoreClient?.(vectorStore);