feat(MCP Server Trigger Node): Handle multiple tool calls in mcp server trigger (#15064)

This commit is contained in:
Yiorgis Gozadinos
2025-05-07 21:37:07 +02:00
committed by GitHub
parent b37387180c
commit 59ba162bd9
4 changed files with 157 additions and 67 deletions

View File

@@ -1,7 +1,11 @@
import type { Tool } from '@langchain/core/tools';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type {
JSONRPCMessage,
ServerRequest,
ServerNotification,
} from '@modelcontextprotocol/sdk/types.js';
import {
JSONRPCMessageSchema,
ListToolsRequestSchema,
@@ -33,6 +37,20 @@ function wasToolCall(body: string) {
}
}
/**
* Extracts the request ID from a JSONRPC message
* Returns undefined if the message doesn't have an ID or can't be parsed
*/
function getRequestId(body: string): string | undefined {
try {
const message: unknown = JSON.parse(body);
const parsedMessage: JSONRPCMessage = JSONRPCMessageSchema.parse(message);
return 'id' in parsedMessage ? String(parsedMessage.id) : undefined;
} catch {
return undefined;
}
}
export class McpServer {
servers: { [sessionId: string]: Server } = {};
@@ -42,7 +60,7 @@ export class McpServer {
private tools: { [sessionId: string]: Tool[] } = {};
private resolveFunctions: { [sessionId: string]: CallableFunction } = {};
private resolveFunctions: { [callId: string]: CallableFunction } = {};
constructor(logger: Logger) {
this.logger = logger;
@@ -59,7 +77,6 @@ export class McpServer {
resp.on('close', async () => {
this.logger.debug(`Deleting transport for ${sessionId}`);
delete this.tools[sessionId];
delete this.resolveFunctions[sessionId];
delete this.transports[sessionId];
delete this.servers[sessionId];
});
@@ -75,16 +92,25 @@ export class McpServer {
async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) {
const sessionId = req.query.sessionId as string;
const transport = this.transports[sessionId];
this.tools[sessionId] = connectedTools;
if (transport) {
// We need to add a promise here because the `handlePostMessage` will send something to the
// MCP Server, that will run in a different context. This means that the return will happen
// almost immediately, and will lead to marking the sub-node as "running" in the final execution
const bodyString = req.rawBody.toString();
const messageId = getRequestId(bodyString);
// Use session & message ID if available, otherwise fall back to sessionId
const callId = messageId ? `${sessionId}_${messageId}` : sessionId;
this.tools[sessionId] = connectedTools;
try {
await new Promise(async (resolve) => {
this.resolveFunctions[sessionId] = resolve;
await transport.handlePostMessage(req, resp, req.rawBody.toString());
this.resolveFunctions[callId] = resolve;
await transport.handlePostMessage(req, resp, bodyString);
});
delete this.resolveFunctions[sessionId];
} finally {
delete this.resolveFunctions[callId];
}
} else {
this.logger.warn(`No transport found for session ${sessionId}`);
resp.status(401).send('No transport found for sessionId');
@@ -94,8 +120,6 @@ export class McpServer {
resp.flush();
}
delete this.tools[sessionId]; // Clean up to avoid keeping all tools in memory
return wasToolCall(req.rawBody.toString());
}
@@ -110,7 +134,9 @@ export class McpServer {
},
);
server.setRequestHandler(ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => {
server.setRequestHandler(
ListToolsRequestSchema,
async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the listing of tools');
}
@@ -125,9 +151,12 @@ export class McpServer {
};
}),
};
});
},
);
server.setRequestHandler(CallToolRequestSchema, async (request, extra: RequestHandlerExtra) => {
server.setRequestHandler(
CallToolRequestSchema,
async (request, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
if (!request.params?.name || !request.params?.arguments) {
throw new OperationalError('Require a name and arguments for the tool call');
}
@@ -135,6 +164,8 @@ export class McpServer {
throw new OperationalError('Require a sessionId for the tool call');
}
const callId = extra.requestId ? `${extra.sessionId}_${extra.requestId}` : extra.sessionId;
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
(tool) => tool.name === request.params.name,
);
@@ -144,8 +175,11 @@ export class McpServer {
try {
const result = await requestedTool.invoke(request.params.arguments);
this.resolveFunctions[extra.sessionId]();
if (this.resolveFunctions[callId]) {
this.resolveFunctions[callId]();
} else {
this.logger.warn(`No resolve function found for ${callId}`);
}
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
@@ -160,7 +194,8 @@ export class McpServer {
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
}
});
},
);
server.onclose = () => {
this.logger.debug('Closing MCP Server');

View File

@@ -78,7 +78,7 @@ describe('McpServer', () => {
it('should call transport.handlePostMessage when transport exists', async () => {
mockTransport.handlePostMessage.mockImplementation(async () => {
// @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[sessionId]();
mcpServer.resolveFunctions[`${sessionId}_123`]();
});
// Add the transport directly
@@ -110,6 +110,61 @@ describe('McpServer', () => {
expect(mockResponse.flush).toHaveBeenCalled();
});
it('should handle multiple tool calls with different ids', async () => {
const firstId = 123;
const secondId = 456;
mockTransport.handlePostMessage.mockImplementation(async () => {
const requestKey = mockRequest.rawBody?.toString().includes(`"id":${firstId}`)
? `${sessionId}_${firstId}`
: `${sessionId}_${secondId}`;
// @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[requestKey]();
});
// Add the transport directly
mcpServer.transports[sessionId] = mockTransport;
// First tool call
mockRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: firstId,
params: { name: 'mockTool', arguments: { param: 'first call' } },
}),
);
// Handle first tool call
const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
expect(firstResult).toBe(true);
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
mockRequest,
mockResponse,
expect.any(String),
);
// Second tool call with different id
mockRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: secondId,
params: { name: 'mockTool', arguments: { param: 'second call' } },
}),
);
// Handle second tool call
const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
expect(secondResult).toBe(true);
// Verify transport's handlePostMessage was called twice
expect(mockTransport.handlePostMessage).toHaveBeenCalledTimes(2);
// Verify flush was called for both requests
expect(mockResponse.flush).toHaveBeenCalledTimes(2);
});
it('should return 401 when transport does not exist', async () => {
// Call without setting up transport
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);

View File

@@ -173,7 +173,7 @@
"@langchain/qdrant": "0.1.1",
"@langchain/redis": "0.1.0",
"@langchain/textsplitters": "0.1.0",
"@modelcontextprotocol/sdk": "1.9.0",
"@modelcontextprotocol/sdk": "1.11.0",
"@mozilla/readability": "0.6.0",
"@n8n/client-oauth2": "workspace:*",
"@n8n/json-schema-to-zod": "workspace:*",

10
pnpm-lock.yaml generated
View File

@@ -783,8 +783,8 @@ importers:
specifier: 0.1.0
version: 0.1.0(@langchain/core@0.3.30(openai@4.78.1(encoding@0.1.13)(zod@3.24.1)))
'@modelcontextprotocol/sdk':
specifier: 1.9.0
version: 1.9.0
specifier: 1.11.0
version: 1.11.0
'@mozilla/readability':
specifier: 0.6.0
version: 0.6.0
@@ -4810,8 +4810,8 @@ packages:
peerDependencies:
zod: '>= 3'
'@modelcontextprotocol/sdk@1.9.0':
resolution: {integrity: sha512-Jq2EUCQpe0iyO5FGpzVYDNFR6oR53AIrwph9yWl7uSc7IWUMsrmpmSaTGra5hQNunXpM+9oit85p924jWuHzUA==}
'@modelcontextprotocol/sdk@1.11.0':
resolution: {integrity: sha512-k/1pb70eD638anoi0e8wUGAlbMJXyvdV4p62Ko+EZ7eBe1xMx8Uhak1R5DgfoofsK5IBBnRwsYGTaLZl+6/+RQ==}
engines: {node: '>=18'}
'@mongodb-js/saslprep@1.1.9':
@@ -17242,7 +17242,7 @@ snapshots:
dependencies:
zod: 3.24.1
'@modelcontextprotocol/sdk@1.9.0':
'@modelcontextprotocol/sdk@1.11.0':
dependencies:
content-type: 1.0.5
cors: 2.8.5