From afaa0bec71d6eb8f6614a83af8083da220ec0d79 Mon Sep 17 00:00:00 2001 From: Mutasem Aldmour <4711238+mutdmour@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:50:53 +0200 Subject: [PATCH] fix: Handle AI errors better in builder (no-changelog) (#18406) --- .../src/ai-workflow-builder-agent.service.ts | 2 +- .../src/test/workflow-builder-agent.test.ts | 362 ++++++++++++++++++ .../src/workflow-builder-agent.ts | 155 ++++++-- .../messages/ErrorMessage.vue | 1 + .../composables/useBuilderMessages.test.ts | 238 ++++++++++++ .../src/composables/useBuilderMessages.ts | 23 +- .../editor-ui/src/stores/builder.store.ts | 1 + 7 files changed, 742 insertions(+), 40 deletions(-) create mode 100644 packages/@n8n/ai-workflow-builder.ee/src/test/workflow-builder-agent.test.ts diff --git a/packages/@n8n/ai-workflow-builder.ee/src/ai-workflow-builder-agent.service.ts b/packages/@n8n/ai-workflow-builder.ee/src/ai-workflow-builder-agent.service.ts index 8c88881aed..90a64f684d 100644 --- a/packages/@n8n/ai-workflow-builder.ee/src/ai-workflow-builder-agent.service.ts +++ b/packages/@n8n/ai-workflow-builder.ee/src/ai-workflow-builder-agent.service.ts @@ -88,7 +88,7 @@ export class AiWorkflowBuilderService { }, }); } catch (error) { - const llmError = new LLMServiceError('Failed to setup LLM models', { + const llmError = new LLMServiceError('Failed to connect to LLM Provider', { cause: error, tags: { hasClient: !!this.client, diff --git a/packages/@n8n/ai-workflow-builder.ee/src/test/workflow-builder-agent.test.ts b/packages/@n8n/ai-workflow-builder.ee/src/test/workflow-builder-agent.test.ts new file mode 100644 index 0000000000..766c959ad0 --- /dev/null +++ b/packages/@n8n/ai-workflow-builder.ee/src/test/workflow-builder-agent.test.ts @@ -0,0 +1,362 @@ +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import type { ToolMessage } from '@langchain/core/messages'; +import { AIMessage, HumanMessage } from '@langchain/core/messages'; +import type { MemorySaver } from '@langchain/langgraph'; +import { GraphRecursionError } from '@langchain/langgraph'; +import type { Logger } from '@n8n/backend-common'; +import { mock } from 'jest-mock-extended'; +import type { INodeTypeDescription } from 'n8n-workflow'; +import { ApplicationError } from 'n8n-workflow'; + +import { MAX_AI_BUILDER_PROMPT_LENGTH } from '@/constants'; +import { ValidationError } from '@/errors'; +import type { StreamOutput } from '@/types/streaming'; +import { createStreamProcessor, formatMessages } from '@/utils/stream-processor'; +import { + WorkflowBuilderAgent, + type WorkflowBuilderAgentConfig, + type ChatPayload, +} from '@/workflow-builder-agent'; + +jest.mock('@/tools/add-node.tool', () => ({ + createAddNodeTool: jest.fn().mockReturnValue({ name: 'add_node' }), +})); +jest.mock('@/tools/connect-nodes.tool', () => ({ + createConnectNodesTool: jest.fn().mockReturnValue({ name: 'connect_nodes' }), +})); +jest.mock('@/tools/node-details.tool', () => ({ + createNodeDetailsTool: jest.fn().mockReturnValue({ name: 'node_details' }), +})); +jest.mock('@/tools/node-search.tool', () => ({ + createNodeSearchTool: jest.fn().mockReturnValue({ name: 'node_search' }), +})); +jest.mock('@/tools/remove-node.tool', () => ({ + createRemoveNodeTool: jest.fn().mockReturnValue({ name: 'remove_node' }), +})); +jest.mock('@/tools/update-node-parameters.tool', () => ({ + createUpdateNodeParametersTool: jest.fn().mockReturnValue({ name: 'update_node_parameters' }), +})); +jest.mock('@/tools/prompts/main-agent.prompt', () => ({ + mainAgentPrompt: { + invoke: jest.fn().mockResolvedValue('mocked prompt'), + }, +})); +jest.mock('@/utils/operations-processor', () => ({ + processOperations: jest.fn(), +})); + +jest.mock('@/utils/stream-processor', () => ({ + createStreamProcessor: jest.fn(), + formatMessages: jest.fn(), +})); +jest.mock('@/utils/tool-executor', () => ({ + executeToolsInParallel: jest.fn(), +})); +jest.mock('@/chains/conversation-compact', () => ({ + conversationCompactChain: jest.fn(), +})); + +const mockRandomUUID = jest.fn(); +Object.defineProperty(global, 'crypto', { + value: { + randomUUID: mockRandomUUID, + }, + writable: true, +}); + +describe('WorkflowBuilderAgent', () => { + let agent: WorkflowBuilderAgent; + let mockLlmSimple: BaseChatModel; + let mockLlmComplex: BaseChatModel; + let mockLogger: Logger; + let mockCheckpointer: MemorySaver; + let parsedNodeTypes: INodeTypeDescription[]; + let config: WorkflowBuilderAgentConfig; + + const mockCreateStreamProcessor = createStreamProcessor as jest.MockedFunction< + typeof createStreamProcessor + >; + const mockFormatMessages = formatMessages as jest.MockedFunction; + + beforeEach(() => { + mockLlmSimple = mock({ + _llmType: jest.fn().mockReturnValue('test-llm'), + bindTools: jest.fn().mockReturnThis(), + invoke: jest.fn(), + }); + + mockLlmComplex = mock({ + _llmType: jest.fn().mockReturnValue('test-llm-complex'), + bindTools: jest.fn().mockReturnThis(), + invoke: jest.fn(), + }); + + mockLogger = mock({ + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }); + + mockCheckpointer = mock(); + mockCheckpointer.getTuple = jest.fn(); + mockCheckpointer.put = jest.fn(); + mockCheckpointer.list = jest.fn(); + + parsedNodeTypes = [ + { + name: 'TestNode', + displayName: 'Test Node', + description: 'A test node', + version: 1, + defaults: {}, + inputs: [], + outputs: [], + properties: [], + group: ['transform'], + } as INodeTypeDescription, + ]; + + config = { + parsedNodeTypes, + llmSimpleTask: mockLlmSimple, + llmComplexTask: mockLlmComplex, + logger: mockLogger, + checkpointer: mockCheckpointer, + }; + + agent = new WorkflowBuilderAgent(config); + }); + + describe('generateThreadId', () => { + beforeEach(() => { + mockRandomUUID.mockReset(); + }); + + it('should generate thread ID with workflowId and userId', () => { + const workflowId = 'workflow-123'; + const userId = 'user-456'; + const threadId = WorkflowBuilderAgent.generateThreadId(workflowId, userId); + expect(threadId).toBe('workflow-workflow-123-user-user-456'); + }); + + it('should generate thread ID with workflowId but without userId', () => { + const workflowId = 'workflow-123'; + const threadId = WorkflowBuilderAgent.generateThreadId(workflowId); + expect(threadId).toMatch(/^workflow-workflow-123-user-\d+$/); + }); + + it('should generate random UUID when no workflowId provided', () => { + const mockUuid = 'test-uuid-1234-5678-9012'; + mockRandomUUID.mockReturnValue(mockUuid); + + const threadId = WorkflowBuilderAgent.generateThreadId(); + + expect(mockRandomUUID).toHaveBeenCalled(); + expect(threadId).toBe(mockUuid); + }); + }); + + describe('chat method', () => { + let mockPayload: ChatPayload; + + beforeEach(() => { + mockPayload = { + message: 'Create a workflow', + workflowContext: { + currentWorkflow: { id: 'workflow-123' }, + }, + }; + }); + + it('should throw ValidationError when message exceeds maximum length', async () => { + const longMessage = 'x'.repeat(MAX_AI_BUILDER_PROMPT_LENGTH + 1); + const payload: ChatPayload = { + message: longMessage, + }; + + await expect(async () => { + const generator = agent.chat(payload); + await generator.next(); + }).rejects.toThrow(ValidationError); + + expect(mockLogger.warn).toHaveBeenCalledWith('Message exceeds maximum length', { + messageLength: longMessage.length, + maxLength: MAX_AI_BUILDER_PROMPT_LENGTH, + }); + }); + + it('should handle valid message length', async () => { + const validMessage = 'Create a simple workflow'; + const payload: ChatPayload = { + message: validMessage, + }; + + // Mock the stream processing to return a proper StreamOutput + const mockStreamOutput: StreamOutput = { + messages: [ + { + role: 'assistant', + type: 'message', + text: 'Processing...', + }, + ], + }; + const mockAsyncGenerator = (async function* () { + yield mockStreamOutput; + })(); + + mockCreateStreamProcessor.mockReturnValue(mockAsyncGenerator); + + // Mock the LLM to return a simple response + (mockLlmSimple.invoke as jest.Mock).mockResolvedValue({ + content: 'Mocked response', + tool_calls: [], + }); + + const generator = agent.chat(payload); + const result = await generator.next(); + + expect(result.value).toEqual(mockStreamOutput); + }); + + it('should handle GraphRecursionError', async () => { + mockCreateStreamProcessor.mockImplementation(() => { + // eslint-disable-next-line require-yield + return (async function* () { + throw new GraphRecursionError('Recursion limit exceeded'); + })(); + }); + + await expect(async () => { + const generator = agent.chat(mockPayload); + await generator.next(); + }).rejects.toThrow(ApplicationError); + }); + + it('should handle invalid request errors', async () => { + const invalidRequestError = Object.assign(new Error('Request failed'), { + error: { + error: { + type: 'invalid_request_error', + message: 'Invalid API request', + }, + }, + }); + + (mockLlmSimple.invoke as jest.Mock).mockRejectedValue(invalidRequestError); + + await expect(async () => { + const generator = agent.chat(mockPayload); + await generator.next(); + }).rejects.toThrow(ApplicationError); + }); + + it('should rethrow unknown errors', async () => { + const unknownError = new Error('Unknown error'); + + // Mock createStreamProcessor to throw an unknown error (not GraphRecursionError or abort) + mockCreateStreamProcessor.mockImplementation(() => { + // eslint-disable-next-line require-yield + return (async function* () { + throw unknownError; + })(); + }); + + await expect(async () => { + const generator = agent.chat(mockPayload); + await generator.next(); + }).rejects.toThrow(unknownError); + }); + }); + + describe('getSessions', () => { + beforeEach(() => { + mockFormatMessages.mockImplementation( + (messages: Array) => + messages.map((m) => ({ type: m.constructor.name.toLowerCase(), content: m.content })), + ); + }); + + it('should return session for existing workflowId', async () => { + const workflowId = 'workflow-123'; + const userId = 'user-456'; + const mockCheckpoint = { + checkpoint: { + channel_values: { + messages: [new HumanMessage('Hello'), new AIMessage('Hi there!')], + }, + ts: '2023-12-01T12:00:00Z', + }, + }; + + (mockCheckpointer.getTuple as jest.Mock).mockResolvedValue(mockCheckpoint); + + const result = await agent.getSessions(workflowId, userId); + + expect(result.sessions).toHaveLength(1); + expect(result.sessions[0]).toMatchObject({ + sessionId: 'workflow-workflow-123-user-user-456', + lastUpdated: '2023-12-01T12:00:00Z', + }); + expect(result.sessions[0].messages).toHaveLength(2); + }); + + it('should return empty sessions when workflowId is undefined', async () => { + const result = await agent.getSessions(undefined); + + expect(result.sessions).toHaveLength(0); + expect(mockCheckpointer.getTuple).not.toHaveBeenCalled(); + }); + + it('should return empty sessions when no checkpoint exists', async () => { + const workflowId = 'workflow-123'; + (mockCheckpointer.getTuple as jest.Mock).mockRejectedValue(new Error('Thread not found')); + + const result = await agent.getSessions(workflowId); + + expect(result.sessions).toHaveLength(0); + expect(mockLogger.debug).toHaveBeenCalledWith('No session found for workflow:', { + workflowId, + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + error: expect.any(Error), + }); + }); + + it('should handle checkpoint without messages', async () => { + const workflowId = 'workflow-123'; + const mockCheckpoint = { + checkpoint: { + channel_values: {}, + ts: '2023-12-01T12:00:00Z', + }, + }; + + (mockCheckpointer.getTuple as jest.Mock).mockResolvedValue(mockCheckpoint); + + const result = await agent.getSessions(workflowId); + + expect(result.sessions).toHaveLength(1); + expect(result.sessions[0].messages).toHaveLength(0); + }); + + it('should handle checkpoint with null messages', async () => { + const workflowId = 'workflow-123'; + const mockCheckpoint = { + checkpoint: { + channel_values: { + messages: null, + }, + ts: '2023-12-01T12:00:00Z', + }, + }; + + (mockCheckpointer.getTuple as jest.Mock).mockResolvedValue(mockCheckpoint); + + const result = await agent.getSessions(workflowId); + + expect(result.sessions).toHaveLength(1); + expect(result.sessions[0].messages).toHaveLength(0); + }); + }); +}); diff --git a/packages/@n8n/ai-workflow-builder.ee/src/workflow-builder-agent.ts b/packages/@n8n/ai-workflow-builder.ee/src/workflow-builder-agent.ts index 8a5aaa5c98..da596021bd 100644 --- a/packages/@n8n/ai-workflow-builder.ee/src/workflow-builder-agent.ts +++ b/packages/@n8n/ai-workflow-builder.ee/src/workflow-builder-agent.ts @@ -3,13 +3,14 @@ import type { ToolMessage } from '@langchain/core/messages'; import { AIMessage, HumanMessage, RemoveMessage } from '@langchain/core/messages'; import type { RunnableConfig } from '@langchain/core/runnables'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; -import { StateGraph, MemorySaver, END } from '@langchain/langgraph'; +import { StateGraph, MemorySaver, END, GraphRecursionError } from '@langchain/langgraph'; import type { Logger } from '@n8n/backend-common'; -import type { - INodeTypeDescription, - IRunExecutionData, - IWorkflowBase, - NodeExecutionSchema, +import { + ApplicationError, + type INodeTypeDescription, + type IRunExecutionData, + type IWorkflowBase, + type NodeExecutionSchema, } from 'n8n-workflow'; import { workflowNameChain } from '@/chains/workflow-name'; @@ -148,7 +149,7 @@ export class WorkflowBuilderAgent { }; const shouldContinue = ({ messages }: typeof WorkflowState.State) => { - const lastMessage = messages[messages.length - 1] as AIMessage; + const lastMessage: AIMessage = messages[messages.length - 1]; if (lastMessage.tool_calls?.length) { return 'tools'; @@ -295,10 +296,26 @@ export class WorkflowBuilderAgent { } async *chat(payload: ChatPayload, userId?: string, abortSignal?: AbortSignal) { - // Check for the message maximum length - if (payload.message.length > MAX_AI_BUILDER_PROMPT_LENGTH) { + this.validateMessageLength(payload.message); + + const { agent, threadConfig, streamConfig } = this.setupAgentAndConfigs( + payload, + userId, + abortSignal, + ); + + try { + const stream = await this.createAgentStream(payload, streamConfig, agent); + yield* this.processAgentStream(stream, agent, threadConfig); + } catch (error: unknown) { + this.handleStreamError(error); + } + } + + private validateMessageLength(message: string): void { + if (message.length > MAX_AI_BUILDER_PROMPT_LENGTH) { this.logger?.warn('Message exceeds maximum length', { - messageLength: payload.message.length, + messageLength: message.length, maxLength: MAX_AI_BUILDER_PROMPT_LENGTH, }); @@ -306,7 +323,9 @@ export class WorkflowBuilderAgent { `Message exceeds maximum length of ${MAX_AI_BUILDER_PROMPT_LENGTH} characters`, ); } + } + private setupAgentAndConfigs(payload: ChatPayload, userId?: string, abortSignal?: AbortSignal) { const agent = this.createWorkflow().compile({ checkpointer: this.checkpointer }); const workflowId = payload.workflowContext?.currentWorkflow?.id; // Generate thread ID from workflowId and userId @@ -320,12 +339,20 @@ export class WorkflowBuilderAgent { const streamConfig = { ...threadConfig, streamMode: ['updates', 'custom'], - recursionLimit: 30, + recursionLimit: 50, signal: abortSignal, callbacks: this.tracer ? [this.tracer] : undefined, - } as RunnableConfig; + }; - const stream = await agent.stream( + return { agent, threadConfig, streamConfig }; + } + + private async createAgentStream( + payload: ChatPayload, + streamConfig: RunnableConfig, + agent: ReturnType['compile']>, + ) { + return await agent.stream( { messages: [new HumanMessage({ content: payload.message })], workflowJSON: this.getDefaultWorkflowJSON(payload), @@ -334,39 +361,95 @@ export class WorkflowBuilderAgent { }, streamConfig, ); + } + private handleStreamError(error: unknown): never { + const invalidRequestErrorMessage = this.getInvalidRequestError(error); + if (invalidRequestErrorMessage) { + throw new ValidationError(invalidRequestErrorMessage); + } + + throw error; + } + + private async *processAgentStream( + stream: AsyncGenerator<[string, unknown], void, unknown>, + agent: ReturnType['compile']>, + threadConfig: RunnableConfig, + ) { try { const streamProcessor = createStreamProcessor(stream); for await (const output of streamProcessor) { yield output; } } catch (error) { - if ( - error && - typeof error === 'object' && - 'message' in error && - typeof error.message === 'string' && - // This is naive, but it's all we get from LangGraph AbortError - ['Abort', 'Aborted'].includes(error.message) - ) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access - const messages = (await agent.getState(threadConfig)).values.messages as Array< - AIMessage | HumanMessage | ToolMessage - >; - - // Handle abort errors gracefully - const abortedAiMessage = new AIMessage({ - content: '[Task aborted]', - id: crypto.randomUUID(), - }); - // TODO: Should we clear tool calls that are in progress? - await agent.updateState(threadConfig, { messages: [...messages, abortedAiMessage] }); - return; - } - throw error; + await this.handleAgentStreamError(error, agent, threadConfig); } } + private async handleAgentStreamError( + error: unknown, + agent: ReturnType['compile']>, + threadConfig: RunnableConfig, + ): Promise { + if ( + error && + typeof error === 'object' && + 'message' in error && + typeof error.message === 'string' && + // This is naive, but it's all we get from LangGraph AbortError + ['Abort', 'Aborted'].includes(error.message) + ) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access + const messages = (await agent.getState(threadConfig)).values.messages as Array< + AIMessage | HumanMessage | ToolMessage + >; + + // Handle abort errors gracefully + const abortedAiMessage = new AIMessage({ + content: '[Task aborted]', + id: crypto.randomUUID(), + }); + // TODO: Should we clear tool calls that are in progress? + await agent.updateState(threadConfig, { messages: [...messages, abortedAiMessage] }); + return; + } + + // If it's not an abort error, check for GraphRecursionError + if (error instanceof GraphRecursionError) { + throw new ApplicationError( + 'Workflow generation stopped: The AI reached the maximum number of steps while building your workflow. This usually means the workflow design became too complex or got stuck in a loop while trying to create the nodes and connections.', + ); + } + + // Re-throw any other errors + throw error; + } + + private getInvalidRequestError(error: unknown): string | undefined { + if ( + error instanceof Error && + 'error' in error && + typeof error.error === 'object' && + error.error + ) { + const innerError = error.error; + if ('error' in innerError && typeof innerError.error === 'object' && innerError.error) { + const errorDetails = innerError.error; + if ( + 'type' in errorDetails && + errorDetails.type === 'invalid_request_error' && + 'message' in errorDetails && + typeof errorDetails.message === 'string' + ) { + return errorDetails.message; + } + } + } + + return undefined; + } + async getSessions(workflowId: string | undefined, userId?: string) { // For now, we'll return the current session if we have a workflowId // MemorySaver doesn't expose a way to list all threads, so we'll need to diff --git a/packages/frontend/@n8n/design-system/src/components/AskAssistantChat/messages/ErrorMessage.vue b/packages/frontend/@n8n/design-system/src/components/AskAssistantChat/messages/ErrorMessage.vue index cf1288fc8e..f11e20e402 100644 --- a/packages/frontend/@n8n/design-system/src/components/AskAssistantChat/messages/ErrorMessage.vue +++ b/packages/frontend/@n8n/design-system/src/components/AskAssistantChat/messages/ErrorMessage.vue @@ -60,6 +60,7 @@ const { t } = useI18n(); font-weight: var(--font-weight-regular); line-height: var(--font-line-height-tight); word-break: break-word; + flex-grow: 1; } .retryButton { diff --git a/packages/frontend/editor-ui/src/composables/useBuilderMessages.test.ts b/packages/frontend/editor-ui/src/composables/useBuilderMessages.test.ts index de5e839f8e..d66eb8c66e 100644 --- a/packages/frontend/editor-ui/src/composables/useBuilderMessages.test.ts +++ b/packages/frontend/editor-ui/src/composables/useBuilderMessages.test.ts @@ -1453,6 +1453,244 @@ describe('useBuilderMessages', () => { }); }); + describe('error message handling with retry', () => { + it('should pass retry function to error messages from processAssistantMessages', () => { + const retryFn = vi.fn(async () => {}); + const currentMessages: ChatUI.AssistantMessage[] = []; + const newMessages: ChatRequest.MessageResponse[] = [ + { + type: 'error', + role: 'assistant', + content: 'Something went wrong', + }, + ]; + + const result = builderMessages.processAssistantMessages( + currentMessages, + newMessages, + 'test-id', + retryFn, + ); + + expect(result.messages).toHaveLength(1); + const errorMessage = result.messages[0] as ChatUI.ErrorMessage; + expect(errorMessage).toMatchObject({ + id: 'test-id-0', + role: 'assistant', + type: 'error', + content: 'Something went wrong', + read: false, + }); + expect(errorMessage.retry).toBe(retryFn); + }); + + it('should not pass retry function to non-error messages', () => { + const retryFn = vi.fn(async () => {}); + const currentMessages: ChatUI.AssistantMessage[] = []; + const newMessages: ChatRequest.MessageResponse[] = [ + { + type: 'message', + role: 'assistant', + text: 'This is a normal text message', + }, + { + role: 'assistant', + type: 'tool', + toolName: 'add_nodes', + toolCallId: 'call-1', + status: 'running', + updates: [], + }, + ]; + + const result = builderMessages.processAssistantMessages( + currentMessages, + newMessages, + 'test-id', + retryFn, + ); + + expect(result.messages).toHaveLength(2); + + const textMessage = result.messages[0] as ChatUI.TextMessage; + expect(textMessage.type).toBe('text'); + expect('retry' in textMessage).toBe(false); + + const toolMessage = result.messages[1] as ChatUI.ToolMessage; + expect(toolMessage.type).toBe('tool'); + expect('retry' in toolMessage).toBe(false); + }); + + it('should clear retry from previous error messages when processing new messages', () => { + const oldRetryFn = vi.fn(async () => {}); + const newRetryFn = vi.fn(async () => {}); + + const currentMessages: ChatUI.AssistantMessage[] = [ + { + id: 'error-1', + role: 'assistant', + type: 'error', + content: 'First error', + retry: oldRetryFn, + read: false, + } as ChatUI.ErrorMessage, + { + id: 'error-2', + role: 'assistant', + type: 'error', + content: 'Second error', + retry: oldRetryFn, + read: false, + } as ChatUI.ErrorMessage, + ]; + + const newMessages: ChatRequest.MessageResponse[] = [ + { + type: 'error', + role: 'assistant', + content: 'New error', + }, + ]; + + const result = builderMessages.processAssistantMessages( + currentMessages, + newMessages, + 'test-id', + newRetryFn, + ); + + expect(result.messages).toHaveLength(3); + + // First error should have retry removed + const firstError = result.messages[0] as ChatUI.ErrorMessage; + expect(firstError.content).toBe('First error'); + expect('retry' in firstError).toBe(false); + + // Second error should have retry removed + const secondError = result.messages[1] as ChatUI.ErrorMessage; + expect(secondError.content).toBe('Second error'); + expect('retry' in secondError).toBe(false); + + // New error should have the new retry function + const newError = result.messages[2] as ChatUI.ErrorMessage; + expect(newError.content).toBe('New error'); + expect(newError.retry).toBe(newRetryFn); + }); + + it('should only keep retry on the last error message when multiple errors exist', () => { + const retryFn = vi.fn(async () => {}); + const currentMessages: ChatUI.AssistantMessage[] = []; + const newMessages: ChatRequest.MessageResponse[] = [ + { + type: 'error', + role: 'assistant', + content: 'First error in batch', + }, + { + type: 'error', + role: 'assistant', + content: 'Second error in batch', + }, + { + type: 'error', + role: 'assistant', + content: 'Third error in batch', + }, + ]; + + const result = builderMessages.processAssistantMessages( + currentMessages, + newMessages, + 'test-id', + retryFn, + ); + + expect(result.messages).toHaveLength(3); + + // First error should not have retry + const firstError = result.messages[0] as ChatUI.ErrorMessage; + expect(firstError.content).toBe('First error in batch'); + expect('retry' in firstError).toBe(false); + + // Second error should not have retry + const secondError = result.messages[1] as ChatUI.ErrorMessage; + expect(secondError.content).toBe('Second error in batch'); + expect('retry' in secondError).toBe(false); + + // Only the last error should have retry + const lastError = result.messages[2] as ChatUI.ErrorMessage; + expect(lastError.content).toBe('Third error in batch'); + expect(lastError.retry).toBe(retryFn); + }); + + it('should handle mixed message types and only affect error messages with retry logic', () => { + const retryFn = vi.fn(async () => {}); + const currentMessages: ChatUI.AssistantMessage[] = [ + { + id: 'msg-1', + role: 'assistant', + type: 'text', + content: 'Normal message', + read: false, + }, + { + id: 'error-1', + role: 'assistant', + type: 'error', + content: 'Old error', + retry: retryFn, + read: false, + } as ChatUI.ErrorMessage, + ]; + + const newMessages: ChatRequest.MessageResponse[] = [ + { + type: 'message', + role: 'assistant', + text: 'New text message', + }, + { + type: 'error', + role: 'assistant', + content: 'New error message', + }, + ]; + + const result = builderMessages.processAssistantMessages( + currentMessages, + newMessages, + 'test-id', + retryFn, + ); + + expect(result.messages).toHaveLength(4); + + // Normal text message should be unchanged + expect(result.messages[0]).toMatchObject({ + type: 'text', + content: 'Normal message', + }); + expect('retry' in result.messages[0]).toBe(false); + + // Old error should have retry removed + const oldError = result.messages[1] as ChatUI.ErrorMessage; + expect(oldError.content).toBe('Old error'); + expect('retry' in oldError).toBe(false); + + // New text message should not have retry + expect(result.messages[2]).toMatchObject({ + type: 'text', + content: 'New text message', + }); + expect('retry' in result.messages[2]).toBe(false); + + // Only the new error should have retry + const newError = result.messages[3] as ChatUI.ErrorMessage; + expect(newError.content).toBe('New error message'); + expect(newError.retry).toBe(retryFn); + }); + }); + describe('applyRatingLogic', () => { it('should apply rating to the last assistant text message after workflow-updated when no tools are running', () => { const messages: ChatUI.AssistantMessage[] = [ diff --git a/packages/frontend/editor-ui/src/composables/useBuilderMessages.ts b/packages/frontend/editor-ui/src/composables/useBuilderMessages.ts index 42c57b7fd5..b71f2fcad6 100644 --- a/packages/frontend/editor-ui/src/composables/useBuilderMessages.ts +++ b/packages/frontend/editor-ui/src/composables/useBuilderMessages.ts @@ -97,6 +97,7 @@ export function useBuilderMessages() { messages: ChatUI.AssistantMessage[], msg: ChatRequest.MessageResponse, messageId: string, + retry?: () => Promise, ): boolean { let shouldClearThinking = false; @@ -127,6 +128,7 @@ export function useBuilderMessages() { type: 'error', content: msg.content, read: false, + retry, }); shouldClearThinking = true; } @@ -248,6 +250,7 @@ export function useBuilderMessages() { currentMessages: ChatUI.AssistantMessage[], newMessages: ChatRequest.MessageResponse[], baseId: string, + retry?: () => Promise, ): MessageProcessingResult { const mutableMessages = [...currentMessages]; let shouldClearThinking = false; @@ -255,22 +258,36 @@ export function useBuilderMessages() { newMessages.forEach((msg, index) => { // Generate unique ID for each message in the batch const messageId = `${baseId}-${index}`; - const clearThinking = processSingleMessage(mutableMessages, msg, messageId); + const clearThinking = processSingleMessage(mutableMessages, msg, messageId, retry); shouldClearThinking = shouldClearThinking || clearThinking; }); const thinkingMessage = determineThinkingMessage(mutableMessages); // Apply rating logic only to messages after workflow-updated - const finalMessages = applyRatingLogic(mutableMessages); + const messagesWithRatingLogic = applyRatingLogic(mutableMessages); + + // Remove retry from all error messages except the last one + const messagesWithRetryLogic = removeRetryFromOldErrorMessages(messagesWithRatingLogic); return { - messages: finalMessages, + messages: messagesWithRetryLogic, thinkingMessage, shouldClearThinking: shouldClearThinking && mutableMessages.length > currentMessages.length, }; } + function removeRetryFromOldErrorMessages(messages: ChatUI.AssistantMessage[]) { + // Remove retry from all error messages except the last one + return messages.map((message, index) => { + if (message.type === 'error' && message.retry && index !== messages.length - 1) { + const { retry, ...messageWithoutRetry } = message; + return messageWithoutRetry; + } + return message; + }); + } + function createUserMessage(content: string, id: string): ChatUI.AssistantMessage { return { id, diff --git a/packages/frontend/editor-ui/src/stores/builder.store.ts b/packages/frontend/editor-ui/src/stores/builder.store.ts index 14b82254c2..121d52335e 100644 --- a/packages/frontend/editor-ui/src/stores/builder.store.ts +++ b/packages/frontend/editor-ui/src/stores/builder.store.ts @@ -286,6 +286,7 @@ export const useBuilderStore = defineStore(STORES.BUILDER, () => { chatMessages.value, response.messages, generateMessageId(), + retry, ); chatMessages.value = result.messages;