feat(OpenAI Node): Filter available models by blacklisting rather than whitelisting (#14780)

This commit is contained in:
Yiorgis Gozadinos
2025-04-23 11:09:16 +02:00
committed by GitHub
parent 418a588e89
commit 0e2eceb33f
3 changed files with 54 additions and 21 deletions

View File

@@ -31,6 +31,11 @@ describe('searchModels', () => {
{ id: 'gpt-3.5-turbo-instruct' },
{ id: 'ft:gpt-3.5-turbo' },
{ id: 'o1-model' },
{ id: 'whisper-1' },
{ id: 'davinci-instruct-beta' },
{ id: 'computer-use-preview' },
{ id: 'whisper-1-preview' },
{ id: 'tts-model' },
{ id: 'other-model' },
],
}),
@@ -53,7 +58,13 @@ describe('searchModels', () => {
baseURL: 'https://api.openai.com/v1',
apiKey: 'test-api-key',
});
expect(result.results).toHaveLength(4);
expect(result.results).toEqual([
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
{ name: 'gpt-4', value: 'gpt-4' },
{ name: 'o1-model', value: 'o1-model' },
{ name: 'other-model', value: 'other-model' },
]);
});
it('should initialize OpenAI with correct credentials', async () => {
@@ -86,8 +97,20 @@ describe('searchModels', () => {
mockContext.getNodeParameter = jest.fn().mockReturnValue('https://custom-api.com');
const result = await searchModels.call(mockContext);
expect(result.results).toHaveLength(6);
expect(result.results).toEqual([
{ name: 'computer-use-preview', value: 'computer-use-preview' },
{ name: 'davinci-instruct-beta', value: 'davinci-instruct-beta' },
{ name: 'ft:gpt-3.5-turbo', value: 'ft:gpt-3.5-turbo' },
{ name: 'gpt-3.5-turbo', value: 'gpt-3.5-turbo' },
{ name: 'gpt-3.5-turbo-instruct', value: 'gpt-3.5-turbo-instruct' },
{ name: 'gpt-4', value: 'gpt-4' },
{ name: 'o1-model', value: 'o1-model' },
{ name: 'other-model', value: 'other-model' },
{ name: 'tts-model', value: 'tts-model' },
{ name: 'whisper-1', value: 'whisper-1' },
{ name: 'whisper-1-preview', value: 'whisper-1-preview' },
]);
expect(result.results).toHaveLength(11);
});
it('should filter models based on search term', async () => {

View File

@@ -16,26 +16,31 @@ export async function searchModels(
const filteredModels = models.filter((model: { id: string }) => {
const url = baseURL && new URL(baseURL);
const isValidModel =
(url && url.hostname !== 'api.openai.com') ||
model.id.startsWith('ft:') ||
model.id.startsWith('o1') ||
model.id.startsWith('o3') ||
(model.id.startsWith('gpt-') && !model.id.includes('instruct'));
const isCustomAPI = url && url.hostname !== 'api.openai.com';
// Filter out TTS, embedding, image generation, and other models
const isInvalidModel =
!isCustomAPI &&
(model.id.startsWith('babbage') ||
model.id.startsWith('davinci') ||
model.id.startsWith('computer-use') ||
model.id.startsWith('dall-e') ||
model.id.startsWith('text-embedding') ||
model.id.startsWith('tts') ||
model.id.startsWith('whisper') ||
model.id.startsWith('omni-moderation') ||
(model.id.startsWith('gpt-') && model.id.includes('instruct')));
if (!filter) return isValidModel;
if (!filter) return !isInvalidModel;
return isValidModel && model.id.toLowerCase().includes(filter.toLowerCase());
return !isInvalidModel && model.id.toLowerCase().includes(filter.toLowerCase());
});
filteredModels.sort((a, b) => a.id.localeCompare(b.id));
const results = {
return {
results: filteredModels.map((model: { id: string }) => ({
name: model.id,
value: model.id,
})),
};
return results;
}

View File

@@ -66,7 +66,6 @@ const getModelSearch =
}
results = results.sort((a, b) => a.name.localeCompare(b.name));
return {
results,
};
@@ -79,14 +78,20 @@ export async function modelSearch(
const credentials = await this.getCredentials<{ url: string }>('openAiApi');
const url = credentials.url && new URL(credentials.url);
const isCustomAPI = url && url.hostname !== 'api.openai.com';
return await getModelSearch(
(model) =>
isCustomAPI ||
model.id.startsWith('gpt-') ||
model.id.startsWith('ft:') ||
model.id.startsWith('o1') ||
model.id.startsWith('o3'),
!isCustomAPI &&
!(
model.id.startsWith('babbage') ||
model.id.startsWith('davinci') ||
model.id.startsWith('computer-use') ||
model.id.startsWith('dall-e') ||
model.id.startsWith('text-embedding') ||
model.id.startsWith('tts') ||
model.id.startsWith('whisper') ||
model.id.startsWith('omni-moderation') ||
(model.id.startsWith('gpt-') && model.id.includes('instruct'))
),
)(this, filter);
}