diff --git a/packages/@n8n/config/src/configs/logging.config.ts b/packages/@n8n/config/src/configs/logging.config.ts index 3b9a9d1570..8bbac021b0 100644 --- a/packages/@n8n/config/src/configs/logging.config.ts +++ b/packages/@n8n/config/src/configs/logging.config.ts @@ -18,6 +18,7 @@ export const LOG_SCOPES = [ 'task-runner', 'insights', 'workflow-activation', + 'ssh-client', ] as const; export type LogScope = (typeof LOG_SCOPES)[number]; diff --git a/packages/core/src/execution-engine/__tests__/ssh-clients-manager.test.ts b/packages/core/src/execution-engine/__tests__/ssh-clients-manager.test.ts index a549cbbd25..66ac55dd0f 100644 --- a/packages/core/src/execution-engine/__tests__/ssh-clients-manager.test.ts +++ b/packages/core/src/execution-engine/__tests__/ssh-clients-manager.test.ts @@ -1,34 +1,43 @@ +import type { Logger } from '@n8n/backend-common'; +import { Container } from '@n8n/di'; import { mock } from 'jest-mock-extended'; import type { SSHCredentials } from 'n8n-workflow'; import { Client } from 'ssh2'; -import { SSHClientsManager } from '../ssh-clients-manager'; +import { SSHClientsConfig, SSHClientsManager } from '../ssh-clients-manager'; -describe('SSHClientsManager', () => { - const idleTimeout = 5 * 60; - const credentials: SSHCredentials = { - sshAuthenticateWith: 'password', - sshHost: 'example.com', - sshPort: 22, - sshUser: 'username', - sshPassword: 'password', - }; +const idleTimeout = 5 * 60; +const cleanUpInterval = 60; +const credentials: SSHCredentials = { + sshAuthenticateWith: 'password', + sshHost: 'example.com', + sshPort: 22, + sshUser: 'username', + sshPassword: 'password', +}; - let sshClientsManager: SSHClientsManager; - const connectSpy = jest.spyOn(Client.prototype, 'connect'); - const endSpy = jest.spyOn(Client.prototype, 'end'); +let sshClientsManager: SSHClientsManager; +const connectSpy = jest.spyOn(Client.prototype, 'connect'); +const endSpy = jest.spyOn(Client.prototype, 'end'); - beforeEach(() => { - jest.clearAllMocks(); - jest.useFakeTimers(); +beforeEach(() => { + jest.clearAllMocks(); - sshClientsManager = new SSHClientsManager(mock({ idleTimeout })); - connectSpy.mockImplementation(function (this: Client) { - this.emit('ready'); - return this; - }); + sshClientsManager = new SSHClientsManager( + mock({ idleTimeout }), + mock({ scoped: () => mock() }), + ); + connectSpy.mockImplementation(function (this: Client) { + this.emit('ready'); + return this; }); +}); +afterEach(() => { + sshClientsManager.onShutdown(); +}); + +describe('getClient', () => { it('should create a new SSH client', async () => { const client = await sshClientsManager.getClient(credentials); @@ -49,22 +58,180 @@ describe('SSHClientsManager', () => { expect(client1).toBe(client2); }); - it('should close all SSH connections on process exit', async () => { + it('should not create multiple clients for the same credentials in parallel', async () => { + // ARRANGE + connectSpy.mockImplementation(function (this: Client) { + setTimeout(() => this.emit('ready'), Math.random() * 10); + return this; + }); + + // ACT + const clients = await Promise.all([ + sshClientsManager.getClient(credentials), + sshClientsManager.getClient(credentials), + sshClientsManager.getClient(credentials), + sshClientsManager.getClient(credentials), + sshClientsManager.getClient(credentials), + sshClientsManager.getClient(credentials), + ]); + + // ASSERT + // returns the same client for all invocations + const ogClient = await sshClientsManager.getClient(credentials); + expect(clients).toHaveLength(6); + for (const client of clients) { + expect(client).toBe(ogClient); + } + expect(connectSpy).toHaveBeenCalledTimes(1); + }); +}); + +describe('onShutdown', () => { + it('should close all SSH connections when onShutdown is called', async () => { await sshClientsManager.getClient(credentials); sshClientsManager.onShutdown(); expect(endSpy).toHaveBeenCalledTimes(1); }); + it('should close all SSH connections on process exit', async () => { + // ARRANGE + await sshClientsManager.getClient(credentials); + + // ACT + // @ts-expect-error we're not supposed to emit `exit` so it's missing from + // the type definition + process.emit('exit'); + + // ASSERT + expect(endSpy).toHaveBeenCalledTimes(1); + }); +}); + +describe('cleanup', () => { + beforeEach(async () => { + jest.useFakeTimers(); + sshClientsManager = new SSHClientsManager( + mock({ idleTimeout }), + mock({ scoped: () => mock() }), + ); + }); + it('should cleanup stale SSH connections', async () => { await sshClientsManager.getClient({ ...credentials, sshHost: 'host1' }); await sshClientsManager.getClient({ ...credentials, sshHost: 'host2' }); await sshClientsManager.getClient({ ...credentials, sshHost: 'host3' }); - jest.advanceTimersByTime((idleTimeout + 1) * 1000); - sshClientsManager.cleanupStaleConnections(); + jest.advanceTimersByTime((idleTimeout + cleanUpInterval + 1) * 1000); expect(endSpy).toHaveBeenCalledTimes(3); expect(sshClientsManager.clients.size).toBe(0); }); + + describe('updateLastUsed', () => { + test('updates lastUsed in the registration', async () => { + // ARRANGE + const client = await sshClientsManager.getClient(credentials); + // schedule client for clean up soon + jest.advanceTimersByTime((idleTimeout - 1) * 1000); + + // ACT 1 + // updating lastUsed should prevent the clean up + sshClientsManager.updateLastUsed(client); + jest.advanceTimersByTime(idleTimeout * 1000); + + // ASSERT 1 + expect(endSpy).toHaveBeenCalledTimes(0); + + // ACT 1 + jest.advanceTimersByTime(cleanUpInterval * 1000); + + // ASSERT 1 + expect(endSpy).toHaveBeenCalledTimes(1); + }); + }); +}); + +describe('abort controller', () => { + test('call `abort` when the client emits `error`', async () => { + // ARRANGE + const abortController = new AbortController(); + const client = await sshClientsManager.getClient(credentials, abortController); + + // ACT 1 + client.emit('error', new Error()); + + // ASSERT 1 + expect(abortController.signal.aborted).toBe(true); + expect(endSpy).toHaveBeenCalledTimes(1); + }); + + test('call `abort` when the client emits `end`', async () => { + // ARRANGE + const abortController = new AbortController(); + const client = await sshClientsManager.getClient(credentials, abortController); + + // ACT 1 + client.emit('end'); + + // ASSERT 1 + expect(abortController.signal.aborted).toBe(true); + expect(endSpy).toHaveBeenCalledTimes(1); + }); + + test('call `abort` when the client emits `close`', async () => { + // ARRANGE + const abortController = new AbortController(); + const client = await sshClientsManager.getClient(credentials, abortController); + + // ACT 1 + client.emit('close'); + + // ASSERT 1 + expect(abortController.signal.aborted).toBe(true); + expect(endSpy).toHaveBeenCalledTimes(1); + }); + + test('closes client when `abort` is being called', async () => { + // ARRANGE + const abortController = new AbortController(); + await sshClientsManager.getClient(credentials, abortController); + + // ACT 1 + abortController.abort(); + + // ASSERT 1 + expect(endSpy).toHaveBeenCalledTimes(1); + }); +}); + +describe('SSHClientsConfig', () => { + beforeEach(() => { + Container.reset(); + }); + + test('allows overriding the default idle timeout', async () => { + // ARRANGE + process.env.N8N_SSH_TUNNEL_IDLE_TIMEOUT = '5'; + + // ACT + const config = Container.get(SSHClientsConfig); + + // ASSERT + expect(config.idleTimeout).toBe(5); + }); + + test.each(['-5', '0', 'foo'])( + 'fall back to default if N8N_SSH_TUNNEL_IDLE_TIMEOUT is `%s`', + async (value) => { + // ARRANGE + process.env.N8N_SSH_TUNNEL_IDLE_TIMEOUT = value; + + // ACT + const config = Container.get(SSHClientsConfig); + + // ASSERT + expect(config.idleTimeout).toBe(300); + }, + ); }); diff --git a/packages/core/src/execution-engine/node-execution-context/utils/__tests__/ssh-tunnel-helper-functions.test.ts b/packages/core/src/execution-engine/node-execution-context/utils/__tests__/ssh-tunnel-helper-functions.test.ts index 5300a7d3c2..be26686035 100644 --- a/packages/core/src/execution-engine/node-execution-context/utils/__tests__/ssh-tunnel-helper-functions.test.ts +++ b/packages/core/src/execution-engine/node-execution-context/utils/__tests__/ssh-tunnel-helper-functions.test.ts @@ -7,6 +7,7 @@ import { SSHClientsManager } from '../../../ssh-clients-manager'; import { getSSHTunnelFunctions } from '../ssh-tunnel-helper-functions'; describe('getSSHTunnelFunctions', () => { + const abortController = new AbortController(); const credentials = mock(); const sshClientsManager = mockInstance(SSHClientsManager); const sshTunnelFunctions = getSSHTunnelFunctions(); @@ -17,9 +18,22 @@ describe('getSSHTunnelFunctions', () => { describe('getSSHClient', () => { it('should invoke sshClientsManager.getClient', async () => { - await sshTunnelFunctions.getSSHClient(credentials); + await sshTunnelFunctions.getSSHClient(credentials, abortController); - expect(sshClientsManager.getClient).toHaveBeenCalledWith(credentials); + expect(sshClientsManager.getClient).toHaveBeenCalledWith(credentials, abortController); + }); + }); + + describe('updateLastUsed', () => { + it('should invoke sshClientsManager.updateLastUsed', async () => { + // ARRANGE + const client = await sshTunnelFunctions.getSSHClient(credentials, abortController); + + // ACT + sshTunnelFunctions.updateLastUsed(client); + + // ASSERT + expect(sshClientsManager.updateLastUsed).toHaveBeenCalledWith(client); }); }); }); diff --git a/packages/core/src/execution-engine/node-execution-context/utils/ssh-tunnel-helper-functions.ts b/packages/core/src/execution-engine/node-execution-context/utils/ssh-tunnel-helper-functions.ts index daeadd2e9f..651b46d17b 100644 --- a/packages/core/src/execution-engine/node-execution-context/utils/ssh-tunnel-helper-functions.ts +++ b/packages/core/src/execution-engine/node-execution-context/utils/ssh-tunnel-helper-functions.ts @@ -1,11 +1,14 @@ import { Container } from '@n8n/di'; import type { SSHTunnelFunctions } from 'n8n-workflow'; +import type { Client } from 'ssh2'; import { SSHClientsManager } from '../../ssh-clients-manager'; export const getSSHTunnelFunctions = (): SSHTunnelFunctions => { const sshClientsManager = Container.get(SSHClientsManager); return { - getSSHClient: async (credentials) => await sshClientsManager.getClient(credentials), + getSSHClient: async (credentials, abortController) => + await sshClientsManager.getClient(credentials, abortController), + updateLastUsed: (client: Client) => sshClientsManager.updateLastUsed(client), }; }; diff --git a/packages/core/src/execution-engine/ssh-clients-manager.ts b/packages/core/src/execution-engine/ssh-clients-manager.ts index 4156cb3ec3..7bec4f2b4a 100644 --- a/packages/core/src/execution-engine/ssh-clients-manager.ts +++ b/packages/core/src/execution-engine/ssh-clients-manager.ts @@ -1,31 +1,99 @@ +import { Logger } from '@n8n/backend-common'; import { Config, Env } from '@n8n/config'; import { Service } from '@n8n/di'; import type { SSHCredentials } from 'n8n-workflow'; import { createHash } from 'node:crypto'; import { Client, type ConnectConfig } from 'ssh2'; +import { z } from 'zod'; @Config -class SSHClientsConfig { +export class SSHClientsConfig { /** How many seconds before an idle SSH tunnel is closed */ - @Env('N8N_SSH_TUNNEL_IDLE_TIMEOUT') + @Env( + 'N8N_SSH_TUNNEL_IDLE_TIMEOUT', + z + .string() + .transform((value) => Number.parseInt(value)) + .superRefine((value, ctx) => { + if (Number.isNaN(value)) { + return ctx.addIssue({ + message: 'must be a valid integer', + code: 'custom', + }); + } + + if (value <= 0) { + return ctx.addIssue({ + message: 'must be positive', + code: 'too_small', + minimum: 0, + inclusive: false, + type: 'number', + }); + } + }), + ) idleTimeout: number = 5 * 60; } +type Registration = { + client: Client; + + /** + * We keep this timestamp to check if a client hasn't been used in a while, + * and if it needs to be closed. + */ + lastUsed: Date; + + abortController: AbortController; + + returnPromise: Promise; +}; + @Service() export class SSHClientsManager { - readonly clients = new Map(); + readonly clients = new Map(); - constructor(private readonly config: SSHClientsConfig) { + readonly clientsReversed = new WeakMap(); + + private cleanupTimer: NodeJS.Timeout; + + constructor( + private readonly config: SSHClientsConfig, + private readonly logger: Logger, + ) { // Close all SSH connections when the process exits process.on('exit', () => this.onShutdown()); - if (process.env.NODE_ENV === 'test') return; - // Regularly close stale SSH connections - setInterval(() => this.cleanupStaleConnections(), 60 * 1000); + this.cleanupTimer = setInterval(() => this.cleanupStaleConnections(), 60 * 1000); + + this.logger = logger.scoped('ssh-client'); } - async getClient(credentials: SSHCredentials): Promise { + updateLastUsed(client: Client) { + const key = this.clientsReversed.get(client); + + if (key) { + const registration = this.clients.get(key); + + if (registration) { + registration.lastUsed = new Date(); + } + } else { + const metadata = {}; + // eslint-disable-next-line @typescript-eslint/unbound-method + Error.captureStackTrace(metadata, this.updateLastUsed); + this.logger.warn( + 'Tried to update `lastUsed` for a client that has been cleaned up already. Probably forgot to subscribe to the AbortController somewhere.', + metadata, + ); + } + } + + async getClient(credentials: SSHCredentials, abortController?: AbortController): Promise { + abortController = abortController ?? new AbortController(); + const { sshAuthenticateWith, sshHost, sshPort, sshUser } = credentials; const sshConfig: ConnectConfig = { host: sshHost, @@ -44,28 +112,73 @@ export class SSHClientsManager { const existing = this.clients.get(clientHash); if (existing) { existing.lastUsed = new Date(); - return existing.client; + return await existing.returnPromise; } - return await new Promise((resolve, reject) => { - const sshClient = new Client(); + const sshClient = this.withCleanupHandler(new Client(), abortController, clientHash); + const returnPromise = new Promise((resolve, reject) => { sshClient.once('error', reject); sshClient.once('ready', () => { sshClient.off('error', reject); - sshClient.once('close', () => this.clients.delete(clientHash)); - this.clients.set(clientHash, { - client: sshClient, - lastUsed: new Date(), - }); resolve(sshClient); }); sshClient.connect(sshConfig); }); + + this.clients.set(clientHash, { + client: sshClient, + lastUsed: new Date(), + abortController, + returnPromise, + }); + this.clientsReversed.set(sshClient, clientHash); + + return await returnPromise; + } + + /** + * Registers the cleanup handler for events (error, close, end) on the ssh + * client and in the abort signal is received. + */ + private withCleanupHandler(sshClient: Client, abortController: AbortController, key: string) { + sshClient.on('error', (error) => { + this.logger.error('encountered error, calling cleanup', { error }); + this.cleanupClient(key); + }); + sshClient.on('end', () => { + this.logger.debug('socket was disconnected, calling abort signal', {}); + this.cleanupClient(key); + }); + sshClient.on('close', () => { + this.logger.debug('socket was closed, calling abort signal', {}); + this.cleanupClient(key); + }); + abortController.signal.addEventListener('abort', () => { + this.logger.debug('Got abort signal, cleaning up ssh client.', { + reason: abortController.signal.reason, + }); + this.cleanupClient(key); + }); + + return sshClient; + } + + private cleanupClient(key: string) { + const registration = this.clients.get(key); + if (registration) { + this.clients.delete(key); + registration.client.end(); + if (!registration.abortController.signal.aborted) { + registration.abortController.abort(); + } + } } onShutdown() { - for (const { client } of this.clients.values()) { - client.end(); + this.logger.debug('Shutting down. Cleaning up all clients'); + clearInterval(this.cleanupTimer); + for (const key of this.clients.keys()) { + this.cleanupClient(key); } } @@ -74,10 +187,10 @@ export class SSHClientsManager { if (clients.size === 0) return; const now = Date.now(); - for (const [hash, { client, lastUsed }] of clients.entries()) { + for (const [key, { lastUsed }] of clients.entries()) { if (now - lastUsed.getTime() > this.config.idleTimeout * 1000) { - client.end(); - clients.delete(hash); + this.logger.debug('Found stale client. Cleaning it up.'); + this.cleanupClient(key); } } } diff --git a/packages/nodes-base/nodes/Postgres/transport/index.ts b/packages/nodes-base/nodes/Postgres/transport/index.ts index 71684ddf95..bd0ed93307 100644 --- a/packages/nodes-base/nodes/Postgres/transport/index.ts +++ b/packages/nodes-base/nodes/Postgres/transport/index.ts @@ -3,8 +3,9 @@ import type { ICredentialTestFunctions, ILoadOptionsFunctions, ITriggerFunctions, + Logger, } from 'n8n-workflow'; -import { createServer, type AddressInfo } from 'node:net'; +import { createServer, type AddressInfo, type Server } from 'node:net'; import pgPromise from 'pg-promise'; import { ConnectionPoolManager } from '@utils/connection-pool-manager'; @@ -53,14 +54,37 @@ const getPostgresConfig = ( return dbConfig; }; +function withCleanupHandler(proxy: Server, abortController: AbortController, logger: Logger) { + proxy.on('error', (error) => { + logger.error('TCP Proxy: Got error, calling abort controller', { error }); + abortController.abort(); + }); + proxy.on('close', () => { + logger.error('TCP Proxy: Was closed, calling abort controller'); + abortController.abort(); + }); + proxy.on('drop', (dropArgument) => { + logger.error('TCP Proxy: Connection was dropped, calling abort controller', { + dropArgument, + }); + abortController.abort(); + }); + abortController.signal.addEventListener('abort', () => { + logger.debug('Got abort signal. Closing TCP proxy server.'); + proxy.close(); + }); + + return proxy; +} + export async function configurePostgres( this: IExecuteFunctions | ICredentialTestFunctions | ILoadOptionsFunctions | ITriggerFunctions, credentials: PostgresNodeCredentials, options: PostgresNodeOptions = {}, ): Promise { - const poolManager = ConnectionPoolManager.getInstance(); + const poolManager = ConnectionPoolManager.getInstance(this.logger); - const fallBackHandler = async () => { + const fallBackHandler = async (abortController: AbortController) => { const pgp = pgPromise({ // prevent spam in console "WARNING: Creating a duplicate database object for the same connection." // duplicate connections created when auto loading parameters, they are closed immediately after, but several could be open at the same time @@ -101,74 +125,33 @@ export async function configurePostgres( if (credentials.sshAuthenticateWith === 'privateKey' && credentials.privateKey) { credentials.privateKey = formatPrivateKey(credentials.privateKey); } - const sshClient = await this.helpers.getSSHClient(credentials); + const sshClient = await this.helpers.getSSHClient(credentials, abortController); // Create a TCP proxy listening on a random available port - const proxy = createServer(); + const proxy = withCleanupHandler(createServer(), abortController, this.logger); + const proxyPort = await new Promise((resolve) => { proxy.listen(0, LOCALHOST, () => { resolve((proxy.address() as AddressInfo).port); }); }); - const close = () => { - proxy.close(); - sshClient.off('end', close); - sshClient.off('error', close); - }; - sshClient.on('end', close); - sshClient.on('error', close); - - await new Promise((resolve, reject) => { - proxy.on('error', (err) => reject(err)); - proxy.on('connection', (localSocket) => { - sshClient.forwardOut( - LOCALHOST, - localSocket.remotePort!, - credentials.host, - credentials.port, - (err, clientChannel) => { - if (err) { - proxy.close(); - localSocket.destroy(); - } else { - localSocket.pipe(clientChannel); - clientChannel.pipe(localSocket); - } - }, - ); - }); - resolve(); - }).catch((err) => { - proxy.close(); - - let message = err.message; - let description = err.description; - - if (err.message.includes('ECONNREFUSED')) { - message = 'Connection refused'; - try { - description = err.message.split('ECONNREFUSED ')[1].trim(); - } catch (e) {} - } - - if (err.message.includes('ENOTFOUND')) { - message = 'Host not found'; - try { - description = err.message.split('ENOTFOUND ')[1].trim(); - } catch (e) {} - } - - if (err.message.includes('ETIMEDOUT')) { - message = 'Connection timed out'; - try { - description = err.message.split('ETIMEDOUT ')[1].trim(); - } catch (e) {} - } - - err.message = message; - err.description = description; - throw err; + proxy.on('connection', (localSocket) => { + sshClient.forwardOut( + LOCALHOST, + localSocket.remotePort!, + credentials.host, + credentials.port, + (error, clientChannel) => { + if (error) { + this.logger.error('SSH Client: Port forwarding encountered an error', { error }); + abortController.abort(); + } else { + localSocket.pipe(clientChannel); + clientChannel.pipe(localSocket); + } + }, + ); }); const db = pgp({ @@ -176,7 +159,20 @@ export async function configurePostgres( port: proxyPort, host: LOCALHOST, }); - return { db, pgp }; + + abortController.signal.addEventListener('abort', async () => { + this.logger.debug('configurePostgres: Got abort signal, closing pg connection.'); + try { + if (!db.$pool.ended) await db.$pool.end(); + } catch (error) { + this.logger.error('configurePostgres: Encountered error while closing the pool.', { + error, + }); + throw error; + } + }); + + return { db, pgp, sshClient }; } }; @@ -185,8 +181,10 @@ export async function configurePostgres( nodeType: 'postgres', nodeVersion: options.nodeVersion as unknown as string, fallBackHandler, - cleanUpHandler: async ({ db }) => { - if (!db.$pool.ended) await db.$pool.end(); + wasUsed: ({ sshClient }) => { + if (sshClient) { + this.helpers.updateLastUsed(sshClient); + } }, }); } diff --git a/packages/nodes-base/utils/__tests__/connection-pool-manager.test.ts b/packages/nodes-base/utils/__tests__/connection-pool-manager.test.ts index fe2c814582..64b53a713a 100644 --- a/packages/nodes-base/utils/__tests__/connection-pool-manager.test.ts +++ b/packages/nodes-base/utils/__tests__/connection-pool-manager.test.ts @@ -1,26 +1,31 @@ +import { mock } from 'jest-mock-extended'; +import { OperationalError, type Logger } from 'n8n-workflow'; + import { ConnectionPoolManager } from '@utils/connection-pool-manager'; const ttl = 5 * 60 * 1000; const cleanUpInterval = 60 * 1000; +const logger = mock(); + let cpm: ConnectionPoolManager; beforeAll(() => { jest.useFakeTimers(); - cpm = ConnectionPoolManager.getInstance(); + cpm = ConnectionPoolManager.getInstance(logger); }); beforeEach(async () => { - await cpm.purgeConnections(); + cpm.purgeConnections(); }); afterAll(() => { - cpm.onShutdown(); + cpm.purgeConnections(); }); test('getInstance returns a singleton', () => { - const instance1 = ConnectionPoolManager.getInstance(); - const instance2 = ConnectionPoolManager.getInstance(); + const instance1 = ConnectionPoolManager.getInstance(logger); + const instance2 = ConnectionPoolManager.getInstance(logger); expect(instance1).toBe(instance2); }); @@ -29,25 +34,27 @@ describe('getConnection', () => { test('calls fallBackHandler only once and returns the first value', async () => { // ARRANGE const connectionType = {}; - const fallBackHandler = jest.fn().mockResolvedValue(connectionType); - const cleanUpHandler = jest.fn(); + const fallBackHandler = jest.fn(async () => { + return connectionType; + }); + const options = { credentials: {}, nodeType: 'example', nodeVersion: '1', fallBackHandler, - cleanUpHandler, + wasUsed: jest.fn(), }; // ACT 1 - const connection = await cpm.getConnection(options); + const connection = await cpm.getConnection(options); // ASSERT 1 expect(fallBackHandler).toHaveBeenCalledTimes(1); expect(connection).toBe(connectionType); // ACT 2 - const connection2 = await cpm.getConnection(options); + const connection2 = await cpm.getConnection(options); // ASSERT 2 expect(fallBackHandler).toHaveBeenCalledTimes(1); expect(connection2).toBe(connectionType); @@ -56,27 +63,29 @@ describe('getConnection', () => { test('creates different pools for different node versions', async () => { // ARRANGE const connectionType1 = {}; - const fallBackHandler1 = jest.fn().mockResolvedValue(connectionType1); - const cleanUpHandler1 = jest.fn(); + const fallBackHandler1 = jest.fn(async () => { + return connectionType1; + }); const connectionType2 = {}; - const fallBackHandler2 = jest.fn().mockResolvedValue(connectionType2); - const cleanUpHandler2 = jest.fn(); + const fallBackHandler2 = jest.fn(async () => { + return connectionType2; + }); // ACT 1 - const connection1 = await cpm.getConnection({ + const connection1 = await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '1', fallBackHandler: fallBackHandler1, - cleanUpHandler: cleanUpHandler1, + wasUsed: jest.fn(), }); - const connection2 = await cpm.getConnection({ + const connection2 = await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '2', fallBackHandler: fallBackHandler2, - cleanUpHandler: cleanUpHandler2, + wasUsed: jest.fn(), }); // ASSERT @@ -92,21 +101,52 @@ describe('getConnection', () => { test('calls cleanUpHandler after TTL expires', async () => { // ARRANGE const connectionType = {}; - const fallBackHandler = jest.fn().mockResolvedValue(connectionType); - const cleanUpHandler = jest.fn(); - await cpm.getConnection({ + let abortController: AbortController | undefined; + const fallBackHandler = jest.fn(async (ac: AbortController) => { + abortController = ac; + return connectionType; + }); + await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '1', fallBackHandler, - cleanUpHandler, + wasUsed: jest.fn(), }); // ACT jest.advanceTimersByTime(ttl + cleanUpInterval * 2); // ASSERT - expect(cleanUpHandler).toHaveBeenCalledTimes(1); + if (abortController === undefined) { + fail("abortController haven't been initialized"); + } + expect(abortController.signal.aborted).toBe(true); + }); + + test('throws OperationsError if the fallBackHandler aborts during connection initialization', async () => { + // ARRANGE + const connectionType = {}; + const fallBackHandler = jest.fn(async (ac: AbortController) => { + ac.abort(); + return connectionType; + }); + + // ACT + const connectionPromise = cpm.getConnection({ + credentials: {}, + nodeType: 'example', + nodeVersion: '1', + fallBackHandler, + wasUsed: jest.fn(), + }); + + // ASSERT + + await expect(connectionPromise).rejects.toThrow(OperationalError); + await expect(connectionPromise).rejects.toThrow( + 'Could not create pool. Connection attempt was aborted.', + ); }); }); @@ -114,66 +154,115 @@ describe('onShutdown', () => { test('calls all clean up handlers', async () => { // ARRANGE const connectionType1 = {}; - const fallBackHandler1 = jest.fn().mockResolvedValue(connectionType1); - const cleanUpHandler1 = jest.fn(); - await cpm.getConnection({ + let abortController1: AbortController | undefined; + const fallBackHandler1 = jest.fn(async (ac: AbortController) => { + abortController1 = ac; + return connectionType1; + }); + await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '1', fallBackHandler: fallBackHandler1, - cleanUpHandler: cleanUpHandler1, + wasUsed: jest.fn(), }); const connectionType2 = {}; - const fallBackHandler2 = jest.fn().mockResolvedValue(connectionType2); - const cleanUpHandler2 = jest.fn(); - await cpm.getConnection({ + let abortController2: AbortController | undefined; + const fallBackHandler2 = jest.fn(async (ac: AbortController) => { + abortController2 = ac; + return connectionType2; + }); + await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '2', fallBackHandler: fallBackHandler2, - cleanUpHandler: cleanUpHandler2, + wasUsed: jest.fn(), }); - // ACT 1 - cpm.onShutdown(); + // ACT + cpm.purgeConnections(); // ASSERT - expect(cleanUpHandler1).toHaveBeenCalledTimes(1); - expect(cleanUpHandler2).toHaveBeenCalledTimes(1); + if (abortController1 === undefined || abortController2 === undefined) { + fail("abortController haven't been initialized"); + } + expect(abortController1.signal.aborted).toBe(true); + expect(abortController2.signal.aborted).toBe(true); }); test('calls all clean up handlers when `exit` is emitted on process', async () => { // ARRANGE const connectionType1 = {}; - const fallBackHandler1 = jest.fn().mockResolvedValue(connectionType1); - const cleanUpHandler1 = jest.fn(); - await cpm.getConnection({ + let abortController1: AbortController | undefined; + const fallBackHandler1 = jest.fn(async (ac: AbortController) => { + abortController1 = ac; + return connectionType1; + }); + await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '1', fallBackHandler: fallBackHandler1, - cleanUpHandler: cleanUpHandler1, + wasUsed: jest.fn(), }); const connectionType2 = {}; - const fallBackHandler2 = jest.fn().mockResolvedValue(connectionType2); - const cleanUpHandler2 = jest.fn(); - await cpm.getConnection({ + let abortController2: AbortController | undefined; + const fallBackHandler2 = jest.fn(async (ac: AbortController) => { + abortController2 = ac; + return connectionType2; + }); + await cpm.getConnection({ credentials: {}, nodeType: 'example', nodeVersion: '2', fallBackHandler: fallBackHandler2, - cleanUpHandler: cleanUpHandler2, + wasUsed: jest.fn(), }); - // ACT 1 + // ACT // @ts-expect-error we're not supposed to emit `exit` so it's missing from // the type definition process.emit('exit'); // ASSERT - expect(cleanUpHandler1).toHaveBeenCalledTimes(1); - expect(cleanUpHandler2).toHaveBeenCalledTimes(1); + if (abortController1 === undefined || abortController2 === undefined) { + fail("abortController haven't been initialized"); + } + expect(abortController1.signal.aborted).toBe(true); + expect(abortController2.signal.aborted).toBe(true); + }); +}); + +describe('wasUsed', () => { + test('is called for every successive `getConnection` call', async () => { + // ARRANGE + const connectionType = {}; + const fallBackHandler = jest.fn(async () => { + return connectionType; + }); + + const wasUsed = jest.fn(); + const options = { + credentials: {}, + nodeType: 'example', + nodeVersion: '1', + fallBackHandler, + wasUsed, + }; + + // ACT 1 + await cpm.getConnection(options); + + // ASSERT 1 + expect(wasUsed).toHaveBeenCalledTimes(0); + + // ACT 2 + await cpm.getConnection(options); + + // ASSERT 2 + expect(wasUsed).toHaveBeenCalledTimes(1); }); }); diff --git a/packages/nodes-base/utils/connection-pool-manager.ts b/packages/nodes-base/utils/connection-pool-manager.ts index 9cdf86b0b4..635c29362b 100644 --- a/packages/nodes-base/utils/connection-pool-manager.ts +++ b/packages/nodes-base/utils/connection-pool-manager.ts @@ -1,4 +1,5 @@ import { createHash } from 'crypto'; +import { OperationalError, type Logger } from 'n8n-workflow'; let instance: ConnectionPoolManager; @@ -15,19 +16,23 @@ type RegistrationOptions = { }; type GetConnectionOption = RegistrationOptions & { - /** When a node requests for a connection pool, but none is available, this handler is called to create new instance of the pool, which then cached and re-used until it goes stale. */ - fallBackHandler: () => Promise; + /** + * When a node requests for a connection pool, but none is available, this + * handler is called to create new instance of the pool, which is then cached + * and re-used until it goes stale. + */ + fallBackHandler: (abortController: AbortController) => Promise; - /** When a pool hasn't been used in a while, or when the server is shutting down, this handler is invoked to close the pool */ - cleanUpHandler: (pool: Pool) => Promise; + wasUsed: (pool: Pool) => void; }; type Registration = { /** This is an instance of a Connection Pool class, that gets reused across multiple executions */ pool: Pool; - /** @see GetConnectionOption['closeHandler'] */ - cleanUpHandler: (pool: Pool) => Promise; + abortController: AbortController; + + wasUsed: (pool: Pool) => void; /** We keep this timestamp to check if a pool hasn't been used in a while, and if it needs to be closed */ lastUsed: number; @@ -38,9 +43,9 @@ export class ConnectionPoolManager { * Gets the singleton instance of the ConnectionPoolManager. * Creates a new instance if one doesn't exist. */ - static getInstance(): ConnectionPoolManager { + static getInstance(logger: Logger): ConnectionPoolManager { if (!instance) { - instance = new ConnectionPoolManager(); + instance = new ConnectionPoolManager(logger); } return instance; } @@ -51,9 +56,12 @@ export class ConnectionPoolManager { * Private constructor that initializes the connection pool manager. * Sets up cleanup handlers for process exit and stale connections. */ - private constructor() { + private constructor(private readonly logger: Logger) { // Close all open pools when the process exits - process.on('exit', () => this.onShutdown()); + process.on('exit', () => { + this.logger.debug('ConnectionPoolManager: Shutting down. Cleaning up all pools'); + this.purgeConnections(); + }); // Regularly close stale pools setInterval(() => this.cleanupStaleConnections(), cleanUpInterval); @@ -84,54 +92,67 @@ export class ConnectionPoolManager { const key = this.makeKey(options); let value = this.map.get(key); - if (!value) { - value = { - pool: await options.fallBackHandler(), - cleanUpHandler: options.cleanUpHandler, - } as Registration; + + if (value) { + value.lastUsed = Date.now(); + value.wasUsed(value.pool); + return value.pool as T; + } + + const abortController = new AbortController(); + value = { + pool: await options.fallBackHandler(abortController), + abortController, + wasUsed: options.wasUsed, + } as Registration; + + // It's possible that `options.fallBackHandler` already called the abort + // function. If that's the case let's not continue. + if (abortController.signal.aborted) { + throw new OperationalError('Could not create pool. Connection attempt was aborted.', { + cause: abortController.signal.reason, + }); } this.map.set(key, { ...value, lastUsed: Date.now() }); + abortController.signal.addEventListener('abort', async () => { + this.logger.debug('ConnectionPoolManager: Got abort signal, cleaning up pool.'); + this.cleanupConnection(key); + }); + return value.pool as T; } + private cleanupConnection(key: string) { + const registration = this.map.get(key); + + if (registration) { + this.map.delete(key); + registration.abortController.abort(); + } + } + /** * Removes and cleans up connection pools that haven't been used within the * TTL. */ private cleanupStaleConnections() { const now = Date.now(); - for (const [key, { cleanUpHandler, lastUsed, pool }] of this.map.entries()) { + for (const [key, { lastUsed }] of this.map.entries()) { if (now - lastUsed > ttl) { - void cleanUpHandler(pool); - this.map.delete(key); + this.logger.debug('ConnectionPoolManager: Found stale pool. Cleaning it up.'); + void this.cleanupConnection(key); } } } /** * Removes and cleans up all existing connection pools. + * Connections are closed in the background. */ - async purgeConnections(): Promise { - await Promise.all( - [...this.map.entries()].map(async ([key, value]) => { - this.map.delete(key); - - return await value.cleanUpHandler(value.pool); - }), - ); - } - - /** - * Cleans up all connection pools when the process is shutting down. - * Does not wait for cleanup promises to resolve also does not remove the - * references from the pool. - * - * Only call this on process shutdown. - */ - onShutdown() { - for (const { cleanUpHandler, pool } of this.map.values()) { - void cleanUpHandler(pool); + purgeConnections(): void { + for (const key of this.map.keys()) { + this.cleanupConnection(key); } } } diff --git a/packages/workflow/src/interfaces.ts b/packages/workflow/src/interfaces.ts index 0cd7cd6f43..ef9b94500d 100644 --- a/packages/workflow/src/interfaces.ts +++ b/packages/workflow/src/interfaces.ts @@ -829,7 +829,8 @@ export type SSHCredentials = { ); export interface SSHTunnelFunctions { - getSSHClient(credentials: SSHCredentials): Promise; + getSSHClient(credentials: SSHCredentials, abortController?: AbortController): Promise; + updateLastUsed(client: SSHClient): void; } type CronUnit = number | '*' | `*/${number}`;