mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(Vector Store Retriever Node): Add reranker support to retriever for QA chain (#16051)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import { VectorStore } from '@langchain/core/vectorstores';
|
||||
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
|
||||
import {
|
||||
NodeConnectionTypes,
|
||||
type INodeType,
|
||||
@@ -65,9 +67,23 @@ export class RetrieverVectorStore implements INodeType {
|
||||
const vectorStore = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiVectorStore,
|
||||
itemIndex,
|
||||
)) as VectorStore;
|
||||
)) as
|
||||
| VectorStore
|
||||
| {
|
||||
reranker: BaseDocumentCompressor;
|
||||
vectorStore: VectorStore;
|
||||
};
|
||||
|
||||
const retriever = vectorStore.asRetriever(topK);
|
||||
let retriever = null;
|
||||
|
||||
if (vectorStore instanceof VectorStore) {
|
||||
retriever = vectorStore.asRetriever(topK);
|
||||
} else {
|
||||
retriever = new ContextualCompressionRetriever({
|
||||
baseCompressor: vectorStore.reranker,
|
||||
baseRetriever: vectorStore.vectorStore.asRetriever(topK),
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
response: logWrapper(retriever, this),
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import { VectorStore } from '@langchain/core/vectorstores';
|
||||
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
|
||||
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { RetrieverVectorStore } from '../RetrieverVectorStore.node';
|
||||
|
||||
const mockLogger = {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
};
|
||||
|
||||
describe('RetrieverVectorStore', () => {
|
||||
let retrieverNode: RetrieverVectorStore;
|
||||
let mockContext: jest.Mocked<ISupplyDataFunctions>;
|
||||
|
||||
beforeEach(() => {
|
||||
retrieverNode = new RetrieverVectorStore();
|
||||
mockContext = {
|
||||
logger: mockLogger,
|
||||
getNodeParameter: jest.fn(),
|
||||
getInputConnectionData: jest.fn(),
|
||||
} as unknown as jest.Mocked<ISupplyDataFunctions>;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('supplyData', () => {
|
||||
it('should create a retriever from a basic VectorStore', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 4;
|
||||
return defaultValue;
|
||||
});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiVectorStore,
|
||||
0,
|
||||
);
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||
expect(result).toHaveProperty('response', { test: 'retriever' });
|
||||
});
|
||||
|
||||
it('should create a retriever with custom topK parameter', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 10;
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(10);
|
||||
expect(result).toHaveProperty('response', { test: 'retriever' });
|
||||
});
|
||||
|
||||
it('should create a ContextualCompressionRetriever when input contains reranker and vectorStore', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
|
||||
|
||||
const mockReranker = {} as BaseDocumentCompressor;
|
||||
|
||||
const inputWithReranker = {
|
||||
reranker: mockReranker,
|
||||
vectorStore: mockVectorStore,
|
||||
};
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 4;
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiVectorStore,
|
||||
0,
|
||||
);
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
|
||||
});
|
||||
|
||||
it('should create a ContextualCompressionRetriever with custom topK when using reranker', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
|
||||
|
||||
const mockReranker = {} as BaseDocumentCompressor;
|
||||
|
||||
const inputWithReranker = {
|
||||
reranker: mockReranker,
|
||||
vectorStore: mockVectorStore,
|
||||
};
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 8;
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(8);
|
||||
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
|
||||
});
|
||||
|
||||
it('should use default topK value when parameter is not provided', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((_param, _itemIndex, defaultValue) => {
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||
|
||||
await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('topK', 0, 4);
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -44,7 +44,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
||||
const useReranker = parameters?.useReranker;
|
||||
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
|
||||
|
||||
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
|
||||
}
|
||||
|
||||
@@ -246,6 +246,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
||||
"show": {
|
||||
"mode": [
|
||||
"load",
|
||||
"retrieve",
|
||||
"retrieve-as-tool",
|
||||
],
|
||||
},
|
||||
|
||||
@@ -72,7 +72,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
||||
const useReranker = parameters?.useReranker;
|
||||
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
|
||||
|
||||
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
|
||||
}
|
||||
|
||||
@@ -215,7 +215,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
||||
description: 'Whether or not to rerank results',
|
||||
displayOptions: {
|
||||
show: {
|
||||
mode: ['load', 'retrieve-as-tool'],
|
||||
mode: ['load', 'retrieve', 'retrieve-as-tool'],
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { MockProxy } from 'jest-mock-extended';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { logWrapper } from '@utils/logWrapper';
|
||||
|
||||
@@ -22,15 +24,19 @@ describe('handleRetrieveOperation', () => {
|
||||
let mockContext: MockProxy<ISupplyDataFunctions>;
|
||||
let mockEmbeddings: MockProxy<Embeddings>;
|
||||
let mockVectorStore: MockProxy<VectorStore>;
|
||||
let mockReranker: MockProxy<BaseDocumentCompressor>;
|
||||
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockContext = mock<ISupplyDataFunctions>();
|
||||
mockContext.getNodeParameter.mockReturnValue(false); // Default useReranker to false
|
||||
|
||||
mockEmbeddings = mock<Embeddings>();
|
||||
|
||||
mockVectorStore = mock<VectorStore>();
|
||||
|
||||
mockReranker = mock<BaseDocumentCompressor>();
|
||||
|
||||
mockArgs = {
|
||||
meta: {
|
||||
displayName: 'Test Vector Store',
|
||||
@@ -88,4 +94,46 @@ describe('handleRetrieveOperation', () => {
|
||||
// Call the closeFunction - should not throw error even with no release method
|
||||
await expect(result.closeFunction!()).resolves.not.toThrow();
|
||||
});
|
||||
|
||||
it('should retrieve vector store without reranker when useReranker is false', async () => {
|
||||
mockContext.getNodeParameter.mockReturnValue(false);
|
||||
|
||||
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
|
||||
|
||||
expect(mockArgs.getVectorStoreClient).toHaveBeenCalledWith(
|
||||
mockContext,
|
||||
{ testFilter: 'value' },
|
||||
mockEmbeddings,
|
||||
0,
|
||||
);
|
||||
|
||||
// Result should contain vector store and close function
|
||||
expect(result).toHaveProperty('response', mockVectorStore);
|
||||
expect(result).toHaveProperty('closeFunction');
|
||||
|
||||
// Should not try to get reranker input connection
|
||||
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should retrieve vector store with reranker when useReranker is true', async () => {
|
||||
mockContext.getNodeParameter.mockReturnValue(true);
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
|
||||
|
||||
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiReranker,
|
||||
0,
|
||||
);
|
||||
|
||||
expect(result.response).toEqual({
|
||||
reranker: mockReranker,
|
||||
vectorStore: mockVectorStore,
|
||||
});
|
||||
expect(result).toHaveProperty('closeFunction');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow';
|
||||
|
||||
import { getMetadataFiltersValues } from '@utils/helpers';
|
||||
import { logWrapper } from '@utils/logWrapper';
|
||||
@@ -19,13 +20,31 @@ export async function handleRetrieveOperation<T extends VectorStore = VectorStor
|
||||
): Promise<SupplyData> {
|
||||
// Get metadata filters
|
||||
const filter = getMetadataFiltersValues(context, itemIndex);
|
||||
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
|
||||
|
||||
// Get the vector store client
|
||||
const vectorStore = await args.getVectorStoreClient(context, filter, embeddings, itemIndex);
|
||||
let response: VectorStore | { reranker: BaseDocumentCompressor; vectorStore: VectorStore } =
|
||||
vectorStore;
|
||||
|
||||
if (useReranker) {
|
||||
const reranker = (await context.getInputConnectionData(
|
||||
NodeConnectionTypes.AiReranker,
|
||||
0,
|
||||
)) as BaseDocumentCompressor;
|
||||
|
||||
// Return reranker and vector store with log wrapper
|
||||
response = {
|
||||
reranker,
|
||||
vectorStore: logWrapper(vectorStore, context),
|
||||
};
|
||||
} else {
|
||||
// Return the vector store with logging wrapper
|
||||
response = logWrapper(vectorStore, context);
|
||||
}
|
||||
|
||||
// Return the vector store with logging wrapper and cleanup function
|
||||
return {
|
||||
response: logWrapper(vectorStore, context),
|
||||
response,
|
||||
closeFunction: async () => {
|
||||
// Release the vector store client if a release method was provided
|
||||
args.releaseVectorStoreClient?.(vectorStore);
|
||||
|
||||
Reference in New Issue
Block a user