diff --git a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts index 929f32013c..3e81db4546 100644 --- a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/RerankerCohere.node.ts @@ -70,16 +70,26 @@ export class RerankerCohere implements INodeType { }, ], }, + { + displayName: 'Top N', + name: 'topN', + type: 'number', + description: 'The maximum number of documents to return after reranking', + default: 3, + }, ], }; async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise { this.logger.debug('Supply data for reranking Cohere'); const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string; + const topN = this.getNodeParameter('topN', itemIndex, 3) as number; const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi'); + const reranker = new CohereRerank({ apiKey: credentials.apiKey, model: modelName, + topN, }); return { diff --git a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts index 98ecea2950..92ffdee20e 100644 --- a/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/rerankers/RerankerCohere/test/RerankerCohere.node.test.ts @@ -55,7 +55,9 @@ describe('RerankerCohere', () => { it('should create CohereRerank with default model and return wrapped instance', async () => { // Setup mocks const mockCredentials = { apiKey: 'test-api-key' }; - (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-v3.5') // modelName + .mockReturnValueOnce(3); // topN (default) (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); // Execute @@ -66,10 +68,12 @@ describe('RerankerCohere', () => { 0, 'rerank-v3.5', ); + expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3); expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi'); expect(CohereRerank).toHaveBeenCalledWith({ apiKey: 'test-api-key', model: 'rerank-v3.5', + topN: 3, }); expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions); expect(result.response).toEqual({ logWrapped: mockCohereRerank }); @@ -78,9 +82,9 @@ describe('RerankerCohere', () => { it('should create CohereRerank with custom model', async () => { // Setup mocks const mockCredentials = { apiKey: 'custom-api-key' }; - (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue( - 'rerank-multilingual-v3.0', - ); + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-multilingual-v3.0') // modelName + .mockReturnValueOnce(3); // topN (default) (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); // Execute @@ -90,13 +94,16 @@ describe('RerankerCohere', () => { expect(CohereRerank).toHaveBeenCalledWith({ apiKey: 'custom-api-key', model: 'rerank-multilingual-v3.0', + topN: 3, }); }); it('should handle different item indices', async () => { // Setup mocks const mockCredentials = { apiKey: 'test-api-key' }; - (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0'); + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-english-v3.0') // modelName + .mockReturnValueOnce(3); // topN (default) (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); // Execute with different item index @@ -108,11 +115,14 @@ describe('RerankerCohere', () => { 2, 'rerank-v3.5', ); + expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 2, 3); }); it('should throw error when credentials are missing', async () => { // Setup mocks - (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-v3.5') // modelName + .mockReturnValueOnce(3); // topN (default) (mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue( new Error('Missing credentials'), ); @@ -126,7 +136,9 @@ describe('RerankerCohere', () => { it('should use fallback model when parameter is not provided', async () => { // Setup mocks - getNodeParameter returns the fallback value const mockCredentials = { apiKey: 'test-api-key' }; - (mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-v3.5') // modelName (fallback value) + .mockReturnValueOnce(3); // topN (fallback value) (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); // Execute @@ -136,6 +148,65 @@ describe('RerankerCohere', () => { expect(CohereRerank).toHaveBeenCalledWith({ apiKey: 'test-api-key', model: 'rerank-v3.5', + topN: 3, + }); + }); + + it('should create CohereRerank with custom topN value', async () => { + // Setup mocks + const mockCredentials = { apiKey: 'test-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-v3.5') // modelName + .mockReturnValueOnce(10); // topN (custom value) + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute + await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0); + + // Verify custom topN is used + expect(CohereRerank).toHaveBeenCalledWith({ + apiKey: 'test-api-key', + model: 'rerank-v3.5', + topN: 10, + }); + expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3); + }); + + it('should create CohereRerank with topN value of 1', async () => { + // Setup mocks + const mockCredentials = { apiKey: 'test-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-english-v3.0') // modelName + .mockReturnValueOnce(1); // topN (edge case value) + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute + await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0); + + // Verify edge case topN is used + expect(CohereRerank).toHaveBeenCalledWith({ + apiKey: 'test-api-key', + model: 'rerank-english-v3.0', + topN: 1, + }); + }); + + it('should create CohereRerank with large topN value', async () => { + // Setup mocks + const mockCredentials = { apiKey: 'test-api-key' }; + (mockSupplyDataFunctions.getNodeParameter as jest.Mock) + .mockReturnValueOnce('rerank-multilingual-v3.0') // modelName + .mockReturnValueOnce(100); // topN (large value) + (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); + + // Execute + await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0); + + // Verify large topN is used + expect(CohereRerank).toHaveBeenCalledWith({ + apiKey: 'test-api-key', + model: 'rerank-multilingual-v3.0', + topN: 100, }); }); });