diff --git a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts new file mode 100644 index 0000000000..dc2b785d30 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts @@ -0,0 +1,90 @@ +/* eslint-disable n8n-nodes-base/node-dirname-against-convention */ +import { CohereRerank } from '@langchain/cohere'; +import { + NodeConnectionTypes, + type INodeType, + type INodeTypeDescription, + type ISupplyDataFunctions, + type SupplyData, +} from 'n8n-workflow'; + +import { logWrapper } from '@utils/logWrapper'; + +export class RerankerCohere implements INodeType { + description: INodeTypeDescription = { + displayName: 'Reranker Cohere', + name: 'rerankerCohere', + icon: { light: 'file:cohere.svg', dark: 'file:cohere.dark.svg' }, + group: ['transform'], + version: 1, + description: + 'Use Cohere Reranker to reorder documents after retrieval from a vector store by relevance to the given query.', + defaults: { + name: 'Reranker Cohere', + }, + requestDefaults: { + ignoreHttpStatusErrors: true, + baseURL: '={{ $credentials.host }}', + }, + credentials: [ + { + name: 'cohereApi', + required: true, + }, + ], + codex: { + categories: ['AI'], + subcategories: { + AI: ['Rerankers'], + }, + resources: { + primaryDocumentation: [ + { + url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.rerankercohere/', + }, + ], + }, + }, + inputs: [], + outputs: [NodeConnectionTypes.AiReranker], + outputNames: ['Reranker'], + properties: [ + { + displayName: 'Model', + name: 'modelName', + type: 'options', + description: + 'The model that should be used to rerank the documents. Learn more.', + default: 'rerank-v3.5', + options: [ + { + name: 'rerank-v3.5', + value: 'rerank-v3.5', + }, + { + name: 'rerank-english-v3.0', + value: 'rerank-english-v3.0', + }, + { + name: 'rerank-multilingual-v3.0', + value: 'rerank-multilingual-v3.0', + }, + ], + }, + ], + }; + + async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise { + this.logger.debug('Supply data for reranking Cohere'); + const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string; + const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi'); + const reranker = new CohereRerank({ + apiKey: credentials.apiKey, + model: modelName, + }); + + return { + response: logWrapper(reranker, this), + }; + } +} diff --git a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/cohere.dark.svg b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/cohere.dark.svg new file mode 100644 index 0000000000..796fe1bcbc --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/cohere.dark.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/cohere.svg b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/cohere.svg new file mode 100644 index 0000000000..c54ba34ee8 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/cohere.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts new file mode 100644 index 0000000000..98ecea2950 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts @@ -0,0 +1,141 @@ +import { CohereRerank } from '@langchain/cohere'; +import { mock } from 'jest-mock-extended'; +import type { ISupplyDataFunctions } from 'n8n-workflow'; + +import { logWrapper } from '@utils/logWrapper'; + +import { RerankerCohere } from '../RerankerCohere.node'; + +// Mock the CohereRerank class +jest.mock('@langchain/cohere', () => ({ + CohereRerank: jest.fn(), +})); + +// Mock the logWrapper utility +jest.mock('@utils/logWrapper', () => ({ + logWrapper: jest.fn().mockImplementation((obj) => ({ logWrapped: obj })), +})); + +describe('RerankerCohere', () => { + let rerankerCohere: RerankerCohere; + let mockSupplyDataFunctions: ISupplyDataFunctions; + let mockCohereRerank: jest.Mocked; + + beforeEach(() => { + rerankerCohere = new RerankerCohere(); + + // Reset the mock + jest.clearAllMocks(); + + // Create a mock CohereRerank instance + mockCohereRerank = { + compressDocuments: jest.fn(), + } as unknown as jest.Mocked; + + // Make the CohereRerank constructor return our mock instance + (CohereRerank as jest.MockedClass).mockImplementation( + () => mockCohereRerank, + ); + + // Create mock supply data functions + mockSupplyDataFunctions = mock({ + logger: { + debug: jest.fn(), + error: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + }, + }); + + // Mock specific methods with proper jest functions + mockSupplyDataFunctions.getNodeParameter = jest.fn(); + mockSupplyDataFunctions.getCredentials = jest.fn(); + }); + + it('should create CohereRerank with default model and return wrapped instance', async () => { + // Setup mocks + const mockCredentials = { apiKey: 'test-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute + const result = await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0); + + expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith( + 'modelName', + 0, + 'rerank-v3.5', + ); + expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi'); + expect(CohereRerank).toHaveBeenCalledWith({ + apiKey: 'test-api-key', + model: 'rerank-v3.5', + }); + expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions); + expect(result.response).toEqual({ logWrapped: mockCohereRerank }); + }); + + it('should create CohereRerank with custom model', async () => { + // Setup mocks + const mockCredentials = { apiKey: 'custom-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue( + 'rerank-multilingual-v3.0', + ); + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute + await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0); + + // Verify + expect(CohereRerank).toHaveBeenCalledWith({ + apiKey: 'custom-api-key', + model: 'rerank-multilingual-v3.0', + }); + }); + + it('should handle different item indices', async () => { + // Setup mocks + const mockCredentials = { apiKey: 'test-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0'); + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute with different item index + await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 2); + + // Verify the correct item index is passed + expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith( + 'modelName', + 2, + 'rerank-v3.5', + ); + }); + + it('should throw error when credentials are missing', async () => { + // Setup mocks + (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue( + new Error('Missing credentials'), + ); + + // Execute and verify error + await expect(rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0)).rejects.toThrow( + 'Missing credentials', + ); + }); + + it('should use fallback model when parameter is not provided', async () => { + // Setup mocks - getNodeParameter returns the fallback value + const mockCredentials = { apiKey: 'test-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute + await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0); + + // Verify fallback is used + expect(CohereRerank).toHaveBeenCalledWith({ + apiKey: 'test-api-key', + model: 'rerank-v3.5', + }); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap index af9ebebce5..a63873a4d6 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/__snapshots__/createVectorStoreNode.test.ts.snap @@ -41,8 +41,13 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] = "inputs": "={{ ((parameters) => { const mode = parameters?.mode; + const useReranker = parameters?.useReranker; const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}] + if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) { + inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1}) + } + if (mode === 'retrieve-as-tool') { return inputs; } @@ -233,6 +238,21 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] = "name": "includeDocumentMetadata", "type": "boolean", }, + { + "default": false, + "description": "Whether or not to rerank results", + "displayName": "Rerank Results", + "displayOptions": { + "show": { + "mode": [ + "load", + "retrieve-as-tool", + ], + }, + }, + "name": "useReranker", + "type": "boolean", + }, { "default": "", "description": "ID of an embedding entry", diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts index 91c7fea47c..833749b954 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/createVectorStoreNode.ts @@ -69,8 +69,13 @@ export const createVectorStoreNode = ( inputs: `={{ ((parameters) => { const mode = parameters?.mode; + const useReranker = parameters?.useReranker; const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}] + if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) { + inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1}) + } + if (mode === 'retrieve-as-tool') { return inputs; } @@ -202,6 +207,18 @@ export const createVectorStoreNode = ( }, }, }, + { + displayName: 'Rerank Results', + name: 'useReranker', + type: 'boolean', + default: false, + description: 'Whether or not to rerank results', + displayOptions: { + show: { + mode: ['load', 'retrieve-as-tool'], + }, + }, + }, // ID is always used for update operation { displayName: 'ID', @@ -233,7 +250,6 @@ export const createVectorStoreNode = ( */ async execute(this: IExecuteFunctions): Promise { const mode = this.getNodeParameter('mode', 0) as NodeOperationMode; - // Get the embeddings model connected to this node const embeddings = (await this.getInputConnectionData( NodeConnectionTypes.AiEmbedding, diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/loadOperation.test.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/loadOperation.test.ts index 75ff118e2e..e35599aab7 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/loadOperation.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/loadOperation.test.ts @@ -2,10 +2,12 @@ /* eslint-disable @typescript-eslint/unbound-method */ import type { Document } from '@langchain/core/documents'; import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { VectorStore } from '@langchain/core/vectorstores'; import type { MockProxy } from 'jest-mock-extended'; import { mock } from 'jest-mock-extended'; import type { IDataObject, IExecuteFunctions } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; import { logAiEvent } from '@utils/helpers'; @@ -22,6 +24,7 @@ describe('handleLoadOperation', () => { let mockContext: MockProxy; let mockEmbeddings: MockProxy; let mockVectorStore: MockProxy; + let mockReranker: MockProxy; let mockArgs: VectorStoreNodeConstructorArgs; let nodeParameters: Record; @@ -30,6 +33,7 @@ describe('handleLoadOperation', () => { prompt: 'test search query', topK: 3, includeDocumentMetadata: true, + useReranker: false, }; mockContext = mock(); @@ -48,6 +52,24 @@ describe('handleLoadOperation', () => { [{ pageContent: 'test content 3', metadata: { test: 'metadata 3' } } as Document, 0.75], ]); + mockReranker = mock(); + mockReranker.compressDocuments.mockResolvedValue([ + { + pageContent: 'test content 2', + metadata: { test: 'metadata 2', relevanceScore: 0.98 }, + } as Document, + { + pageContent: 'test content 1', + metadata: { test: 'metadata 1', relevanceScore: 0.92 }, + } as Document, + { + pageContent: 'test content 3', + metadata: { test: 'metadata 3', relevanceScore: 0.88 }, + } as Document, + ]); + + mockContext.getInputConnectionData.mockResolvedValue(mockReranker); + mockArgs = { meta: { displayName: 'Test Vector Store', @@ -142,4 +164,82 @@ describe('handleLoadOperation', () => { expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore); }); + + describe('reranking functionality', () => { + beforeEach(() => { + nodeParameters.useReranker = true; + }); + + it('should use reranker when useReranker is true', async () => { + const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0); + + expect(mockContext.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiReranker, + 0, + ); + expect(mockReranker.compressDocuments).toHaveBeenCalledWith( + [ + { pageContent: 'test content 1', metadata: { test: 'metadata 1' } }, + { pageContent: 'test content 2', metadata: { test: 'metadata 2' } }, + { pageContent: 'test content 3', metadata: { test: 'metadata 3' } }, + ], + 'test search query', + ); + expect(result).toHaveLength(3); + }); + + it('should return reranked documents with relevance scores', async () => { + const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0); + + // First result should be the reranked first document (was second in original order) + expect((result[0].json?.document as IDataObject)?.pageContent).toEqual('test content 2'); + expect(result[0].json?.score).toEqual(0.98); + + // Second result should be the reranked second document (was first in original order) + expect((result[1].json?.document as IDataObject)?.pageContent).toEqual('test content 1'); + expect(result[1].json?.score).toEqual(0.92); + + // Third result should be the reranked third document + expect((result[2].json?.document as IDataObject)?.pageContent).toEqual('test content 3'); + expect(result[2].json?.score).toEqual(0.88); + }); + + it('should remove relevanceScore from metadata after reranking', async () => { + const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0); + + // Check that relevanceScore is not included in the metadata + expect((result[0].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 2' }); + expect((result[1].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 1' }); + expect((result[2].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 3' }); + }); + + it('should handle reranking with includeDocumentMetadata false', async () => { + nodeParameters.includeDocumentMetadata = false; + + const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0); + + expect(result[0].json?.document).not.toHaveProperty('metadata'); + expect((result[0].json?.document as IDataObject)?.pageContent).toEqual('test content 2'); + expect(result[0].json?.score).toEqual(0.98); + }); + + it('should not call reranker when useReranker is false', async () => { + nodeParameters.useReranker = false; + + await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0); + + expect(mockContext.getInputConnectionData).not.toHaveBeenCalled(); + expect(mockReranker.compressDocuments).not.toHaveBeenCalled(); + }); + + it('should release vector store client even if reranking fails', async () => { + mockReranker.compressDocuments.mockRejectedValue(new Error('Reranking failed')); + + await expect(handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0)).rejects.toThrow( + 'Reranking failed', + ); + + expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore); + }); + }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveAsToolOperation.test.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveAsToolOperation.test.ts index 8699eb526c..137be11bcf 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveAsToolOperation.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/__tests__/retrieveAsToolOperation.test.ts @@ -1,11 +1,13 @@ /* eslint-disable @typescript-eslint/unbound-method */ import type { Document } from '@langchain/core/documents'; import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { VectorStore } from '@langchain/core/vectorstores'; import type { MockProxy } from 'jest-mock-extended'; import { mock } from 'jest-mock-extended'; import { DynamicTool } from 'langchain/tools'; import type { ISupplyDataFunctions } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; import { logWrapper } from '@utils/logWrapper'; @@ -26,6 +28,7 @@ describe('handleRetrieveAsToolOperation', () => { let mockContext: MockProxy; let mockEmbeddings: MockProxy; let mockVectorStore: MockProxy; + let mockReranker: MockProxy; let mockArgs: VectorStoreNodeConstructorArgs; let nodeParameters: Record; @@ -35,6 +38,7 @@ describe('handleRetrieveAsToolOperation', () => { toolDescription: 'Search the test knowledge base', topK: 3, includeDocumentMetadata: true, + useReranker: false, }; mockContext = mock(); @@ -60,6 +64,20 @@ describe('handleRetrieveAsToolOperation', () => { [{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } } as Document, 0.85], ]); + mockReranker = mock(); + mockReranker.compressDocuments.mockResolvedValue([ + { + pageContent: 'test content 2', + metadata: { test: 'metadata 2', relevanceScore: 0.98 }, + } as Document, + { + pageContent: 'test content 1', + metadata: { test: 'metadata 1', relevanceScore: 0.92 }, + } as Document, + ]); + + mockContext.getInputConnectionData.mockResolvedValue(mockReranker); + mockArgs = { meta: { displayName: 'Test Vector Store', @@ -215,4 +233,115 @@ describe('handleRetrieveAsToolOperation', () => { // Should still release the client expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore); }); + + describe('reranking functionality', () => { + beforeEach(() => { + nodeParameters.useReranker = true; + }); + + it('should use reranker when useReranker is true', async () => { + const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0); + const tool = result.response as DynamicTool; + + await tool.func('test query'); + + expect(mockContext.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiReranker, + 0, + ); + expect(mockReranker.compressDocuments).toHaveBeenCalledWith( + [ + { pageContent: 'test content 1', metadata: { test: 'metadata 1' } }, + { pageContent: 'test content 2', metadata: { test: 'metadata 2' } }, + ], + 'test query', + ); + }); + + it('should return reranked documents in the correct order', async () => { + const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0); + const tool = result.response as DynamicTool; + + const toolResult = await tool.func('test query'); + + expect(toolResult).toHaveLength(2); + + // First result should be the reranked first document (was second in original order) + const parsedFirst = JSON.parse(toolResult[0].text); + expect(parsedFirst.pageContent).toEqual('test content 2'); + expect(parsedFirst.metadata).toEqual({ test: 'metadata 2' }); + + // Second result should be the reranked second document (was first in original order) + const parsedSecond = JSON.parse(toolResult[1].text); + expect(parsedSecond.pageContent).toEqual('test content 1'); + expect(parsedSecond.metadata).toEqual({ test: 'metadata 1' }); + }); + + it('should handle reranking with includeDocumentMetadata false', async () => { + nodeParameters.includeDocumentMetadata = false; + + const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0); + const tool = result.response as DynamicTool; + + const toolResult = await tool.func('test query'); + + // Parse the JSON text to verify it excludes metadata but maintains reranked order + const parsedFirst = JSON.parse(toolResult[0].text); + expect(parsedFirst).toHaveProperty('pageContent', 'test content 2'); + expect(parsedFirst).not.toHaveProperty('metadata'); + + const parsedSecond = JSON.parse(toolResult[1].text); + expect(parsedSecond).toHaveProperty('pageContent', 'test content 1'); + expect(parsedSecond).not.toHaveProperty('metadata'); + }); + + it('should not call reranker when useReranker is false', async () => { + nodeParameters.useReranker = false; + + const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0); + const tool = result.response as DynamicTool; + + await tool.func('test query'); + + expect(mockContext.getInputConnectionData).not.toHaveBeenCalled(); + expect(mockReranker.compressDocuments).not.toHaveBeenCalled(); + }); + + it('should release vector store client even if reranking fails', async () => { + mockReranker.compressDocuments.mockRejectedValueOnce(new Error('Reranking failed')); + + const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0); + const tool = result.response as DynamicTool; + + await expect(tool.func('test query')).rejects.toThrow('Reranking failed'); + + // Should still release the client + expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore); + }); + + it('should properly handle relevanceScore from reranker metadata', async () => { + // Mock reranker to return documents with relevanceScore in different metadata structure + mockReranker.compressDocuments.mockResolvedValueOnce([ + { + pageContent: 'test content 2', + metadata: { test: 'metadata 2', relevanceScore: 0.98, otherField: 'value' }, + } as Document, + { + pageContent: 'test content 1', + metadata: { test: 'metadata 1', relevanceScore: 0.92 }, + } as Document, + ]); + + const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0); + const tool = result.response as DynamicTool; + + const toolResult = await tool.invoke('test query'); + + // Check that relevanceScore is used but not included in the final metadata + const parsedFirst = JSON.parse(toolResult[0].text); + expect(parsedFirst.pageContent).toEqual('test content 2'); + expect(parsedFirst.metadata).toEqual({ test: 'metadata 2', otherField: 'value' }); + expect(parsedFirst.metadata).not.toHaveProperty('relevanceScore'); + }); + }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/loadOperation.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/loadOperation.ts index 3eb5e58076..e8a33dd158 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/loadOperation.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/loadOperation.ts @@ -1,6 +1,7 @@ import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { VectorStore } from '@langchain/core/vectorstores'; -import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow'; +import { NodeConnectionTypes, type IExecuteFunctions, type INodeExecutionData } from 'n8n-workflow'; import { getMetadataFiltersValues, logAiEvent } from '@utils/helpers'; @@ -29,6 +30,8 @@ export async function handleLoadOperation( // Get the search parameters from the node const prompt = context.getNodeParameter('prompt', itemIndex) as string; const topK = context.getNodeParameter('topK', itemIndex, 4) as number; + const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean; + const includeDocumentMetadata = context.getNodeParameter( 'includeDocumentMetadata', itemIndex, @@ -39,7 +42,22 @@ export async function handleLoadOperation( const embeddedPrompt = await embeddings.embedQuery(prompt); // Get the most similar documents to the embedded prompt - const docs = await vectorStore.similaritySearchVectorWithScore(embeddedPrompt, topK, filter); + let docs = await vectorStore.similaritySearchVectorWithScore(embeddedPrompt, topK, filter); + + // If reranker is used, rerank the documents + if (useReranker && docs.length > 0) { + const reranker = (await context.getInputConnectionData( + NodeConnectionTypes.AiReranker, + 0, + )) as BaseDocumentCompressor; + const documents = docs.map(([doc]) => doc); + + const rerankedDocuments = await reranker.compressDocuments(documents, prompt); + docs = rerankedDocuments.map((doc) => { + const { relevanceScore, ...metadata } = doc.metadata || {}; + return [{ ...doc, metadata }, relevanceScore]; + }); + } // Format the documents for the output const serializedDocs = docs.map(([doc, score]) => { diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveAsToolOperation.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveAsToolOperation.ts index f77c0c3ae2..69d95e54db 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveAsToolOperation.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/shared/createVectorStoreNode/operations/retrieveAsToolOperation.ts @@ -1,7 +1,8 @@ import type { Embeddings } from '@langchain/core/embeddings'; +import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { VectorStore } from '@langchain/core/vectorstores'; import { DynamicTool } from 'langchain/tools'; -import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow'; +import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow'; import { getMetadataFiltersValues, nodeNameToToolName } from '@utils/helpers'; import { logWrapper } from '@utils/logWrapper'; @@ -29,6 +30,7 @@ export async function handleRetrieveAsToolOperation 0) { + const reranker = (await context.getInputConnectionData( + NodeConnectionTypes.AiReranker, + 0, + )) as BaseDocumentCompressor; + + const docs = documents.map(([doc]) => doc); + const rerankedDocuments = await reranker.compressDocuments(docs, input); + documents = rerankedDocuments.map((doc) => { + const { relevanceScore, ...metadata } = doc.metadata; + return [{ ...doc, metadata }, relevanceScore]; + }); + } + // Format the documents for the tool output return documents .map((document) => { diff --git a/packages/@n8n/nodes-langchain/package.json b/packages/@n8n/nodes-langchain/package.json index 00aa52f141..0eb74d0e50 100644 --- a/packages/@n8n/nodes-langchain/package.json +++ b/packages/@n8n/nodes-langchain/package.json @@ -99,6 +99,7 @@ "dist/nodes/output_parser/OutputParserAutofixing/OutputParserAutofixing.node.js", "dist/nodes/output_parser/OutputParserItemList/OutputParserItemList.node.js", "dist/nodes/output_parser/OutputParserStructured/OutputParserStructured.node.js", + "dist/nodes/rerankers/RerankerCohere/RerankerCohere.node.js", "dist/nodes/retrievers/RetrieverContextualCompression/RetrieverContextualCompression.node.js", "dist/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.js", "dist/nodes/retrievers/RetrieverMultiQuery/RetrieverMultiQuery.node.js", diff --git a/packages/@n8n/nodes-langchain/utils/logWrapper.ts b/packages/@n8n/nodes-langchain/utils/logWrapper.ts index 17e123009b..0089b53419 100644 --- a/packages/@n8n/nodes-langchain/utils/logWrapper.ts +++ b/packages/@n8n/nodes-langchain/utils/logWrapper.ts @@ -6,6 +6,7 @@ import { Embeddings } from '@langchain/core/embeddings'; import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory'; import type { BaseMessage } from '@langchain/core/messages'; import { BaseRetriever } from '@langchain/core/retrievers'; +import { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors'; import type { StructuredTool, Tool } from '@langchain/core/tools'; import { VectorStore } from '@langchain/core/vectorstores'; import { TextSplitter } from '@langchain/textsplitters'; @@ -18,7 +19,12 @@ import type { ITaskMetadata, NodeConnectionType, } from 'n8n-workflow'; -import { NodeOperationError, NodeConnectionTypes, parseErrorMetadata } from 'n8n-workflow'; +import { + NodeOperationError, + NodeConnectionTypes, + parseErrorMetadata, + deepCopy, +} from 'n8n-workflow'; import { logAiEvent, isToolsInstance, isBaseChatMemory, isBaseChatMessageHistory } from './helpers'; import { N8nBinaryLoader } from './N8nBinaryLoader'; @@ -102,6 +108,7 @@ export function logWrapper< | BaseChatMemory | BaseChatMessageHistory | BaseRetriever + | BaseDocumentCompressor | Embeddings | Document[] | Document @@ -297,6 +304,32 @@ export function logWrapper< } } + // ========== Rerankers ========== + if (originalInstance instanceof BaseDocumentCompressor) { + if (prop === 'compressDocuments' && 'compressDocuments' in target) { + return async (documents: Document[], query: string): Promise => { + connectionType = NodeConnectionTypes.AiReranker; + const { index } = executeFunctions.addInputData(connectionType, [ + [{ json: { query, documents } }], + ]); + + const response = (await callMethodAsync.call(target, { + executeFunctions, + connectionType, + currentNodeRunIndex: index, + method: target[prop], + // compressDocuments mutates the original object + // messing up the input data logging + arguments: [deepCopy(documents), query], + })) as Document[]; + + logAiEvent(executeFunctions, 'ai-document-reranked', { query }); + executeFunctions.addOutputData(connectionType, index, [[{ json: { response } }]]); + return response; + }; + } + } + // ========== N8n Loaders Process All ========== if ( originalInstance instanceof N8nJsonLoader || diff --git a/packages/frontend/editor-ui/src/utils/aiUtils.ts b/packages/frontend/editor-ui/src/utils/aiUtils.ts index fc2fa8eda1..5d5c5af840 100644 --- a/packages/frontend/editor-ui/src/utils/aiUtils.ts +++ b/packages/frontend/editor-ui/src/utils/aiUtils.ts @@ -155,6 +155,7 @@ const outputTypeParsers: { }, [NodeConnectionTypes.AiOutputParser]: fallbackParser, [NodeConnectionTypes.AiRetriever]: fallbackParser, + [NodeConnectionTypes.AiReranker]: fallbackParser, [NodeConnectionTypes.AiVectorStore](execData: IDataObject) { if (execData.documents) { return { diff --git a/packages/workflow/src/interfaces.ts b/packages/workflow/src/interfaces.ts index 4ae5867044..9832586d46 100644 --- a/packages/workflow/src/interfaces.ts +++ b/packages/workflow/src/interfaces.ts @@ -1871,6 +1871,7 @@ export const NodeConnectionTypes = { AiMemory: 'ai_memory', AiOutputParser: 'ai_outputParser', AiRetriever: 'ai_retriever', + AiReranker: 'ai_reranker', AiTextSplitter: 'ai_textSplitter', AiTool: 'ai_tool', AiVectorStore: 'ai_vectorStore', @@ -1881,20 +1882,7 @@ export type NodeConnectionType = (typeof NodeConnectionTypes)[keyof typeof NodeC export type AINodeConnectionType = Exclude; -export const nodeConnectionTypes: NodeConnectionType[] = [ - NodeConnectionTypes.AiAgent, - NodeConnectionTypes.AiChain, - NodeConnectionTypes.AiDocument, - NodeConnectionTypes.AiEmbedding, - NodeConnectionTypes.AiLanguageModel, - NodeConnectionTypes.AiMemory, - NodeConnectionTypes.AiOutputParser, - NodeConnectionTypes.AiRetriever, - NodeConnectionTypes.AiTextSplitter, - NodeConnectionTypes.AiTool, - NodeConnectionTypes.AiVectorStore, - NodeConnectionTypes.Main, -]; +export const nodeConnectionTypes: NodeConnectionType[] = Object.values(NodeConnectionTypes); export interface INodeInputFilter { // TODO: Later add more filter options like categories, subcatogries, @@ -2333,6 +2321,7 @@ export type AiEvent = | 'ai-message-added-to-memory' | 'ai-output-parsed' | 'ai-documents-retrieved' + | 'ai-document-reranked' | 'ai-document-embedded' | 'ai-query-embedded' | 'ai-document-processed'