mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-21 20:00:02 +00:00
fix(Reranker Cohere Node): Add 'Top N' parameter to control document return count (#17921)
This commit is contained in:
committed by
GitHub
parent
6574c573cf
commit
523a55d5ee
@@ -70,16 +70,26 @@ export class RerankerCohere implements INodeType {
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Top N',
|
||||||
|
name: 'topN',
|
||||||
|
type: 'number',
|
||||||
|
description: 'The maximum number of documents to return after reranking',
|
||||||
|
default: 3,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
||||||
this.logger.debug('Supply data for reranking Cohere');
|
this.logger.debug('Supply data for reranking Cohere');
|
||||||
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
|
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
|
||||||
|
const topN = this.getNodeParameter('topN', itemIndex, 3) as number;
|
||||||
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
|
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
|
||||||
|
|
||||||
const reranker = new CohereRerank({
|
const reranker = new CohereRerank({
|
||||||
apiKey: credentials.apiKey,
|
apiKey: credentials.apiKey,
|
||||||
model: modelName,
|
model: modelName,
|
||||||
|
topN,
|
||||||
});
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -55,7 +55,9 @@ describe('RerankerCohere', () => {
|
|||||||
it('should create CohereRerank with default model and return wrapped instance', async () => {
|
it('should create CohereRerank with default model and return wrapped instance', async () => {
|
||||||
// Setup mocks
|
// Setup mocks
|
||||||
const mockCredentials = { apiKey: 'test-api-key' };
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-v3.5') // modelName
|
||||||
|
.mockReturnValueOnce(3); // topN (default)
|
||||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
@@ -66,10 +68,12 @@ describe('RerankerCohere', () => {
|
|||||||
0,
|
0,
|
||||||
'rerank-v3.5',
|
'rerank-v3.5',
|
||||||
);
|
);
|
||||||
|
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
|
||||||
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
|
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
|
||||||
expect(CohereRerank).toHaveBeenCalledWith({
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
apiKey: 'test-api-key',
|
apiKey: 'test-api-key',
|
||||||
model: 'rerank-v3.5',
|
model: 'rerank-v3.5',
|
||||||
|
topN: 3,
|
||||||
});
|
});
|
||||||
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
|
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
|
||||||
expect(result.response).toEqual({ logWrapped: mockCohereRerank });
|
expect(result.response).toEqual({ logWrapped: mockCohereRerank });
|
||||||
@@ -78,9 +82,9 @@ describe('RerankerCohere', () => {
|
|||||||
it('should create CohereRerank with custom model', async () => {
|
it('should create CohereRerank with custom model', async () => {
|
||||||
// Setup mocks
|
// Setup mocks
|
||||||
const mockCredentials = { apiKey: 'custom-api-key' };
|
const mockCredentials = { apiKey: 'custom-api-key' };
|
||||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue(
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
'rerank-multilingual-v3.0',
|
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
|
||||||
);
|
.mockReturnValueOnce(3); // topN (default)
|
||||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
@@ -90,13 +94,16 @@ describe('RerankerCohere', () => {
|
|||||||
expect(CohereRerank).toHaveBeenCalledWith({
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
apiKey: 'custom-api-key',
|
apiKey: 'custom-api-key',
|
||||||
model: 'rerank-multilingual-v3.0',
|
model: 'rerank-multilingual-v3.0',
|
||||||
|
topN: 3,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle different item indices', async () => {
|
it('should handle different item indices', async () => {
|
||||||
// Setup mocks
|
// Setup mocks
|
||||||
const mockCredentials = { apiKey: 'test-api-key' };
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0');
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-english-v3.0') // modelName
|
||||||
|
.mockReturnValueOnce(3); // topN (default)
|
||||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
// Execute with different item index
|
// Execute with different item index
|
||||||
@@ -108,11 +115,14 @@ describe('RerankerCohere', () => {
|
|||||||
2,
|
2,
|
||||||
'rerank-v3.5',
|
'rerank-v3.5',
|
||||||
);
|
);
|
||||||
|
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 2, 3);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should throw error when credentials are missing', async () => {
|
it('should throw error when credentials are missing', async () => {
|
||||||
// Setup mocks
|
// Setup mocks
|
||||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-v3.5') // modelName
|
||||||
|
.mockReturnValueOnce(3); // topN (default)
|
||||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
|
||||||
new Error('Missing credentials'),
|
new Error('Missing credentials'),
|
||||||
);
|
);
|
||||||
@@ -126,7 +136,9 @@ describe('RerankerCohere', () => {
|
|||||||
it('should use fallback model when parameter is not provided', async () => {
|
it('should use fallback model when parameter is not provided', async () => {
|
||||||
// Setup mocks - getNodeParameter returns the fallback value
|
// Setup mocks - getNodeParameter returns the fallback value
|
||||||
const mockCredentials = { apiKey: 'test-api-key' };
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-v3.5') // modelName (fallback value)
|
||||||
|
.mockReturnValueOnce(3); // topN (fallback value)
|
||||||
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
@@ -136,6 +148,65 @@ describe('RerankerCohere', () => {
|
|||||||
expect(CohereRerank).toHaveBeenCalledWith({
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
apiKey: 'test-api-key',
|
apiKey: 'test-api-key',
|
||||||
model: 'rerank-v3.5',
|
model: 'rerank-v3.5',
|
||||||
|
topN: 3,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create CohereRerank with custom topN value', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-v3.5') // modelName
|
||||||
|
.mockReturnValueOnce(10); // topN (custom value)
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||||
|
|
||||||
|
// Verify custom topN is used
|
||||||
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
model: 'rerank-v3.5',
|
||||||
|
topN: 10,
|
||||||
|
});
|
||||||
|
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith('topN', 0, 3);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create CohereRerank with topN value of 1', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-english-v3.0') // modelName
|
||||||
|
.mockReturnValueOnce(1); // topN (edge case value)
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||||
|
|
||||||
|
// Verify edge case topN is used
|
||||||
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
model: 'rerank-english-v3.0',
|
||||||
|
topN: 1,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create CohereRerank with large topN value', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock)
|
||||||
|
.mockReturnValueOnce('rerank-multilingual-v3.0') // modelName
|
||||||
|
.mockReturnValueOnce(100); // topN (large value)
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||||
|
|
||||||
|
// Verify large topN is used
|
||||||
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
model: 'rerank-multilingual-v3.0',
|
||||||
|
topN: 100,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user