From 07a636eed6cb4d1f487018b34632810d24d99824 Mon Sep 17 00:00:00 2001 From: Yiorgis Gozadinos Date: Thu, 29 May 2025 15:07:17 +0200 Subject: [PATCH] feat(MCP Server Trigger Node): Cleanup MCP server management, use sanitized trigger node's name as name for the MCP server (#15751) --- .../nodes/mcp/McpTrigger/McpServer.ts | 89 +++++++++---------- .../nodes/mcp/McpTrigger/McpTrigger.node.ts | 16 ++-- .../mcp/McpTrigger/__test__/McpServer.test.ts | 58 +++++++----- .../__test__/McpTrigger.node.test.ts | 67 +++++++++++--- 4 files changed, 142 insertions(+), 88 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts index 34a6ed7803..961d2aa5e2 100644 --- a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts +++ b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts @@ -38,8 +38,9 @@ function wasToolCall(body: string) { } /** - * Extracts the request ID from a JSONRPC message - * Returns undefined if the message doesn't have an ID or can't be parsed + * Extracts the request ID from a JSONRPC message (for example for tool calls). + * Returns undefined if the message doesn't have an ID (for example on a tool list request) + * */ function getRequestId(body: string): string | undefined { try { @@ -51,25 +52,56 @@ function getRequestId(body: string): string | undefined { } } -export class McpServer { +/** + * This singleton is shared across the instance, making sure it is the one + * keeping account of MCP servers. + * It needs to stay in memory to keep track of the long-lived connections. + * It requires a logger at first creation to set everything up. + */ +export class McpServerManager { + static #instance: McpServerManager; + servers: { [sessionId: string]: Server } = {}; transports: { [sessionId: string]: FlushingSSEServerTransport } = {}; - logger: Logger; - private tools: { [sessionId: string]: Tool[] } = {}; private resolveFunctions: { [callId: string]: CallableFunction } = {}; - constructor(logger: Logger) { + logger: Logger; + + private constructor(logger: Logger) { this.logger = logger; this.logger.debug('MCP Server created'); } - async connectTransport(postUrl: string, resp: CompressionResponse): Promise { + static instance(logger: Logger): McpServerManager { + if (!McpServerManager.#instance) { + McpServerManager.#instance = new McpServerManager(logger); + logger.debug('Created singleton MCP manager'); + } + + return McpServerManager.#instance; + } + + async createServerAndTransport( + serverName: string, + postUrl: string, + resp: CompressionResponse, + ): Promise { const transport = new FlushingSSEServerTransport(postUrl, resp); - const server = this.setUpServer(); + const server = new Server( + { + name: serverName, + version: '0.1.0', + }, + { + capabilities: { tools: {} }, + }, + ); + + this.setUpHandlers(server); const { sessionId } = transport; this.transports[sessionId] = transport; this.servers[sessionId] = server; @@ -123,17 +155,7 @@ export class McpServer { return wasToolCall(req.rawBody.toString()); } - setUpServer(): Server { - const server = new Server( - { - name: 'n8n-mcp-server', - version: '0.1.0', - }, - { - capabilities: { tools: {} }, - }, - ); - + setUpHandlers(server: Server) { server.setRequestHandler( ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => { @@ -203,34 +225,5 @@ export class McpServer { server.onerror = (error: unknown) => { this.logger.error(`MCP Error: ${error}`); }; - return server; - } -} - -/** - * This singleton is shared across the instance, making sure we only have one server to worry about. - * It needs to stay in memory to keep track of the long-lived connections. - * It requires a logger at first creation to set everything up. - */ -export class McpServerSingleton { - static #instance: McpServerSingleton; - - private _serverData: McpServer; - - private constructor(logger: Logger) { - this._serverData = new McpServer(logger); - } - - static instance(logger: Logger): McpServer { - if (!McpServerSingleton.#instance) { - McpServerSingleton.#instance = new McpServerSingleton(logger); - logger.debug('Created singleton for MCP Servers'); - } - - return McpServerSingleton.#instance.serverData; - } - - get serverData() { - return this._serverData; } } diff --git a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpTrigger.node.ts b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpTrigger.node.ts index a97dfe1632..7719098e48 100644 --- a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpTrigger.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpTrigger.node.ts @@ -3,11 +3,10 @@ import { validateWebhookAuthentication } from 'n8n-nodes-base/dist/nodes/Webhook import type { INodeTypeDescription, IWebhookFunctions, IWebhookResponseData } from 'n8n-workflow'; import { NodeConnectionTypes, Node } from 'n8n-workflow'; -import { getConnectedTools } from '@utils/helpers'; +import { getConnectedTools, nodeNameToToolName } from '@utils/helpers'; import type { CompressionResponse } from './FlushingSSEServerTransport'; -import { McpServerSingleton } from './McpServer'; -import type { McpServer } from './McpServer'; +import { McpServerManager } from './McpServer'; const MCP_SSE_SETUP_PATH = 'sse'; const MCP_SSE_MESSAGES_PATH = 'messages'; @@ -21,7 +20,7 @@ export class McpTrigger extends Node { dark: 'file:../mcp.dark.svg', }, group: ['trigger'], - version: 1, + version: [1, 1.1], description: 'Expose n8n tools as an MCP Server endpoint', activationMessage: 'You can now connect your MCP Clients to the SSE URL.', defaults: { @@ -143,8 +142,11 @@ export class McpTrigger extends Node { } throw error; } + const node = context.getNode(); + // Get a url/tool friendly name for the server, based on the node name + const serverName = node.typeVersion > 1 ? nodeNameToToolName(node) : 'n8n-mcp-server'; - const mcpServer: McpServer = McpServerSingleton.instance(context.logger); + const mcpServerManager: McpServerManager = McpServerManager.instance(context.logger); if (webhookName === 'setup') { // Sets up the transport and opens the long-lived connection. This resp @@ -153,7 +155,7 @@ export class McpTrigger extends Node { new RegExp(`/${MCP_SSE_SETUP_PATH}$`), `/${MCP_SSE_MESSAGES_PATH}`, ); - await mcpServer.connectTransport(postUrl, resp); + await mcpServerManager.createServerAndTransport(serverName, postUrl, resp); return { noWebhookResponse: true }; } else if (webhookName === 'default') { @@ -162,7 +164,7 @@ export class McpTrigger extends Node { // 'setup' call const connectedTools = await getConnectedTools(context, true); - const wasToolCall = await mcpServer.handlePostMessage(req, resp, connectedTools); + const wasToolCall = await mcpServerManager.handlePostMessage(req, resp, connectedTools); if (wasToolCall) return { noWebhookResponse: true, workflowData: [[{ json: {} }]] }; return { noWebhookResponse: true }; diff --git a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts index a698b2bd03..3e64029e47 100644 --- a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts @@ -6,7 +6,7 @@ import { captor, mock } from 'jest-mock-extended'; import type { CompressionResponse } from '../FlushingSSEServerTransport'; import { FlushingSSEServerTransport } from '../FlushingSSEServerTransport'; -import { McpServer } from '../McpServer'; +import { McpServerManager } from '../McpServer'; const sessionId = 'mock-session-id'; const mockServer = mock(); @@ -28,20 +28,18 @@ describe('McpServer', () => { const mockResponse = mock(); const mockTool = mock({ name: 'mockTool' }); - let mcpServer: McpServer; + const mcpServerManager = McpServerManager.instance(mock()); beforeEach(() => { jest.clearAllMocks(); mockResponse.status.mockReturnThis(); - - mcpServer = new McpServer(mock()); }); describe('connectTransport', () => { const postUrl = '/post-url'; it('should set up a transport and server', async () => { - await mcpServer.connectTransport(postUrl, mockResponse); + await mcpServerManager.createServerAndTransport('mcpServer', postUrl, mockResponse); // Check that FlushingSSEServerTransport was initialized with correct params expect(FlushingSSEServerTransport).toHaveBeenCalledWith(postUrl, mockResponse); @@ -50,18 +48,18 @@ describe('McpServer', () => { expect(Server).toHaveBeenCalled(); // Check that transport and server are stored - expect(mcpServer.transports[sessionId]).toBeDefined(); - expect(mcpServer.servers[sessionId]).toBeDefined(); + expect(mcpServerManager.transports[sessionId]).toBeDefined(); + expect(mcpServerManager.servers[sessionId]).toBeDefined(); // Check that connect was called on the server - expect(mcpServer.servers[sessionId].connect).toHaveBeenCalled(); + expect(mcpServerManager.servers[sessionId].connect).toHaveBeenCalled(); // Check that flush was called if available expect(mockResponse.flush).toHaveBeenCalled(); }); it('should set up close handler that cleans up resources', async () => { - await mcpServer.connectTransport(postUrl, mockResponse); + await mcpServerManager.createServerAndTransport('mcpServer', postUrl, mockResponse); // Get the close callback and execute it const closeCallbackCaptor = captor<() => Promise>(); @@ -69,8 +67,8 @@ describe('McpServer', () => { await closeCallbackCaptor.value(); // Check that resources were cleaned up - expect(mcpServer.transports[sessionId]).toBeUndefined(); - expect(mcpServer.servers[sessionId]).toBeUndefined(); + expect(mcpServerManager.transports[sessionId]).toBeUndefined(); + expect(mcpServerManager.servers[sessionId]).toBeUndefined(); }); }); @@ -78,11 +76,11 @@ describe('McpServer', () => { it('should call transport.handlePostMessage when transport exists', async () => { mockTransport.handlePostMessage.mockImplementation(async () => { // @ts-expect-error private property `resolveFunctions` - mcpServer.resolveFunctions[`${sessionId}_123`](); + mcpServerManager.resolveFunctions[`${sessionId}_123`](); }); // Add the transport directly - mcpServer.transports[sessionId] = mockTransport; + mcpServerManager.transports[sessionId] = mockTransport; mockRequest.rawBody = Buffer.from( JSON.stringify({ @@ -94,7 +92,9 @@ describe('McpServer', () => { ); // Call the method - const result = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); + const result = await mcpServerManager.handlePostMessage(mockRequest, mockResponse, [ + mockTool, + ]); // Verify that transport's handlePostMessage was called expect(mockTransport.handlePostMessage).toHaveBeenCalledWith( @@ -119,11 +119,11 @@ describe('McpServer', () => { ? `${sessionId}_${firstId}` : `${sessionId}_${secondId}`; // @ts-expect-error private property `resolveFunctions` - mcpServer.resolveFunctions[requestKey](); + mcpServerManager.resolveFunctions[requestKey](); }); // Add the transport directly - mcpServer.transports[sessionId] = mockTransport; + mcpServerManager.transports[sessionId] = mockTransport; // First tool call mockRequest.rawBody = Buffer.from( @@ -136,7 +136,9 @@ describe('McpServer', () => { ); // Handle first tool call - const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); + const firstResult = await mcpServerManager.handlePostMessage(mockRequest, mockResponse, [ + mockTool, + ]); expect(firstResult).toBe(true); expect(mockTransport.handlePostMessage).toHaveBeenCalledWith( mockRequest, @@ -155,7 +157,9 @@ describe('McpServer', () => { ); // Handle second tool call - const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); + const secondResult = await mcpServerManager.handlePostMessage(mockRequest, mockResponse, [ + mockTool, + ]); expect(secondResult).toBe(true); // Verify transport's handlePostMessage was called twice @@ -166,8 +170,22 @@ describe('McpServer', () => { }); it('should return 401 when transport does not exist', async () => { - // Call without setting up transport - await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); + // Set up request with rawBody and ensure sessionId is properly set + const testRequest = mock({ + query: { sessionId: 'non-existent-session' }, + path: '/sse', + }); + testRequest.rawBody = Buffer.from( + JSON.stringify({ + jsonrpc: '2.0', + method: 'tools/call', + id: 123, + params: { name: 'mockTool' }, + }), + ); + + // Call without setting up transport for this sessionId + await mcpServerManager.handlePostMessage(testRequest, mockResponse, [mockTool]); // Verify error status was set expect(mockResponse.status).toHaveBeenCalledWith(401); diff --git a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpTrigger.node.test.ts b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpTrigger.node.test.ts index 9719ffd9f3..3cfd39180f 100644 --- a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpTrigger.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpTrigger.node.test.ts @@ -2,20 +2,20 @@ import { jest } from '@jest/globals'; import type { Tool } from '@langchain/core/tools'; import type { Request, Response } from 'express'; import { mock } from 'jest-mock-extended'; -import type { IWebhookFunctions } from 'n8n-workflow'; +import type { INode, IWebhookFunctions } from 'n8n-workflow'; -import type { McpServer } from '../McpServer'; +import * as helpers from '@utils/helpers'; + +import type { McpServerManager } from '../McpServer'; import { McpTrigger } from '../McpTrigger.node'; const mockTool = mock({ name: 'mockTool' }); -jest.mock('@utils/helpers', () => ({ - getConnectedTools: jest.fn().mockImplementation(() => [mockTool]), -})); +jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue([mockTool]); -const mockServer = mock(); +const mockServerManager = mock(); jest.mock('../McpServer', () => ({ - McpServerSingleton: { - instance: jest.fn().mockImplementation(() => mockServer), + McpServerManager: { + instance: jest.fn().mockImplementation(() => mockServerManager), }, })); @@ -30,9 +30,12 @@ describe('McpTrigger Node', () => { jest.clearAllMocks(); mcpTrigger = new McpTrigger(); - mockContext.getRequestObject.mockReturnValue(mockRequest); mockContext.getResponseObject.mockReturnValue(mockResponse); + mockContext.getNode.mockReturnValue({ + name: 'McpTrigger', + typeVersion: 1.1, + } as INode); }); describe('webhook method', () => { @@ -44,7 +47,8 @@ describe('McpTrigger Node', () => { const result = await mcpTrigger.webhook(mockContext); // Verify that the connectTransport method was called with correct URL - expect(mockServer.connectTransport).toHaveBeenCalledWith( + expect(mockServerManager.createServerAndTransport).toHaveBeenCalledWith( + 'McpTrigger', '/custom-path/messages', mockResponse, ); @@ -58,13 +62,13 @@ describe('McpTrigger Node', () => { mockContext.getWebhookName.mockReturnValue('default'); // Mock that the server executes a tool and returns true - mockServer.handlePostMessage.mockResolvedValueOnce(true); + mockServerManager.handlePostMessage.mockResolvedValueOnce(true); // Call the webhook method const result = await mcpTrigger.webhook(mockContext); // Verify that handlePostMessage was called with request, response and tools - expect(mockServer.handlePostMessage).toHaveBeenCalledWith(mockRequest, mockResponse, [ + expect(mockServerManager.handlePostMessage).toHaveBeenCalledWith(mockRequest, mockResponse, [ mockTool, ]); @@ -80,7 +84,7 @@ describe('McpTrigger Node', () => { mockContext.getWebhookName.mockReturnValue('default'); // Mock that the server doesn't execute a tool and returns false - mockServer.handlePostMessage.mockResolvedValueOnce(false); + mockServerManager.handlePostMessage.mockResolvedValueOnce(false); // Call the webhook method const result = await mcpTrigger.webhook(mockContext); @@ -88,5 +92,42 @@ describe('McpTrigger Node', () => { // Verify the returned result when no tool was called expect(result).toEqual({ noWebhookResponse: true }); }); + + it('should pass the correct server name to McpServerSingleton.instance for version > 1', async () => { + // Configure node with version > 1 and custom name + mockContext.getNode.mockReturnValue({ + name: 'My custom MCP server!', + typeVersion: 1.1, + } as INode); + mockContext.getWebhookName.mockReturnValue('setup'); + + // Call the webhook method + await mcpTrigger.webhook(mockContext); + + // Verify that connectTransport was called with the sanitized server name + expect(mockServerManager.createServerAndTransport).toHaveBeenCalledWith( + 'My_custom_MCP_server_', + '/custom-path/messages', + mockResponse, + ); + }); + + it('should use default server name for version 1', async () => { + // Configure node with version 1 + mockContext.getNode.mockReturnValue({ + typeVersion: 1, + } as INode); + mockContext.getWebhookName.mockReturnValue('setup'); + + // Call the webhook method + await mcpTrigger.webhook(mockContext); + + // Verify that connectTransport was called with the default server name + expect(mockServerManager.createServerAndTransport).toHaveBeenCalledWith( + 'n8n-mcp-server', + '/custom-path/messages', + mockResponse, + ); + }); }); });