mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(Vector Store Retriever Node): Add reranker support to retriever for QA chain (#16051)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import { VectorStore } from '@langchain/core/vectorstores';
|
||||
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
|
||||
import {
|
||||
NodeConnectionTypes,
|
||||
type INodeType,
|
||||
@@ -65,9 +67,23 @@ export class RetrieverVectorStore implements INodeType {
|
||||
const vectorStore = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiVectorStore,
|
||||
itemIndex,
|
||||
)) as VectorStore;
|
||||
)) as
|
||||
| VectorStore
|
||||
| {
|
||||
reranker: BaseDocumentCompressor;
|
||||
vectorStore: VectorStore;
|
||||
};
|
||||
|
||||
const retriever = vectorStore.asRetriever(topK);
|
||||
let retriever = null;
|
||||
|
||||
if (vectorStore instanceof VectorStore) {
|
||||
retriever = vectorStore.asRetriever(topK);
|
||||
} else {
|
||||
retriever = new ContextualCompressionRetriever({
|
||||
baseCompressor: vectorStore.reranker,
|
||||
baseRetriever: vectorStore.vectorStore.asRetriever(topK),
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
response: logWrapper(retriever, this),
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import { VectorStore } from '@langchain/core/vectorstores';
|
||||
import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression';
|
||||
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { RetrieverVectorStore } from '../RetrieverVectorStore.node';
|
||||
|
||||
const mockLogger = {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
};
|
||||
|
||||
describe('RetrieverVectorStore', () => {
|
||||
let retrieverNode: RetrieverVectorStore;
|
||||
let mockContext: jest.Mocked<ISupplyDataFunctions>;
|
||||
|
||||
beforeEach(() => {
|
||||
retrieverNode = new RetrieverVectorStore();
|
||||
mockContext = {
|
||||
logger: mockLogger,
|
||||
getNodeParameter: jest.fn(),
|
||||
getInputConnectionData: jest.fn(),
|
||||
} as unknown as jest.Mocked<ISupplyDataFunctions>;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('supplyData', () => {
|
||||
it('should create a retriever from a basic VectorStore', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 4;
|
||||
return defaultValue;
|
||||
});
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiVectorStore,
|
||||
0,
|
||||
);
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||
expect(result).toHaveProperty('response', { test: 'retriever' });
|
||||
});
|
||||
|
||||
it('should create a retriever with custom topK parameter', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 10;
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(10);
|
||||
expect(result).toHaveProperty('response', { test: 'retriever' });
|
||||
});
|
||||
|
||||
it('should create a ContextualCompressionRetriever when input contains reranker and vectorStore', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
|
||||
|
||||
const mockReranker = {} as BaseDocumentCompressor;
|
||||
|
||||
const inputWithReranker = {
|
||||
reranker: mockReranker,
|
||||
vectorStore: mockVectorStore,
|
||||
};
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 4;
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiVectorStore,
|
||||
0,
|
||||
);
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
|
||||
});
|
||||
|
||||
it('should create a ContextualCompressionRetriever with custom topK when using reranker', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'base-retriever' });
|
||||
|
||||
const mockReranker = {} as BaseDocumentCompressor;
|
||||
|
||||
const inputWithReranker = {
|
||||
reranker: mockReranker,
|
||||
vectorStore: mockVectorStore,
|
||||
};
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((param, _itemIndex, defaultValue) => {
|
||||
if (param === 'topK') return 8;
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(inputWithReranker);
|
||||
|
||||
const result = await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(8);
|
||||
expect(result.response).toBeInstanceOf(ContextualCompressionRetriever);
|
||||
});
|
||||
|
||||
it('should use default topK value when parameter is not provided', async () => {
|
||||
const mockVectorStore = Object.create(VectorStore.prototype) as VectorStore;
|
||||
mockVectorStore.asRetriever = jest.fn().mockReturnValue({ test: 'retriever' });
|
||||
|
||||
mockContext.getNodeParameter.mockImplementation((_param, _itemIndex, defaultValue) => {
|
||||
return defaultValue;
|
||||
});
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockVectorStore);
|
||||
|
||||
await retrieverNode.supplyData.call(mockContext, 0);
|
||||
|
||||
expect(mockContext.getNodeParameter).toHaveBeenCalledWith('topK', 0, 4);
|
||||
expect(mockVectorStore.asRetriever).toHaveBeenCalledWith(4);
|
||||
});
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user