mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(MCP Server Trigger Node): Handle multiple tool calls in mcp server trigger (#15064)
This commit is contained in:
committed by
GitHub
parent
b37387180c
commit
59ba162bd9
@@ -1,7 +1,11 @@
|
||||
import type { Tool } from '@langchain/core/tools';
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type {
|
||||
JSONRPCMessage,
|
||||
ServerRequest,
|
||||
ServerNotification,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import {
|
||||
JSONRPCMessageSchema,
|
||||
ListToolsRequestSchema,
|
||||
@@ -33,6 +37,20 @@ function wasToolCall(body: string) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts the request ID from a JSONRPC message
|
||||
* Returns undefined if the message doesn't have an ID or can't be parsed
|
||||
*/
|
||||
function getRequestId(body: string): string | undefined {
|
||||
try {
|
||||
const message: unknown = JSON.parse(body);
|
||||
const parsedMessage: JSONRPCMessage = JSONRPCMessageSchema.parse(message);
|
||||
return 'id' in parsedMessage ? String(parsedMessage.id) : undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export class McpServer {
|
||||
servers: { [sessionId: string]: Server } = {};
|
||||
|
||||
@@ -42,7 +60,7 @@ export class McpServer {
|
||||
|
||||
private tools: { [sessionId: string]: Tool[] } = {};
|
||||
|
||||
private resolveFunctions: { [sessionId: string]: CallableFunction } = {};
|
||||
private resolveFunctions: { [callId: string]: CallableFunction } = {};
|
||||
|
||||
constructor(logger: Logger) {
|
||||
this.logger = logger;
|
||||
@@ -59,7 +77,6 @@ export class McpServer {
|
||||
resp.on('close', async () => {
|
||||
this.logger.debug(`Deleting transport for ${sessionId}`);
|
||||
delete this.tools[sessionId];
|
||||
delete this.resolveFunctions[sessionId];
|
||||
delete this.transports[sessionId];
|
||||
delete this.servers[sessionId];
|
||||
});
|
||||
@@ -75,16 +92,25 @@ export class McpServer {
|
||||
async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) {
|
||||
const sessionId = req.query.sessionId as string;
|
||||
const transport = this.transports[sessionId];
|
||||
this.tools[sessionId] = connectedTools;
|
||||
if (transport) {
|
||||
// We need to add a promise here because the `handlePostMessage` will send something to the
|
||||
// MCP Server, that will run in a different context. This means that the return will happen
|
||||
// almost immediately, and will lead to marking the sub-node as "running" in the final execution
|
||||
await new Promise(async (resolve) => {
|
||||
this.resolveFunctions[sessionId] = resolve;
|
||||
await transport.handlePostMessage(req, resp, req.rawBody.toString());
|
||||
});
|
||||
delete this.resolveFunctions[sessionId];
|
||||
const bodyString = req.rawBody.toString();
|
||||
const messageId = getRequestId(bodyString);
|
||||
|
||||
// Use session & message ID if available, otherwise fall back to sessionId
|
||||
const callId = messageId ? `${sessionId}_${messageId}` : sessionId;
|
||||
this.tools[sessionId] = connectedTools;
|
||||
|
||||
try {
|
||||
await new Promise(async (resolve) => {
|
||||
this.resolveFunctions[callId] = resolve;
|
||||
await transport.handlePostMessage(req, resp, bodyString);
|
||||
});
|
||||
} finally {
|
||||
delete this.resolveFunctions[callId];
|
||||
}
|
||||
} else {
|
||||
this.logger.warn(`No transport found for session ${sessionId}`);
|
||||
resp.status(401).send('No transport found for sessionId');
|
||||
@@ -94,8 +120,6 @@ export class McpServer {
|
||||
resp.flush();
|
||||
}
|
||||
|
||||
delete this.tools[sessionId]; // Clean up to avoid keeping all tools in memory
|
||||
|
||||
return wasToolCall(req.rawBody.toString());
|
||||
}
|
||||
|
||||
@@ -110,57 +134,68 @@ export class McpServer {
|
||||
},
|
||||
);
|
||||
|
||||
server.setRequestHandler(ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => {
|
||||
if (!extra.sessionId) {
|
||||
throw new OperationalError('Require a sessionId for the listing of tools');
|
||||
}
|
||||
|
||||
return {
|
||||
tools: this.tools[extra.sessionId].map((tool) => {
|
||||
return {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
// Allow additional properties on tool call input
|
||||
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
|
||||
};
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request, extra: RequestHandlerExtra) => {
|
||||
if (!request.params?.name || !request.params?.arguments) {
|
||||
throw new OperationalError('Require a name and arguments for the tool call');
|
||||
}
|
||||
if (!extra.sessionId) {
|
||||
throw new OperationalError('Require a sessionId for the tool call');
|
||||
}
|
||||
|
||||
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
|
||||
(tool) => tool.name === request.params.name,
|
||||
);
|
||||
if (!requestedTool) {
|
||||
throw new OperationalError('Tool not found');
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await requestedTool.invoke(request.params.arguments);
|
||||
|
||||
this.resolveFunctions[extra.sessionId]();
|
||||
|
||||
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
|
||||
|
||||
if (typeof result === 'object') {
|
||||
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
|
||||
server.setRequestHandler(
|
||||
ListToolsRequestSchema,
|
||||
async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
|
||||
if (!extra.sessionId) {
|
||||
throw new OperationalError('Require a sessionId for the listing of tools');
|
||||
}
|
||||
if (typeof result === 'string') {
|
||||
return { content: [{ type: 'text', text: result }] };
|
||||
|
||||
return {
|
||||
tools: this.tools[extra.sessionId].map((tool) => {
|
||||
return {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
// Allow additional properties on tool call input
|
||||
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
|
||||
};
|
||||
}),
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
server.setRequestHandler(
|
||||
CallToolRequestSchema,
|
||||
async (request, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
|
||||
if (!request.params?.name || !request.params?.arguments) {
|
||||
throw new OperationalError('Require a name and arguments for the tool call');
|
||||
}
|
||||
return { content: [{ type: 'text', text: String(result) }] };
|
||||
} catch (error) {
|
||||
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
|
||||
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
|
||||
}
|
||||
});
|
||||
if (!extra.sessionId) {
|
||||
throw new OperationalError('Require a sessionId for the tool call');
|
||||
}
|
||||
|
||||
const callId = extra.requestId ? `${extra.sessionId}_${extra.requestId}` : extra.sessionId;
|
||||
|
||||
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
|
||||
(tool) => tool.name === request.params.name,
|
||||
);
|
||||
if (!requestedTool) {
|
||||
throw new OperationalError('Tool not found');
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await requestedTool.invoke(request.params.arguments);
|
||||
if (this.resolveFunctions[callId]) {
|
||||
this.resolveFunctions[callId]();
|
||||
} else {
|
||||
this.logger.warn(`No resolve function found for ${callId}`);
|
||||
}
|
||||
|
||||
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
|
||||
|
||||
if (typeof result === 'object') {
|
||||
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
|
||||
}
|
||||
if (typeof result === 'string') {
|
||||
return { content: [{ type: 'text', text: result }] };
|
||||
}
|
||||
return { content: [{ type: 'text', text: String(result) }] };
|
||||
} catch (error) {
|
||||
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
|
||||
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
server.onclose = () => {
|
||||
this.logger.debug('Closing MCP Server');
|
||||
|
||||
@@ -78,7 +78,7 @@ describe('McpServer', () => {
|
||||
it('should call transport.handlePostMessage when transport exists', async () => {
|
||||
mockTransport.handlePostMessage.mockImplementation(async () => {
|
||||
// @ts-expect-error private property `resolveFunctions`
|
||||
mcpServer.resolveFunctions[sessionId]();
|
||||
mcpServer.resolveFunctions[`${sessionId}_123`]();
|
||||
});
|
||||
|
||||
// Add the transport directly
|
||||
@@ -110,6 +110,61 @@ describe('McpServer', () => {
|
||||
expect(mockResponse.flush).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle multiple tool calls with different ids', async () => {
|
||||
const firstId = 123;
|
||||
const secondId = 456;
|
||||
|
||||
mockTransport.handlePostMessage.mockImplementation(async () => {
|
||||
const requestKey = mockRequest.rawBody?.toString().includes(`"id":${firstId}`)
|
||||
? `${sessionId}_${firstId}`
|
||||
: `${sessionId}_${secondId}`;
|
||||
// @ts-expect-error private property `resolveFunctions`
|
||||
mcpServer.resolveFunctions[requestKey]();
|
||||
});
|
||||
|
||||
// Add the transport directly
|
||||
mcpServer.transports[sessionId] = mockTransport;
|
||||
|
||||
// First tool call
|
||||
mockRequest.rawBody = Buffer.from(
|
||||
JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'tools/call',
|
||||
id: firstId,
|
||||
params: { name: 'mockTool', arguments: { param: 'first call' } },
|
||||
}),
|
||||
);
|
||||
|
||||
// Handle first tool call
|
||||
const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
||||
expect(firstResult).toBe(true);
|
||||
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
|
||||
mockRequest,
|
||||
mockResponse,
|
||||
expect.any(String),
|
||||
);
|
||||
|
||||
// Second tool call with different id
|
||||
mockRequest.rawBody = Buffer.from(
|
||||
JSON.stringify({
|
||||
jsonrpc: '2.0',
|
||||
method: 'tools/call',
|
||||
id: secondId,
|
||||
params: { name: 'mockTool', arguments: { param: 'second call' } },
|
||||
}),
|
||||
);
|
||||
|
||||
// Handle second tool call
|
||||
const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
||||
expect(secondResult).toBe(true);
|
||||
|
||||
// Verify transport's handlePostMessage was called twice
|
||||
expect(mockTransport.handlePostMessage).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify flush was called for both requests
|
||||
expect(mockResponse.flush).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should return 401 when transport does not exist', async () => {
|
||||
// Call without setting up transport
|
||||
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
||||
|
||||
Reference in New Issue
Block a user