mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(Vector Store Retriever Node): Add reranker support to retriever for QA chain (#16051)
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
|
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
|
import { VectorStore } from '@langchain/core/vectorstores';
|
||||||
|
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
|
||||||
import {
|
import {
|
||||||
NodeConnectionTypes,
|
NodeConnectionTypes,
|
||||||
type INodeType,
|
type INodeType,
|
||||||
@@ -65,9 +67,23 @@ export class RetrieverVectorStore implements INodeType {
|
|||||||
const vectorStore = (await this.getInputConnectionData(
|
const vectorStore = (await this.getInputConnectionData(
|
||||||
NodeConnectionTypes.AiVectorStore,
|
NodeConnectionTypes.AiVectorStore,
|
||||||
itemIndex,
|
itemIndex,
|
||||||
)) as VectorStore;
|
)) as
|
||||||
|
| VectorStore
|
||||||
|
| {
|
||||||
|
reranker: BaseDocumentCompressor;
|
||||||
|
vectorStore: VectorStore;
|
||||||
|
};
|
||||||
|
|
||||||
const retriever = vectorStore.asRetriever(topK);
|
let retriever = null;
|
||||||
|
|
||||||
|
if (vectorStore instanceof VectorStore) {
|
||||||
|
retriever = vectorStore.asRetriever(topK);
|
||||||
|
} else {
|
||||||
|
retriever = new ContextualCompressionRetriever({
|
||||||
|
baseCompressor: vectorStore.reranker,
|
||||||
|
baseRetriever: vectorStore.vectorStore.asRetriever(topK),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
response: logWrapper(retriever, this),
|
response: logWrapper(retriever, this),
|
||||||
|
|||||||
@@ -0,0 +1,133 @@
|
|||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
|
import { VectorStore } from '@langchain/core/vectorstores';
|
||||||
|
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
|
||||||
|
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||||
|
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||||
|
|
||||||
|
import { RetrieverVectorStore } from '../RetrieverVectorStore.node';
|
||||||
|
|
||||||
|
const mockLogger = {
|
||||||
|
debug: jest.fn(),
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
describe('RetrieverVectorStore', () => {
|
||||||
|
let retrieverNode: RetrieverVectorStore;
|
||||||
|
let mockContext: jest.Mocked<ISupplyDataFunctions>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
retrieverNode = new RetrieverVectorStore();
|
||||||
|
mockContext = {
|
||||||
|
logger: mockLogger,
|
||||||
|
getNodeParameter: jest.fn(),
|
||||||
|
getInputConnectionData: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<ISupplyDataFunctions>;
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('supplyData', () => {
|
||||||
|
it('should create a retriever from a basic VectorStore', async () => {
|
||||||
|
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||||
|
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||||
|
|
||||||
|
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||||
|
if (param === 'topK') return 4;
|
||||||
|
return defaultValue;
|
||||||
|
});
|
||||||
|
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||||
|
|
||||||
|
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||||
|
NodeConnectionTypes.AiVectorStore,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||||
|
expect(result).toHaveProperty('response', { test: 'retriever' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create a retriever with custom topK parameter', async () => {
|
||||||
|
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||||
|
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||||
|
|
||||||
|
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||||
|
if (param === 'topK') return 10;
|
||||||
|
return defaultValue;
|
||||||
|
});
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||||
|
|
||||||
|
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||||
|
|
||||||
|
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(10);
|
||||||
|
expect(result).toHaveProperty('response', { test: 'retriever' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create a ContextualCompressionRetriever when input contains reranker and vectorStore', async () => {
|
||||||
|
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||||
|
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
|
||||||
|
|
||||||
|
const mockReranker = {} as BaseDocumentCompressor;
|
||||||
|
|
||||||
|
const inputWithReranker = {
|
||||||
|
reranker: mockReranker,
|
||||||
|
vectorStore: mockVectorStore,
|
||||||
|
};
|
||||||
|
|
||||||
|
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||||
|
if (param === 'topK') return 4;
|
||||||
|
return defaultValue;
|
||||||
|
});
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
|
||||||
|
|
||||||
|
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||||
|
NodeConnectionTypes.AiVectorStore,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||||
|
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create a ContextualCompressionRetriever with custom topK when using reranker', async () => {
|
||||||
|
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||||
|
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
|
||||||
|
|
||||||
|
const mockReranker = {} as BaseDocumentCompressor;
|
||||||
|
|
||||||
|
const inputWithReranker = {
|
||||||
|
reranker: mockReranker,
|
||||||
|
vectorStore: mockVectorStore,
|
||||||
|
};
|
||||||
|
|
||||||
|
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||||
|
if (param === 'topK') return 8;
|
||||||
|
return defaultValue;
|
||||||
|
});
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
|
||||||
|
|
||||||
|
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||||
|
|
||||||
|
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(8);
|
||||||
|
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use default topK value when parameter is not provided', async () => {
|
||||||
|
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||||
|
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||||
|
|
||||||
|
mockContext.getNodeParameter.mockImplementation((_param, _itemIndex, defaultValue) => {
|
||||||
|
return defaultValue;
|
||||||
|
});
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||||
|
|
||||||
|
await retrieverNode.supplyData.call(mockContext, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('topK', 0, 4);
|
||||||
|
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -44,7 +44,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
|||||||
const useReranker = parameters?.useReranker;
|
const useReranker = parameters?.useReranker;
|
||||||
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
|
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
|
||||||
|
|
||||||
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||||
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
|
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,6 +246,7 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
|||||||
"show": {
|
"show": {
|
||||||
"mode": [
|
"mode": [
|
||||||
"load",
|
"load",
|
||||||
|
"retrieve",
|
||||||
"retrieve-as-tool",
|
"retrieve-as-tool",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
|||||||
const useReranker = parameters?.useReranker;
|
const useReranker = parameters?.useReranker;
|
||||||
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
|
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
|
||||||
|
|
||||||
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
if (['load', 'retrieve', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||||
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
|
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +215,7 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
|||||||
description: 'Whether or not to rerank results',
|
description: 'Whether or not to rerank results',
|
||||||
displayOptions: {
|
displayOptions: {
|
||||||
show: {
|
show: {
|
||||||
mode: ['load', 'retrieve-as-tool'],
|
mode: ['load', 'retrieve', 'retrieve-as-tool'],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import type { Embeddings } from '@langchain/core/embeddings';
|
import type { Embeddings } from '@langchain/core/embeddings';
|
||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import type { MockProxy } from 'jest-mock-extended';
|
import type { MockProxy } from 'jest-mock-extended';
|
||||||
import { mock } from 'jest-mock-extended';
|
import { mock } from 'jest-mock-extended';
|
||||||
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||||
|
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||||
|
|
||||||
import { logWrapper } from '@utils/logWrapper';
|
import { logWrapper } from '@utils/logWrapper';
|
||||||
|
|
||||||
@@ -22,15 +24,19 @@ describe('handleRetrieveOperation', () => {
|
|||||||
let mockContext: MockProxy<ISupplyDataFunctions>;
|
let mockContext: MockProxy<ISupplyDataFunctions>;
|
||||||
let mockEmbeddings: MockProxy<Embeddings>;
|
let mockEmbeddings: MockProxy<Embeddings>;
|
||||||
let mockVectorStore: MockProxy<VectorStore>;
|
let mockVectorStore: MockProxy<VectorStore>;
|
||||||
|
let mockReranker: MockProxy<BaseDocumentCompressor>;
|
||||||
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockContext = mock<ISupplyDataFunctions>();
|
mockContext = mock<ISupplyDataFunctions>();
|
||||||
|
mockContext.getNodeParameter.mockReturnValue(false); // Default useReranker to false
|
||||||
|
|
||||||
mockEmbeddings = mock<Embeddings>();
|
mockEmbeddings = mock<Embeddings>();
|
||||||
|
|
||||||
mockVectorStore = mock<VectorStore>();
|
mockVectorStore = mock<VectorStore>();
|
||||||
|
|
||||||
|
mockReranker = mock<BaseDocumentCompressor>();
|
||||||
|
|
||||||
mockArgs = {
|
mockArgs = {
|
||||||
meta: {
|
meta: {
|
||||||
displayName: 'Test Vector Store',
|
displayName: 'Test Vector Store',
|
||||||
@@ -88,4 +94,46 @@ describe('handleRetrieveOperation', () => {
|
|||||||
// Call the closeFunction - should not throw error even with no release method
|
// Call the closeFunction - should not throw error even with no release method
|
||||||
await expect(result.closeFunction!()).resolves.not.toThrow();
|
await expect(result.closeFunction!()).resolves.not.toThrow();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should retrieve vector store without reranker when useReranker is false', async () => {
|
||||||
|
mockContext.getNodeParameter.mockReturnValue(false);
|
||||||
|
|
||||||
|
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
|
||||||
|
|
||||||
|
expect(mockArgs.getVectorStoreClient).toHaveBeenCalledWith(
|
||||||
|
mockContext,
|
||||||
|
{ testFilter: 'value' },
|
||||||
|
mockEmbeddings,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Result should contain vector store and close function
|
||||||
|
expect(result).toHaveProperty('response', mockVectorStore);
|
||||||
|
expect(result).toHaveProperty('closeFunction');
|
||||||
|
|
||||||
|
// Should not try to get reranker input connection
|
||||||
|
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should retrieve vector store with reranker when useReranker is true', async () => {
|
||||||
|
mockContext.getNodeParameter.mockReturnValue(true);
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
|
||||||
|
|
||||||
|
const result = await handleRetrieveOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('useReranker', 0, false);
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||||
|
NodeConnectionTypes.AiReranker,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.response).toEqual({
|
||||||
|
reranker: mockReranker,
|
||||||
|
vectorStore: mockVectorStore,
|
||||||
|
});
|
||||||
|
expect(result).toHaveProperty('closeFunction');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import type { Embeddings } from '@langchain/core/embeddings';
|
import type { Embeddings } from '@langchain/core/embeddings';
|
||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow';
|
import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow';
|
||||||
|
|
||||||
import { getMetadataFiltersValues } from '@utils/helpers';
|
import { getMetadataFiltersValues } from '@utils/helpers';
|
||||||
import { logWrapper } from '@utils/logWrapper';
|
import { logWrapper } from '@utils/logWrapper';
|
||||||
@@ -19,13 +20,31 @@ export async function handleRetrieveOperation<T extends VectorStore = VectorStor
|
|||||||
): Promise<SupplyData> {
|
): Promise<SupplyData> {
|
||||||
// Get metadata filters
|
// Get metadata filters
|
||||||
const filter = getMetadataFiltersValues(context, itemIndex);
|
const filter = getMetadataFiltersValues(context, itemIndex);
|
||||||
|
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
|
||||||
|
|
||||||
// Get the vector store client
|
// Get the vector store client
|
||||||
const vectorStore = await args.getVectorStoreClient(context, filter, embeddings, itemIndex);
|
const vectorStore = await args.getVectorStoreClient(context, filter, embeddings, itemIndex);
|
||||||
|
let response: VectorStore | { reranker: BaseDocumentCompressor; vectorStore: VectorStore } =
|
||||||
|
vectorStore;
|
||||||
|
|
||||||
|
if (useReranker) {
|
||||||
|
const reranker = (await context.getInputConnectionData(
|
||||||
|
NodeConnectionTypes.AiReranker,
|
||||||
|
0,
|
||||||
|
)) as BaseDocumentCompressor;
|
||||||
|
|
||||||
|
// Return reranker and vector store with log wrapper
|
||||||
|
response = {
|
||||||
|
reranker,
|
||||||
|
vectorStore: logWrapper(vectorStore, context),
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// Return the vector store with logging wrapper
|
||||||
|
response = logWrapper(vectorStore, context);
|
||||||
|
}
|
||||||
|
|
||||||
// Return the vector store with logging wrapper and cleanup function
|
|
||||||
return {
|
return {
|
||||||
response: logWrapper(vectorStore, context),
|
response,
|
||||||
closeFunction: async () => {
|
closeFunction: async () => {
|
||||||
// Release the vector store client if a release method was provided
|
// Release the vector store client if a release method was provided
|
||||||
args.releaseVectorStoreClient?.(vectorStore);
|
args.releaseVectorStoreClient?.(vectorStore);
|
||||||
|
|||||||
Reference in New Issue
Block a user