mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 10:02:05 +00:00
feat(MCP Server Trigger Node): Handle multiple tool calls in mcp server trigger (#15064)
This commit is contained in:
committed by
GitHub
parent
b37387180c
commit
59ba162bd9
@@ -1,7 +1,11 @@
|
|||||||
import type { Tool } from '@langchain/core/tools';
|
import type { Tool } from '@langchain/core/tools';
|
||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||||
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
|
import type { RequestHandlerExtra } from '@modelcontextprotocol/sdk/shared/protocol.js';
|
||||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
import type {
|
||||||
|
JSONRPCMessage,
|
||||||
|
ServerRequest,
|
||||||
|
ServerNotification,
|
||||||
|
} from '@modelcontextprotocol/sdk/types.js';
|
||||||
import {
|
import {
|
||||||
JSONRPCMessageSchema,
|
JSONRPCMessageSchema,
|
||||||
ListToolsRequestSchema,
|
ListToolsRequestSchema,
|
||||||
@@ -33,6 +37,20 @@ function wasToolCall(body: string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the request ID from a JSONRPC message
|
||||||
|
* Returns undefined if the message doesn't have an ID or can't be parsed
|
||||||
|
*/
|
||||||
|
function getRequestId(body: string): string | undefined {
|
||||||
|
try {
|
||||||
|
const message: unknown = JSON.parse(body);
|
||||||
|
const parsedMessage: JSONRPCMessage = JSONRPCMessageSchema.parse(message);
|
||||||
|
return 'id' in parsedMessage ? String(parsedMessage.id) : undefined;
|
||||||
|
} catch {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export class McpServer {
|
export class McpServer {
|
||||||
servers: { [sessionId: string]: Server } = {};
|
servers: { [sessionId: string]: Server } = {};
|
||||||
|
|
||||||
@@ -42,7 +60,7 @@ export class McpServer {
|
|||||||
|
|
||||||
private tools: { [sessionId: string]: Tool[] } = {};
|
private tools: { [sessionId: string]: Tool[] } = {};
|
||||||
|
|
||||||
private resolveFunctions: { [sessionId: string]: CallableFunction } = {};
|
private resolveFunctions: { [callId: string]: CallableFunction } = {};
|
||||||
|
|
||||||
constructor(logger: Logger) {
|
constructor(logger: Logger) {
|
||||||
this.logger = logger;
|
this.logger = logger;
|
||||||
@@ -59,7 +77,6 @@ export class McpServer {
|
|||||||
resp.on('close', async () => {
|
resp.on('close', async () => {
|
||||||
this.logger.debug(`Deleting transport for ${sessionId}`);
|
this.logger.debug(`Deleting transport for ${sessionId}`);
|
||||||
delete this.tools[sessionId];
|
delete this.tools[sessionId];
|
||||||
delete this.resolveFunctions[sessionId];
|
|
||||||
delete this.transports[sessionId];
|
delete this.transports[sessionId];
|
||||||
delete this.servers[sessionId];
|
delete this.servers[sessionId];
|
||||||
});
|
});
|
||||||
@@ -75,16 +92,25 @@ export class McpServer {
|
|||||||
async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) {
|
async handlePostMessage(req: express.Request, resp: CompressionResponse, connectedTools: Tool[]) {
|
||||||
const sessionId = req.query.sessionId as string;
|
const sessionId = req.query.sessionId as string;
|
||||||
const transport = this.transports[sessionId];
|
const transport = this.transports[sessionId];
|
||||||
this.tools[sessionId] = connectedTools;
|
|
||||||
if (transport) {
|
if (transport) {
|
||||||
// We need to add a promise here because the `handlePostMessage` will send something to the
|
// We need to add a promise here because the `handlePostMessage` will send something to the
|
||||||
// MCP Server, that will run in a different context. This means that the return will happen
|
// MCP Server, that will run in a different context. This means that the return will happen
|
||||||
// almost immediately, and will lead to marking the sub-node as "running" in the final execution
|
// almost immediately, and will lead to marking the sub-node as "running" in the final execution
|
||||||
await new Promise(async (resolve) => {
|
const bodyString = req.rawBody.toString();
|
||||||
this.resolveFunctions[sessionId] = resolve;
|
const messageId = getRequestId(bodyString);
|
||||||
await transport.handlePostMessage(req, resp, req.rawBody.toString());
|
|
||||||
});
|
// Use session & message ID if available, otherwise fall back to sessionId
|
||||||
delete this.resolveFunctions[sessionId];
|
const callId = messageId ? `${sessionId}_${messageId}` : sessionId;
|
||||||
|
this.tools[sessionId] = connectedTools;
|
||||||
|
|
||||||
|
try {
|
||||||
|
await new Promise(async (resolve) => {
|
||||||
|
this.resolveFunctions[callId] = resolve;
|
||||||
|
await transport.handlePostMessage(req, resp, bodyString);
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
delete this.resolveFunctions[callId];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
this.logger.warn(`No transport found for session ${sessionId}`);
|
this.logger.warn(`No transport found for session ${sessionId}`);
|
||||||
resp.status(401).send('No transport found for sessionId');
|
resp.status(401).send('No transport found for sessionId');
|
||||||
@@ -94,8 +120,6 @@ export class McpServer {
|
|||||||
resp.flush();
|
resp.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
delete this.tools[sessionId]; // Clean up to avoid keeping all tools in memory
|
|
||||||
|
|
||||||
return wasToolCall(req.rawBody.toString());
|
return wasToolCall(req.rawBody.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,57 +134,68 @@ export class McpServer {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
server.setRequestHandler(ListToolsRequestSchema, async (_, extra: RequestHandlerExtra) => {
|
server.setRequestHandler(
|
||||||
if (!extra.sessionId) {
|
ListToolsRequestSchema,
|
||||||
throw new OperationalError('Require a sessionId for the listing of tools');
|
async (_, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
|
||||||
}
|
if (!extra.sessionId) {
|
||||||
|
throw new OperationalError('Require a sessionId for the listing of tools');
|
||||||
return {
|
|
||||||
tools: this.tools[extra.sessionId].map((tool) => {
|
|
||||||
return {
|
|
||||||
name: tool.name,
|
|
||||||
description: tool.description,
|
|
||||||
// Allow additional properties on tool call input
|
|
||||||
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
|
|
||||||
};
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
server.setRequestHandler(CallToolRequestSchema, async (request, extra: RequestHandlerExtra) => {
|
|
||||||
if (!request.params?.name || !request.params?.arguments) {
|
|
||||||
throw new OperationalError('Require a name and arguments for the tool call');
|
|
||||||
}
|
|
||||||
if (!extra.sessionId) {
|
|
||||||
throw new OperationalError('Require a sessionId for the tool call');
|
|
||||||
}
|
|
||||||
|
|
||||||
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
|
|
||||||
(tool) => tool.name === request.params.name,
|
|
||||||
);
|
|
||||||
if (!requestedTool) {
|
|
||||||
throw new OperationalError('Tool not found');
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const result = await requestedTool.invoke(request.params.arguments);
|
|
||||||
|
|
||||||
this.resolveFunctions[extra.sessionId]();
|
|
||||||
|
|
||||||
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
|
|
||||||
|
|
||||||
if (typeof result === 'object') {
|
|
||||||
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
|
|
||||||
}
|
}
|
||||||
if (typeof result === 'string') {
|
|
||||||
return { content: [{ type: 'text', text: result }] };
|
return {
|
||||||
|
tools: this.tools[extra.sessionId].map((tool) => {
|
||||||
|
return {
|
||||||
|
name: tool.name,
|
||||||
|
description: tool.description,
|
||||||
|
// Allow additional properties on tool call input
|
||||||
|
inputSchema: zodToJsonSchema(tool.schema, { removeAdditionalStrategy: 'strict' }),
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
server.setRequestHandler(
|
||||||
|
CallToolRequestSchema,
|
||||||
|
async (request, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
|
||||||
|
if (!request.params?.name || !request.params?.arguments) {
|
||||||
|
throw new OperationalError('Require a name and arguments for the tool call');
|
||||||
}
|
}
|
||||||
return { content: [{ type: 'text', text: String(result) }] };
|
if (!extra.sessionId) {
|
||||||
} catch (error) {
|
throw new OperationalError('Require a sessionId for the tool call');
|
||||||
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
|
}
|
||||||
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
|
|
||||||
}
|
const callId = extra.requestId ? `${extra.sessionId}_${extra.requestId}` : extra.sessionId;
|
||||||
});
|
|
||||||
|
const requestedTool: Tool | undefined = this.tools[extra.sessionId].find(
|
||||||
|
(tool) => tool.name === request.params.name,
|
||||||
|
);
|
||||||
|
if (!requestedTool) {
|
||||||
|
throw new OperationalError('Tool not found');
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await requestedTool.invoke(request.params.arguments);
|
||||||
|
if (this.resolveFunctions[callId]) {
|
||||||
|
this.resolveFunctions[callId]();
|
||||||
|
} else {
|
||||||
|
this.logger.warn(`No resolve function found for ${callId}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.logger.debug(`Got request for ${requestedTool.name}, and executed it.`);
|
||||||
|
|
||||||
|
if (typeof result === 'object') {
|
||||||
|
return { content: [{ type: 'text', text: JSON.stringify(result) }] };
|
||||||
|
}
|
||||||
|
if (typeof result === 'string') {
|
||||||
|
return { content: [{ type: 'text', text: result }] };
|
||||||
|
}
|
||||||
|
return { content: [{ type: 'text', text: String(result) }] };
|
||||||
|
} catch (error) {
|
||||||
|
this.logger.error(`Error while executing Tool ${requestedTool.name}: ${error}`);
|
||||||
|
return { isError: true, content: [{ type: 'text', text: `Error: ${error.message}` }] };
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
server.onclose = () => {
|
server.onclose = () => {
|
||||||
this.logger.debug('Closing MCP Server');
|
this.logger.debug('Closing MCP Server');
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ describe('McpServer', () => {
|
|||||||
it('should call transport.handlePostMessage when transport exists', async () => {
|
it('should call transport.handlePostMessage when transport exists', async () => {
|
||||||
mockTransport.handlePostMessage.mockImplementation(async () => {
|
mockTransport.handlePostMessage.mockImplementation(async () => {
|
||||||
// @ts-expect-error private property `resolveFunctions`
|
// @ts-expect-error private property `resolveFunctions`
|
||||||
mcpServer.resolveFunctions[sessionId]();
|
mcpServer.resolveFunctions[`${sessionId}_123`]();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add the transport directly
|
// Add the transport directly
|
||||||
@@ -110,6 +110,61 @@ describe('McpServer', () => {
|
|||||||
expect(mockResponse.flush).toHaveBeenCalled();
|
expect(mockResponse.flush).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle multiple tool calls with different ids', async () => {
|
||||||
|
const firstId = 123;
|
||||||
|
const secondId = 456;
|
||||||
|
|
||||||
|
mockTransport.handlePostMessage.mockImplementation(async () => {
|
||||||
|
const requestKey = mockRequest.rawBody?.toString().includes(`"id":${firstId}`)
|
||||||
|
? `${sessionId}_${firstId}`
|
||||||
|
: `${sessionId}_${secondId}`;
|
||||||
|
// @ts-expect-error private property `resolveFunctions`
|
||||||
|
mcpServer.resolveFunctions[requestKey]();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add the transport directly
|
||||||
|
mcpServer.transports[sessionId] = mockTransport;
|
||||||
|
|
||||||
|
// First tool call
|
||||||
|
mockRequest.rawBody = Buffer.from(
|
||||||
|
JSON.stringify({
|
||||||
|
jsonrpc: '2.0',
|
||||||
|
method: 'tools/call',
|
||||||
|
id: firstId,
|
||||||
|
params: { name: 'mockTool', arguments: { param: 'first call' } },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handle first tool call
|
||||||
|
const firstResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
||||||
|
expect(firstResult).toBe(true);
|
||||||
|
expect(mockTransport.handlePostMessage).toHaveBeenCalledWith(
|
||||||
|
mockRequest,
|
||||||
|
mockResponse,
|
||||||
|
expect.any(String),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Second tool call with different id
|
||||||
|
mockRequest.rawBody = Buffer.from(
|
||||||
|
JSON.stringify({
|
||||||
|
jsonrpc: '2.0',
|
||||||
|
method: 'tools/call',
|
||||||
|
id: secondId,
|
||||||
|
params: { name: 'mockTool', arguments: { param: 'second call' } },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handle second tool call
|
||||||
|
const secondResult = await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
||||||
|
expect(secondResult).toBe(true);
|
||||||
|
|
||||||
|
// Verify transport's handlePostMessage was called twice
|
||||||
|
expect(mockTransport.handlePostMessage).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
|
// Verify flush was called for both requests
|
||||||
|
expect(mockResponse.flush).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
|
||||||
it('should return 401 when transport does not exist', async () => {
|
it('should return 401 when transport does not exist', async () => {
|
||||||
// Call without setting up transport
|
// Call without setting up transport
|
||||||
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
await mcpServer.handlePostMessage(mockRequest, mockResponse, [mockTool]);
|
||||||
|
|||||||
@@ -173,7 +173,7 @@
|
|||||||
"@langchain/qdrant": "0.1.1",
|
"@langchain/qdrant": "0.1.1",
|
||||||
"@langchain/redis": "0.1.0",
|
"@langchain/redis": "0.1.0",
|
||||||
"@langchain/textsplitters": "0.1.0",
|
"@langchain/textsplitters": "0.1.0",
|
||||||
"@modelcontextprotocol/sdk": "1.9.0",
|
"@modelcontextprotocol/sdk": "1.11.0",
|
||||||
"@mozilla/readability": "0.6.0",
|
"@mozilla/readability": "0.6.0",
|
||||||
"@n8n/client-oauth2": "workspace:*",
|
"@n8n/client-oauth2": "workspace:*",
|
||||||
"@n8n/json-schema-to-zod": "workspace:*",
|
"@n8n/json-schema-to-zod": "workspace:*",
|
||||||
|
|||||||
10
pnpm-lock.yaml
generated
10
pnpm-lock.yaml
generated
@@ -783,8 +783,8 @@ importers:
|
|||||||
specifier: 0.1.0
|
specifier: 0.1.0
|
||||||
version: 0.1.0(@langchain/core@0.3.30(openai@4.78.1(encoding@0.1.13)(zod@3.24.1)))
|
version: 0.1.0(@langchain/core@0.3.30(openai@4.78.1(encoding@0.1.13)(zod@3.24.1)))
|
||||||
'@modelcontextprotocol/sdk':
|
'@modelcontextprotocol/sdk':
|
||||||
specifier: 1.9.0
|
specifier: 1.11.0
|
||||||
version: 1.9.0
|
version: 1.11.0
|
||||||
'@mozilla/readability':
|
'@mozilla/readability':
|
||||||
specifier: 0.6.0
|
specifier: 0.6.0
|
||||||
version: 0.6.0
|
version: 0.6.0
|
||||||
@@ -4810,8 +4810,8 @@ packages:
|
|||||||
peerDependencies:
|
peerDependencies:
|
||||||
zod: '>= 3'
|
zod: '>= 3'
|
||||||
|
|
||||||
'@modelcontextprotocol/sdk@1.9.0':
|
'@modelcontextprotocol/sdk@1.11.0':
|
||||||
resolution: {integrity: sha512-Jq2EUCQpe0iyO5FGpzVYDNFR6oR53AIrwph9yWl7uSc7IWUMsrmpmSaTGra5hQNunXpM+9oit85p924jWuHzUA==}
|
resolution: {integrity: sha512-k/1pb70eD638anoi0e8wUGAlbMJXyvdV4p62Ko+EZ7eBe1xMx8Uhak1R5DgfoofsK5IBBnRwsYGTaLZl+6/+RQ==}
|
||||||
engines: {node: '>=18'}
|
engines: {node: '>=18'}
|
||||||
|
|
||||||
'@mongodb-js/saslprep@1.1.9':
|
'@mongodb-js/saslprep@1.1.9':
|
||||||
@@ -17242,7 +17242,7 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
zod: 3.24.1
|
zod: 3.24.1
|
||||||
|
|
||||||
'@modelcontextprotocol/sdk@1.9.0':
|
'@modelcontextprotocol/sdk@1.11.0':
|
||||||
dependencies:
|
dependencies:
|
||||||
content-type: 1.0.5
|
content-type: 1.0.5
|
||||||
cors: 2.8.5
|
cors: 2.8.5
|
||||||
|
|||||||
Reference in New Issue
Block a user