From 59a08eed723e39a1a855d5a23808057cc5cec424 Mon Sep 17 00:00:00 2001 From: Durran Jordan Date: Thu, 14 Aug 2025 09:48:38 +0200 Subject: [PATCH] refactor(MongoDB Vector Store Node): Refactor mongodb vector store node (#16239) Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> --- .../VectorStoreMongoDBAtlas.node.test.ts | 185 ++++++++++++----- .../VectorStoreMongoDBAtlas.node.ts | 196 ++++++++++-------- 2 files changed, 236 insertions(+), 145 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts index ba3dda5eea..3596a10532 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts @@ -1,78 +1,151 @@ +import { mock } from 'jest-mock-extended'; import { MongoClient } from 'mongodb'; +import type { ILoadOptionsFunctions } from 'n8n-workflow'; -import { getMongoClient, mongoConfig } from './VectorStoreMongoDBAtlas.node'; +import { + EMBEDDING_NAME, + getCollectionName, + getEmbeddingFieldName, + getMetadataFieldName, + getMongoClient, + getVectorIndexName, + mongoConfig, + METADATA_FIELD_NAME, + MONGODB_COLLECTION_NAME, + VECTOR_INDEX_NAME, +} from './VectorStoreMongoDBAtlas.node'; jest.mock('mongodb', () => ({ MongoClient: jest.fn(), })); -describe('VectorStoreMongoDBAtlas -> getMongoClient', () => { - const mockContext = { - getCredentials: jest.fn(), - }; - const mockClient1 = { - connect: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - }; - const mockClient2 = { - connect: jest.fn().mockResolvedValue(undefined), - close: jest.fn().mockResolvedValue(undefined), - }; - const MockMongoClient = MongoClient as jest.MockedClass; +describe('VectorStoreMongoDBAtlas', () => { + const helpers = mock(); + const executeFunctions = mock({ helpers }); beforeEach(() => { jest.resetAllMocks(); - mongoConfig.client = null; - mongoConfig.connectionString = ''; }); - it('should reuse the same client when connection string is unchanged', async () => { - MockMongoClient.mockImplementation(() => mockClient1 as unknown as MongoClient); - mockContext.getCredentials.mockResolvedValue({ - connectionString: 'mongodb://localhost:27017', + describe('.getMongoClient', () => { + const mockContext = { + getCredentials: jest.fn(), + }; + const mockClient1 = { + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + }; + const mockClient2 = { + connect: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + }; + const MockMongoClient = MongoClient as jest.MockedClass; + + beforeEach(() => { + mongoConfig.client = null; + mongoConfig.connectionString = ''; }); - const client1 = await getMongoClient(mockContext); - const client2 = await getMongoClient(mockContext); - - expect(MockMongoClient).toHaveBeenCalledTimes(1); - expect(MockMongoClient).toHaveBeenCalledWith('mongodb://localhost:27017', { - appName: 'devrel.integration.n8n_vector_integ', - }); - expect(mockClient1.connect).toHaveBeenCalledTimes(1); - expect(mockClient1.close).not.toHaveBeenCalled(); - expect(mockClient2.connect).not.toHaveBeenCalled(); - expect(client1).toBe(mockClient1); - expect(client2).toBe(mockClient1); - }); - - it('should create new client when connection string changes', async () => { - MockMongoClient.mockImplementationOnce( - () => mockClient1 as unknown as MongoClient, - ).mockImplementationOnce(() => mockClient2 as unknown as MongoClient); - mockContext.getCredentials - .mockResolvedValueOnce({ + it('should reuse the same client when connection string is unchanged', async () => { + MockMongoClient.mockImplementation(() => mockClient1 as unknown as MongoClient); + mockContext.getCredentials.mockResolvedValue({ connectionString: 'mongodb://localhost:27017', - }) - .mockResolvedValueOnce({ - connectionString: 'mongodb://different-host:27017', }); - const client1 = await getMongoClient(mockContext); - const client2 = await getMongoClient(mockContext); + const client1 = await getMongoClient(mockContext); + const client2 = await getMongoClient(mockContext); - expect(MockMongoClient).toHaveBeenCalledTimes(2); - expect(MockMongoClient).toHaveBeenNthCalledWith(1, 'mongodb://localhost:27017', { - appName: 'devrel.integration.n8n_vector_integ', + expect(MockMongoClient).toHaveBeenCalledTimes(1); + expect(MockMongoClient).toHaveBeenCalledWith('mongodb://localhost:27017', { + appName: 'devrel.integration.n8n_vector_integ', + }); + expect(mockClient1.connect).toHaveBeenCalledTimes(1); + expect(mockClient1.close).not.toHaveBeenCalled(); + expect(mockClient2.connect).not.toHaveBeenCalled(); + expect(client1).toBe(mockClient1); + expect(client2).toBe(mockClient1); }); - expect(MockMongoClient).toHaveBeenNthCalledWith(2, 'mongodb://different-host:27017', { - appName: 'devrel.integration.n8n_vector_integ', + + it('should create new client when connection string changes', async () => { + MockMongoClient.mockImplementationOnce( + () => mockClient1 as unknown as MongoClient, + ).mockImplementationOnce(() => mockClient2 as unknown as MongoClient); + mockContext.getCredentials + .mockResolvedValueOnce({ + connectionString: 'mongodb://localhost:27017', + }) + .mockResolvedValueOnce({ + connectionString: 'mongodb://different-host:27017', + }); + + const client1 = await getMongoClient(mockContext); + const client2 = await getMongoClient(mockContext); + + expect(MockMongoClient).toHaveBeenCalledTimes(2); + expect(MockMongoClient).toHaveBeenNthCalledWith(1, 'mongodb://localhost:27017', { + appName: 'devrel.integration.n8n_vector_integ', + }); + expect(MockMongoClient).toHaveBeenNthCalledWith(2, 'mongodb://different-host:27017', { + appName: 'devrel.integration.n8n_vector_integ', + }); + expect(mockClient1.connect).toHaveBeenCalledTimes(1); + expect(mockClient1.close).toHaveBeenCalledTimes(1); + expect(mockClient2.connect).toHaveBeenCalledTimes(1); + expect(mockClient2.close).not.toHaveBeenCalled(); + expect(client1).toBe(mockClient1); + expect(client2).toBe(mockClient2); + }); + }); + + describe('.getCollectionName', () => { + beforeEach(() => { + executeFunctions.getNodeParameter.mockImplementation((paramName: string) => { + if (paramName === MONGODB_COLLECTION_NAME) return 'testCollection'; + return ''; + }); + }); + + it('returns the collection name from the context', () => { + expect(getCollectionName(executeFunctions, 0)).toEqual('testCollection'); + }); + }); + + describe('.getVectorIndexName', () => { + beforeEach(() => { + executeFunctions.getNodeParameter.mockImplementation((paramName: string) => { + if (paramName === VECTOR_INDEX_NAME) return 'testIndex'; + return ''; + }); + }); + + it('returns the index name from the context', () => { + expect(getVectorIndexName(executeFunctions, 0)).toEqual('testIndex'); + }); + }); + + describe('.getEmbeddingFieldName', () => { + beforeEach(() => { + executeFunctions.getNodeParameter.mockImplementation((paramName: string) => { + if (paramName === EMBEDDING_NAME) return 'testEmbedding'; + return ''; + }); + }); + + it('returns the embedding name from the context', () => { + expect(getEmbeddingFieldName(executeFunctions, 0)).toEqual('testEmbedding'); + }); + }); + + describe('.getMetadataFieldName', () => { + beforeEach(() => { + executeFunctions.getNodeParameter.mockImplementation((paramName: string) => { + if (paramName === METADATA_FIELD_NAME) return 'testMetadata'; + return ''; + }); + }); + + it('returns the metadata field name from the context', () => { + expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata'); }); - expect(mockClient1.connect).toHaveBeenCalledTimes(1); - expect(mockClient1.close).toHaveBeenCalledTimes(1); - expect(mockClient2.connect).toHaveBeenCalledTimes(1); - expect(mockClient2.close).not.toHaveBeenCalled(); - expect(client1).toBe(mockClient1); - expect(client2).toBe(mockClient2); }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts index fb016226f0..48d1e27cdc 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts @@ -1,7 +1,12 @@ import { MongoDBAtlasVectorSearch } from '@langchain/mongodb'; import { MongoClient } from 'mongodb'; -import { type ILoadOptionsFunctions, NodeOperationError, type INodeProperties } from 'n8n-workflow'; - +import { + type ILoadOptionsFunctions, + NodeOperationError, + type INodeProperties, + type IExecuteFunctions, + type ISupplyDataFunctions, +} from 'n8n-workflow'; import { metadataFilterField } from '@utils/sharedFields'; import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode'; @@ -108,8 +113,27 @@ export const mongoConfig = { connectionString: '', }; +/** + * Constants for the name of the credentials and Node parameters. + */ +export const MONGODB_CREDENTIALS = 'mongoDb'; +export const MONGODB_COLLECTION_NAME = 'mongoCollection'; +export const VECTOR_INDEX_NAME = 'vectorIndexName'; +export const EMBEDDING_NAME = 'embedding'; +export const METADATA_FIELD_NAME = 'metadata_field'; + +/** + * Type used for cleaner, more intentional typing. + */ +type IFunctionsContext = IExecuteFunctions | ISupplyDataFunctions | ILoadOptionsFunctions; + +/** + * Get the mongo client. + * @param context - The context. + * @returns the MongoClient for the node. + */ export async function getMongoClient(context: any) { - const credentials = await context.getCredentials('mongoDb'); + const credentials = await context.getCredentials(MONGODB_CREDENTIALS); const connectionString = credentials.connectionString as string; if (!mongoConfig.client || mongoConfig.connectionString !== connectionString) { if (mongoConfig.client) { @@ -125,16 +149,25 @@ export async function getMongoClient(context: any) { return mongoConfig.client; } -async function mongoClientAndDatabase(context: any) { - const client = await getMongoClient(context); - const credentials = await context.getCredentials('mongoDb'); - const db = client.db(credentials.database as string); - return { client, db }; +/** + * Get the database object from the MongoClient by the configured name. + * @param context - The context. + * @returns the Db object. + */ +export async function getDatabase(context: IFunctionsContext, client: MongoClient) { + const credentials = await context.getCredentials(MONGODB_CREDENTIALS); + return client.db(credentials.database as string); } -async function mongoCollectionSearch(this: ILoadOptionsFunctions) { - const { db } = await mongoClientAndDatabase(this); +/** + * Get all the collection in the database. + * @param this The load options context. + * @returns The list of collections. + */ +export async function getCollections(this: ILoadOptionsFunctions) { try { + const client = await getMongoClient(this); + const db = await getDatabase(this, client); const collections = await db.listCollections().toArray(); const results = collections.map((collection) => ({ name: collection.name, @@ -146,6 +179,29 @@ async function mongoCollectionSearch(this: ILoadOptionsFunctions) { throw new NodeOperationError(this.getNode(), `Error: ${error.message}`); } } + +/** + * Get a parameter from the context. + * @param key - The key of the parameter. + * @param context - The context. + * @param itemIndex - The index. + * @returns The value. + */ +export function getParameter(key: string, context: IFunctionsContext, itemIndex: number): string { + const value = context.getNodeParameter(key, itemIndex, '', { + extractValue: true, + }) as string; + if (typeof value !== 'string') { + throw new NodeOperationError(context.getNode(), `Parameter ${key} must be a string`); + } + return value; +} + +export const getCollectionName = getParameter.bind(null, MONGODB_COLLECTION_NAME); +export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME); +export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME); +export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME); + export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ meta: { displayName: 'MongoDB Atlas Vector Store', @@ -162,64 +218,44 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ ], operationModes: ['load', 'insert', 'retrieve', 'update', 'retrieve-as-tool'], }, - methods: { listSearch: { mongoCollectionSearch } }, + methods: { listSearch: { mongoCollectionSearch: getCollections } }, retrieveFields, loadFields: retrieveFields, insertFields, sharedFields, async getVectorStoreClient(context, _filter, embeddings, itemIndex) { try { - const { db } = await mongoClientAndDatabase(context); - try { - const collectionName = context.getNodeParameter('mongoCollection', itemIndex, '', { - extractValue: true, - }) as string; + const client = await getMongoClient(context); + const db = await getDatabase(context, client); + const collectionName = getCollectionName(context, itemIndex); + const mongoVectorIndexName = getVectorIndexName(context, itemIndex); + const embeddingFieldName = getEmbeddingFieldName(context, itemIndex); + const metadataFieldName = getMetadataFieldName(context, itemIndex); - const mongoVectorIndexName = context.getNodeParameter('vectorIndexName', itemIndex, '', { - extractValue: true, - }) as string; + const collection = db.collection(collectionName); - const embeddingFieldName = context.getNodeParameter('embedding', itemIndex, '', { - extractValue: true, - }) as string; + // test index exists + const indexes = await collection.listSearchIndexes().toArray(); - const metadataFieldName = context.getNodeParameter('metadata_field', itemIndex, '', { - extractValue: true, - }) as string; + const indexExists = indexes.some((index) => index.name === mongoVectorIndexName); - const collection = db.collection(collectionName); - - // test index exists - const indexes = await collection.listSearchIndexes().toArray(); - - const indexExists = indexes.some((index) => index.name === mongoVectorIndexName); - - if (!indexExists) { - throw new NodeOperationError( - context.getNode(), - `Index ${mongoVectorIndexName} not found`, - { - itemIndex, - description: 'Please check that the index exists in your collection', - }, - ); - } - - return new MongoDBAtlasVectorSearch(embeddings, { - collection, - indexName: mongoVectorIndexName, // Default index name - textKey: metadataFieldName, // Field containing raw text - embeddingKey: embeddingFieldName, // Field containing embeddings - }); - } catch (error) { - throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { + if (!indexExists) { + throw new NodeOperationError(context.getNode(), `Index ${mongoVectorIndexName} not found`, { itemIndex, - description: 'Please check your MongoDB Atlas connection details', + description: 'Please check that the index exists in your collection', }); - } finally { - // Don't close the client here to maintain connection pooling } + + return new MongoDBAtlasVectorSearch(embeddings, { + collection, + indexName: mongoVectorIndexName, // Default index name + textKey: metadataFieldName, // Field containing raw text + embeddingKey: embeddingFieldName, // Field containing embeddings + }); } catch (error) { + if (error instanceof NodeOperationError) { + throw error; + } throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { itemIndex, description: 'Please check your MongoDB Atlas connection details', @@ -228,43 +264,25 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ }, async populateVectorStore(context, embeddings, documents, itemIndex) { try { - const { db } = await mongoClientAndDatabase(context); - try { - const mongoCollectionName = context.getNodeParameter('mongoCollection', itemIndex, '', { - extractValue: true, - }) as string; - const embeddingFieldName = context.getNodeParameter('embedding', itemIndex, '', { - extractValue: true, - }) as string; + const client = await getMongoClient(context); + const db = await getDatabase(context, client); + const collectionName = getCollectionName(context, itemIndex); + const mongoVectorIndexName = getVectorIndexName(context, itemIndex); + const embeddingFieldName = getEmbeddingFieldName(context, itemIndex); + const metadataFieldName = getMetadataFieldName(context, itemIndex); - const metadataFieldName = context.getNodeParameter('metadata_field', itemIndex, '', { - extractValue: true, - }) as string; - - const mongoDBAtlasVectorIndex = context.getNodeParameter('vectorIndexName', itemIndex, '', { - extractValue: true, - }) as string; - - // Check if collection exists - const collections = await db.listCollections({ name: mongoCollectionName }).toArray(); - if (collections.length === 0) { - await db.createCollection(mongoCollectionName); - } - const collection = db.collection(mongoCollectionName); - await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, { - collection, - indexName: mongoDBAtlasVectorIndex, // Default index name - textKey: metadataFieldName, // Field containing raw text - embeddingKey: embeddingFieldName, // Field containing embeddings - }); - } catch (error) { - throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { - itemIndex, - description: 'Please check your MongoDB Atlas connection details', - }); - } finally { - // Don't close the client here to maintain connection pooling + // Check if collection exists + const collections = await db.listCollections({ name: collectionName }).toArray(); + if (collections.length === 0) { + await db.createCollection(collectionName); } + const collection = db.collection(collectionName); + await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, { + collection, + indexName: mongoVectorIndexName, // Default index name + textKey: metadataFieldName, // Field containing raw text + embeddingKey: embeddingFieldName, // Field containing embeddings + }); } catch (error) { throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { itemIndex,