mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
fix(Reranker Cohere Node): Add 'Top N' parameter to control document return count (#17921)
This commit is contained in:
committed by
GitHub
parent
6574c573cf
commit
523a55d5ee
@@ -70,16 +70,26 @@ export class RerankerCohere implements INodeType {
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
displayName: 'Top N',
|
||||
name: 'topN',
|
||||
type: 'number',
|
||||
description: 'The maximum number of documents to return after reranking',
|
||||
default: 3,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
||||
this.logger.debug('Supply data for reranking Cohere');
|
||||
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
|
||||
const topN = this.getNodeParameter('topN', itemIndex, 3) as number;
|
||||
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
|
||||
|
||||
const reranker = new CohereRerank({
|
||||
apiKey: credentials.apiKey,
|
||||
model: modelName,
|
||||
topN,
|
||||
});
|
||||
|
||||
return {
|
||||
|
||||
@@ -55,7 +55,9 @@ describe('RerankerCohere', () => {
|
||||
it('should create CohereRerank with default model and return wrapped instance', async () => {
|
||||
// Setup mocks
|
||||
const mockCredentials = { apiKey: 'test-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-v3.5') // modelName
|
||||
.mockReturnValueOnce(3); // topN (default)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute
|
||||
@@ -66,10 +68,12 @@ describe('RerankerCohere', () => {
|
||||
0,
|
||||
'rerank-v3.5',
|
||||
);
|
||||
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
|
||||
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
|
||||
expect(CohereRerank).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
model: 'rerank-v3.5',
|
||||
topN: 3,
|
||||
});
|
||||
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
|
||||
expect(result.response).toEqual({ logWrapped: mockCohereRerank });
|
||||
@@ -78,9 +82,9 @@ describe('RerankerCohere', () => {
|
||||
it('should create CohereRerank with custom model', async () => {
|
||||
// Setup mocks
|
||||
const mockCredentials = { apiKey: 'custom-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue(
|
||||
'rerank-multilingual-v3.0',
|
||||
);
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
|
||||
.mockReturnValueOnce(3); // topN (default)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute
|
||||
@@ -90,13 +94,16 @@ describe('RerankerCohere', () => {
|
||||
expect(CohereRerank).toHaveBeenCalledWith({
|
||||
apiKey: 'custom-api-key',
|
||||
model: 'rerank-multilingual-v3.0',
|
||||
topN: 3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle different item indices', async () => {
|
||||
// Setup mocks
|
||||
const mockCredentials = { apiKey: 'test-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0');
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-english-v3.0') // modelName
|
||||
.mockReturnValueOnce(3); // topN (default)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute with different item index
|
||||
@@ -108,11 +115,14 @@ describe('RerankerCohere', () => {
|
||||
2,
|
||||
'rerank-v3.5',
|
||||
);
|
||||
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 2, 3);
|
||||
});
|
||||
|
||||
it('should throw error when credentials are missing', async () => {
|
||||
// Setup mocks
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-v3.5') // modelName
|
||||
.mockReturnValueOnce(3); // topN (default)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
|
||||
new Error('Missing credentials'),
|
||||
);
|
||||
@@ -126,7 +136,9 @@ describe('RerankerCohere', () => {
|
||||
it('should use fallback model when parameter is not provided', async () => {
|
||||
// Setup mocks - getNodeParameter returns the fallback value
|
||||
const mockCredentials = { apiKey: 'test-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-v3.5') // modelName (fallback value)
|
||||
.mockReturnValueOnce(3); // topN (fallback value)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute
|
||||
@@ -136,6 +148,65 @@ describe('RerankerCohere', () => {
|
||||
expect(CohereRerank).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
model: 'rerank-v3.5',
|
||||
topN: 3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should create CohereRerank with custom topN value', async () => {
|
||||
// Setup mocks
|
||||
const mockCredentials = { apiKey: 'test-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-v3.5') // modelName
|
||||
.mockReturnValueOnce(10); // topN (custom value)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute
|
||||
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||
|
||||
// Verify custom topN is used
|
||||
expect(CohereRerank).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
model: 'rerank-v3.5',
|
||||
topN: 10,
|
||||
});
|
||||
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
|
||||
});
|
||||
|
||||
it('should create CohereRerank with topN value of 1', async () => {
|
||||
// Setup mocks
|
||||
const mockCredentials = { apiKey: 'test-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-english-v3.0') // modelName
|
||||
.mockReturnValueOnce(1); // topN (edge case value)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute
|
||||
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||
|
||||
// Verify edge case topN is used
|
||||
expect(CohereRerank).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
model: 'rerank-english-v3.0',
|
||||
topN: 1,
|
||||
});
|
||||
});
|
||||
|
||||
it('should create CohereRerank with large topN value', async () => {
|
||||
// Setup mocks
|
||||
const mockCredentials = { apiKey: 'test-api-key' };
|
||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
|
||||
.mockReturnValueOnce(100); // topN (large value)
|
||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||
|
||||
// Execute
|
||||
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||
|
||||
// Verify large topN is used
|
||||
expect(CohereRerank).toHaveBeenCalledWith({
|
||||
apiKey: 'test-api-key',
|
||||
model: 'rerank-multilingual-v3.0',
|
||||
topN: 100,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user