refactor(MongoDB Vector Store Node): Refactor mongodb vector store node (#16239)

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
This commit is contained in:
Durran Jordan
2025-08-14 09:48:38 +02:00
committed by GitHub
parent 9043869b10
commit 59a08eed72
2 changed files with 236 additions and 145 deletions

View File

@@ -1,78 +1,151 @@
import { mock } from 'jest-mock-extended';
import { MongoClient } from 'mongodb'; import { MongoClient } from 'mongodb';
import type { ILoadOptionsFunctions } from 'n8n-workflow';
import { getMongoClient, mongoConfig } from './VectorStoreMongoDBAtlas.node'; import {
EMBEDDING_NAME,
getCollectionName,
getEmbeddingFieldName,
getMetadataFieldName,
getMongoClient,
getVectorIndexName,
mongoConfig,
METADATA_FIELD_NAME,
MONGODB_COLLECTION_NAME,
VECTOR_INDEX_NAME,
} from './VectorStoreMongoDBAtlas.node';
jest.mock('mongodb', () => ({ jest.mock('mongodb', () => ({
MongoClient: jest.fn(), MongoClient: jest.fn(),
})); }));
describe('VectorStoreMongoDBAtlas -> getMongoClient', () => { describe('VectorStoreMongoDBAtlas', () => {
const mockContext = { const helpers = mock<ILoadOptionsFunctions['helpers']>();
getCredentials: jest.fn(), const executeFunctions = mock<ILoadOptionsFunctions>({ helpers });
};
const mockClient1 = {
connect: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
};
const mockClient2 = {
connect: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
};
const MockMongoClient = MongoClient as jest.MockedClass<typeof MongoClient>;
beforeEach(() => { beforeEach(() => {
jest.resetAllMocks(); jest.resetAllMocks();
mongoConfig.client = null;
mongoConfig.connectionString = '';
}); });
it('should reuse the same client when connection string is unchanged', async () => { describe('.getMongoClient', () => {
MockMongoClient.mockImplementation(() => mockClient1 as unknown as MongoClient); const mockContext = {
mockContext.getCredentials.mockResolvedValue({ getCredentials: jest.fn(),
connectionString: 'mongodb://localhost:27017', };
const mockClient1 = {
connect: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
};
const mockClient2 = {
connect: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
};
const MockMongoClient = MongoClient as jest.MockedClass<typeof MongoClient>;
beforeEach(() => {
mongoConfig.client = null;
mongoConfig.connectionString = '';
}); });
const client1 = await getMongoClient(mockContext); it('should reuse the same client when connection string is unchanged', async () => {
const client2 = await getMongoClient(mockContext); MockMongoClient.mockImplementation(() => mockClient1 as unknown as MongoClient);
mockContext.getCredentials.mockResolvedValue({
expect(MockMongoClient).toHaveBeenCalledTimes(1);
expect(MockMongoClient).toHaveBeenCalledWith('mongodb://localhost:27017', {
appName: 'devrel.integration.n8n_vector_integ',
});
expect(mockClient1.connect).toHaveBeenCalledTimes(1);
expect(mockClient1.close).not.toHaveBeenCalled();
expect(mockClient2.connect).not.toHaveBeenCalled();
expect(client1).toBe(mockClient1);
expect(client2).toBe(mockClient1);
});
it('should create new client when connection string changes', async () => {
MockMongoClient.mockImplementationOnce(
() => mockClient1 as unknown as MongoClient,
).mockImplementationOnce(() => mockClient2 as unknown as MongoClient);
mockContext.getCredentials
.mockResolvedValueOnce({
connectionString: 'mongodb://localhost:27017', connectionString: 'mongodb://localhost:27017',
})
.mockResolvedValueOnce({
connectionString: 'mongodb://different-host:27017',
}); });
const client1 = await getMongoClient(mockContext); const client1 = await getMongoClient(mockContext);
const client2 = await getMongoClient(mockContext); const client2 = await getMongoClient(mockContext);
expect(MockMongoClient).toHaveBeenCalledTimes(2); expect(MockMongoClient).toHaveBeenCalledTimes(1);
expect(MockMongoClient).toHaveBeenNthCalledWith(1, 'mongodb://localhost:27017', { expect(MockMongoClient).toHaveBeenCalledWith('mongodb://localhost:27017', {
appName: 'devrel.integration.n8n_vector_integ', appName: 'devrel.integration.n8n_vector_integ',
});
expect(mockClient1.connect).toHaveBeenCalledTimes(1);
expect(mockClient1.close).not.toHaveBeenCalled();
expect(mockClient2.connect).not.toHaveBeenCalled();
expect(client1).toBe(mockClient1);
expect(client2).toBe(mockClient1);
}); });
expect(MockMongoClient).toHaveBeenNthCalledWith(2, 'mongodb://different-host:27017', {
appName: 'devrel.integration.n8n_vector_integ', it('should create new client when connection string changes', async () => {
MockMongoClient.mockImplementationOnce(
() => mockClient1 as unknown as MongoClient,
).mockImplementationOnce(() => mockClient2 as unknown as MongoClient);
mockContext.getCredentials
.mockResolvedValueOnce({
connectionString: 'mongodb://localhost:27017',
})
.mockResolvedValueOnce({
connectionString: 'mongodb://different-host:27017',
});
const client1 = await getMongoClient(mockContext);
const client2 = await getMongoClient(mockContext);
expect(MockMongoClient).toHaveBeenCalledTimes(2);
expect(MockMongoClient).toHaveBeenNthCalledWith(1, 'mongodb://localhost:27017', {
appName: 'devrel.integration.n8n_vector_integ',
});
expect(MockMongoClient).toHaveBeenNthCalledWith(2, 'mongodb://different-host:27017', {
appName: 'devrel.integration.n8n_vector_integ',
});
expect(mockClient1.connect).toHaveBeenCalledTimes(1);
expect(mockClient1.close).toHaveBeenCalledTimes(1);
expect(mockClient2.connect).toHaveBeenCalledTimes(1);
expect(mockClient2.close).not.toHaveBeenCalled();
expect(client1).toBe(mockClient1);
expect(client2).toBe(mockClient2);
});
});
describe('.getCollectionName', () => {
beforeEach(() => {
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
if (paramName === MONGODB_COLLECTION_NAME) return 'testCollection';
return '';
});
});
it('returns the collection name from the context', () => {
expect(getCollectionName(executeFunctions, 0)).toEqual('testCollection');
});
});
describe('.getVectorIndexName', () => {
beforeEach(() => {
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
if (paramName === VECTOR_INDEX_NAME) return 'testIndex';
return '';
});
});
it('returns the index name from the context', () => {
expect(getVectorIndexName(executeFunctions, 0)).toEqual('testIndex');
});
});
describe('.getEmbeddingFieldName', () => {
beforeEach(() => {
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
if (paramName === EMBEDDING_NAME) return 'testEmbedding';
return '';
});
});
it('returns the embedding name from the context', () => {
expect(getEmbeddingFieldName(executeFunctions, 0)).toEqual('testEmbedding');
});
});
describe('.getMetadataFieldName', () => {
beforeEach(() => {
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
if (paramName === METADATA_FIELD_NAME) return 'testMetadata';
return '';
});
});
it('returns the metadata field name from the context', () => {
expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata');
}); });
expect(mockClient1.connect).toHaveBeenCalledTimes(1);
expect(mockClient1.close).toHaveBeenCalledTimes(1);
expect(mockClient2.connect).toHaveBeenCalledTimes(1);
expect(mockClient2.close).not.toHaveBeenCalled();
expect(client1).toBe(mockClient1);
expect(client2).toBe(mockClient2);
}); });
}); });

View File

@@ -1,7 +1,12 @@
import { MongoDBAtlasVectorSearch } from '@langchain/mongodb'; import { MongoDBAtlasVectorSearch } from '@langchain/mongodb';
import { MongoClient } from 'mongodb'; import { MongoClient } from 'mongodb';
import { type ILoadOptionsFunctions, NodeOperationError, type INodeProperties } from 'n8n-workflow'; import {
type ILoadOptionsFunctions,
NodeOperationError,
type INodeProperties,
type IExecuteFunctions,
type ISupplyDataFunctions,
} from 'n8n-workflow';
import { metadataFilterField } from '@utils/sharedFields'; import { metadataFilterField } from '@utils/sharedFields';
import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode'; import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode';
@@ -108,8 +113,27 @@ export const mongoConfig = {
connectionString: '', connectionString: '',
}; };
/**
* Constants for the name of the credentials and Node parameters.
*/
export const MONGODB_CREDENTIALS = 'mongoDb';
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
export const VECTOR_INDEX_NAME = 'vectorIndexName';
export const EMBEDDING_NAME = 'embedding';
export const METADATA_FIELD_NAME = 'metadata_field';
/**
* Type used for cleaner, more intentional typing.
*/
type IFunctionsContext = IExecuteFunctions | ISupplyDataFunctions | ILoadOptionsFunctions;
/**
* Get the mongo client.
* @param context - The context.
* @returns the MongoClient for the node.
*/
export async function getMongoClient(context: any) { export async function getMongoClient(context: any) {
const credentials = await context.getCredentials('mongoDb'); const credentials = await context.getCredentials(MONGODB_CREDENTIALS);
const connectionString = credentials.connectionString as string; const connectionString = credentials.connectionString as string;
if (!mongoConfig.client || mongoConfig.connectionString !== connectionString) { if (!mongoConfig.client || mongoConfig.connectionString !== connectionString) {
if (mongoConfig.client) { if (mongoConfig.client) {
@@ -125,16 +149,25 @@ export async function getMongoClient(context: any) {
return mongoConfig.client; return mongoConfig.client;
} }
async function mongoClientAndDatabase(context: any) { /**
const client = await getMongoClient(context); * Get the database object from the MongoClient by the configured name.
const credentials = await context.getCredentials('mongoDb'); * @param context - The context.
const db = client.db(credentials.database as string); * @returns the Db object.
return { client, db }; */
export async function getDatabase(context: IFunctionsContext, client: MongoClient) {
const credentials = await context.getCredentials(MONGODB_CREDENTIALS);
return client.db(credentials.database as string);
} }
async function mongoCollectionSearch(this: ILoadOptionsFunctions) { /**
const { db } = await mongoClientAndDatabase(this); * Get all the collection in the database.
* @param this The load options context.
* @returns The list of collections.
*/
export async function getCollections(this: ILoadOptionsFunctions) {
try { try {
const client = await getMongoClient(this);
const db = await getDatabase(this, client);
const collections = await db.listCollections().toArray(); const collections = await db.listCollections().toArray();
const results = collections.map((collection) => ({ const results = collections.map((collection) => ({
name: collection.name, name: collection.name,
@@ -146,6 +179,29 @@ async function mongoCollectionSearch(this: ILoadOptionsFunctions) {
throw new NodeOperationError(this.getNode(), `Error: ${error.message}`); throw new NodeOperationError(this.getNode(), `Error: ${error.message}`);
} }
} }
/**
* Get a parameter from the context.
* @param key - The key of the parameter.
* @param context - The context.
* @param itemIndex - The index.
* @returns The value.
*/
export function getParameter(key: string, context: IFunctionsContext, itemIndex: number): string {
const value = context.getNodeParameter(key, itemIndex, '', {
extractValue: true,
}) as string;
if (typeof value !== 'string') {
throw new NodeOperationError(context.getNode(), `Parameter ${key} must be a string`);
}
return value;
}
export const getCollectionName = getParameter.bind(null, MONGODB_COLLECTION_NAME);
export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME);
export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME);
export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME);
export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
meta: { meta: {
displayName: 'MongoDB Atlas Vector Store', displayName: 'MongoDB Atlas Vector Store',
@@ -162,64 +218,44 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
], ],
operationModes: ['load', 'insert', 'retrieve', 'update', 'retrieve-as-tool'], operationModes: ['load', 'insert', 'retrieve', 'update', 'retrieve-as-tool'],
}, },
methods: { listSearch: { mongoCollectionSearch } }, methods: { listSearch: { mongoCollectionSearch: getCollections } },
retrieveFields, retrieveFields,
loadFields: retrieveFields, loadFields: retrieveFields,
insertFields, insertFields,
sharedFields, sharedFields,
async getVectorStoreClient(context, _filter, embeddings, itemIndex) { async getVectorStoreClient(context, _filter, embeddings, itemIndex) {
try { try {
const { db } = await mongoClientAndDatabase(context); const client = await getMongoClient(context);
try { const db = await getDatabase(context, client);
const collectionName = context.getNodeParameter('mongoCollection', itemIndex, '', { const collectionName = getCollectionName(context, itemIndex);
extractValue: true, const mongoVectorIndexName = getVectorIndexName(context, itemIndex);
}) as string; const embeddingFieldName = getEmbeddingFieldName(context, itemIndex);
const metadataFieldName = getMetadataFieldName(context, itemIndex);
const mongoVectorIndexName = context.getNodeParameter('vectorIndexName', itemIndex, '', { const collection = db.collection(collectionName);
extractValue: true,
}) as string;
const embeddingFieldName = context.getNodeParameter('embedding', itemIndex, '', { // test index exists
extractValue: true, const indexes = await collection.listSearchIndexes().toArray();
}) as string;
const metadataFieldName = context.getNodeParameter('metadata_field', itemIndex, '', { const indexExists = indexes.some((index) => index.name === mongoVectorIndexName);
extractValue: true,
}) as string;
const collection = db.collection(collectionName); if (!indexExists) {
throw new NodeOperationError(context.getNode(), `Index ${mongoVectorIndexName} not found`, {
// test index exists
const indexes = await collection.listSearchIndexes().toArray();
const indexExists = indexes.some((index) => index.name === mongoVectorIndexName);
if (!indexExists) {
throw new NodeOperationError(
context.getNode(),
`Index ${mongoVectorIndexName} not found`,
{
itemIndex,
description: 'Please check that the index exists in your collection',
},
);
}
return new MongoDBAtlasVectorSearch(embeddings, {
collection,
indexName: mongoVectorIndexName, // Default index name
textKey: metadataFieldName, // Field containing raw text
embeddingKey: embeddingFieldName, // Field containing embeddings
});
} catch (error) {
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
itemIndex, itemIndex,
description: 'Please check your MongoDB Atlas connection details', description: 'Please check that the index exists in your collection',
}); });
} finally {
// Don't close the client here to maintain connection pooling
} }
return new MongoDBAtlasVectorSearch(embeddings, {
collection,
indexName: mongoVectorIndexName, // Default index name
textKey: metadataFieldName, // Field containing raw text
embeddingKey: embeddingFieldName, // Field containing embeddings
});
} catch (error) { } catch (error) {
if (error instanceof NodeOperationError) {
throw error;
}
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
itemIndex, itemIndex,
description: 'Please check your MongoDB Atlas connection details', description: 'Please check your MongoDB Atlas connection details',
@@ -228,43 +264,25 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
}, },
async populateVectorStore(context, embeddings, documents, itemIndex) { async populateVectorStore(context, embeddings, documents, itemIndex) {
try { try {
const { db } = await mongoClientAndDatabase(context); const client = await getMongoClient(context);
try { const db = await getDatabase(context, client);
const mongoCollectionName = context.getNodeParameter('mongoCollection', itemIndex, '', { const collectionName = getCollectionName(context, itemIndex);
extractValue: true, const mongoVectorIndexName = getVectorIndexName(context, itemIndex);
}) as string; const embeddingFieldName = getEmbeddingFieldName(context, itemIndex);
const embeddingFieldName = context.getNodeParameter('embedding', itemIndex, '', { const metadataFieldName = getMetadataFieldName(context, itemIndex);
extractValue: true,
}) as string;
const metadataFieldName = context.getNodeParameter('metadata_field', itemIndex, '', { // Check if collection exists
extractValue: true, const collections = await db.listCollections({ name: collectionName }).toArray();
}) as string; if (collections.length === 0) {
await db.createCollection(collectionName);
const mongoDBAtlasVectorIndex = context.getNodeParameter('vectorIndexName', itemIndex, '', {
extractValue: true,
}) as string;
// Check if collection exists
const collections = await db.listCollections({ name: mongoCollectionName }).toArray();
if (collections.length === 0) {
await db.createCollection(mongoCollectionName);
}
const collection = db.collection(mongoCollectionName);
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
collection,
indexName: mongoDBAtlasVectorIndex, // Default index name
textKey: metadataFieldName, // Field containing raw text
embeddingKey: embeddingFieldName, // Field containing embeddings
});
} catch (error) {
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
itemIndex,
description: 'Please check your MongoDB Atlas connection details',
});
} finally {
// Don't close the client here to maintain connection pooling
} }
const collection = db.collection(collectionName);
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
collection,
indexName: mongoVectorIndexName, // Default index name
textKey: metadataFieldName, // Field containing raw text
embeddingKey: embeddingFieldName, // Field containing embeddings
});
} catch (error) { } catch (error) {
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
itemIndex, itemIndex,