fix(Reranker Cohere Node): Add 'Top N' parameter to control document return count (#17921)

This commit is contained in:
Andrew Zolotukhin
2025-08-07 10:49:25 +03:00
committed by GitHub
parent 6574c573cf
commit 523a55d5ee
2 changed files with 88 additions and 7 deletions

View File

@@ -70,16 +70,26 @@ export class RerankerCohere implements INodeType {
},
],
},
{
displayName: 'Top N',
name: 'topN',
type: 'number',
description: 'The maximum number of documents to return after reranking',
default: 3,
},
],
};
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
this.logger.debug('Supply data for reranking Cohere');
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
const topN = this.getNodeParameter('topN', itemIndex, 3) as number;
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
const reranker = new CohereRerank({
apiKey: credentials.apiKey,
model: modelName,
topN,
});
return {

View File

@@ -55,7 +55,9 @@ describe('RerankerCohere', () => {
it('should create CohereRerank with default model and return wrapped instance', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
@@ -66,10 +68,12 @@ describe('RerankerCohere', () => {
0,
'rerank-v3.5',
);
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-v3.5',
topN: 3,
});
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
expect(result.response).toEqual({ logWrapped: mockCohereRerank });
@@ -78,9 +82,9 @@ describe('RerankerCohere', () => {
it('should create CohereRerank with custom model', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'custom-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue(
'rerank-multilingual-v3.0',
);
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
@@ -90,13 +94,16 @@ describe('RerankerCohere', () => {
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'custom-api-key',
model: 'rerank-multilingual-v3.0',
topN: 3,
});
});
it('should handle different item indices', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0');
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-english-v3.0') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute with different item index
@@ -108,11 +115,14 @@ describe('RerankerCohere', () => {
2,
'rerank-v3.5',
);
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 2, 3);
});
it('should throw error when credentials are missing', async () => {
// Setup mocks
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName
.mockReturnValueOnce(3); // topN (default)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
new Error('Missing credentials'),
);
@@ -126,7 +136,9 @@ describe('RerankerCohere', () => {
it('should use fallback model when parameter is not provided', async () => {
// Setup mocks - getNodeParameter returns the fallback value
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName (fallback value)
.mockReturnValueOnce(3); // topN (fallback value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
@@ -136,6 +148,65 @@ describe('RerankerCohere', () => {
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-v3.5',
topN: 3,
});
});
it('should create CohereRerank with custom topN value', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-v3.5') // modelName
.mockReturnValueOnce(10); // topN (custom value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
// Verify custom topN is used
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-v3.5',
topN: 10,
});
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
});
it('should create CohereRerank with topN value of 1', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-english-v3.0') // modelName
.mockReturnValueOnce(1); // topN (edge case value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
// Verify edge case topN is used
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-english-v3.0',
topN: 1,
});
});
it('should create CohereRerank with large topN value', async () => {
// Setup mocks
const mockCredentials = { apiKey: 'test-api-key' };
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
.mockReturnValueOnce(100); // topN (large value)
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
// Execute
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
// Verify large topN is used
expect(CohereRerank).toHaveBeenCalledWith({
apiKey: 'test-api-key',
model: 'rerank-multilingual-v3.0',
topN: 100,
});
});
});