fix(Reranker Cohere Node): Add 'Top N' parameter to control document return count (#17921)

This commit is contained in:
Andrew Zolotukhin
2025-08-07 10:49:25 +03:00
committed by GitHub
parent 6574c573cf
commit 523a55d5ee
2 changed files with 88 additions and 7 deletions

View File

@@ -70,16 +70,26 @@ export class RerankerCohere implements INodeType {
}, },
], ],
}, },
{
displayName: 'Top N',
name: 'topN',
type: 'number',
description: 'The maximum number of documents to return after reranking',
default: 3,
},
], ],
}; };
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> { async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
this.logger.debug('Supply data for reranking Cohere'); this.logger.debug('Supply data for reranking Cohere');
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string; const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
const topN = this.getNodeParameter('topN', itemIndex, 3) as number;
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi'); const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
const reranker = new CohereRerank({ const reranker = new CohereRerank({
apiKey: credentials.apiKey, apiKey: credentials.apiKey,
model: modelName, model: modelName,
topN,
}); });
return { return {

View File

@@ -55,7 +55,9 @@ describe('RerankerCohere', () => {
it('should create CohereRerank with default model and return wrapped instance', async () => { it('should create CohereRerank with default model and return wrapped instance', async () => {
// Setup mocks // Setup mocks
const mockCredentials = { apiKey: 'test-api-key' }; const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); (mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute // Execute
@@ -66,10 +68,12 @@ describe('RerankerCohere', () => {
0, 0,
'rerank-v3.5', 'rerank-v3.5',
); );
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi'); expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
expect(CohereRerank).toHaveBeenCalledWith({ expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key', apiKey: 'test-api-key',
model: 'rerank-v3.5', model: 'rerank-v3.5',
topN: 3,
}); });
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions); expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
expect(result.response).toEqual({ logWrapped: mockCohereRerank }); expect(result.response).toEqual({ logWrapped: mockCohereRerank });
@@ -78,9 +82,9 @@ describe('RerankerCohere', () => {
it('should create CohereRerank with custom model', async () => { it('should create CohereRerank with custom model', async () => {
// Setup mocks // Setup mocks
const mockCredentials = { apiKey: 'custom-api-key' }; const mockCredentials = { apiKey: 'custom-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue( (mockSupplyDataFunctions.getNodeParameter as jest.Mock)
'rerank-multilingual-v3.0', .mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
); .mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute // Execute
@@ -90,13 +94,16 @@ describe('RerankerCohere', () => {
expect(CohereRerank).toHaveBeenCalledWith({ expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'custom-api-key', apiKey: 'custom-api-key',
model: 'rerank-multilingual-v3.0', model: 'rerank-multilingual-v3.0',
topN: 3,
}); });
}); });
it('should handle different item indices', async () => { it('should handle different item indices', async () => {
// Setup mocks // Setup mocks
const mockCredentials = { apiKey: 'test-api-key' }; const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0'); (mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-english-v3.0') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute with different item index // Execute with different item index
@@ -108,11 +115,14 @@ describe('RerankerCohere', () => {
2, 2,
'rerank-v3.5', 'rerank-v3.5',
); );
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 2, 3);
}); });
it('should throw error when credentials are missing', async () => { it('should throw error when credentials are missing', async () => {
// Setup mocks // Setup mocks
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); (mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue( (mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
new Error('Missing credentials'), new Error('Missing credentials'),
); );
@@ -126,7 +136,9 @@ describe('RerankerCohere', () => {
it('should use fallback model when parameter is not provided', async () => { it('should use fallback model when parameter is not provided', async () => {
// Setup mocks - getNodeParameter returns the fallback value // Setup mocks - getNodeParameter returns the fallback value
const mockCredentials = { apiKey: 'test-api-key' }; const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value (mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName (fallback value)
.mockReturnValueOnce(3); // topN (fallback value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials); (mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute // Execute
@@ -136,6 +148,65 @@ describe('RerankerCohere', () => {
expect(CohereRerank).toHaveBeenCalledWith({ expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key', apiKey: 'test-api-key',
model: 'rerank-v3.5', model: 'rerank-v3.5',
topN: 3,
});
});
it('should create CohereRerank with custom topN value', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName
.mockReturnValueOnce(10); // topN (custom value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
// Verify custom topN is used
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-v3.5',
topN: 10,
});
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
});
it('should create CohereRerank with topN value of 1', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-english-v3.0') // modelName
.mockReturnValueOnce(1); // topN (edge case value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
// Verify edge case topN is used
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-english-v3.0',
topN: 1,
});
});
it('should create CohereRerank with large topN value', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
.mockReturnValueOnce(100); // topN (large value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
// Verify large topN is used
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-multilingual-v3.0',
topN: 100,
}); });
}); });
}); });