refactor(core): Use consistent CSRF state validation across oAuth controllers (#9104)

Co-authored-by: Danny Martini <danny@n8n.io>
This commit is contained in:
कारतोफ्फेलस्क्रिप्ट™
2024-05-23 19:08:01 +02:00
committed by GitHub
parent 3b93aae6dc
commit b585777c79
6 changed files with 183 additions and 90 deletions

View File

@@ -1,6 +1,10 @@
import { Service } from 'typedi';
import Csrf from 'csrf';
import type { Response } from 'express';
import { Credentials } from 'n8n-core';
import type { ICredentialDataDecryptedObject, IWorkflowExecuteAdditionalData } from 'n8n-workflow';
import { jsonParse, ApplicationError } from 'n8n-workflow';
import config from '@/config';
import type { CredentialsEntity } from '@db/entities/CredentialsEntity';
import type { User } from '@db/entities/User';
@@ -17,6 +21,11 @@ import { UrlService } from '@/services/url.service';
import { BadRequestError } from '@/errors/response-errors/bad-request.error';
import { NotFoundError } from '@/errors/response-errors/not-found.error';
export interface CsrfStateParam {
cid: string;
token: string;
}
@Service()
export abstract class AbstractOAuthController {
abstract oauthVersion: number;
@@ -108,4 +117,37 @@ export abstract class AbstractOAuthController {
protected async getCredentialWithoutUser(credentialId: string): Promise<ICredentialsDb | null> {
return await this.credentialsRepository.findOneBy({ id: credentialId });
}
protected createCsrfState(credentialsId: string): [string, string] {
const token = new Csrf();
const csrfSecret = token.secretSync();
const state: CsrfStateParam = {
token: token.create(csrfSecret),
cid: credentialsId,
};
return [csrfSecret, Buffer.from(JSON.stringify(state)).toString('base64')];
}
protected decodeCsrfState(encodedState: string): CsrfStateParam {
const errorMessage = 'Invalid state format';
const decoded = jsonParse<CsrfStateParam>(Buffer.from(encodedState, 'base64').toString(), {
errorMessage,
});
if (typeof decoded.cid !== 'string' || typeof decoded.token !== 'string') {
throw new ApplicationError(errorMessage);
}
return decoded;
}
protected verifyCsrfState(decrypted: ICredentialDataDecryptedObject, state: CsrfStateParam) {
const token = new Csrf();
return (
decrypted.csrfSecret === undefined ||
!token.verify(decrypted.csrfSecret as string, state.token)
);
}
protected renderCallbackError(res: Response, message: string, reason?: string) {
res.render('oauth-error-callback', { error: { message, reason } });
}
}

View File

@@ -4,13 +4,11 @@ import axios from 'axios';
import type { RequestOptions } from 'oauth-1.0a';
import clientOAuth1 from 'oauth-1.0a';
import { createHmac } from 'crypto';
import { RESPONSE_ERROR_MESSAGES } from '@/constants';
import { Get, RestController } from '@/decorators';
import { OAuthRequest } from '@/requests';
import { sendErrorResponse } from '@/ResponseHelper';
import { AbstractOAuthController } from './abstractOAuth.controller';
import { AbstractOAuthController, type CsrfStateParam } from './abstractOAuth.controller';
import { NotFoundError } from '@/errors/response-errors/not-found.error';
import { ServiceUnavailableError } from '@/errors/response-errors/service-unavailable.error';
interface OAuth1CredentialData {
signatureMethod: 'HMAC-SHA256' | 'HMAC-SHA512' | 'HMAC-SHA1';
@@ -44,6 +42,7 @@ export class OAuth1CredentialController extends AbstractOAuthController {
decryptedDataOriginal,
additionalData,
);
const [csrfSecret, state] = this.createCsrfState(credential.id);
const signatureMethod = oauthCredentials.signatureMethod;
@@ -61,7 +60,7 @@ export class OAuth1CredentialController extends AbstractOAuthController {
};
const oauthRequestData = {
oauth_callback: `${this.baseUrl}/callback?cid=${credential.id}`,
oauth_callback: `${this.baseUrl}/callback?state=${state}`,
};
await this.externalHooks.run('oauth1.authenticate', [oAuthOptions, oauthRequestData]);
@@ -90,6 +89,7 @@ export class OAuth1CredentialController extends AbstractOAuthController {
const returnUri = `${oauthCredentials.authUrl}?oauth_token=${responseJson.oauth_token}`;
decryptedDataOriginal.csrfSecret = csrfSecret;
await this.encryptAndSaveData(credential, decryptedDataOriginal);
this.logger.verbose('OAuth1 authorization successful for new credential', {
@@ -103,31 +103,31 @@ export class OAuth1CredentialController extends AbstractOAuthController {
/** Verify and store app code. Generate access tokens and store for respective credential */
@Get('/callback', { usesTemplates: true })
async handleCallback(req: OAuthRequest.OAuth1Credential.Callback, res: Response) {
const userId = req.user?.id;
try {
const { oauth_verifier, oauth_token, cid: credentialId } = req.query;
const { oauth_verifier, oauth_token, state: encodedState } = req.query;
if (!oauth_verifier || !oauth_token) {
const errorResponse = new ServiceUnavailableError(
`Insufficient parameters for OAuth1 callback. Received following query parameters: ${JSON.stringify(
req.query,
)}`,
if (!oauth_verifier || !oauth_token || !encodedState) {
return this.renderCallbackError(
res,
'Insufficient parameters for OAuth1 callback.',
`Received following query parameters: ${JSON.stringify(req.query)}`,
);
this.logger.error('OAuth1 callback failed because of insufficient parameters received', {
userId: req.user?.id,
credentialId,
});
return sendErrorResponse(res, errorResponse);
}
const credential = await this.getCredentialWithoutUser(credentialId);
let state: CsrfStateParam;
try {
state = this.decodeCsrfState(encodedState);
} catch (error) {
return this.renderCallbackError(res, (error as Error).message);
}
const credentialId = state.cid;
const credential = await this.getCredentialWithoutUser(credentialId);
if (!credential) {
this.logger.error('OAuth1 callback failed because of insufficient user permissions', {
userId: req.user?.id,
credentialId,
});
const errorResponse = new NotFoundError(RESPONSE_ERROR_MESSAGES.NO_CREDENTIAL);
return sendErrorResponse(res, errorResponse);
const errorMessage = 'OAuth1 callback failed because of insufficient permissions';
this.logger.error(errorMessage, { userId, credentialId });
return this.renderCallbackError(res, errorMessage);
}
const additionalData = await this.getAdditionalData(req.user);
@@ -138,6 +138,12 @@ export class OAuth1CredentialController extends AbstractOAuthController {
additionalData,
);
if (this.verifyCsrfState(decryptedDataOriginal, state)) {
const errorMessage = 'The OAuth1 callback state is invalid!';
this.logger.debug(errorMessage, { userId, credentialId });
return this.renderCallbackError(res, errorMessage);
}
const options: AxiosRequestConfig = {
method: 'POST',
url: oauthCredentials.accessTokenUrl,
@@ -152,10 +158,7 @@ export class OAuth1CredentialController extends AbstractOAuthController {
try {
oauthToken = await axios.request(options);
} catch (error) {
this.logger.error('Unable to fetch tokens for OAuth1 callback', {
userId: req.user?.id,
credentialId,
});
this.logger.error('Unable to fetch tokens for OAuth1 callback', { userId, credentialId });
const errorResponse = new NotFoundError('Unable to get access tokens!');
return sendErrorResponse(res, errorResponse);
}
@@ -171,14 +174,13 @@ export class OAuth1CredentialController extends AbstractOAuthController {
await this.encryptAndSaveData(credential, decryptedDataOriginal);
this.logger.verbose('OAuth1 callback successful for new credential', {
userId: req.user?.id,
userId,
credentialId,
});
return res.render('oauth-callback');
} catch (error) {
this.logger.error('OAuth1 callback failed because of insufficient user permissions', {
userId: req.user?.id,
credentialId: req.query.cid,
userId,
});
// Error response
return sendErrorResponse(res, error as Error);

View File

@@ -1,21 +1,15 @@
import type { ClientOAuth2Options, OAuth2CredentialData } from '@n8n/client-oauth2';
import { ClientOAuth2 } from '@n8n/client-oauth2';
import Csrf from 'csrf';
import { Response } from 'express';
import pkceChallenge from 'pkce-challenge';
import * as qs from 'querystring';
import omit from 'lodash/omit';
import set from 'lodash/set';
import split from 'lodash/split';
import { ApplicationError, jsonParse, jsonStringify } from 'n8n-workflow';
import { Get, RestController } from '@/decorators';
import { jsonStringify } from 'n8n-workflow';
import { OAuthRequest } from '@/requests';
import { AbstractOAuthController } from './abstractOAuth.controller';
interface CsrfStateParam {
cid: string;
token: string;
}
import { AbstractOAuthController, type CsrfStateParam } from './abstractOAuth.controller';
@RestController('/oauth2-credential')
export class OAuth2CredentialController extends AbstractOAuthController {
@@ -87,8 +81,8 @@ export class OAuth2CredentialController extends AbstractOAuthController {
/** Verify and store app code. Generate access tokens and store for respective credential */
@Get('/callback', { usesTemplates: true })
async handleCallback(req: OAuthRequest.OAuth2Credential.Callback, res: Response) {
const userId = req.user?.id;
try {
// realmId it's currently just use for the quickbook OAuth2 flow
const { code, state: encodedState } = req.query;
if (!code || !encodedState) {
return this.renderCallbackError(
@@ -105,13 +99,11 @@ export class OAuth2CredentialController extends AbstractOAuthController {
return this.renderCallbackError(res, (error as Error).message);
}
const credential = await this.getCredentialWithoutUser(state.cid);
const credentialId = state.cid;
const credential = await this.getCredentialWithoutUser(credentialId);
if (!credential) {
const errorMessage = 'OAuth2 callback failed because of insufficient permissions';
this.logger.error(errorMessage, {
userId: req.user?.id,
credentialId: state.cid,
});
this.logger.error(errorMessage, { userId, credentialId });
return this.renderCallbackError(res, errorMessage);
}
@@ -123,16 +115,9 @@ export class OAuth2CredentialController extends AbstractOAuthController {
additionalData,
);
const token = new Csrf();
if (
decryptedDataOriginal.csrfSecret === undefined ||
!token.verify(decryptedDataOriginal.csrfSecret as string, state.token)
) {
if (this.verifyCsrfState(decryptedDataOriginal, state)) {
const errorMessage = 'The OAuth2 callback state is invalid!';
this.logger.debug(errorMessage, {
userId: req.user?.id,
credentialId: credential.id,
});
this.logger.debug(errorMessage, { userId, credentialId });
return this.renderCallbackError(res, errorMessage);
}
@@ -171,10 +156,7 @@ export class OAuth2CredentialController extends AbstractOAuthController {
if (oauthToken === undefined) {
const errorMessage = 'Unable to get OAuth2 access tokens!';
this.logger.error(errorMessage, {
userId: req.user?.id,
credentialId: credential.id,
});
this.logger.error(errorMessage, { userId, credentialId });
return this.renderCallbackError(res, errorMessage);
}
@@ -191,8 +173,8 @@ export class OAuth2CredentialController extends AbstractOAuthController {
await this.encryptAndSaveData(credential, decryptedDataOriginal);
this.logger.verbose('OAuth2 callback successful for credential', {
userId: req.user?.id,
credentialId: credential.id,
userId,
credentialId,
});
return res.render('oauth-callback');
@@ -219,29 +201,4 @@ export class OAuth2CredentialController extends AbstractOAuthController {
ignoreSSLIssues: credential.ignoreSSLIssues ?? false,
};
}
private renderCallbackError(res: Response, message: string, reason?: string) {
res.render('oauth-error-callback', { error: { message, reason } });
}
private createCsrfState(credentialsId: string): [string, string] {
const token = new Csrf();
const csrfSecret = token.secretSync();
const state: CsrfStateParam = {
token: token.create(csrfSecret),
cid: credentialsId,
};
return [csrfSecret, Buffer.from(JSON.stringify(state)).toString('base64')];
}
private decodeCsrfState(encodedState: string): CsrfStateParam {
const errorMessage = 'Invalid state format';
const decoded = jsonParse<CsrfStateParam>(Buffer.from(encodedState, 'base64').toString(), {
errorMessage,
});
if (typeof decoded.cid !== 'string' || typeof decoded.token !== 'string') {
throw new ApplicationError(errorMessage);
}
return decoded;
}
}

View File

@@ -377,7 +377,7 @@ export declare namespace OAuthRequest {
{},
{},
{},
{ oauth_verifier: string; oauth_token: string; cid: string }
{ oauth_verifier: string; oauth_token: string; state: string }
> & {
user?: User;
};

View File

@@ -1,10 +1,12 @@
import nock from 'nock';
import Container from 'typedi';
import type { Response } from 'express';
import Csrf from 'csrf';
import { Cipher } from 'n8n-core';
import { mock } from 'jest-mock-extended';
import { OAuth1CredentialController } from '@/controllers/oauth/oAuth1Credential.controller';
import type { CredentialsEntity } from '@db/entities/CredentialsEntity';
import { CredentialsEntity } from '@db/entities/CredentialsEntity';
import type { User } from '@db/entities/User';
import type { OAuthRequest } from '@/requests';
import { CredentialsRepository } from '@db/repositories/credentials.repository';
@@ -14,11 +16,11 @@ import { Logger } from '@/Logger';
import { VariablesService } from '@/environments/variables/variables.service.ee';
import { SecretsHelper } from '@/SecretsHelpers';
import { CredentialsHelper } from '@/CredentialsHelper';
import { mockInstance } from '../../shared/mocking';
import { BadRequestError } from '@/errors/response-errors/bad-request.error';
import { NotFoundError } from '@/errors/response-errors/not-found.error';
import { mockInstance } from '../../../shared/mocking';
describe('OAuth1CredentialController', () => {
mockInstance(Logger);
mockInstance(ExternalHooks);
@@ -30,6 +32,8 @@ describe('OAuth1CredentialController', () => {
const credentialsHelper = mockInstance(CredentialsHelper);
const credentialsRepository = mockInstance(CredentialsRepository);
const sharedCredentialsRepository = mockInstance(SharedCredentialsRepository);
const csrfSecret = 'csrf-secret';
const user = mock<User>({
id: '123',
password: 'password',
@@ -66,6 +70,8 @@ describe('OAuth1CredentialController', () => {
});
it('should return a valid auth URI', async () => {
jest.spyOn(Csrf.prototype, 'secretSync').mockReturnValueOnce(csrfSecret);
jest.spyOn(Csrf.prototype, 'create').mockReturnValueOnce('token');
sharedCredentialsRepository.findCredentialForUser.mockResolvedValueOnce(credential);
credentialsHelper.getDecrypted.mockResolvedValueOnce({});
credentialsHelper.applyDefaultsAndOverwrites.mockReturnValueOnce({
@@ -75,7 +81,8 @@ describe('OAuth1CredentialController', () => {
});
nock('https://example.domain')
.post('/oauth/request_token', {
oauth_callback: 'http://localhost:5678/rest/oauth1-credential/callback?cid=1',
oauth_callback:
'http://localhost:5678/rest/oauth1-credential/callback?state=eyJ0b2tlbiI6InRva2VuIiwiY2lkIjoiMSJ9',
})
.reply(200, { oauth_token: 'random-token' });
cipher.encrypt.mockReturnValue('encrypted');
@@ -92,6 +99,91 @@ describe('OAuth1CredentialController', () => {
type: 'oAuth1Api',
}),
);
expect(cipher.encrypt).toHaveBeenCalledWith({ csrfSecret });
});
});
describe('handleCallback', () => {
const validState = Buffer.from(
JSON.stringify({
token: 'token',
cid: '1',
}),
).toString('base64');
it('should render the error page when required query params are missing', async () => {
const req = mock<OAuthRequest.OAuth1Credential.Callback>();
const res = mock<Response>();
req.query = { state: 'test' } as OAuthRequest.OAuth1Credential.Callback['query'];
await controller.handleCallback(req, res);
expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'Insufficient parameters for OAuth1 callback.',
reason: 'Received following query parameters: {"state":"test"}',
},
});
expect(credentialsRepository.findOneBy).not.toHaveBeenCalled();
});
it('should render the error page when `state` query param is invalid', async () => {
const req = mock<OAuthRequest.OAuth1Credential.Callback>();
const res = mock<Response>();
req.query = {
oauth_verifier: 'verifier',
oauth_token: 'token',
state: 'test',
} as OAuthRequest.OAuth1Credential.Callback['query'];
await controller.handleCallback(req, res);
expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'Invalid state format',
},
});
expect(credentialsRepository.findOneBy).not.toHaveBeenCalled();
});
it('should render the error page when credential is not found in DB', async () => {
credentialsRepository.findOneBy.mockResolvedValueOnce(null);
const req = mock<OAuthRequest.OAuth1Credential.Callback>();
const res = mock<Response>();
req.query = {
oauth_verifier: 'verifier',
oauth_token: 'token',
state: validState,
} as OAuthRequest.OAuth1Credential.Callback['query'];
await controller.handleCallback(req, res);
expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'OAuth1 callback failed because of insufficient permissions',
},
});
expect(credentialsRepository.findOneBy).toHaveBeenCalledTimes(1);
expect(credentialsRepository.findOneBy).toHaveBeenCalledWith({ id: '1' });
});
it('should render the error page when state differs from the stored state in the credential', async () => {
credentialsRepository.findOneBy.mockResolvedValue(new CredentialsEntity());
credentialsHelper.getDecrypted.mockResolvedValue({ csrfSecret: 'invalid' });
const req = mock<OAuthRequest.OAuth1Credential.Callback>();
const res = mock<Response>();
req.query = {
oauth_verifier: 'verifier',
oauth_token: 'token',
state: validState,
} as OAuthRequest.OAuth1Credential.Callback['query'];
await controller.handleCallback(req, res);
expect(res.render).toHaveBeenCalledWith('oauth-error-callback', {
error: {
message: 'The OAuth1 callback state is invalid!',
},
});
});
});
});

View File

@@ -16,11 +16,11 @@ import { Logger } from '@/Logger';
import { VariablesService } from '@/environments/variables/variables.service.ee';
import { SecretsHelper } from '@/SecretsHelpers';
import { CredentialsHelper } from '@/CredentialsHelper';
import { mockInstance } from '../../shared/mocking';
import { BadRequestError } from '@/errors/response-errors/bad-request.error';
import { NotFoundError } from '@/errors/response-errors/not-found.error';
import { mockInstance } from '../../../shared/mocking';
describe('OAuth2CredentialController', () => {
mockInstance(Logger);
mockInstance(SecretsHelper);