diff --git a/packages/@n8n/api-types/src/dto/ai/__tests__/ai-apply-suggestion-request.dto.test.ts b/packages/@n8n/api-types/src/dto/ai/__tests__/ai-apply-suggestion-request.dto.test.ts new file mode 100644 index 0000000000..568900e409 --- /dev/null +++ b/packages/@n8n/api-types/src/dto/ai/__tests__/ai-apply-suggestion-request.dto.test.ts @@ -0,0 +1,36 @@ +import { AiApplySuggestionRequestDto } from '../ai-apply-suggestion-request.dto'; + +describe('AiApplySuggestionRequestDto', () => { + it('should validate a valid suggestion application request', () => { + const validRequest = { + sessionId: 'session-123', + suggestionId: 'suggestion-456', + }; + + const result = AiApplySuggestionRequestDto.safeParse(validRequest); + + expect(result.success).toBe(true); + }); + + it('should fail if sessionId is missing', () => { + const invalidRequest = { + suggestionId: 'suggestion-456', + }; + + const result = AiApplySuggestionRequestDto.safeParse(invalidRequest); + + expect(result.success).toBe(false); + expect(result.error?.issues[0].path).toEqual(['sessionId']); + }); + + it('should fail if suggestionId is missing', () => { + const invalidRequest = { + sessionId: 'session-123', + }; + + const result = AiApplySuggestionRequestDto.safeParse(invalidRequest); + + expect(result.success).toBe(false); + expect(result.error?.issues[0].path).toEqual(['suggestionId']); + }); +}); diff --git a/packages/@n8n/api-types/src/dto/ai/__tests__/ai-ask-request.dto.test.ts b/packages/@n8n/api-types/src/dto/ai/__tests__/ai-ask-request.dto.test.ts new file mode 100644 index 0000000000..a87eb5f3a4 --- /dev/null +++ b/packages/@n8n/api-types/src/dto/ai/__tests__/ai-ask-request.dto.test.ts @@ -0,0 +1,252 @@ +import { AiAskRequestDto } from '../ai-ask-request.dto'; + +describe('AiAskRequestDto', () => { + const validRequest = { + question: 'How can I improve this workflow?', + context: { + schema: [ + { + nodeName: 'TestNode', + schema: { + type: 'string', + key: 'testKey', + value: 'testValue', + path: '/test/path', + }, + }, + ], + inputSchema: { + nodeName: 'InputNode', + schema: { + type: 'object', + key: 'inputKey', + value: [ + { + type: 'string', + key: 'nestedKey', + value: 'nestedValue', + path: '/nested/path', + }, + ], + path: '/input/path', + }, + }, + pushRef: 'push-123', + ndvPushRef: 'ndv-push-456', + }, + forNode: 'TestWorkflowNode', + }; + + it('should validate a valid AI ask request', () => { + const result = AiAskRequestDto.safeParse(validRequest); + + expect(result.success).toBe(true); + }); + + it('should fail if question is missing', () => { + const invalidRequest = { + ...validRequest, + question: undefined, + }; + + const result = AiAskRequestDto.safeParse(invalidRequest); + + expect(result.success).toBe(false); + expect(result.error?.issues[0].path).toEqual(['question']); + }); + + it('should fail if context is invalid', () => { + const invalidRequest = { + ...validRequest, + context: { + ...validRequest.context, + schema: [ + { + nodeName: 'TestNode', + schema: { + type: 'invalid-type', // Invalid type + value: 'testValue', + path: '/test/path', + }, + }, + ], + }, + }; + + const result = AiAskRequestDto.safeParse(invalidRequest); + + expect(result.success).toBe(false); + }); + + it('should fail if forNode is missing', () => { + const invalidRequest = { + ...validRequest, + forNode: undefined, + }; + + const result = AiAskRequestDto.safeParse(invalidRequest); + + expect(result.success).toBe(false); + expect(result.error?.issues[0].path).toEqual(['forNode']); + }); + + it('should validate all possible schema types', () => { + const allTypesRequest = { + question: 'Test all possible types', + context: { + schema: [ + { + nodeName: 'AllTypesNode', + schema: { + type: 'object', + key: 'typesRoot', + value: [ + { type: 'string', key: 'stringType', value: 'string', path: '/types/string' }, + { type: 'number', key: 'numberType', value: 'number', path: '/types/number' }, + { type: 'boolean', key: 'booleanType', value: 'boolean', path: '/types/boolean' }, + { type: 'bigint', key: 'bigintType', value: 'bigint', path: '/types/bigint' }, + { type: 'symbol', key: 'symbolType', value: 'symbol', path: '/types/symbol' }, + { type: 'array', key: 'arrayType', value: [], path: '/types/array' }, + { type: 'object', key: 'objectType', value: [], path: '/types/object' }, + { + type: 'function', + key: 'functionType', + value: 'function', + path: '/types/function', + }, + { type: 'null', key: 'nullType', value: 'null', path: '/types/null' }, + { + type: 'undefined', + key: 'undefinedType', + value: 'undefined', + path: '/types/undefined', + }, + ], + path: '/types/root', + }, + }, + ], + inputSchema: { + nodeName: 'InputNode', + schema: { + type: 'object', + key: 'simpleInput', + value: [ + { + type: 'string', + key: 'simpleKey', + value: 'simpleValue', + path: '/simple/path', + }, + ], + path: '/simple/input/path', + }, + }, + pushRef: 'push-types-123', + ndvPushRef: 'ndv-push-types-456', + }, + forNode: 'TypeCheckNode', + }; + + const result = AiAskRequestDto.safeParse(allTypesRequest); + expect(result.success).toBe(true); + }); + + it('should fail with invalid type', () => { + const invalidTypeRequest = { + question: 'Test invalid type', + context: { + schema: [ + { + nodeName: 'InvalidTypeNode', + schema: { + type: 'invalid-type', // This should fail + key: 'invalidKey', + value: 'invalidValue', + path: '/invalid/path', + }, + }, + ], + inputSchema: { + nodeName: 'InputNode', + schema: { + type: 'object', + key: 'simpleInput', + value: [ + { + type: 'string', + key: 'simpleKey', + value: 'simpleValue', + path: '/simple/path', + }, + ], + path: '/simple/input/path', + }, + }, + pushRef: 'push-invalid-123', + ndvPushRef: 'ndv-push-invalid-456', + }, + forNode: 'InvalidTypeNode', + }; + + const result = AiAskRequestDto.safeParse(invalidTypeRequest); + expect(result.success).toBe(false); + }); + + it('should validate multiple schema entries', () => { + const multiSchemaRequest = { + question: 'Multiple schema test', + context: { + schema: [ + { + nodeName: 'FirstNode', + schema: { + type: 'string', + key: 'firstKey', + value: 'firstValue', + path: '/first/path', + }, + }, + { + nodeName: 'SecondNode', + schema: { + type: 'object', + key: 'secondKey', + value: [ + { + type: 'number', + key: 'nestedKey', + value: 'nestedValue', + path: '/second/nested/path', + }, + ], + path: '/second/path', + }, + }, + ], + inputSchema: { + nodeName: 'InputNode', + schema: { + type: 'object', + key: 'simpleInput', + value: [ + { + type: 'string', + key: 'simpleKey', + value: 'simpleValue', + path: '/simple/path', + }, + ], + path: '/simple/input/path', + }, + }, + pushRef: 'push-multi-123', + ndvPushRef: 'ndv-push-multi-456', + }, + forNode: 'MultiSchemaNode', + }; + + const result = AiAskRequestDto.safeParse(multiSchemaRequest); + expect(result.success).toBe(true); + }); +}); diff --git a/packages/@n8n/api-types/src/dto/ai/__tests__/ai-chat-request.dto.test.ts b/packages/@n8n/api-types/src/dto/ai/__tests__/ai-chat-request.dto.test.ts new file mode 100644 index 0000000000..ce1ccffac5 --- /dev/null +++ b/packages/@n8n/api-types/src/dto/ai/__tests__/ai-chat-request.dto.test.ts @@ -0,0 +1,34 @@ +import { AiChatRequestDto } from '../ai-chat-request.dto'; + +describe('AiChatRequestDto', () => { + it('should validate a request with a payload and session ID', () => { + const validRequest = { + payload: { someKey: 'someValue' }, + sessionId: 'session-123', + }; + + const result = AiChatRequestDto.safeParse(validRequest); + + expect(result.success).toBe(true); + }); + + it('should validate a request with only a payload', () => { + const validRequest = { + payload: { complexObject: { nested: 'value' } }, + }; + + const result = AiChatRequestDto.safeParse(validRequest); + + expect(result.success).toBe(true); + }); + + it('should fail if payload is missing', () => { + const invalidRequest = { + sessionId: 'session-123', + }; + + const result = AiChatRequestDto.safeParse(invalidRequest); + + expect(result.success).toBe(false); + }); +}); diff --git a/packages/@n8n/api-types/src/dto/ai/ai-apply-suggestion-request.dto.ts b/packages/@n8n/api-types/src/dto/ai/ai-apply-suggestion-request.dto.ts new file mode 100644 index 0000000000..cc808dfd24 --- /dev/null +++ b/packages/@n8n/api-types/src/dto/ai/ai-apply-suggestion-request.dto.ts @@ -0,0 +1,7 @@ +import { z } from 'zod'; +import { Z } from 'zod-class'; + +export class AiApplySuggestionRequestDto extends Z.class({ + sessionId: z.string(), + suggestionId: z.string(), +}) {} diff --git a/packages/@n8n/api-types/src/dto/ai/ai-ask-request.dto.ts b/packages/@n8n/api-types/src/dto/ai/ai-ask-request.dto.ts new file mode 100644 index 0000000000..9039243e05 --- /dev/null +++ b/packages/@n8n/api-types/src/dto/ai/ai-ask-request.dto.ts @@ -0,0 +1,53 @@ +import type { AiAssistantSDK, SchemaType } from '@n8n_io/ai-assistant-sdk'; +import { z } from 'zod'; +import { Z } from 'zod-class'; + +// Note: This is copied from the sdk, since this type is not exported +type Schema = { + type: SchemaType; + key?: string; + value: string | Schema[]; + path: string; +}; + +// Create a lazy validator to handle the recursive type +const schemaValidator: z.ZodType = z.lazy(() => + z.object({ + type: z.enum([ + 'string', + 'number', + 'boolean', + 'bigint', + 'symbol', + 'array', + 'object', + 'function', + 'null', + 'undefined', + ]), + key: z.string().optional(), + value: z.union([z.string(), z.lazy(() => schemaValidator.array())]), + path: z.string(), + }), +); + +export class AiAskRequestDto + extends Z.class({ + question: z.string(), + context: z.object({ + schema: z.array( + z.object({ + nodeName: z.string(), + schema: schemaValidator, + }), + ), + inputSchema: z.object({ + nodeName: z.string(), + schema: schemaValidator, + }), + pushRef: z.string(), + ndvPushRef: z.string(), + }), + forNode: z.string(), + }) + implements AiAssistantSDK.AskAiRequestPayload {} diff --git a/packages/@n8n/api-types/src/dto/ai/ai-chat-request.dto.ts b/packages/@n8n/api-types/src/dto/ai/ai-chat-request.dto.ts new file mode 100644 index 0000000000..59e7a26aa3 --- /dev/null +++ b/packages/@n8n/api-types/src/dto/ai/ai-chat-request.dto.ts @@ -0,0 +1,10 @@ +import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk'; +import { z } from 'zod'; +import { Z } from 'zod-class'; + +export class AiChatRequestDto + extends Z.class({ + payload: z.object({}).passthrough(), // Allow any object shape + sessionId: z.string().optional(), + }) + implements AiAssistantSDK.ChatRequestPayload {} diff --git a/packages/@n8n/api-types/src/dto/index.ts b/packages/@n8n/api-types/src/dto/index.ts index 97d5d38459..0e57e07110 100644 --- a/packages/@n8n/api-types/src/dto/index.ts +++ b/packages/@n8n/api-types/src/dto/index.ts @@ -1,3 +1,6 @@ +export { AiAskRequestDto } from './ai/ai-ask-request.dto'; +export { AiChatRequestDto } from './ai/ai-chat-request.dto'; +export { AiApplySuggestionRequestDto } from './ai/ai-apply-suggestion-request.dto'; export { PasswordUpdateRequestDto } from './user/password-update-request.dto'; export { RoleChangeRequestDto } from './user/role-change-request.dto'; export { SettingsUpdateRequestDto } from './user/settings-update-request.dto'; diff --git a/packages/@n8n/config/src/configs/aiAssistant.config.ts b/packages/@n8n/config/src/configs/aiAssistant.config.ts new file mode 100644 index 0000000000..ff8a3986f2 --- /dev/null +++ b/packages/@n8n/config/src/configs/aiAssistant.config.ts @@ -0,0 +1,8 @@ +import { Config, Env } from '../decorators'; + +@Config +export class AiAssistantConfig { + /** Base URL of the AI assistant service */ + @Env('N8N_AI_ASSISTANT_BASE_URL') + baseUrl: string = ''; +} diff --git a/packages/@n8n/config/src/index.ts b/packages/@n8n/config/src/index.ts index a5144d4196..945b5f1237 100644 --- a/packages/@n8n/config/src/index.ts +++ b/packages/@n8n/config/src/index.ts @@ -1,3 +1,4 @@ +import { AiAssistantConfig } from './configs/aiAssistant.config'; import { CacheConfig } from './configs/cache.config'; import { CredentialsConfig } from './configs/credentials.config'; import { DatabaseConfig } from './configs/database.config'; @@ -121,4 +122,7 @@ export class GlobalConfig { @Nested diagnostics: DiagnosticsConfig; + + @Nested + aiAssistant: AiAssistantConfig; } diff --git a/packages/@n8n/config/test/config.test.ts b/packages/@n8n/config/test/config.test.ts index 771d915ee4..d6d19c47fe 100644 --- a/packages/@n8n/config/test/config.test.ts +++ b/packages/@n8n/config/test/config.test.ts @@ -289,6 +289,9 @@ describe('GlobalConfig', () => { apiHost: 'https://ph.n8n.io', }, }, + aiAssistant: { + baseUrl: '', + }, }; it('should use all default values when no env variables are defined', () => { diff --git a/packages/cli/src/config/schema.ts b/packages/cli/src/config/schema.ts index 54fa07e7f5..e8d28cb782 100644 --- a/packages/cli/src/config/schema.ts +++ b/packages/cli/src/config/schema.ts @@ -341,15 +341,6 @@ export const schema = { }, }, - aiAssistant: { - baseUrl: { - doc: 'Base URL of the AI assistant service', - format: String, - default: '', - env: 'N8N_AI_ASSISTANT_BASE_URL', - }, - }, - expression: { evaluator: { doc: 'Expression evaluator to use', diff --git a/packages/cli/src/controllers/__tests__/ai.controller.test.ts b/packages/cli/src/controllers/__tests__/ai.controller.test.ts new file mode 100644 index 0000000000..30785ae938 --- /dev/null +++ b/packages/cli/src/controllers/__tests__/ai.controller.test.ts @@ -0,0 +1,111 @@ +import type { + AiAskRequestDto, + AiApplySuggestionRequestDto, + AiChatRequestDto, +} from '@n8n/api-types'; +import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk'; +import { mock } from 'jest-mock-extended'; + +import { InternalServerError } from '@/errors/response-errors/internal-server.error'; +import type { AuthenticatedRequest } from '@/requests'; +import type { AiService } from '@/services/ai.service'; + +import { AiController, type FlushableResponse } from '../ai.controller'; + +describe('AiController', () => { + const aiService = mock(); + const controller = new AiController(aiService); + + const request = mock({ + user: { id: 'user123' }, + }); + const response = mock(); + + beforeEach(() => { + jest.clearAllMocks(); + + response.header.mockReturnThis(); + }); + + describe('chat', () => { + const payload = mock(); + + it('should handle chat request successfully', async () => { + aiService.chat.mockResolvedValue( + mock({ + body: mock({ + pipeTo: jest.fn().mockImplementation(async (writableStream) => { + // Simulate stream writing + const writer = writableStream.getWriter(); + await writer.write(JSON.stringify({ message: 'test response' })); + await writer.close(); + }), + }), + }), + ); + + await controller.chat(request, response, payload); + + expect(aiService.chat).toHaveBeenCalledWith(payload, request.user); + expect(response.header).toHaveBeenCalledWith('Content-type', 'application/json-lines'); + expect(response.flush).toHaveBeenCalled(); + expect(response.end).toHaveBeenCalled(); + }); + + it('should throw InternalServerError if chat fails', async () => { + const mockError = new Error('Chat failed'); + + aiService.chat.mockRejectedValue(mockError); + + await expect(controller.chat(request, response, payload)).rejects.toThrow( + InternalServerError, + ); + }); + }); + + describe('applySuggestion', () => { + const payload = mock(); + + it('should apply suggestion successfully', async () => { + const clientResponse = mock(); + aiService.applySuggestion.mockResolvedValue(clientResponse); + + const result = await controller.applySuggestion(request, response, payload); + + expect(aiService.applySuggestion).toHaveBeenCalledWith(payload, request.user); + expect(result).toEqual(clientResponse); + }); + + it('should throw InternalServerError if applying suggestion fails', async () => { + const mockError = new Error('Apply suggestion failed'); + aiService.applySuggestion.mockRejectedValue(mockError); + + await expect(controller.applySuggestion(request, response, payload)).rejects.toThrow( + InternalServerError, + ); + }); + }); + + describe('askAi method', () => { + const payload = mock(); + + it('should ask AI successfully', async () => { + const clientResponse = mock(); + aiService.askAi.mockResolvedValue(clientResponse); + + const result = await controller.askAi(request, response, payload); + + expect(aiService.askAi).toHaveBeenCalledWith(payload, request.user); + expect(result).toEqual(clientResponse); + }); + + it('should throw InternalServerError if asking AI fails', async () => { + const mockError = new Error('Ask AI failed'); + aiService.askAi.mockRejectedValue(mockError); + + await expect(controller.askAi(request, response, payload)).rejects.toThrow( + InternalServerError, + ); + }); + }); +}); diff --git a/packages/cli/src/controllers/ai.controller.ts b/packages/cli/src/controllers/ai.controller.ts index be1231911a..59499112a3 100644 --- a/packages/cli/src/controllers/ai.controller.ts +++ b/packages/cli/src/controllers/ai.controller.ts @@ -1,23 +1,24 @@ +import { AiChatRequestDto, AiApplySuggestionRequestDto, AiAskRequestDto } from '@n8n/api-types'; import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk'; -import type { Response } from 'express'; +import { Response } from 'express'; import { strict as assert } from 'node:assert'; import { WritableStream } from 'node:stream/web'; -import { Post, RestController } from '@/decorators'; +import { Body, Post, RestController } from '@/decorators'; import { InternalServerError } from '@/errors/response-errors/internal-server.error'; -import { AiAssistantRequest } from '@/requests'; +import { AuthenticatedRequest } from '@/requests'; import { AiService } from '@/services/ai.service'; -type FlushableResponse = Response & { flush: () => void }; +export type FlushableResponse = Response & { flush: () => void }; @RestController('/ai') export class AiController { constructor(private readonly aiService: AiService) {} @Post('/chat', { rateLimit: { limit: 100 } }) - async chat(req: AiAssistantRequest.Chat, res: FlushableResponse) { + async chat(req: AuthenticatedRequest, res: FlushableResponse, @Body payload: AiChatRequestDto) { try { - const aiResponse = await this.aiService.chat(req.body, req.user); + const aiResponse = await this.aiService.chat(payload, req.user); if (aiResponse.body) { res.header('Content-type', 'application/json-lines').flush(); await aiResponse.body.pipeTo( @@ -38,10 +39,12 @@ export class AiController { @Post('/chat/apply-suggestion') async applySuggestion( - req: AiAssistantRequest.ApplySuggestionPayload, + req: AuthenticatedRequest, + _: Response, + @Body payload: AiApplySuggestionRequestDto, ): Promise { try { - return await this.aiService.applySuggestion(req.body, req.user); + return await this.aiService.applySuggestion(payload, req.user); } catch (e) { assert(e instanceof Error); throw new InternalServerError(e.message, e); @@ -49,9 +52,13 @@ export class AiController { } @Post('/ask-ai') - async askAi(req: AiAssistantRequest.AskAiPayload): Promise { + async askAi( + req: AuthenticatedRequest, + _: Response, + @Body payload: AiAskRequestDto, + ): Promise { try { - return await this.aiService.askAi(req.body, req.user); + return await this.aiService.askAi(payload, req.user); } catch (e) { assert(e instanceof Error); throw new InternalServerError(e.message, e); diff --git a/packages/cli/src/requests.ts b/packages/cli/src/requests.ts index 7afb1e1bd3..f7ac415a75 100644 --- a/packages/cli/src/requests.ts +++ b/packages/cli/src/requests.ts @@ -1,5 +1,4 @@ import type { Scope } from '@n8n/permissions'; -import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk'; import type express from 'express'; import type { BannerName, @@ -574,15 +573,3 @@ export declare namespace NpsSurveyRequest { // once some schema validation is added type NpsSurveyUpdate = AuthenticatedRequest<{}, {}, unknown>; } - -// ---------------------------------- -// /ai-assistant -// ---------------------------------- - -export declare namespace AiAssistantRequest { - type Chat = AuthenticatedRequest<{}, {}, AiAssistantSDK.ChatRequestPayload>; - - type SuggestionPayload = { sessionId: string; suggestionId: string }; - type ApplySuggestionPayload = AuthenticatedRequest<{}, {}, SuggestionPayload>; - type AskAiPayload = AuthenticatedRequest<{}, {}, AiAssistantSDK.AskAiRequestPayload>; -} diff --git a/packages/cli/src/services/__tests__/ai.service.test.ts b/packages/cli/src/services/__tests__/ai.service.test.ts new file mode 100644 index 0000000000..dbdcaa3e71 --- /dev/null +++ b/packages/cli/src/services/__tests__/ai.service.test.ts @@ -0,0 +1,132 @@ +import type { + AiAskRequestDto, + AiApplySuggestionRequestDto, + AiChatRequestDto, +} from '@n8n/api-types'; +import type { GlobalConfig } from '@n8n/config'; +import { AiAssistantClient, type AiAssistantSDK } from '@n8n_io/ai-assistant-sdk'; +import { mock } from 'jest-mock-extended'; +import type { IUser } from 'n8n-workflow'; + +import { N8N_VERSION } from '@/constants'; +import type { License } from '@/license'; + +import { AiService } from '../ai.service'; + +jest.mock('@n8n_io/ai-assistant-sdk', () => ({ + AiAssistantClient: jest.fn(), +})); + +describe('AiService', () => { + let aiService: AiService; + + const baseUrl = 'https://ai-assistant-url.com'; + const user = mock({ id: 'user123' }); + const client = mock(); + const license = mock(); + const globalConfig = mock({ + logging: { level: 'info' }, + aiAssistant: { baseUrl }, + }); + + beforeEach(() => { + jest.clearAllMocks(); + (AiAssistantClient as jest.Mock).mockImplementation(() => client); + aiService = new AiService(license, globalConfig); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('init', () => { + it('should not initialize client if AI assistant is not enabled', async () => { + license.isAiAssistantEnabled.mockReturnValue(false); + + await aiService.init(); + + expect(AiAssistantClient).not.toHaveBeenCalled(); + }); + + it('should initialize client when AI assistant is enabled', async () => { + license.isAiAssistantEnabled.mockReturnValue(true); + license.loadCertStr.mockResolvedValue('mock-license-cert'); + license.getConsumerId.mockReturnValue('mock-consumer-id'); + + await aiService.init(); + + expect(AiAssistantClient).toHaveBeenCalledWith({ + licenseCert: 'mock-license-cert', + consumerId: 'mock-consumer-id', + n8nVersion: N8N_VERSION, + baseUrl, + logLevel: 'info', + }); + }); + }); + + describe('chat', () => { + const payload = mock(); + + it('should call client chat method after initialization', async () => { + license.isAiAssistantEnabled.mockReturnValue(true); + const clientResponse = mock(); + client.chat.mockResolvedValue(clientResponse); + + const result = await aiService.chat(payload, user); + + expect(client.chat).toHaveBeenCalledWith(payload, { id: user.id }); + expect(result).toEqual(clientResponse); + }); + + it('should throw error if client is not initialized', async () => { + license.isAiAssistantEnabled.mockReturnValue(false); + + await expect(aiService.chat(payload, user)).rejects.toThrow('Assistant client not setup'); + }); + }); + + describe('applySuggestion', () => { + const payload = mock(); + + it('should call client applySuggestion', async () => { + license.isAiAssistantEnabled.mockReturnValue(true); + const clientResponse = mock(); + client.applySuggestion.mockResolvedValue(clientResponse); + + const result = await aiService.applySuggestion(payload, user); + + expect(client.applySuggestion).toHaveBeenCalledWith(payload, { id: user.id }); + expect(result).toEqual(clientResponse); + }); + + it('should throw error if client is not initialized', async () => { + license.isAiAssistantEnabled.mockReturnValue(false); + + await expect(aiService.applySuggestion(payload, user)).rejects.toThrow( + 'Assistant client not setup', + ); + }); + }); + + describe('askAi', () => { + const payload = mock(); + + it('should call client askAi method after initialization', async () => { + license.isAiAssistantEnabled.mockReturnValue(true); + const clientResponse = mock(); + client.askAi.mockResolvedValue(clientResponse); + + const result = await aiService.askAi(payload, user); + + expect(client.askAi).toHaveBeenCalledWith(payload, { id: user.id }); + expect(result).toEqual(clientResponse); + }); + + it('should throw error if client is not initialized', async () => { + license.isAiAssistantEnabled.mockReturnValue(false); + + await expect(aiService.askAi(payload, user)).rejects.toThrow('Assistant client not setup'); + }); + }); +}); diff --git a/packages/cli/src/services/ai.service.ts b/packages/cli/src/services/ai.service.ts index a7b07219b5..74e28ad288 100644 --- a/packages/cli/src/services/ai.service.ts +++ b/packages/cli/src/services/ai.service.ts @@ -1,12 +1,13 @@ +import type { + AiApplySuggestionRequestDto, + AiAskRequestDto, + AiChatRequestDto, +} from '@n8n/api-types'; import { GlobalConfig } from '@n8n/config'; -import type { AiAssistantSDK } from '@n8n_io/ai-assistant-sdk'; import { AiAssistantClient } from '@n8n_io/ai-assistant-sdk'; import { assert, type IUser } from 'n8n-workflow'; import { Service } from 'typedi'; -import config from '@/config'; -import type { AiAssistantRequest } from '@/requests'; - import { N8N_VERSION } from '../constants'; import { License } from '../license'; @@ -27,7 +28,7 @@ export class AiService { const licenseCert = await this.licenseService.loadCertStr(); const consumerId = this.licenseService.getConsumerId(); - const baseUrl = config.get('aiAssistant.baseUrl'); + const baseUrl = this.globalConfig.aiAssistant.baseUrl; const logLevel = this.globalConfig.logging.level; this.client = new AiAssistantClient({ @@ -39,7 +40,7 @@ export class AiService { }); } - async chat(payload: AiAssistantSDK.ChatRequestPayload, user: IUser) { + async chat(payload: AiChatRequestDto, user: IUser) { if (!this.client) { await this.init(); } @@ -48,7 +49,7 @@ export class AiService { return await this.client.chat(payload, { id: user.id }); } - async applySuggestion(payload: AiAssistantRequest.SuggestionPayload, user: IUser) { + async applySuggestion(payload: AiApplySuggestionRequestDto, user: IUser) { if (!this.client) { await this.init(); } @@ -57,7 +58,7 @@ export class AiService { return await this.client.applySuggestion(payload, { id: user.id }); } - async askAi(payload: AiAssistantSDK.AskAiRequestPayload, user: IUser) { + async askAi(payload: AiAskRequestDto, user: IUser) { if (!this.client) { await this.init(); }