From 59ba162bd9fe9967d0b0733c3955d65256e062f5 Mon Sep 17 00:00:00 2001 From: Yiorgis Gozadinos Date: Wed, 7 May 2025 21:37:07 +0200 Subject: [PATCH] feat(MCP Server Trigger Node): Handle multiple tool calls in mcp server trigger (#15064) --- .../nodes/mcp/McpTrigger/McpServer.ts | 155 +++++++++++------- .../mcp/McpTrigger/__test__/McpServer.test.ts | 57 ++++++- packages/@n8n/nodes-langchain/package.json | 2 +- pnpm-lock.yaml | 10 +- 4 files changed, 157 insertions(+), 67 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts index 4c94fb0ed6..34a6ed7803 100644 --- a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts +++ b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/McpServer.ts @@ -1,7 +1,11 @@ import type { Tool } from '@langchain/core/tools'; import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js'; -import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import type { + JSONRPCMessage, + ServerRequest, + ServerNotification, +} from '@modelcontextprotocol/sdk/types.js'; import { JSONRPCMessageSchema, ListToolsRequestSchema, @@ -33,6 +37,20 @@ function wasToolCall(body: string) { } } +/** + * Extracts the request ID from a JSONRPC message + * Returns undefined if the message doesn't have an ID or can't be parsed + */ +function getRequestId(body: string): string | undefined { + try { + const message: unknown = JSON.parse(body); + const parsedMessage: JSONRPCMessage = JSONRPCMessageSchema.parse(message); + return 'id' in parsedMessage ? String(parsedMessage.id) : undefined; + } catch { + return undefined; + } +} + export class McpServer { servers: { [sessionId: string]: Server } = {}; @@ -42,7 +60,7 @@ export class McpServer { private tools: { [sessionId: string]: Tool[] } = {}; - private resolveFunctions: { [sessionId: string]: CallableFunction } = {}; + private resolveFunctions: { [callId: string]: CallableFunction } = {}; constructor(logger: Logger) { this.logger = logger; @@ -59,7 +77,6 @@ export class McpServer { resp.on('close', async () => { this.logger.debug(`Deleting transport for ${sessionId}`); delete this.tools[sessionId]; - delete this.resolveFunctions[sessionId]; delete this.transports[sessionId]; delete this.servers[sessionId]; }); @@ -75,16 +92,25 @@ export class McpServer { async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) { const sessionId = req.query.sessionId as string; const transport = this.transports[sessionId]; - this.tools[sessionId] = connectedTools; if (transport) { // We need to add a promise here because the `handlePostMessage` will send something to the // MCP Server, that will run in a different context. This means that the return will happen // almost immediately, and will lead to marking the sub-node as "running" in the final execution - await new Promise(async (resolve) => { - this.resolveFunctions[sessionId] = resolve; - await transport.handlePostMessage(req, resp, req.rawBody.toString()); - }); - delete this.resolveFunctions[sessionId]; + const bodyString = req.rawBody.toString(); + const messageId = getRequestId(bodyString); + + // Use session & message ID if available, otherwise fall back to sessionId + const callId = messageId ? `${sessionId}_${messageId}` : sessionId; + this.tools[sessionId] = connectedTools; + + try { + await new Promise(async (resolve) => { + this.resolveFunctions[callId] = resolve; + await transport.handlePostMessage(req, resp, bodyString); + }); + } finally { + delete this.resolveFunctions[callId]; + } } else { this.logger.warn(`No transport found for session ${sessionId}`); resp.status(401).send('No transport found for sessionId'); @@ -94,8 +120,6 @@ export class McpServer { resp.flush(); } - delete this.tools[sessionId]; // Clean up to avoid keeping all tools in memory - return wasToolCall(req.rawBody.toString()); } @@ -110,57 +134,68 @@ export class McpServer { }, ); - server.setRequestHandler(ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => { - if (!extra.sessionId) { - throw new OperationalError('Require a sessionId for the listing of tools'); - } - - return { - tools: this.tools[extra.sessionId].map((tool) => { - return { - name: tool.name, - description: tool.description, - // Allow additional properties on tool call input - inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }), - }; - }), - }; - }); - - server.setRequestHandler(CallToolRequestSchema, async (request, extra: RequestHandlerExtra) => { - if (!request.params?.name || !request.params?.arguments) { - throw new OperationalError('Require a name and arguments for the tool call'); - } - if (!extra.sessionId) { - throw new OperationalError('Require a sessionId for the tool call'); - } - - const requestedTool: Tool | undefined = this.tools[extra.sessionId].find( - (tool) => tool.name === request.params.name, - ); - if (!requestedTool) { - throw new OperationalError('Tool not found'); - } - - try { - const result = await requestedTool.invoke(request.params.arguments); - - this.resolveFunctions[extra.sessionId](); - - this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`); - - if (typeof result === 'object') { - return { content: [{ type: 'text', text: JSON.stringify(result) }] }; + server.setRequestHandler( + ListToolsRequestSchema, + async (_, extra: RequestHandlerExtra) => { + if (!extra.sessionId) { + throw new OperationalError('Require a sessionId for the listing of tools'); } - if (typeof result === 'string') { - return { content: [{ type: 'text', text: result }] }; + + return { + tools: this.tools[extra.sessionId].map((tool) => { + return { + name: tool.name, + description: tool.description, + // Allow additional properties on tool call input + inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }), + }; + }), + }; + }, + ); + + server.setRequestHandler( + CallToolRequestSchema, + async (request, extra: RequestHandlerExtra) => { + if (!request.params?.name || !request.params?.arguments) { + throw new OperationalError('Require a name and arguments for the tool call'); } - return { content: [{ type: 'text', text: String(result) }] }; - } catch (error) { - this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`); - return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] }; - } - }); + if (!extra.sessionId) { + throw new OperationalError('Require a sessionId for the tool call'); + } + + const callId = extra.requestId ? `${extra.sessionId}_${extra.requestId}` : extra.sessionId; + + const requestedTool: Tool | undefined = this.tools[extra.sessionId].find( + (tool) => tool.name === request.params.name, + ); + if (!requestedTool) { + throw new OperationalError('Tool not found'); + } + + try { + const result = await requestedTool.invoke(request.params.arguments); + if (this.resolveFunctions[callId]) { + this.resolveFunctions[callId](); + } else { + this.logger.warn(`No resolve function found for ${callId}`); + } + + this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`); + + if (typeof result === 'object') { + return { content: [{ type: 'text', text: JSON.stringify(result) }] }; + } + if (typeof result === 'string') { + return { content: [{ type: 'text', text: result }] }; + } + return { content: [{ type: 'text', text: String(result) }] }; + } catch (error) { + this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`); + return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] }; + } + }, + ); server.onclose = () => { this.logger.debug('Closing MCP Server'); diff --git a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts index beddc18fec..a698b2bd03 100644 --- a/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/mcp/McpTrigger/__test__/McpServer.test.ts @@ -78,7 +78,7 @@ describe('McpServer', () => { it('should call transport.handlePostMessage when transport exists', async () => { mockTransport.handlePostMessage.mockImplementation(async () => { // @ts-expect-error private property `resolveFunctions` - mcpServer.resolveFunctions[sessionId](); + mcpServer.resolveFunctions[`${sessionId}_123`](); }); // Add the transport directly @@ -110,6 +110,61 @@ describe('McpServer', () => { expect(mockResponse.flush).toHaveBeenCalled(); }); + it('should handle multiple tool calls with different ids', async () => { + const firstId = 123; + const secondId = 456; + + mockTransport.handlePostMessage.mockImplementation(async () => { + const requestKey = mockRequest.rawBody?.toString().includes(`"id":${firstId}`) + ? `${sessionId}_${firstId}` + : `${sessionId}_${secondId}`; + // @ts-expect-error private property `resolveFunctions` + mcpServer.resolveFunctions[requestKey](); + }); + + // Add the transport directly + mcpServer.transports[sessionId] = mockTransport; + + // First tool call + mockRequest.rawBody = Buffer.from( + JSON.stringify({ + jsonrpc: '2.0', + method: 'tools/call', + id: firstId, + params: { name: 'mockTool', arguments: { param: 'first call' } }, + }), + ); + + // Handle first tool call + const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); + expect(firstResult).toBe(true); + expect(mockTransport.handlePostMessage).toHaveBeenCalledWith( + mockRequest, + mockResponse, + expect.any(String), + ); + + // Second tool call with different id + mockRequest.rawBody = Buffer.from( + JSON.stringify({ + jsonrpc: '2.0', + method: 'tools/call', + id: secondId, + params: { name: 'mockTool', arguments: { param: 'second call' } }, + }), + ); + + // Handle second tool call + const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); + expect(secondResult).toBe(true); + + // Verify transport's handlePostMessage was called twice + expect(mockTransport.handlePostMessage).toHaveBeenCalledTimes(2); + + // Verify flush was called for both requests + expect(mockResponse.flush).toHaveBeenCalledTimes(2); + }); + it('should return 401 when transport does not exist', async () => { // Call without setting up transport await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); diff --git a/packages/@n8n/nodes-langchain/package.json b/packages/@n8n/nodes-langchain/package.json index 4241946c1f..29a7cdfdd5 100644 --- a/packages/@n8n/nodes-langchain/package.json +++ b/packages/@n8n/nodes-langchain/package.json @@ -173,7 +173,7 @@ "@langchain/qdrant": "0.1.1", "@langchain/redis": "0.1.0", "@langchain/textsplitters": "0.1.0", - "@modelcontextprotocol/sdk": "1.9.0", + "@modelcontextprotocol/sdk": "1.11.0", "@mozilla/readability": "0.6.0", "@n8n/client-oauth2": "workspace:*", "@n8n/json-schema-to-zod": "workspace:*", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index cd17374f63..5dfb609437 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -783,8 +783,8 @@ importers: specifier: 0.1.0 version: 0.1.0(@langchain/core@0.3.30(openai@4.78.1(encoding@0.1.13)(zod@3.24.1))) '@modelcontextprotocol/sdk': - specifier: 1.9.0 - version: 1.9.0 + specifier: 1.11.0 + version: 1.11.0 '@mozilla/readability': specifier: 0.6.0 version: 0.6.0 @@ -4810,8 +4810,8 @@ packages: peerDependencies: zod: '>= 3' - '@modelcontextprotocol/sdk@1.9.0': - resolution: {integrity: sha512-Jq2EUCQpe0iyO5FGpzVYDNFR6oR53AIrwph9yWl7uSc7IWUMsrmpmSaTGra5hQNunXpM+9oit85p924jWuHzUA==} + '@modelcontextprotocol/sdk@1.11.0': + resolution: {integrity: sha512-k/1pb70eD638anoi0e8wUGAlbMJXyvdV4p62Ko+EZ7eBe1xMx8Uhak1R5DgfoofsK5IBBnRwsYGTaLZl+6/+RQ==} engines: {node: '>=18'} '@mongodb-js/saslprep@1.1.9': @@ -17242,7 +17242,7 @@ snapshots: dependencies: zod: 3.24.1 - '@modelcontextprotocol/sdk@1.9.0': + '@modelcontextprotocol/sdk@1.11.0': dependencies: content-type: 1.0.5 cors: 2.8.5