feat(MCP Server Trigger Node): Handle multiple tool calls in mcp server trigger (#15064)

This commit is contained in:
Yiorgis Gozadinos
2025-05-07 21:37:07 +02:00
committed by GitHub
parent b37387180c
commit 59ba162bd9
4 changed files with 157 additions and 67 deletions

View File

@@ -1,7 +1,11 @@
import type { Tool } from '@langchain/core/tools'; import type { Tool } from '@langchain/core/tools';
import { Server } from '@modelcontextprotocol/sdk/server/index.js'; import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js'; import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; import type {
JSONRPCMessage,
ServerRequest,
ServerNotification,
} from '@modelcontextprotocol/sdk/types.js';
import { import {
JSONRPCMessageSchema, JSONRPCMessageSchema,
ListToolsRequestSchema, ListToolsRequestSchema,
@@ -33,6 +37,20 @@ function wasToolCall(body: string) {
} }
} }
/**
* Extracts the request ID from a JSONRPC message
* Returns undefined if the message doesn't have an ID or can't be parsed
*/
function getRequestId(body: string): string | undefined {
try {
const message: unknown = JSON.parse(body);
const parsedMessage: JSONRPCMessage = JSONRPCMessageSchema.parse(message);
return 'id' in parsedMessage ? String(parsedMessage.id) : undefined;
} catch {
return undefined;
}
}
export class McpServer { export class McpServer {
servers: { [sessionId: string]: Server } = {}; servers: { [sessionId: string]: Server } = {};
@@ -42,7 +60,7 @@ export class McpServer {
private tools: { [sessionId: string]: Tool[] } = {}; private tools: { [sessionId: string]: Tool[] } = {};
private resolveFunctions: { [sessionId: string]: CallableFunction } = {}; private resolveFunctions: { [callId: string]: CallableFunction } = {};
constructor(logger: Logger) { constructor(logger: Logger) {
this.logger = logger; this.logger = logger;
@@ -59,7 +77,6 @@ export class McpServer {
resp.on('close', async () => { resp.on('close', async () => {
this.logger.debug(`Deleting transport for ${sessionId}`); this.logger.debug(`Deleting transport for ${sessionId}`);
delete this.tools[sessionId]; delete this.tools[sessionId];
delete this.resolveFunctions[sessionId];
delete this.transports[sessionId]; delete this.transports[sessionId];
delete this.servers[sessionId]; delete this.servers[sessionId];
}); });
@@ -75,16 +92,25 @@ export class McpServer {
async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) { async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) {
const sessionId = req.query.sessionId as string; const sessionId = req.query.sessionId as string;
const transport = this.transports[sessionId]; const transport = this.transports[sessionId];
this.tools[sessionId] = connectedTools;
if (transport) { if (transport) {
// We need to add a promise here because the `handlePostMessage` will send something to the // We need to add a promise here because the `handlePostMessage` will send something to the
// MCP Server, that will run in a different context. This means that the return will happen // MCP Server, that will run in a different context. This means that the return will happen
// almost immediately, and will lead to marking the sub-node as "running" in the final execution // almost immediately, and will lead to marking the sub-node as "running" in the final execution
await new Promise(async (resolve) => { const bodyString = req.rawBody.toString();
this.resolveFunctions[sessionId] = resolve; const messageId = getRequestId(bodyString);
await transport.handlePostMessage(req, resp, req.rawBody.toString());
}); // Use session & message ID if available, otherwise fall back to sessionId
delete this.resolveFunctions[sessionId]; const callId = messageId ? `${sessionId}_${messageId}` : sessionId;
this.tools[sessionId] = connectedTools;
try {
await new Promise(async (resolve) => {
this.resolveFunctions[callId] = resolve;
await transport.handlePostMessage(req, resp, bodyString);
});
} finally {
delete this.resolveFunctions[callId];
}
} else { } else {
this.logger.warn(`No transport found for session ${sessionId}`); this.logger.warn(`No transport found for session ${sessionId}`);
resp.status(401).send('No transport found for sessionId'); resp.status(401).send('No transport found for sessionId');
@@ -94,8 +120,6 @@ export class McpServer {
resp.flush(); resp.flush();
} }
delete this.tools[sessionId]; // Clean up to avoid keeping all tools in memory
return wasToolCall(req.rawBody.toString()); return wasToolCall(req.rawBody.toString());
} }
@@ -110,57 +134,68 @@ export class McpServer {
}, },
); );
server.setRequestHandler(ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => { server.setRequestHandler(
if (!extra.sessionId) { ListToolsRequestSchema,
throw new OperationalError('Require a sessionId for the listing of tools'); async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
} if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the listing of tools');
return {
tools: this.tools[extra.sessionId].map((tool) => {
return {
name: tool.name,
description: tool.description,
// Allow additional properties on tool call input
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
};
}),
};
});
server.setRequestHandler(CallToolRequestSchema, async (request, extra: RequestHandlerExtra) => {
if (!request.params?.name || !request.params?.arguments) {
throw new OperationalError('Require a name and arguments for the tool call');
}
if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the tool call');
}
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
(tool) => tool.name === request.params.name,
);
if (!requestedTool) {
throw new OperationalError('Tool not found');
}
try {
const result = await requestedTool.invoke(request.params.arguments);
this.resolveFunctions[extra.sessionId]();
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
if (typeof result === 'object') {
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
} }
if (typeof result === 'string') {
return { content: [{ type: 'text', text: result }] }; return {
tools: this.tools[extra.sessionId].map((tool) => {
return {
name: tool.name,
description: tool.description,
// Allow additional properties on tool call input
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
};
}),
};
},
);
server.setRequestHandler(
CallToolRequestSchema,
async (request, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
if (!request.params?.name || !request.params?.arguments) {
throw new OperationalError('Require a name and arguments for the tool call');
} }
return { content: [{ type: 'text', text: String(result) }] }; if (!extra.sessionId) {
} catch (error) { throw new OperationalError('Require a sessionId for the tool call');
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`); }
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
} const callId = extra.requestId ? `${extra.sessionId}_${extra.requestId}` : extra.sessionId;
});
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
(tool) => tool.name === request.params.name,
);
if (!requestedTool) {
throw new OperationalError('Tool not found');
}
try {
const result = await requestedTool.invoke(request.params.arguments);
if (this.resolveFunctions[callId]) {
this.resolveFunctions[callId]();
} else {
this.logger.warn(`No resolve function found for ${callId}`);
}
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
if (typeof result === 'object') {
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
}
if (typeof result === 'string') {
return { content: [{ type: 'text', text: result }] };
}
return { content: [{ type: 'text', text: String(result) }] };
} catch (error) {
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
}
},
);
server.onclose = () => { server.onclose = () => {
this.logger.debug('Closing MCP Server'); this.logger.debug('Closing MCP Server');

View File

@@ -78,7 +78,7 @@ describe('McpServer', () => {
it('should call transport.handlePostMessage when transport exists', async () => { it('should call transport.handlePostMessage when transport exists', async () => {
mockTransport.handlePostMessage.mockImplementation(async () => { mockTransport.handlePostMessage.mockImplementation(async () => {
// @ts-expect-error private property `resolveFunctions` // @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[sessionId](); mcpServer.resolveFunctions[`${sessionId}_123`]();
}); });
// Add the transport directly // Add the transport directly
@@ -110,6 +110,61 @@ describe('McpServer', () => {
expect(mockResponse.flush).toHaveBeenCalled(); expect(mockResponse.flush).toHaveBeenCalled();
}); });
it('should handle multiple tool calls with different ids', async () => {
const firstId = 123;
const secondId = 456;
mockTransport.handlePostMessage.mockImplementation(async () => {
const requestKey = mockRequest.rawBody?.toString().includes(`"id":${firstId}`)
? `${sessionId}_${firstId}`
: `${sessionId}_${secondId}`;
// @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[requestKey]();
});
// Add the transport directly
mcpServer.transports[sessionId] = mockTransport;
// First tool call
mockRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: firstId,
params: { name: 'mockTool', arguments: { param: 'first call' } },
}),
);
// Handle first tool call
const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
expect(firstResult).toBe(true);
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
mockRequest,
mockResponse,
expect.any(String),
);
// Second tool call with different id
mockRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: secondId,
params: { name: 'mockTool', arguments: { param: 'second call' } },
}),
);
// Handle second tool call
const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
expect(secondResult).toBe(true);
// Verify transport's handlePostMessage was called twice
expect(mockTransport.handlePostMessage).toHaveBeenCalledTimes(2);
// Verify flush was called for both requests
expect(mockResponse.flush).toHaveBeenCalledTimes(2);
});
it('should return 401 when transport does not exist', async () => { it('should return 401 when transport does not exist', async () => {
// Call without setting up transport // Call without setting up transport
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]); await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);

View File

@@ -173,7 +173,7 @@
"@langchain/qdrant": "0.1.1", "@langchain/qdrant": "0.1.1",
"@langchain/redis": "0.1.0", "@langchain/redis": "0.1.0",
"@langchain/textsplitters": "0.1.0", "@langchain/textsplitters": "0.1.0",
"@modelcontextprotocol/sdk": "1.9.0", "@modelcontextprotocol/sdk": "1.11.0",
"@mozilla/readability": "0.6.0", "@mozilla/readability": "0.6.0",
"@n8n/client-oauth2": "workspace:*", "@n8n/client-oauth2": "workspace:*",
"@n8n/json-schema-to-zod": "workspace:*", "@n8n/json-schema-to-zod": "workspace:*",

10
pnpm-lock.yaml generated
View File

@@ -783,8 +783,8 @@ importers:
specifier: 0.1.0 specifier: 0.1.0
version: 0.1.0(@langchain/core@0.3.30(openai@4.78.1(encoding@0.1.13)(zod@3.24.1))) version: 0.1.0(@langchain/core@0.3.30(openai@4.78.1(encoding@0.1.13)(zod@3.24.1)))
'@modelcontextprotocol/sdk': '@modelcontextprotocol/sdk':
specifier: 1.9.0 specifier: 1.11.0
version: 1.9.0 version: 1.11.0
'@mozilla/readability': '@mozilla/readability':
specifier: 0.6.0 specifier: 0.6.0
version: 0.6.0 version: 0.6.0
@@ -4810,8 +4810,8 @@ packages:
peerDependencies: peerDependencies:
zod: '>= 3' zod: '>= 3'
'@modelcontextprotocol/sdk@1.9.0': '@modelcontextprotocol/sdk@1.11.0':
resolution: {integrity: sha512-Jq2EUCQpe0iyO5FGpzVYDNFR6oR53AIrwph9yWl7uSc7IWUMsrmpmSaTGra5hQNunXpM+9oit85p924jWuHzUA==} resolution: {integrity: sha512-k/1pb70eD638anoi0e8wUGAlbMJXyvdV4p62Ko+EZ7eBe1xMx8Uhak1R5DgfoofsK5IBBnRwsYGTaLZl+6/+RQ==}
engines: {node: '>=18'} engines: {node: '>=18'}
'@mongodb-js/saslprep@1.1.9': '@mongodb-js/saslprep@1.1.9':
@@ -17242,7 +17242,7 @@ snapshots:
dependencies: dependencies:
zod: 3.24.1 zod: 3.24.1
'@modelcontextprotocol/sdk@1.9.0': '@modelcontextprotocol/sdk@1.11.0':
dependencies: dependencies:
content-type: 1.0.5 content-type: 1.0.5
cors: 2.8.5 cors: 2.8.5