mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 01:56:46 +00:00
feat: Add Cohere reranking capability to vector stores (#16014)
Co-authored-by: Yiorgis Gozadinos <yiorgis@n8n.io> Co-authored-by: Mutasem Aldmour <mutasem@n8n.io>
This commit is contained in:
@@ -41,8 +41,13 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
||||
"inputs": "={{
|
||||
((parameters) => {
|
||||
const mode = parameters?.mode;
|
||||
const useReranker = parameters?.useReranker;
|
||||
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
|
||||
|
||||
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
|
||||
}
|
||||
|
||||
if (mode === 'retrieve-as-tool') {
|
||||
return inputs;
|
||||
}
|
||||
@@ -233,6 +238,21 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
||||
"name": "includeDocumentMetadata",
|
||||
"type": "boolean",
|
||||
},
|
||||
{
|
||||
"default": false,
|
||||
"description": "Whether or not to rerank results",
|
||||
"displayName": "Rerank Results",
|
||||
"displayOptions": {
|
||||
"show": {
|
||||
"mode": [
|
||||
"load",
|
||||
"retrieve-as-tool",
|
||||
],
|
||||
},
|
||||
},
|
||||
"name": "useReranker",
|
||||
"type": "boolean",
|
||||
},
|
||||
{
|
||||
"default": "",
|
||||
"description": "ID of an embedding entry",
|
||||
|
||||
@@ -69,8 +69,13 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
||||
inputs: `={{
|
||||
((parameters) => {
|
||||
const mode = parameters?.mode;
|
||||
const useReranker = parameters?.useReranker;
|
||||
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
|
||||
|
||||
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
|
||||
}
|
||||
|
||||
if (mode === 'retrieve-as-tool') {
|
||||
return inputs;
|
||||
}
|
||||
@@ -202,6 +207,18 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
displayName: 'Rerank Results',
|
||||
name: 'useReranker',
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
description: 'Whether or not to rerank results',
|
||||
displayOptions: {
|
||||
show: {
|
||||
mode: ['load', 'retrieve-as-tool'],
|
||||
},
|
||||
},
|
||||
},
|
||||
// ID is always used for update operation
|
||||
{
|
||||
displayName: 'ID',
|
||||
@@ -233,7 +250,6 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
||||
*/
|
||||
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
|
||||
const mode = this.getNodeParameter('mode', 0) as NodeOperationMode;
|
||||
|
||||
// Get the embeddings model connected to this node
|
||||
const embeddings = (await this.getInputConnectionData(
|
||||
NodeConnectionTypes.AiEmbedding,
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
/* eslint-disable @typescript-eslint/unbound-method */
|
||||
import type { Document } from '@langchain/core/documents';
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { MockProxy } from 'jest-mock-extended';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { logAiEvent } from '@utils/helpers';
|
||||
|
||||
@@ -22,6 +24,7 @@ describe('handleLoadOperation', () => {
|
||||
let mockContext: MockProxy<IExecuteFunctions>;
|
||||
let mockEmbeddings: MockProxy<Embeddings>;
|
||||
let mockVectorStore: MockProxy<VectorStore>;
|
||||
let mockReranker: MockProxy<BaseDocumentCompressor>;
|
||||
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
||||
let nodeParameters: Record<string, any>;
|
||||
|
||||
@@ -30,6 +33,7 @@ describe('handleLoadOperation', () => {
|
||||
prompt: 'test search query',
|
||||
topK: 3,
|
||||
includeDocumentMetadata: true,
|
||||
useReranker: false,
|
||||
};
|
||||
|
||||
mockContext = mock<IExecuteFunctions>();
|
||||
@@ -48,6 +52,24 @@ describe('handleLoadOperation', () => {
|
||||
[{ pageContent: 'test content 3', metadata: { test: 'metadata 3' } } as Document, 0.75],
|
||||
]);
|
||||
|
||||
mockReranker = mock<BaseDocumentCompressor>();
|
||||
mockReranker.compressDocuments.mockResolvedValue([
|
||||
{
|
||||
pageContent: 'test content 2',
|
||||
metadata: { test: 'metadata 2', relevanceScore: 0.98 },
|
||||
} as Document,
|
||||
{
|
||||
pageContent: 'test content 1',
|
||||
metadata: { test: 'metadata 1', relevanceScore: 0.92 },
|
||||
} as Document,
|
||||
{
|
||||
pageContent: 'test content 3',
|
||||
metadata: { test: 'metadata 3', relevanceScore: 0.88 },
|
||||
} as Document,
|
||||
]);
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
|
||||
|
||||
mockArgs = {
|
||||
meta: {
|
||||
displayName: 'Test Vector Store',
|
||||
@@ -142,4 +164,82 @@ describe('handleLoadOperation', () => {
|
||||
|
||||
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||
});
|
||||
|
||||
describe('reranking functionality', () => {
|
||||
beforeEach(() => {
|
||||
nodeParameters.useReranker = true;
|
||||
});
|
||||
|
||||
it('should use reranker when useReranker is true', async () => {
|
||||
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiReranker,
|
||||
0,
|
||||
);
|
||||
expect(mockReranker.compressDocuments).toHaveBeenCalledWith(
|
||||
[
|
||||
{ pageContent: 'test content 1', metadata: { test: 'metadata 1' } },
|
||||
{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } },
|
||||
{ pageContent: 'test content 3', metadata: { test: 'metadata 3' } },
|
||||
],
|
||||
'test search query',
|
||||
);
|
||||
expect(result).toHaveLength(3);
|
||||
});
|
||||
|
||||
it('should return reranked documents with relevance scores', async () => {
|
||||
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
// First result should be the reranked first document (was second in original order)
|
||||
expect((result[0].json?.document as IDataObject)?.pageContent).toEqual('test content 2');
|
||||
expect(result[0].json?.score).toEqual(0.98);
|
||||
|
||||
// Second result should be the reranked second document (was first in original order)
|
||||
expect((result[1].json?.document as IDataObject)?.pageContent).toEqual('test content 1');
|
||||
expect(result[1].json?.score).toEqual(0.92);
|
||||
|
||||
// Third result should be the reranked third document
|
||||
expect((result[2].json?.document as IDataObject)?.pageContent).toEqual('test content 3');
|
||||
expect(result[2].json?.score).toEqual(0.88);
|
||||
});
|
||||
|
||||
it('should remove relevanceScore from metadata after reranking', async () => {
|
||||
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
// Check that relevanceScore is not included in the metadata
|
||||
expect((result[0].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 2' });
|
||||
expect((result[1].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 1' });
|
||||
expect((result[2].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 3' });
|
||||
});
|
||||
|
||||
it('should handle reranking with includeDocumentMetadata false', async () => {
|
||||
nodeParameters.includeDocumentMetadata = false;
|
||||
|
||||
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
expect(result[0].json?.document).not.toHaveProperty('metadata');
|
||||
expect((result[0].json?.document as IDataObject)?.pageContent).toEqual('test content 2');
|
||||
expect(result[0].json?.score).toEqual(0.98);
|
||||
});
|
||||
|
||||
it('should not call reranker when useReranker is false', async () => {
|
||||
nodeParameters.useReranker = false;
|
||||
|
||||
await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
|
||||
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
|
||||
expect(mockReranker.compressDocuments).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should release vector store client even if reranking fails', async () => {
|
||||
mockReranker.compressDocuments.mockRejectedValue(new Error('Reranking failed'));
|
||||
|
||||
await expect(handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0)).rejects.toThrow(
|
||||
'Reranking failed',
|
||||
);
|
||||
|
||||
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
/* eslint-disable @typescript-eslint/unbound-method */
|
||||
import type { Document } from '@langchain/core/documents';
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { MockProxy } from 'jest-mock-extended';
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import { DynamicTool } from 'langchain/tools';
|
||||
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||
|
||||
import { logWrapper } from '@utils/logWrapper';
|
||||
|
||||
@@ -26,6 +28,7 @@ describe('handleRetrieveAsToolOperation', () => {
|
||||
let mockContext: MockProxy<ISupplyDataFunctions>;
|
||||
let mockEmbeddings: MockProxy<Embeddings>;
|
||||
let mockVectorStore: MockProxy<VectorStore>;
|
||||
let mockReranker: MockProxy<BaseDocumentCompressor>;
|
||||
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
||||
let nodeParameters: Record<string, any>;
|
||||
|
||||
@@ -35,6 +38,7 @@ describe('handleRetrieveAsToolOperation', () => {
|
||||
toolDescription: 'Search the test knowledge base',
|
||||
topK: 3,
|
||||
includeDocumentMetadata: true,
|
||||
useReranker: false,
|
||||
};
|
||||
|
||||
mockContext = mock<ISupplyDataFunctions>();
|
||||
@@ -60,6 +64,20 @@ describe('handleRetrieveAsToolOperation', () => {
|
||||
[{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } } as Document, 0.85],
|
||||
]);
|
||||
|
||||
mockReranker = mock<BaseDocumentCompressor>();
|
||||
mockReranker.compressDocuments.mockResolvedValue([
|
||||
{
|
||||
pageContent: 'test content 2',
|
||||
metadata: { test: 'metadata 2', relevanceScore: 0.98 },
|
||||
} as Document,
|
||||
{
|
||||
pageContent: 'test content 1',
|
||||
metadata: { test: 'metadata 1', relevanceScore: 0.92 },
|
||||
} as Document,
|
||||
]);
|
||||
|
||||
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
|
||||
|
||||
mockArgs = {
|
||||
meta: {
|
||||
displayName: 'Test Vector Store',
|
||||
@@ -215,4 +233,115 @@ describe('handleRetrieveAsToolOperation', () => {
|
||||
// Should still release the client
|
||||
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||
});
|
||||
|
||||
describe('reranking functionality', () => {
|
||||
beforeEach(() => {
|
||||
nodeParameters.useReranker = true;
|
||||
});
|
||||
|
||||
it('should use reranker when useReranker is true', async () => {
|
||||
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
const tool = result.response as DynamicTool;
|
||||
|
||||
await tool.func('test query');
|
||||
|
||||
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||
NodeConnectionTypes.AiReranker,
|
||||
0,
|
||||
);
|
||||
expect(mockReranker.compressDocuments).toHaveBeenCalledWith(
|
||||
[
|
||||
{ pageContent: 'test content 1', metadata: { test: 'metadata 1' } },
|
||||
{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } },
|
||||
],
|
||||
'test query',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return reranked documents in the correct order', async () => {
|
||||
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
const tool = result.response as DynamicTool;
|
||||
|
||||
const toolResult = await tool.func('test query');
|
||||
|
||||
expect(toolResult).toHaveLength(2);
|
||||
|
||||
// First result should be the reranked first document (was second in original order)
|
||||
const parsedFirst = JSON.parse(toolResult[0].text);
|
||||
expect(parsedFirst.pageContent).toEqual('test content 2');
|
||||
expect(parsedFirst.metadata).toEqual({ test: 'metadata 2' });
|
||||
|
||||
// Second result should be the reranked second document (was first in original order)
|
||||
const parsedSecond = JSON.parse(toolResult[1].text);
|
||||
expect(parsedSecond.pageContent).toEqual('test content 1');
|
||||
expect(parsedSecond.metadata).toEqual({ test: 'metadata 1' });
|
||||
});
|
||||
|
||||
it('should handle reranking with includeDocumentMetadata false', async () => {
|
||||
nodeParameters.includeDocumentMetadata = false;
|
||||
|
||||
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
const tool = result.response as DynamicTool;
|
||||
|
||||
const toolResult = await tool.func('test query');
|
||||
|
||||
// Parse the JSON text to verify it excludes metadata but maintains reranked order
|
||||
const parsedFirst = JSON.parse(toolResult[0].text);
|
||||
expect(parsedFirst).toHaveProperty('pageContent', 'test content 2');
|
||||
expect(parsedFirst).not.toHaveProperty('metadata');
|
||||
|
||||
const parsedSecond = JSON.parse(toolResult[1].text);
|
||||
expect(parsedSecond).toHaveProperty('pageContent', 'test content 1');
|
||||
expect(parsedSecond).not.toHaveProperty('metadata');
|
||||
});
|
||||
|
||||
it('should not call reranker when useReranker is false', async () => {
|
||||
nodeParameters.useReranker = false;
|
||||
|
||||
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
const tool = result.response as DynamicTool;
|
||||
|
||||
await tool.func('test query');
|
||||
|
||||
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
|
||||
expect(mockReranker.compressDocuments).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should release vector store client even if reranking fails', async () => {
|
||||
mockReranker.compressDocuments.mockRejectedValueOnce(new Error('Reranking failed'));
|
||||
|
||||
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
const tool = result.response as DynamicTool;
|
||||
|
||||
await expect(tool.func('test query')).rejects.toThrow('Reranking failed');
|
||||
|
||||
// Should still release the client
|
||||
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||
});
|
||||
|
||||
it('should properly handle relevanceScore from reranker metadata', async () => {
|
||||
// Mock reranker to return documents with relevanceScore in different metadata structure
|
||||
mockReranker.compressDocuments.mockResolvedValueOnce([
|
||||
{
|
||||
pageContent: 'test content 2',
|
||||
metadata: { test: 'metadata 2', relevanceScore: 0.98, otherField: 'value' },
|
||||
} as Document,
|
||||
{
|
||||
pageContent: 'test content 1',
|
||||
metadata: { test: 'metadata 1', relevanceScore: 0.92 },
|
||||
} as Document,
|
||||
]);
|
||||
|
||||
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||
const tool = result.response as DynamicTool;
|
||||
|
||||
const toolResult = await tool.invoke('test query');
|
||||
|
||||
// Check that relevanceScore is used but not included in the final metadata
|
||||
const parsedFirst = JSON.parse(toolResult[0].text);
|
||||
expect(parsedFirst.pageContent).toEqual('test content 2');
|
||||
expect(parsedFirst.metadata).toEqual({ test: 'metadata 2', otherField: 'value' });
|
||||
expect(parsedFirst.metadata).not.toHaveProperty('relevanceScore');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes, type IExecuteFunctions, type INodeExecutionData } from 'n8n-workflow';
|
||||
|
||||
import { getMetadataFiltersValues, logAiEvent } from '@utils/helpers';
|
||||
|
||||
@@ -29,6 +30,8 @@ export async function handleLoadOperation<T extends VectorStore = VectorStore>(
|
||||
// Get the search parameters from the node
|
||||
const prompt = context.getNodeParameter('prompt', itemIndex) as string;
|
||||
const topK = context.getNodeParameter('topK', itemIndex, 4) as number;
|
||||
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
|
||||
|
||||
const includeDocumentMetadata = context.getNodeParameter(
|
||||
'includeDocumentMetadata',
|
||||
itemIndex,
|
||||
@@ -39,7 +42,22 @@ export async function handleLoadOperation<T extends VectorStore = VectorStore>(
|
||||
const embeddedPrompt = await embeddings.embedQuery(prompt);
|
||||
|
||||
// Get the most similar documents to the embedded prompt
|
||||
const docs = await vectorStore.similaritySearchVectorWithScore(embeddedPrompt, topK, filter);
|
||||
let docs = await vectorStore.similaritySearchVectorWithScore(embeddedPrompt, topK, filter);
|
||||
|
||||
// If reranker is used, rerank the documents
|
||||
if (useReranker && docs.length > 0) {
|
||||
const reranker = (await context.getInputConnectionData(
|
||||
NodeConnectionTypes.AiReranker,
|
||||
0,
|
||||
)) as BaseDocumentCompressor;
|
||||
const documents = docs.map(([doc]) => doc);
|
||||
|
||||
const rerankedDocuments = await reranker.compressDocuments(documents, prompt);
|
||||
docs = rerankedDocuments.map((doc) => {
|
||||
const { relevanceScore, ...metadata } = doc.metadata || {};
|
||||
return [{ ...doc, metadata }, relevanceScore];
|
||||
});
|
||||
}
|
||||
|
||||
// Format the documents for the output
|
||||
const serializedDocs = docs.map(([doc, score]) => {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||
import { DynamicTool } from 'langchain/tools';
|
||||
import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow';
|
||||
import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow';
|
||||
|
||||
import { getMetadataFiltersValues, nodeNameToToolName } from '@utils/helpers';
|
||||
import { logWrapper } from '@utils/logWrapper';
|
||||
@@ -29,6 +30,7 @@ export async function handleRetrieveAsToolOperation<T extends VectorStore = Vect
|
||||
: nodeNameToToolName(node);
|
||||
|
||||
const topK = context.getNodeParameter('topK', itemIndex, 4) as number;
|
||||
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
|
||||
const includeDocumentMetadata = context.getNodeParameter(
|
||||
'includeDocumentMetadata',
|
||||
itemIndex,
|
||||
@@ -58,12 +60,27 @@ export async function handleRetrieveAsToolOperation<T extends VectorStore = Vect
|
||||
const embeddedPrompt = await embeddings.embedQuery(input);
|
||||
|
||||
// Search for similar documents
|
||||
const documents = await vectorStore.similaritySearchVectorWithScore(
|
||||
let documents = await vectorStore.similaritySearchVectorWithScore(
|
||||
embeddedPrompt,
|
||||
topK,
|
||||
filter,
|
||||
);
|
||||
|
||||
// If reranker is used, rerank the documents
|
||||
if (useReranker && documents.length > 0) {
|
||||
const reranker = (await context.getInputConnectionData(
|
||||
NodeConnectionTypes.AiReranker,
|
||||
0,
|
||||
)) as BaseDocumentCompressor;
|
||||
|
||||
const docs = documents.map(([doc]) => doc);
|
||||
const rerankedDocuments = await reranker.compressDocuments(docs, input);
|
||||
documents = rerankedDocuments.map((doc) => {
|
||||
const { relevanceScore, ...metadata } = doc.metadata;
|
||||
return [{ ...doc, metadata }, relevanceScore];
|
||||
});
|
||||
}
|
||||
|
||||
// Format the documents for the tool output
|
||||
return documents
|
||||
.map((document) => {
|
||||
|
||||
Reference in New Issue
Block a user