fix: Postgres node with ssh tunnel getting into a broken state and not being recreated (#16054)

This commit is contained in:
Danny Martini
2025-06-13 14:38:21 +02:00
committed by GitHub
parent 80a784a50c
commit 879114b572
9 changed files with 606 additions and 199 deletions

View File

@@ -18,6 +18,7 @@ export const LOG_SCOPES = [
'task-runner', 'task-runner',
'insights', 'insights',
'workflow-activation', 'workflow-activation',
'ssh-client',
] as const; ] as const;
export type LogScope = (typeof LOG_SCOPES)[number]; export type LogScope = (typeof LOG_SCOPES)[number];

View File

@@ -1,34 +1,43 @@
import type { Logger } from '@n8n/backend-common';
import { Container } from '@n8n/di';
import { mock } from 'jest-mock-extended'; import { mock } from 'jest-mock-extended';
import type { SSHCredentials } from 'n8n-workflow'; import type { SSHCredentials } from 'n8n-workflow';
import { Client } from 'ssh2'; import { Client } from 'ssh2';
import { SSHClientsManager } from '../ssh-clients-manager'; import { SSHClientsConfig, SSHClientsManager } from '../ssh-clients-manager';
describe('SSHClientsManager', () => { const idleTimeout = 5 * 60;
const idleTimeout = 5 * 60; const cleanUpInterval = 60;
const credentials: SSHCredentials = { const credentials: SSHCredentials = {
sshAuthenticateWith: 'password', sshAuthenticateWith: 'password',
sshHost: 'example.com', sshHost: 'example.com',
sshPort: 22, sshPort: 22,
sshUser: 'username', sshUser: 'username',
sshPassword: 'password', sshPassword: 'password',
}; };
let sshClientsManager: SSHClientsManager; let sshClientsManager: SSHClientsManager;
const connectSpy = jest.spyOn(Client.prototype, 'connect'); const connectSpy = jest.spyOn(Client.prototype, 'connect');
const endSpy = jest.spyOn(Client.prototype, 'end'); const endSpy = jest.spyOn(Client.prototype, 'end');
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
jest.useFakeTimers();
sshClientsManager = new SSHClientsManager(mock({ idleTimeout })); sshClientsManager = new SSHClientsManager(
connectSpy.mockImplementation(function (this: Client) { mock({ idleTimeout }),
this.emit('ready'); mock<Logger>({ scoped: () => mock<Logger>() }),
return this; );
}); connectSpy.mockImplementation(function (this: Client) {
this.emit('ready');
return this;
}); });
});
afterEach(() => {
sshClientsManager.onShutdown();
});
describe('getClient', () => {
it('should create a new SSH client', async () => { it('should create a new SSH client', async () => {
const client = await sshClientsManager.getClient(credentials); const client = await sshClientsManager.getClient(credentials);
@@ -49,22 +58,180 @@ describe('SSHClientsManager', () => {
expect(client1).toBe(client2); expect(client1).toBe(client2);
}); });
it('should close all SSH connections on process exit', async () => { it('should not create multiple clients for the same credentials in parallel', async () => {
// ARRANGE
connectSpy.mockImplementation(function (this: Client) {
setTimeout(() => this.emit('ready'), Math.random() * 10);
return this;
});
// ACT
const clients = await Promise.all([
sshClientsManager.getClient(credentials),
sshClientsManager.getClient(credentials),
sshClientsManager.getClient(credentials),
sshClientsManager.getClient(credentials),
sshClientsManager.getClient(credentials),
sshClientsManager.getClient(credentials),
]);
// ASSERT
// returns the same client for all invocations
const ogClient = await sshClientsManager.getClient(credentials);
expect(clients).toHaveLength(6);
for (const client of clients) {
expect(client).toBe(ogClient);
}
expect(connectSpy).toHaveBeenCalledTimes(1);
});
});
describe('onShutdown', () => {
it('should close all SSH connections when onShutdown is called', async () => {
await sshClientsManager.getClient(credentials); await sshClientsManager.getClient(credentials);
sshClientsManager.onShutdown(); sshClientsManager.onShutdown();
expect(endSpy).toHaveBeenCalledTimes(1); expect(endSpy).toHaveBeenCalledTimes(1);
}); });
it('should close all SSH connections on process exit', async () => {
// ARRANGE
await sshClientsManager.getClient(credentials);
// ACT
// @ts-expect-error we're not supposed to emit `exit` so it's missing from
// the type definition
process.emit('exit');
// ASSERT
expect(endSpy).toHaveBeenCalledTimes(1);
});
});
describe('cleanup', () => {
beforeEach(async () => {
jest.useFakeTimers();
sshClientsManager = new SSHClientsManager(
mock({ idleTimeout }),
mock<Logger>({ scoped: () => mock<Logger>() }),
);
});
it('should cleanup stale SSH connections', async () => { it('should cleanup stale SSH connections', async () => {
await sshClientsManager.getClient({ ...credentials, sshHost: 'host1' }); await sshClientsManager.getClient({ ...credentials, sshHost: 'host1' });
await sshClientsManager.getClient({ ...credentials, sshHost: 'host2' }); await sshClientsManager.getClient({ ...credentials, sshHost: 'host2' });
await sshClientsManager.getClient({ ...credentials, sshHost: 'host3' }); await sshClientsManager.getClient({ ...credentials, sshHost: 'host3' });
jest.advanceTimersByTime((idleTimeout + 1) * 1000); jest.advanceTimersByTime((idleTimeout + cleanUpInterval + 1) * 1000);
sshClientsManager.cleanupStaleConnections();
expect(endSpy).toHaveBeenCalledTimes(3); expect(endSpy).toHaveBeenCalledTimes(3);
expect(sshClientsManager.clients.size).toBe(0); expect(sshClientsManager.clients.size).toBe(0);
}); });
describe('updateLastUsed', () => {
test('updates lastUsed in the registration', async () => {
// ARRANGE
const client = await sshClientsManager.getClient(credentials);
// schedule client for clean up soon
jest.advanceTimersByTime((idleTimeout - 1) * 1000);
// ACT 1
// updating lastUsed should prevent the clean up
sshClientsManager.updateLastUsed(client);
jest.advanceTimersByTime(idleTimeout * 1000);
// ASSERT 1
expect(endSpy).toHaveBeenCalledTimes(0);
// ACT 1
jest.advanceTimersByTime(cleanUpInterval * 1000);
// ASSERT 1
expect(endSpy).toHaveBeenCalledTimes(1);
});
});
});
describe('abort controller', () => {
test('call `abort` when the client emits `error`', async () => {
// ARRANGE
const abortController = new AbortController();
const client = await sshClientsManager.getClient(credentials, abortController);
// ACT 1
client.emit('error', new Error());
// ASSERT 1
expect(abortController.signal.aborted).toBe(true);
expect(endSpy).toHaveBeenCalledTimes(1);
});
test('call `abort` when the client emits `end`', async () => {
// ARRANGE
const abortController = new AbortController();
const client = await sshClientsManager.getClient(credentials, abortController);
// ACT 1
client.emit('end');
// ASSERT 1
expect(abortController.signal.aborted).toBe(true);
expect(endSpy).toHaveBeenCalledTimes(1);
});
test('call `abort` when the client emits `close`', async () => {
// ARRANGE
const abortController = new AbortController();
const client = await sshClientsManager.getClient(credentials, abortController);
// ACT 1
client.emit('close');
// ASSERT 1
expect(abortController.signal.aborted).toBe(true);
expect(endSpy).toHaveBeenCalledTimes(1);
});
test('closes client when `abort` is being called', async () => {
// ARRANGE
const abortController = new AbortController();
await sshClientsManager.getClient(credentials, abortController);
// ACT 1
abortController.abort();
// ASSERT 1
expect(endSpy).toHaveBeenCalledTimes(1);
});
});
describe('SSHClientsConfig', () => {
beforeEach(() => {
Container.reset();
});
test('allows overriding the default idle timeout', async () => {
// ARRANGE
process.env.N8N_SSH_TUNNEL_IDLE_TIMEOUT = '5';
// ACT
const config = Container.get(SSHClientsConfig);
// ASSERT
expect(config.idleTimeout).toBe(5);
});
test.each(['-5', '0', 'foo'])(
'fall back to default if N8N_SSH_TUNNEL_IDLE_TIMEOUT is `%s`',
async (value) => {
// ARRANGE
process.env.N8N_SSH_TUNNEL_IDLE_TIMEOUT = value;
// ACT
const config = Container.get(SSHClientsConfig);
// ASSERT
expect(config.idleTimeout).toBe(300);
},
);
}); });

View File

@@ -7,6 +7,7 @@ import { SSHClientsManager } from '../../../ssh-clients-manager';
import { getSSHTunnelFunctions } from '../ssh-tunnel-helper-functions'; import { getSSHTunnelFunctions } from '../ssh-tunnel-helper-functions';
describe('getSSHTunnelFunctions', () => { describe('getSSHTunnelFunctions', () => {
const abortController = new AbortController();
const credentials = mock<SSHCredentials>(); const credentials = mock<SSHCredentials>();
const sshClientsManager = mockInstance(SSHClientsManager); const sshClientsManager = mockInstance(SSHClientsManager);
const sshTunnelFunctions = getSSHTunnelFunctions(); const sshTunnelFunctions = getSSHTunnelFunctions();
@@ -17,9 +18,22 @@ describe('getSSHTunnelFunctions', () => {
describe('getSSHClient', () => { describe('getSSHClient', () => {
it('should invoke sshClientsManager.getClient', async () => { it('should invoke sshClientsManager.getClient', async () => {
await sshTunnelFunctions.getSSHClient(credentials); await sshTunnelFunctions.getSSHClient(credentials, abortController);
expect(sshClientsManager.getClient).toHaveBeenCalledWith(credentials); expect(sshClientsManager.getClient).toHaveBeenCalledWith(credentials, abortController);
});
});
describe('updateLastUsed', () => {
it('should invoke sshClientsManager.updateLastUsed', async () => {
// ARRANGE
const client = await sshTunnelFunctions.getSSHClient(credentials, abortController);
// ACT
sshTunnelFunctions.updateLastUsed(client);
// ASSERT
expect(sshClientsManager.updateLastUsed).toHaveBeenCalledWith(client);
}); });
}); });
}); });

View File

@@ -1,11 +1,14 @@
import { Container } from '@n8n/di'; import { Container } from '@n8n/di';
import type { SSHTunnelFunctions } from 'n8n-workflow'; import type { SSHTunnelFunctions } from 'n8n-workflow';
import type { Client } from 'ssh2';
import { SSHClientsManager } from '../../ssh-clients-manager'; import { SSHClientsManager } from '../../ssh-clients-manager';
export const getSSHTunnelFunctions = (): SSHTunnelFunctions => { export const getSSHTunnelFunctions = (): SSHTunnelFunctions => {
const sshClientsManager = Container.get(SSHClientsManager); const sshClientsManager = Container.get(SSHClientsManager);
return { return {
getSSHClient: async (credentials) => await sshClientsManager.getClient(credentials), getSSHClient: async (credentials, abortController) =>
await sshClientsManager.getClient(credentials, abortController),
updateLastUsed: (client: Client) => sshClientsManager.updateLastUsed(client),
}; };
}; };

View File

@@ -1,31 +1,99 @@
import { Logger } from '@n8n/backend-common';
import { Config, Env } from '@n8n/config'; import { Config, Env } from '@n8n/config';
import { Service } from '@n8n/di'; import { Service } from '@n8n/di';
import type { SSHCredentials } from 'n8n-workflow'; import type { SSHCredentials } from 'n8n-workflow';
import { createHash } from 'node:crypto'; import { createHash } from 'node:crypto';
import { Client, type ConnectConfig } from 'ssh2'; import { Client, type ConnectConfig } from 'ssh2';
import { z } from 'zod';
@Config @Config
class SSHClientsConfig { export class SSHClientsConfig {
/** How many seconds before an idle SSH tunnel is closed */ /** How many seconds before an idle SSH tunnel is closed */
@Env('N8N_SSH_TUNNEL_IDLE_TIMEOUT') @Env(
'N8N_SSH_TUNNEL_IDLE_TIMEOUT',
z
.string()
.transform((value) => Number.parseInt(value))
.superRefine((value, ctx) => {
if (Number.isNaN(value)) {
return ctx.addIssue({
message: 'must be a valid integer',
code: 'custom',
});
}
if (value <= 0) {
return ctx.addIssue({
message: 'must be positive',
code: 'too_small',
minimum: 0,
inclusive: false,
type: 'number',
});
}
}),
)
idleTimeout: number = 5 * 60; idleTimeout: number = 5 * 60;
} }
type Registration = {
client: Client;
/**
* We keep this timestamp to check if a client hasn't been used in a while,
* and if it needs to be closed.
*/
lastUsed: Date;
abortController: AbortController;
returnPromise: Promise<Client>;
};
@Service() @Service()
export class SSHClientsManager { export class SSHClientsManager {
readonly clients = new Map<string, { client: Client; lastUsed: Date }>(); readonly clients = new Map<string, Registration>();
constructor(private readonly config: SSHClientsConfig) { readonly clientsReversed = new WeakMap<Client, string>();
private cleanupTimer: NodeJS.Timeout;
constructor(
private readonly config: SSHClientsConfig,
private readonly logger: Logger,
) {
// Close all SSH connections when the process exits // Close all SSH connections when the process exits
process.on('exit', () => this.onShutdown()); process.on('exit', () => this.onShutdown());
if (process.env.NODE_ENV === 'test') return;
// Regularly close stale SSH connections // Regularly close stale SSH connections
setInterval(() => this.cleanupStaleConnections(), 60 * 1000); this.cleanupTimer = setInterval(() => this.cleanupStaleConnections(), 60 * 1000);
this.logger = logger.scoped('ssh-client');
} }
async getClient(credentials: SSHCredentials): Promise<Client> { updateLastUsed(client: Client) {
const key = this.clientsReversed.get(client);
if (key) {
const registration = this.clients.get(key);
if (registration) {
registration.lastUsed = new Date();
}
} else {
const metadata = {};
// eslint-disable-next-line @typescript-eslint/unbound-method
Error.captureStackTrace(metadata, this.updateLastUsed);
this.logger.warn(
'Tried to update `lastUsed` for a client that has been cleaned up already. Probably forgot to subscribe to the AbortController somewhere.',
metadata,
);
}
}
async getClient(credentials: SSHCredentials, abortController?: AbortController): Promise<Client> {
abortController = abortController ?? new AbortController();
const { sshAuthenticateWith, sshHost, sshPort, sshUser } = credentials; const { sshAuthenticateWith, sshHost, sshPort, sshUser } = credentials;
const sshConfig: ConnectConfig = { const sshConfig: ConnectConfig = {
host: sshHost, host: sshHost,
@@ -44,28 +112,73 @@ export class SSHClientsManager {
const existing = this.clients.get(clientHash); const existing = this.clients.get(clientHash);
if (existing) { if (existing) {
existing.lastUsed = new Date(); existing.lastUsed = new Date();
return existing.client; return await existing.returnPromise;
} }
return await new Promise((resolve, reject) => { const sshClient = this.withCleanupHandler(new Client(), abortController, clientHash);
const sshClient = new Client(); const returnPromise = new Promise<Client>((resolve, reject) => {
sshClient.once('error', reject); sshClient.once('error', reject);
sshClient.once('ready', () => { sshClient.once('ready', () => {
sshClient.off('error', reject); sshClient.off('error', reject);
sshClient.once('close', () => this.clients.delete(clientHash));
this.clients.set(clientHash, {
client: sshClient,
lastUsed: new Date(),
});
resolve(sshClient); resolve(sshClient);
}); });
sshClient.connect(sshConfig); sshClient.connect(sshConfig);
}); });
this.clients.set(clientHash, {
client: sshClient,
lastUsed: new Date(),
abortController,
returnPromise,
});
this.clientsReversed.set(sshClient, clientHash);
return await returnPromise;
}
/**
* Registers the cleanup handler for events (error, close, end) on the ssh
* client and in the abort signal is received.
*/
private withCleanupHandler(sshClient: Client, abortController: AbortController, key: string) {
sshClient.on('error', (error) => {
this.logger.error('encountered error, calling cleanup', { error });
this.cleanupClient(key);
});
sshClient.on('end', () => {
this.logger.debug('socket was disconnected, calling abort signal', {});
this.cleanupClient(key);
});
sshClient.on('close', () => {
this.logger.debug('socket was closed, calling abort signal', {});
this.cleanupClient(key);
});
abortController.signal.addEventListener('abort', () => {
this.logger.debug('Got abort signal, cleaning up ssh client.', {
reason: abortController.signal.reason,
});
this.cleanupClient(key);
});
return sshClient;
}
private cleanupClient(key: string) {
const registration = this.clients.get(key);
if (registration) {
this.clients.delete(key);
registration.client.end();
if (!registration.abortController.signal.aborted) {
registration.abortController.abort();
}
}
} }
onShutdown() { onShutdown() {
for (const { client } of this.clients.values()) { this.logger.debug('Shutting down. Cleaning up all clients');
client.end(); clearInterval(this.cleanupTimer);
for (const key of this.clients.keys()) {
this.cleanupClient(key);
} }
} }
@@ -74,10 +187,10 @@ export class SSHClientsManager {
if (clients.size === 0) return; if (clients.size === 0) return;
const now = Date.now(); const now = Date.now();
for (const [hash, { client, lastUsed }] of clients.entries()) { for (const [key, { lastUsed }] of clients.entries()) {
if (now - lastUsed.getTime() > this.config.idleTimeout * 1000) { if (now - lastUsed.getTime() > this.config.idleTimeout * 1000) {
client.end(); this.logger.debug('Found stale client. Cleaning it up.');
clients.delete(hash); this.cleanupClient(key);
} }
} }
} }

View File

@@ -3,8 +3,9 @@ import type {
ICredentialTestFunctions, ICredentialTestFunctions,
ILoadOptionsFunctions, ILoadOptionsFunctions,
ITriggerFunctions, ITriggerFunctions,
Logger,
} from 'n8n-workflow'; } from 'n8n-workflow';
import { createServer, type AddressInfo } from 'node:net'; import { createServer, type AddressInfo, type Server } from 'node:net';
import pgPromise from 'pg-promise'; import pgPromise from 'pg-promise';
import { ConnectionPoolManager } from '@utils/connection-pool-manager'; import { ConnectionPoolManager } from '@utils/connection-pool-manager';
@@ -53,14 +54,37 @@ const getPostgresConfig = (
return dbConfig; return dbConfig;
}; };
function withCleanupHandler(proxy: Server, abortController: AbortController, logger: Logger) {
proxy.on('error', (error) => {
logger.error('TCP Proxy: Got error, calling abort controller', { error });
abortController.abort();
});
proxy.on('close', () => {
logger.error('TCP Proxy: Was closed, calling abort controller');
abortController.abort();
});
proxy.on('drop', (dropArgument) => {
logger.error('TCP Proxy: Connection was dropped, calling abort controller', {
dropArgument,
});
abortController.abort();
});
abortController.signal.addEventListener('abort', () => {
logger.debug('Got abort signal. Closing TCP proxy server.');
proxy.close();
});
return proxy;
}
export async function configurePostgres( export async function configurePostgres(
this: IExecuteFunctions | ICredentialTestFunctions | ILoadOptionsFunctions | ITriggerFunctions, this: IExecuteFunctions | ICredentialTestFunctions | ILoadOptionsFunctions | ITriggerFunctions,
credentials: PostgresNodeCredentials, credentials: PostgresNodeCredentials,
options: PostgresNodeOptions = {}, options: PostgresNodeOptions = {},
): Promise<ConnectionsData> { ): Promise<ConnectionsData> {
const poolManager = ConnectionPoolManager.getInstance(); const poolManager = ConnectionPoolManager.getInstance(this.logger);
const fallBackHandler = async () => { const fallBackHandler = async (abortController: AbortController) => {
const pgp = pgPromise({ const pgp = pgPromise({
// prevent spam in console "WARNING: Creating a duplicate database object for the same connection." // prevent spam in console "WARNING: Creating a duplicate database object for the same connection."
// duplicate connections created when auto loading parameters, they are closed immediately after, but several could be open at the same time // duplicate connections created when auto loading parameters, they are closed immediately after, but several could be open at the same time
@@ -101,74 +125,33 @@ export async function configurePostgres(
if (credentials.sshAuthenticateWith === 'privateKey' && credentials.privateKey) { if (credentials.sshAuthenticateWith === 'privateKey' && credentials.privateKey) {
credentials.privateKey = formatPrivateKey(credentials.privateKey); credentials.privateKey = formatPrivateKey(credentials.privateKey);
} }
const sshClient = await this.helpers.getSSHClient(credentials); const sshClient = await this.helpers.getSSHClient(credentials, abortController);
// Create a TCP proxy listening on a random available port // Create a TCP proxy listening on a random available port
const proxy = createServer(); const proxy = withCleanupHandler(createServer(), abortController, this.logger);
const proxyPort = await new Promise<number>((resolve) => { const proxyPort = await new Promise<number>((resolve) => {
proxy.listen(0, LOCALHOST, () => { proxy.listen(0, LOCALHOST, () => {
resolve((proxy.address() as AddressInfo).port); resolve((proxy.address() as AddressInfo).port);
}); });
}); });
const close = () => { proxy.on('connection', (localSocket) => {
proxy.close(); sshClient.forwardOut(
sshClient.off('end', close); LOCALHOST,
sshClient.off('error', close); localSocket.remotePort!,
}; credentials.host,
sshClient.on('end', close); credentials.port,
sshClient.on('error', close); (error, clientChannel) => {
if (error) {
await new Promise<void>((resolve, reject) => { this.logger.error('SSH Client: Port forwarding encountered an error', { error });
proxy.on('error', (err) => reject(err)); abortController.abort();
proxy.on('connection', (localSocket) => { } else {
sshClient.forwardOut( localSocket.pipe(clientChannel);
LOCALHOST, clientChannel.pipe(localSocket);
localSocket.remotePort!, }
credentials.host, },
credentials.port, );
(err, clientChannel) => {
if (err) {
proxy.close();
localSocket.destroy();
} else {
localSocket.pipe(clientChannel);
clientChannel.pipe(localSocket);
}
},
);
});
resolve();
}).catch((err) => {
proxy.close();
let message = err.message;
let description = err.description;
if (err.message.includes('ECONNREFUSED')) {
message = 'Connection refused';
try {
description = err.message.split('ECONNREFUSED ')[1].trim();
} catch (e) {}
}
if (err.message.includes('ENOTFOUND')) {
message = 'Host not found';
try {
description = err.message.split('ENOTFOUND ')[1].trim();
} catch (e) {}
}
if (err.message.includes('ETIMEDOUT')) {
message = 'Connection timed out';
try {
description = err.message.split('ETIMEDOUT ')[1].trim();
} catch (e) {}
}
err.message = message;
err.description = description;
throw err;
}); });
const db = pgp({ const db = pgp({
@@ -176,7 +159,20 @@ export async function configurePostgres(
port: proxyPort, port: proxyPort,
host: LOCALHOST, host: LOCALHOST,
}); });
return { db, pgp };
abortController.signal.addEventListener('abort', async () => {
this.logger.debug('configurePostgres: Got abort signal, closing pg connection.');
try {
if (!db.$pool.ended) await db.$pool.end();
} catch (error) {
this.logger.error('configurePostgres: Encountered error while closing the pool.', {
error,
});
throw error;
}
});
return { db, pgp, sshClient };
} }
}; };
@@ -185,8 +181,10 @@ export async function configurePostgres(
nodeType: 'postgres', nodeType: 'postgres',
nodeVersion: options.nodeVersion as unknown as string, nodeVersion: options.nodeVersion as unknown as string,
fallBackHandler, fallBackHandler,
cleanUpHandler: async ({ db }) => { wasUsed: ({ sshClient }) => {
if (!db.$pool.ended) await db.$pool.end(); if (sshClient) {
this.helpers.updateLastUsed(sshClient);
}
}, },
}); });
} }

View File

@@ -1,26 +1,31 @@
import { mock } from 'jest-mock-extended';
import { OperationalError, type Logger } from 'n8n-workflow';
import { ConnectionPoolManager } from '@utils/connection-pool-manager'; import { ConnectionPoolManager } from '@utils/connection-pool-manager';
const ttl = 5 * 60 * 1000; const ttl = 5 * 60 * 1000;
const cleanUpInterval = 60 * 1000; const cleanUpInterval = 60 * 1000;
const logger = mock<Logger>();
let cpm: ConnectionPoolManager; let cpm: ConnectionPoolManager;
beforeAll(() => { beforeAll(() => {
jest.useFakeTimers(); jest.useFakeTimers();
cpm = ConnectionPoolManager.getInstance(); cpm = ConnectionPoolManager.getInstance(logger);
}); });
beforeEach(async () => { beforeEach(async () => {
await cpm.purgeConnections(); cpm.purgeConnections();
}); });
afterAll(() => { afterAll(() => {
cpm.onShutdown(); cpm.purgeConnections();
}); });
test('getInstance returns a singleton', () => { test('getInstance returns a singleton', () => {
const instance1 = ConnectionPoolManager.getInstance(); const instance1 = ConnectionPoolManager.getInstance(logger);
const instance2 = ConnectionPoolManager.getInstance(); const instance2 = ConnectionPoolManager.getInstance(logger);
expect(instance1).toBe(instance2); expect(instance1).toBe(instance2);
}); });
@@ -29,25 +34,27 @@ describe('getConnection', () => {
test('calls fallBackHandler only once and returns the first value', async () => { test('calls fallBackHandler only once and returns the first value', async () => {
// ARRANGE // ARRANGE
const connectionType = {}; const connectionType = {};
const fallBackHandler = jest.fn().mockResolvedValue(connectionType); const fallBackHandler = jest.fn(async () => {
const cleanUpHandler = jest.fn(); return connectionType;
});
const options = { const options = {
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '1', nodeVersion: '1',
fallBackHandler, fallBackHandler,
cleanUpHandler, wasUsed: jest.fn(),
}; };
// ACT 1 // ACT 1
const connection = await cpm.getConnection<string>(options); const connection = await cpm.getConnection(options);
// ASSERT 1 // ASSERT 1
expect(fallBackHandler).toHaveBeenCalledTimes(1); expect(fallBackHandler).toHaveBeenCalledTimes(1);
expect(connection).toBe(connectionType); expect(connection).toBe(connectionType);
// ACT 2 // ACT 2
const connection2 = await cpm.getConnection<string>(options); const connection2 = await cpm.getConnection(options);
// ASSERT 2 // ASSERT 2
expect(fallBackHandler).toHaveBeenCalledTimes(1); expect(fallBackHandler).toHaveBeenCalledTimes(1);
expect(connection2).toBe(connectionType); expect(connection2).toBe(connectionType);
@@ -56,27 +63,29 @@ describe('getConnection', () => {
test('creates different pools for different node versions', async () => { test('creates different pools for different node versions', async () => {
// ARRANGE // ARRANGE
const connectionType1 = {}; const connectionType1 = {};
const fallBackHandler1 = jest.fn().mockResolvedValue(connectionType1); const fallBackHandler1 = jest.fn(async () => {
const cleanUpHandler1 = jest.fn(); return connectionType1;
});
const connectionType2 = {}; const connectionType2 = {};
const fallBackHandler2 = jest.fn().mockResolvedValue(connectionType2); const fallBackHandler2 = jest.fn(async () => {
const cleanUpHandler2 = jest.fn(); return connectionType2;
});
// ACT 1 // ACT 1
const connection1 = await cpm.getConnection<string>({ const connection1 = await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '1', nodeVersion: '1',
fallBackHandler: fallBackHandler1, fallBackHandler: fallBackHandler1,
cleanUpHandler: cleanUpHandler1, wasUsed: jest.fn(),
}); });
const connection2 = await cpm.getConnection<string>({ const connection2 = await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '2', nodeVersion: '2',
fallBackHandler: fallBackHandler2, fallBackHandler: fallBackHandler2,
cleanUpHandler: cleanUpHandler2, wasUsed: jest.fn(),
}); });
// ASSERT // ASSERT
@@ -92,21 +101,52 @@ describe('getConnection', () => {
test('calls cleanUpHandler after TTL expires', async () => { test('calls cleanUpHandler after TTL expires', async () => {
// ARRANGE // ARRANGE
const connectionType = {}; const connectionType = {};
const fallBackHandler = jest.fn().mockResolvedValue(connectionType); let abortController: AbortController | undefined;
const cleanUpHandler = jest.fn(); const fallBackHandler = jest.fn(async (ac: AbortController) => {
await cpm.getConnection<string>({ abortController = ac;
return connectionType;
});
await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '1', nodeVersion: '1',
fallBackHandler, fallBackHandler,
cleanUpHandler, wasUsed: jest.fn(),
}); });
// ACT // ACT
jest.advanceTimersByTime(ttl + cleanUpInterval * 2); jest.advanceTimersByTime(ttl + cleanUpInterval * 2);
// ASSERT // ASSERT
expect(cleanUpHandler).toHaveBeenCalledTimes(1); if (abortController === undefined) {
fail("abortController haven't been initialized");
}
expect(abortController.signal.aborted).toBe(true);
});
test('throws OperationsError if the fallBackHandler aborts during connection initialization', async () => {
// ARRANGE
const connectionType = {};
const fallBackHandler = jest.fn(async (ac: AbortController) => {
ac.abort();
return connectionType;
});
// ACT
const connectionPromise = cpm.getConnection({
credentials: {},
nodeType: 'example',
nodeVersion: '1',
fallBackHandler,
wasUsed: jest.fn(),
});
// ASSERT
await expect(connectionPromise).rejects.toThrow(OperationalError);
await expect(connectionPromise).rejects.toThrow(
'Could not create pool. Connection attempt was aborted.',
);
}); });
}); });
@@ -114,66 +154,115 @@ describe('onShutdown', () => {
test('calls all clean up handlers', async () => { test('calls all clean up handlers', async () => {
// ARRANGE // ARRANGE
const connectionType1 = {}; const connectionType1 = {};
const fallBackHandler1 = jest.fn().mockResolvedValue(connectionType1); let abortController1: AbortController | undefined;
const cleanUpHandler1 = jest.fn(); const fallBackHandler1 = jest.fn(async (ac: AbortController) => {
await cpm.getConnection<string>({ abortController1 = ac;
return connectionType1;
});
await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '1', nodeVersion: '1',
fallBackHandler: fallBackHandler1, fallBackHandler: fallBackHandler1,
cleanUpHandler: cleanUpHandler1, wasUsed: jest.fn(),
}); });
const connectionType2 = {}; const connectionType2 = {};
const fallBackHandler2 = jest.fn().mockResolvedValue(connectionType2); let abortController2: AbortController | undefined;
const cleanUpHandler2 = jest.fn(); const fallBackHandler2 = jest.fn(async (ac: AbortController) => {
await cpm.getConnection<string>({ abortController2 = ac;
return connectionType2;
});
await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '2', nodeVersion: '2',
fallBackHandler: fallBackHandler2, fallBackHandler: fallBackHandler2,
cleanUpHandler: cleanUpHandler2, wasUsed: jest.fn(),
}); });
// ACT 1 // ACT
cpm.onShutdown(); cpm.purgeConnections();
// ASSERT // ASSERT
expect(cleanUpHandler1).toHaveBeenCalledTimes(1); if (abortController1 === undefined || abortController2 === undefined) {
expect(cleanUpHandler2).toHaveBeenCalledTimes(1); fail("abortController haven't been initialized");
}
expect(abortController1.signal.aborted).toBe(true);
expect(abortController2.signal.aborted).toBe(true);
}); });
test('calls all clean up handlers when `exit` is emitted on process', async () => { test('calls all clean up handlers when `exit` is emitted on process', async () => {
// ARRANGE // ARRANGE
const connectionType1 = {}; const connectionType1 = {};
const fallBackHandler1 = jest.fn().mockResolvedValue(connectionType1); let abortController1: AbortController | undefined;
const cleanUpHandler1 = jest.fn(); const fallBackHandler1 = jest.fn(async (ac: AbortController) => {
await cpm.getConnection<string>({ abortController1 = ac;
return connectionType1;
});
await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '1', nodeVersion: '1',
fallBackHandler: fallBackHandler1, fallBackHandler: fallBackHandler1,
cleanUpHandler: cleanUpHandler1, wasUsed: jest.fn(),
}); });
const connectionType2 = {}; const connectionType2 = {};
const fallBackHandler2 = jest.fn().mockResolvedValue(connectionType2); let abortController2: AbortController | undefined;
const cleanUpHandler2 = jest.fn(); const fallBackHandler2 = jest.fn(async (ac: AbortController) => {
await cpm.getConnection<string>({ abortController2 = ac;
return connectionType2;
});
await cpm.getConnection({
credentials: {}, credentials: {},
nodeType: 'example', nodeType: 'example',
nodeVersion: '2', nodeVersion: '2',
fallBackHandler: fallBackHandler2, fallBackHandler: fallBackHandler2,
cleanUpHandler: cleanUpHandler2, wasUsed: jest.fn(),
}); });
// ACT 1 // ACT
// @ts-expect-error we're not supposed to emit `exit` so it's missing from // @ts-expect-error we're not supposed to emit `exit` so it's missing from
// the type definition // the type definition
process.emit('exit'); process.emit('exit');
// ASSERT // ASSERT
expect(cleanUpHandler1).toHaveBeenCalledTimes(1); if (abortController1 === undefined || abortController2 === undefined) {
expect(cleanUpHandler2).toHaveBeenCalledTimes(1); fail("abortController haven't been initialized");
}
expect(abortController1.signal.aborted).toBe(true);
expect(abortController2.signal.aborted).toBe(true);
});
});
describe('wasUsed', () => {
test('is called for every successive `getConnection` call', async () => {
// ARRANGE
const connectionType = {};
const fallBackHandler = jest.fn(async () => {
return connectionType;
});
const wasUsed = jest.fn();
const options = {
credentials: {},
nodeType: 'example',
nodeVersion: '1',
fallBackHandler,
wasUsed,
};
// ACT 1
await cpm.getConnection(options);
// ASSERT 1
expect(wasUsed).toHaveBeenCalledTimes(0);
// ACT 2
await cpm.getConnection(options);
// ASSERT 2
expect(wasUsed).toHaveBeenCalledTimes(1);
}); });
}); });

View File

@@ -1,4 +1,5 @@
import { createHash } from 'crypto'; import { createHash } from 'crypto';
import { OperationalError, type Logger } from 'n8n-workflow';
let instance: ConnectionPoolManager; let instance: ConnectionPoolManager;
@@ -15,19 +16,23 @@ type RegistrationOptions = {
}; };
type GetConnectionOption<Pool> = RegistrationOptions & { type GetConnectionOption<Pool> = RegistrationOptions & {
/** When a node requests for a connection pool, but none is available, this handler is called to create new instance of the pool, which then cached and re-used until it goes stale. */ /**
fallBackHandler: () => Promise<Pool>; * When a node requests for a connection pool, but none is available, this
* handler is called to create new instance of the pool, which is then cached
* and re-used until it goes stale.
*/
fallBackHandler: (abortController: AbortController) => Promise<Pool>;
/** When a pool hasn't been used in a while, or when the server is shutting down, this handler is invoked to close the pool */ wasUsed: (pool: Pool) => void;
cleanUpHandler: (pool: Pool) => Promise<void>;
}; };
type Registration<Pool> = { type Registration<Pool> = {
/** This is an instance of a Connection Pool class, that gets reused across multiple executions */ /** This is an instance of a Connection Pool class, that gets reused across multiple executions */
pool: Pool; pool: Pool;
/** @see GetConnectionOption['closeHandler'] */ abortController: AbortController;
cleanUpHandler: (pool: Pool) => Promise<void>;
wasUsed: (pool: Pool) => void;
/** We keep this timestamp to check if a pool hasn't been used in a while, and if it needs to be closed */ /** We keep this timestamp to check if a pool hasn't been used in a while, and if it needs to be closed */
lastUsed: number; lastUsed: number;
@@ -38,9 +43,9 @@ export class ConnectionPoolManager {
* Gets the singleton instance of the ConnectionPoolManager. * Gets the singleton instance of the ConnectionPoolManager.
* Creates a new instance if one doesn't exist. * Creates a new instance if one doesn't exist.
*/ */
static getInstance(): ConnectionPoolManager { static getInstance(logger: Logger): ConnectionPoolManager {
if (!instance) { if (!instance) {
instance = new ConnectionPoolManager(); instance = new ConnectionPoolManager(logger);
} }
return instance; return instance;
} }
@@ -51,9 +56,12 @@ export class ConnectionPoolManager {
* Private constructor that initializes the connection pool manager. * Private constructor that initializes the connection pool manager.
* Sets up cleanup handlers for process exit and stale connections. * Sets up cleanup handlers for process exit and stale connections.
*/ */
private constructor() { private constructor(private readonly logger: Logger) {
// Close all open pools when the process exits // Close all open pools when the process exits
process.on('exit', () => this.onShutdown()); process.on('exit', () => {
this.logger.debug('ConnectionPoolManager: Shutting down. Cleaning up all pools');
this.purgeConnections();
});
// Regularly close stale pools // Regularly close stale pools
setInterval(() => this.cleanupStaleConnections(), cleanUpInterval); setInterval(() => this.cleanupStaleConnections(), cleanUpInterval);
@@ -84,54 +92,67 @@ export class ConnectionPoolManager {
const key = this.makeKey(options); const key = this.makeKey(options);
let value = this.map.get(key); let value = this.map.get(key);
if (!value) {
value = { if (value) {
pool: await options.fallBackHandler(), value.lastUsed = Date.now();
cleanUpHandler: options.cleanUpHandler, value.wasUsed(value.pool);
} as Registration<unknown>; return value.pool as T;
}
const abortController = new AbortController();
value = {
pool: await options.fallBackHandler(abortController),
abortController,
wasUsed: options.wasUsed,
} as Registration<unknown>;
// It's possible that `options.fallBackHandler` already called the abort
// function. If that's the case let's not continue.
if (abortController.signal.aborted) {
throw new OperationalError('Could not create pool. Connection attempt was aborted.', {
cause: abortController.signal.reason,
});
} }
this.map.set(key, { ...value, lastUsed: Date.now() }); this.map.set(key, { ...value, lastUsed: Date.now() });
abortController.signal.addEventListener('abort', async () => {
this.logger.debug('ConnectionPoolManager: Got abort signal, cleaning up pool.');
this.cleanupConnection(key);
});
return value.pool as T; return value.pool as T;
} }
private cleanupConnection(key: string) {
const registration = this.map.get(key);
if (registration) {
this.map.delete(key);
registration.abortController.abort();
}
}
/** /**
* Removes and cleans up connection pools that haven't been used within the * Removes and cleans up connection pools that haven't been used within the
* TTL. * TTL.
*/ */
private cleanupStaleConnections() { private cleanupStaleConnections() {
const now = Date.now(); const now = Date.now();
for (const [key, { cleanUpHandler, lastUsed, pool }] of this.map.entries()) { for (const [key, { lastUsed }] of this.map.entries()) {
if (now - lastUsed > ttl) { if (now - lastUsed > ttl) {
void cleanUpHandler(pool); this.logger.debug('ConnectionPoolManager: Found stale pool. Cleaning it up.');
this.map.delete(key); void this.cleanupConnection(key);
} }
} }
} }
/** /**
* Removes and cleans up all existing connection pools. * Removes and cleans up all existing connection pools.
* Connections are closed in the background.
*/ */
async purgeConnections(): Promise<void> { purgeConnections(): void {
await Promise.all( for (const key of this.map.keys()) {
[...this.map.entries()].map(async ([key, value]) => { this.cleanupConnection(key);
this.map.delete(key);
return await value.cleanUpHandler(value.pool);
}),
);
}
/**
* Cleans up all connection pools when the process is shutting down.
* Does not wait for cleanup promises to resolve also does not remove the
* references from the pool.
*
* Only call this on process shutdown.
*/
onShutdown() {
for (const { cleanUpHandler, pool } of this.map.values()) {
void cleanUpHandler(pool);
} }
} }
} }

View File

@@ -829,7 +829,8 @@ export type SSHCredentials = {
); );
export interface SSHTunnelFunctions { export interface SSHTunnelFunctions {
getSSHClient(credentials: SSHCredentials): Promise<SSHClient>; getSSHClient(credentials: SSHCredentials, abortController?: AbortController): Promise<SSHClient>;
updateLastUsed(client: SSHClient): void;
} }
type CronUnit = number | '*' | `*/${number}`; type CronUnit = number | '*' | `*/${number}`;