feat(MCP Server Trigger Node): Cleanup MCP server management, use sanitized trigger node's name as name for the MCP server (#15751)

This commit is contained in:
Yiorgis Gozadinos
2025-05-29 15:07:17 +02:00
committed by GitHub
parent 1daf0ff169
commit 07a636eed6
4 changed files with 142 additions and 88 deletions

View File

@@ -38,8 +38,9 @@ function wasToolCall(body: string) {
} }
/** /**
* Extracts the request ID from a JSONRPC message * Extracts the request ID from a JSONRPC message (for example for tool calls).
* Returns undefined if the message doesn't have an ID or can't be parsed * Returns undefined if the message doesn't have an ID (for example on a tool list request)
*
*/ */
function getRequestId(body: string): string | undefined { function getRequestId(body: string): string | undefined {
try { try {
@@ -51,25 +52,56 @@ function getRequestId(body: string): string | undefined {
} }
} }
export class McpServer { /**
* This singleton is shared across the instance, making sure it is the one
* keeping account of MCP servers.
* It needs to stay in memory to keep track of the long-lived connections.
* It requires a logger at first creation to set everything up.
*/
export class McpServerManager {
static #instance: McpServerManager;
servers: { [sessionId: string]: Server } = {}; servers: { [sessionId: string]: Server } = {};
transports: { [sessionId: string]: FlushingSSEServerTransport } = {}; transports: { [sessionId: string]: FlushingSSEServerTransport } = {};
logger: Logger;
private tools: { [sessionId: string]: Tool[] } = {}; private tools: { [sessionId: string]: Tool[] } = {};
private resolveFunctions: { [callId: string]: CallableFunction } = {}; private resolveFunctions: { [callId: string]: CallableFunction } = {};
constructor(logger: Logger) { logger: Logger;
private constructor(logger: Logger) {
this.logger = logger; this.logger = logger;
this.logger.debug('MCP Server created'); this.logger.debug('MCP Server created');
} }
async connectTransport(postUrl: string, resp: CompressionResponse): Promise<void> { static instance(logger: Logger): McpServerManager {
if (!McpServerManager.#instance) {
McpServerManager.#instance = new McpServerManager(logger);
logger.debug('Created singleton MCP manager');
}
return McpServerManager.#instance;
}
async createServerAndTransport(
serverName: string,
postUrl: string,
resp: CompressionResponse,
): Promise<void> {
const transport = new FlushingSSEServerTransport(postUrl, resp); const transport = new FlushingSSEServerTransport(postUrl, resp);
const server = this.setUpServer(); const server = new Server(
{
name: serverName,
version: '0.1.0',
},
{
capabilities: { tools: {} },
},
);
this.setUpHandlers(server);
const { sessionId } = transport; const { sessionId } = transport;
this.transports[sessionId] = transport; this.transports[sessionId] = transport;
this.servers[sessionId] = server; this.servers[sessionId] = server;
@@ -123,17 +155,7 @@ export class McpServer {
return wasToolCall(req.rawBody.toString()); return wasToolCall(req.rawBody.toString());
} }
setUpServer(): Server { setUpHandlers(server: Server) {
const server = new Server(
{
name: 'n8n-mcp-server',
version: '0.1.0',
},
{
capabilities: { tools: {} },
},
);
server.setRequestHandler( server.setRequestHandler(
ListToolsRequestSchema, ListToolsRequestSchema,
async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => { async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
@@ -203,34 +225,5 @@ export class McpServer {
server.onerror = (error: unknown) => { server.onerror = (error: unknown) => {
this.logger.error(`MCP Error: ${error}`); this.logger.error(`MCP Error: ${error}`);
}; };
return server;
}
}
/**
* This singleton is shared across the instance, making sure we only have one server to worry about.
* It needs to stay in memory to keep track of the long-lived connections.
* It requires a logger at first creation to set everything up.
*/
export class McpServerSingleton {
static #instance: McpServerSingleton;
private _serverData: McpServer;
private constructor(logger: Logger) {
this._serverData = new McpServer(logger);
}
static instance(logger: Logger): McpServer {
if (!McpServerSingleton.#instance) {
McpServerSingleton.#instance = new McpServerSingleton(logger);
logger.debug('Created singleton for MCP Servers');
}
return McpServerSingleton.#instance.serverData;
}
get serverData() {
return this._serverData;
} }
} }

View File

@@ -3,11 +3,10 @@ import { validateWebhookAuthentication } from 'n8n-nodes-base/dist/nodes/Webhook
import type { INodeTypeDescription, IWebhookFunctions, IWebhookResponseData } from 'n8n-workflow'; import type { INodeTypeDescription, IWebhookFunctions, IWebhookResponseData } from 'n8n-workflow';
import { NodeConnectionTypes, Node } from 'n8n-workflow'; import { NodeConnectionTypes, Node } from 'n8n-workflow';
import { getConnectedTools } from '@utils/helpers'; import { getConnectedTools, nodeNameToToolName } from '@utils/helpers';
import type { CompressionResponse } from './FlushingSSEServerTransport'; import type { CompressionResponse } from './FlushingSSEServerTransport';
import { McpServerSingleton } from './McpServer'; import { McpServerManager } from './McpServer';
import type { McpServer } from './McpServer';
const MCP_SSE_SETUP_PATH = 'sse'; const MCP_SSE_SETUP_PATH = 'sse';
const MCP_SSE_MESSAGES_PATH = 'messages'; const MCP_SSE_MESSAGES_PATH = 'messages';
@@ -21,7 +20,7 @@ export class McpTrigger extends Node {
dark: 'file:../mcp.dark.svg', dark: 'file:../mcp.dark.svg',
}, },
group: ['trigger'], group: ['trigger'],
version: 1, version: [1, 1.1],
description: 'Expose n8n tools as an MCP Server endpoint', description: 'Expose n8n tools as an MCP Server endpoint',
activationMessage: 'You can now connect your MCP Clients to the SSE URL.', activationMessage: 'You can now connect your MCP Clients to the SSE URL.',
defaults: { defaults: {
@@ -143,8 +142,11 @@ export class McpTrigger extends Node {
} }
throw error; throw error;
} }
const node = context.getNode();
// Get a url/tool friendly name for the server, based on the node name
const serverName = node.typeVersion > 1 ? nodeNameToToolName(node) : 'n8n-mcp-server';
const mcpServer: McpServer = McpServerSingleton.instance(context.logger); const mcpServerManager: McpServerManager = McpServerManager.instance(context.logger);
if (webhookName === 'setup') { if (webhookName === 'setup') {
// Sets up the transport and opens the long-lived connection. This resp // Sets up the transport and opens the long-lived connection. This resp
@@ -153,7 +155,7 @@ export class McpTrigger extends Node {
new RegExp(`/${MCP_SSE_SETUP_PATH}$`), new RegExp(`/${MCP_SSE_SETUP_PATH}$`),
`/${MCP_SSE_MESSAGES_PATH}`, `/${MCP_SSE_MESSAGES_PATH}`,
); );
await mcpServer.connectTransport(postUrl, resp); await mcpServerManager.createServerAndTransport(serverName, postUrl, resp);
return { noWebhookResponse: true }; return { noWebhookResponse: true };
} else if (webhookName === 'default') { } else if (webhookName === 'default') {
@@ -162,7 +164,7 @@ export class McpTrigger extends Node {
// 'setup' call // 'setup' call
const connectedTools = await getConnectedTools(context, true); const connectedTools = await getConnectedTools(context, true);
const wasToolCall = await mcpServer.handlePostMessage(req, resp, connectedTools); const wasToolCall = await mcpServerManager.handlePostMessage(req, resp, connectedTools);
if (wasToolCall) return { noWebhookResponse: true, workflowData: [[{ json: {} }]] }; if (wasToolCall) return { noWebhookResponse: true, workflowData: [[{ json: {} }]] };
return { noWebhookResponse: true }; return { noWebhookResponse: true };

View File

@@ -6,7 +6,7 @@ import { captor, mock } from 'jest-mock-extended';
import type { CompressionResponse } from '../FlushingSSEServerTransport'; import type { CompressionResponse } from '../FlushingSSEServerTransport';
import { FlushingSSEServerTransport } from '../FlushingSSEServerTransport'; import { FlushingSSEServerTransport } from '../FlushingSSEServerTransport';
import { McpServer } from '../McpServer'; import { McpServerManager } from '../McpServer';
const sessionId = 'mock-session-id'; const sessionId = 'mock-session-id';
const mockServer = mock<Server>(); const mockServer = mock<Server>();
@@ -28,20 +28,18 @@ describe('McpServer', () => {
const mockResponse = mock<CompressionResponse>(); const mockResponse = mock<CompressionResponse>();
const mockTool = mock<Tool>({ name: 'mockTool' }); const mockTool = mock<Tool>({ name: 'mockTool' });
let mcpServer: McpServer; const mcpServerManager = McpServerManager.instance(mock());
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
mockResponse.status.mockReturnThis(); mockResponse.status.mockReturnThis();
mcpServer = new McpServer(mock());
}); });
describe('connectTransport', () => { describe('connectTransport', () => {
const postUrl = '/post-url'; const postUrl = '/post-url';
it('should set up a transport and server', async () => { it('should set up a transport and server', async () => {
await mcpServer.connectTransport(postUrl, mockResponse); await mcpServerManager.createServerAndTransport('mcpServer', postUrl, mockResponse);
// Check that FlushingSSEServerTransport was initialized with correct params // Check that FlushingSSEServerTransport was initialized with correct params
expect(FlushingSSEServerTransport).toHaveBeenCalledWith(postUrl, mockResponse); expect(FlushingSSEServerTransport).toHaveBeenCalledWith(postUrl, mockResponse);
@@ -50,18 +48,18 @@ describe('McpServer', () => {
expect(Server).toHaveBeenCalled(); expect(Server).toHaveBeenCalled();
// Check that transport and server are stored // Check that transport and server are stored
expect(mcpServer.transports[sessionId]).toBeDefined(); expect(mcpServerManager.transports[sessionId]).toBeDefined();
expect(mcpServer.servers[sessionId]).toBeDefined(); expect(mcpServerManager.servers[sessionId]).toBeDefined();
// Check that connect was called on the server // Check that connect was called on the server
expect(mcpServer.servers[sessionId].connect).toHaveBeenCalled(); expect(mcpServerManager.servers[sessionId].connect).toHaveBeenCalled();
// Check that flush was called if available // Check that flush was called if available
expect(mockResponse.flush).toHaveBeenCalled(); expect(mockResponse.flush).toHaveBeenCalled();
}); });
it('should set up close handler that cleans up resources', async () => { it('should set up close handler that cleans up resources', async () => {
await mcpServer.connectTransport(postUrl, mockResponse); await mcpServerManager.createServerAndTransport('mcpServer', postUrl, mockResponse);
// Get the close callback and execute it // Get the close callback and execute it
const closeCallbackCaptor = captor<() => Promise<void>>(); const closeCallbackCaptor = captor<() => Promise<void>>();
@@ -69,8 +67,8 @@ describe('McpServer', () => {
await closeCallbackCaptor.value(); await closeCallbackCaptor.value();
// Check that resources were cleaned up // Check that resources were cleaned up
expect(mcpServer.transports[sessionId]).toBeUndefined(); expect(mcpServerManager.transports[sessionId]).toBeUndefined();
expect(mcpServer.servers[sessionId]).toBeUndefined(); expect(mcpServerManager.servers[sessionId]).toBeUndefined();
}); });
}); });
@@ -78,11 +76,11 @@ describe('McpServer', () => {
it('should call transport.handlePostMessage when transport exists', async () => { it('should call transport.handlePostMessage when transport exists', async () => {
mockTransport.handlePostMessage.mockImplementation(async () => { mockTransport.handlePostMessage.mockImplementation(async () => {
// @ts-expect-error private property `resolveFunctions` // @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[`${sessionId}_123`](); mcpServerManager.resolveFunctions[`${sessionId}_123`]();
}); });
// Add the transport directly // Add the transport directly
mcpServer.transports[sessionId] = mockTransport; mcpServerManager.transports[sessionId] = mockTransport;
mockRequest.rawBody = Buffer.from( mockRequest.rawBody = Buffer.from(
JSON.stringify({ JSON.stringify({
@@ -94,7 +92,9 @@ describe('McpServer', () => {
); );
// Call the method // Call the method
const result = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); const result = await mcpServerManager.handlePostMessage(mockRequest, mockResponse, [
mockTool,
]);
// Verify that transport's handlePostMessage was called // Verify that transport's handlePostMessage was called
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith( expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
@@ -119,11 +119,11 @@ describe('McpServer', () => {
? `${sessionId}_${firstId}` ? `${sessionId}_${firstId}`
: `${sessionId}_${secondId}`; : `${sessionId}_${secondId}`;
// @ts-expect-error private property `resolveFunctions` // @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[requestKey](); mcpServerManager.resolveFunctions[requestKey]();
}); });
// Add the transport directly // Add the transport directly
mcpServer.transports[sessionId] = mockTransport; mcpServerManager.transports[sessionId] = mockTransport;
// First tool call // First tool call
mockRequest.rawBody = Buffer.from( mockRequest.rawBody = Buffer.from(
@@ -136,7 +136,9 @@ describe('McpServer', () => {
); );
// Handle first tool call // Handle first tool call
const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); const firstResult = await mcpServerManager.handlePostMessage(mockRequest, mockResponse, [
mockTool,
]);
expect(firstResult).toBe(true); expect(firstResult).toBe(true);
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith( expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
mockRequest, mockRequest,
@@ -155,7 +157,9 @@ describe('McpServer', () => {
); );
// Handle second tool call // Handle second tool call
const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); const secondResult = await mcpServerManager.handlePostMessage(mockRequest, mockResponse, [
mockTool,
]);
expect(secondResult).toBe(true); expect(secondResult).toBe(true);
// Verify transport's handlePostMessage was called twice // Verify transport's handlePostMessage was called twice
@@ -166,8 +170,22 @@ describe('McpServer', () => {
}); });
it('should return 401 when transport does not exist', async () => { it('should return 401 when transport does not exist', async () => {
// Call without setting up transport // Set up request with rawBody and ensure sessionId is properly set
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); const testRequest = mock<Request>({
query: { sessionId: 'non-existent-session' },
path: '/sse',
});
testRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: 123,
params: { name: 'mockTool' },
}),
);
// Call without setting up transport for this sessionId
await mcpServerManager.handlePostMessage(testRequest, mockResponse, [mockTool]);
// Verify error status was set // Verify error status was set
expect(mockResponse.status).toHaveBeenCalledWith(401); expect(mockResponse.status).toHaveBeenCalledWith(401);

View File

@@ -2,20 +2,20 @@ import { jest } from '@jest/globals';
import type { Tool } from '@langchain/core/tools'; import type { Tool } from '@langchain/core/tools';
import type { Request, Response } from 'express'; import type { Request, Response } from 'express';
import { mock } from 'jest-mock-extended'; import { mock } from 'jest-mock-extended';
import type { IWebhookFunctions } from 'n8n-workflow'; import type { INode, IWebhookFunctions } from 'n8n-workflow';
import type { McpServer } from '../McpServer'; import * as helpers from '@utils/helpers';
import type { McpServerManager } from '../McpServer';
import { McpTrigger } from '../McpTrigger.node'; import { McpTrigger } from '../McpTrigger.node';
const mockTool = mock<Tool>({ name: 'mockTool' }); const mockTool = mock<Tool>({ name: 'mockTool' });
jest.mock('@utils/helpers', () => ({ jest.spyOn(helpers, 'getConnectedTools').mockResolvedValue([mockTool]);
getConnectedTools: jest.fn().mockImplementation(() => [mockTool]),
}));
const mockServer = mock<McpServer>(); const mockServerManager = mock<McpServerManager>();
jest.mock('../McpServer', () => ({ jest.mock('../McpServer', () => ({
McpServerSingleton: { McpServerManager: {
instance: jest.fn().mockImplementation(() => mockServer), instance: jest.fn().mockImplementation(() => mockServerManager),
}, },
})); }));
@@ -30,9 +30,12 @@ describe('McpTrigger Node', () => {
jest.clearAllMocks(); jest.clearAllMocks();
mcpTrigger = new McpTrigger(); mcpTrigger = new McpTrigger();
mockContext.getRequestObject.mockReturnValue(mockRequest); mockContext.getRequestObject.mockReturnValue(mockRequest);
mockContext.getResponseObject.mockReturnValue(mockResponse); mockContext.getResponseObject.mockReturnValue(mockResponse);
mockContext.getNode.mockReturnValue({
name: 'McpTrigger',
typeVersion: 1.1,
} as INode);
}); });
describe('webhook method', () => { describe('webhook method', () => {
@@ -44,7 +47,8 @@ describe('McpTrigger Node', () => {
const result = await mcpTrigger.webhook(mockContext); const result = await mcpTrigger.webhook(mockContext);
// Verify that the connectTransport method was called with correct URL // Verify that the connectTransport method was called with correct URL
expect(mockServer.connectTransport).toHaveBeenCalledWith( expect(mockServerManager.createServerAndTransport).toHaveBeenCalledWith(
'McpTrigger',
'/custom-path/messages', '/custom-path/messages',
mockResponse, mockResponse,
); );
@@ -58,13 +62,13 @@ describe('McpTrigger Node', () => {
mockContext.getWebhookName.mockReturnValue('default'); mockContext.getWebhookName.mockReturnValue('default');
// Mock that the server executes a tool and returns true // Mock that the server executes a tool and returns true
mockServer.handlePostMessage.mockResolvedValueOnce(true); mockServerManager.handlePostMessage.mockResolvedValueOnce(true);
// Call the webhook method // Call the webhook method
const result = await mcpTrigger.webhook(mockContext); const result = await mcpTrigger.webhook(mockContext);
// Verify that handlePostMessage was called with request, response and tools // Verify that handlePostMessage was called with request, response and tools
expect(mockServer.handlePostMessage).toHaveBeenCalledWith(mockRequest, mockResponse, [ expect(mockServerManager.handlePostMessage).toHaveBeenCalledWith(mockRequest, mockResponse, [
mockTool, mockTool,
]); ]);
@@ -80,7 +84,7 @@ describe('McpTrigger Node', () => {
mockContext.getWebhookName.mockReturnValue('default'); mockContext.getWebhookName.mockReturnValue('default');
// Mock that the server doesn't execute a tool and returns false // Mock that the server doesn't execute a tool and returns false
mockServer.handlePostMessage.mockResolvedValueOnce(false); mockServerManager.handlePostMessage.mockResolvedValueOnce(false);
// Call the webhook method // Call the webhook method
const result = await mcpTrigger.webhook(mockContext); const result = await mcpTrigger.webhook(mockContext);
@@ -88,5 +92,42 @@ describe('McpTrigger Node', () => {
// Verify the returned result when no tool was called // Verify the returned result when no tool was called
expect(result).toEqual({ noWebhookResponse: true }); expect(result).toEqual({ noWebhookResponse: true });
}); });
it('should pass the correct server name to McpServerSingleton.instance for version > 1', async () => {
// Configure node with version > 1 and custom name
mockContext.getNode.mockReturnValue({
name: 'My custom MCP server!',
typeVersion: 1.1,
} as INode);
mockContext.getWebhookName.mockReturnValue('setup');
// Call the webhook method
await mcpTrigger.webhook(mockContext);
// Verify that connectTransport was called with the sanitized server name
expect(mockServerManager.createServerAndTransport).toHaveBeenCalledWith(
'My_custom_MCP_server_',
'/custom-path/messages',
mockResponse,
);
});
it('should use default server name for version 1', async () => {
// Configure node with version 1
mockContext.getNode.mockReturnValue({
typeVersion: 1,
} as INode);
mockContext.getWebhookName.mockReturnValue('setup');
// Call the webhook method
await mcpTrigger.webhook(mockContext);
// Verify that connectTransport was called with the default server name
expect(mockServerManager.createServerAndTransport).toHaveBeenCalledWith(
'n8n-mcp-server',
'/custom-path/messages',
mockResponse,
);
});
}); });
}); });