feat(MCP Server Trigger Node): Handle multiple tool calls in mcp server trigger (#15064)

This commit is contained in:
Yiorgis Gozadinos
2025-05-07 21:37:07 +02:00
committed by GitHub
parent b37387180c
commit 59ba162bd9
4 changed files with 157 additions and 67 deletions

View File

@@ -1,7 +1,11 @@
import type { Tool } from '@langchain/core/tools';
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type {
JSONRPCMessage,
ServerRequest,
ServerNotification,
} from '@modelcontextprotocol/sdk/types.js';
import {
JSONRPCMessageSchema,
ListToolsRequestSchema,
@@ -33,6 +37,20 @@ function wasToolCall(body: string) {
}
}
/**
* Extracts the request ID from a JSONRPC message
* Returns undefined if the message doesn't have an ID or can't be parsed
*/
function getRequestId(body: string): string | undefined {
try {
const message: unknown = JSON.parse(body);
const parsedMessage: JSONRPCMessage = JSONRPCMessageSchema.parse(message);
return 'id' in parsedMessage ? String(parsedMessage.id) : undefined;
} catch {
return undefined;
}
}
export class McpServer {
servers: { [sessionId: string]: Server } = {};
@@ -42,7 +60,7 @@ export class McpServer {
private tools: { [sessionId: string]: Tool[] } = {};
private resolveFunctions: { [sessionId: string]: CallableFunction } = {};
private resolveFunctions: { [callId: string]: CallableFunction } = {};
constructor(logger: Logger) {
this.logger = logger;
@@ -59,7 +77,6 @@ export class McpServer {
resp.on('close', async () => {
this.logger.debug(`Deleting transport for ${sessionId}`);
delete this.tools[sessionId];
delete this.resolveFunctions[sessionId];
delete this.transports[sessionId];
delete this.servers[sessionId];
});
@@ -75,16 +92,25 @@ export class McpServer {
async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) {
const sessionId = req.query.sessionId as string;
const transport = this.transports[sessionId];
this.tools[sessionId] = connectedTools;
if (transport) {
// We need to add a promise here because the `handlePostMessage` will send something to the
// MCP Server, that will run in a different context. This means that the return will happen
// almost immediately, and will lead to marking the sub-node as "running" in the final execution
await new Promise(async (resolve) => {
this.resolveFunctions[sessionId] = resolve;
await transport.handlePostMessage(req, resp, req.rawBody.toString());
});
delete this.resolveFunctions[sessionId];
const bodyString = req.rawBody.toString();
const messageId = getRequestId(bodyString);
// Use session & message ID if available, otherwise fall back to sessionId
const callId = messageId ? `${sessionId}_${messageId}` : sessionId;
this.tools[sessionId] = connectedTools;
try {
await new Promise(async (resolve) => {
this.resolveFunctions[callId] = resolve;
await transport.handlePostMessage(req, resp, bodyString);
});
} finally {
delete this.resolveFunctions[callId];
}
} else {
this.logger.warn(`No transport found for session ${sessionId}`);
resp.status(401).send('No transport found for sessionId');
@@ -94,8 +120,6 @@ export class McpServer {
resp.flush();
}
delete this.tools[sessionId]; // Clean up to avoid keeping all tools in memory
return wasToolCall(req.rawBody.toString());
}
@@ -110,57 +134,68 @@ export class McpServer {
},
);
server.setRequestHandler(ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => {
if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the listing of tools');
}
return {
tools: this.tools[extra.sessionId].map((tool) => {
return {
name: tool.name,
description: tool.description,
// Allow additional properties on tool call input
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
};
}),
};
});
server.setRequestHandler(CallToolRequestSchema, async (request, extra: RequestHandlerExtra) => {
if (!request.params?.name || !request.params?.arguments) {
throw new OperationalError('Require a name and arguments for the tool call');
}
if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the tool call');
}
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
(tool) => tool.name === request.params.name,
);
if (!requestedTool) {
throw new OperationalError('Tool not found');
}
try {
const result = await requestedTool.invoke(request.params.arguments);
this.resolveFunctions[extra.sessionId]();
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
if (typeof result === 'object') {
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
server.setRequestHandler(
ListToolsRequestSchema,
async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the listing of tools');
}
if (typeof result === 'string') {
return { content: [{ type: 'text', text: result }] };
return {
tools: this.tools[extra.sessionId].map((tool) => {
return {
name: tool.name,
description: tool.description,
// Allow additional properties on tool call input
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
};
}),
};
},
);
server.setRequestHandler(
CallToolRequestSchema,
async (request, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
if (!request.params?.name || !request.params?.arguments) {
throw new OperationalError('Require a name and arguments for the tool call');
}
return { content: [{ type: 'text', text: String(result) }] };
} catch (error) {
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
}
});
if (!extra.sessionId) {
throw new OperationalError('Require a sessionId for the tool call');
}
const callId = extra.requestId ? `${extra.sessionId}_${extra.requestId}` : extra.sessionId;
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
(tool) => tool.name === request.params.name,
);
if (!requestedTool) {
throw new OperationalError('Tool not found');
}
try {
const result = await requestedTool.invoke(request.params.arguments);
if (this.resolveFunctions[callId]) {
this.resolveFunctions[callId]();
} else {
this.logger.warn(`No resolve function found for ${callId}`);
}
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
if (typeof result === 'object') {
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
}
if (typeof result === 'string') {
return { content: [{ type: 'text', text: result }] };
}
return { content: [{ type: 'text', text: String(result) }] };
} catch (error) {
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
}
},
);
server.onclose = () => {
this.logger.debug('Closing MCP Server');

View File

@@ -78,7 +78,7 @@ describe('McpServer', () => {
it('should call transport.handlePostMessage when transport exists', async () => {
mockTransport.handlePostMessage.mockImplementation(async () => {
// @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[sessionId]();
mcpServer.resolveFunctions[`${sessionId}_123`]();
});
// Add the transport directly
@@ -110,6 +110,61 @@ describe('McpServer', () => {
expect(mockResponse.flush).toHaveBeenCalled();
});
it('should handle multiple tool calls with different ids', async () => {
const firstId = 123;
const secondId = 456;
mockTransport.handlePostMessage.mockImplementation(async () => {
const requestKey = mockRequest.rawBody?.toString().includes(`"id":${firstId}`)
? `${sessionId}_${firstId}`
: `${sessionId}_${secondId}`;
// @ts-expect-error private property `resolveFunctions`
mcpServer.resolveFunctions[requestKey]();
});
// Add the transport directly
mcpServer.transports[sessionId] = mockTransport;
// First tool call
mockRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: firstId,
params: { name: 'mockTool', arguments: { param: 'first call' } },
}),
);
// Handle first tool call
const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
expect(firstResult).toBe(true);
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
mockRequest,
mockResponse,
expect.any(String),
);
// Second tool call with different id
mockRequest.rawBody = Buffer.from(
JSON.stringify({
jsonrpc: '2.0',
method: 'tools/call',
id: secondId,
params: { name: 'mockTool', arguments: { param: 'second call' } },
}),
);
// Handle second tool call
const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
expect(secondResult).toBe(true);
// Verify transport's handlePostMessage was called twice
expect(mockTransport.handlePostMessage).toHaveBeenCalledTimes(2);
// Verify flush was called for both requests
expect(mockResponse.flush).toHaveBeenCalledTimes(2);
});
it('should return 401 when transport does not exist', async () => {
// Call without setting up transport
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);