mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
refactor(MongoDB Vector Store Node): Refactor mongodb vector store node (#16239)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,12 +1,33 @@
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import { MongoClient } from 'mongodb';
|
||||
import type { ILoadOptionsFunctions } from 'n8n-workflow';
|
||||
|
||||
import { getMongoClient, mongoConfig } from './VectorStoreMongoDBAtlas.node';
|
||||
import {
|
||||
EMBEDDING_NAME,
|
||||
getCollectionName,
|
||||
getEmbeddingFieldName,
|
||||
getMetadataFieldName,
|
||||
getMongoClient,
|
||||
getVectorIndexName,
|
||||
mongoConfig,
|
||||
METADATA_FIELD_NAME,
|
||||
MONGODB_COLLECTION_NAME,
|
||||
VECTOR_INDEX_NAME,
|
||||
} from './VectorStoreMongoDBAtlas.node';
|
||||
|
||||
jest.mock('mongodb', () => ({
|
||||
MongoClient: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('VectorStoreMongoDBAtlas -> getMongoClient', () => {
|
||||
describe('VectorStoreMongoDBAtlas', () => {
|
||||
const helpers = mock<ILoadOptionsFunctions['helpers']>();
|
||||
const executeFunctions = mock<ILoadOptionsFunctions>({ helpers });
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
});
|
||||
|
||||
describe('.getMongoClient', () => {
|
||||
const mockContext = {
|
||||
getCredentials: jest.fn(),
|
||||
};
|
||||
@@ -21,7 +42,6 @@ describe('VectorStoreMongoDBAtlas -> getMongoClient', () => {
|
||||
const MockMongoClient = MongoClient as jest.MockedClass<typeof MongoClient>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
mongoConfig.client = null;
|
||||
mongoConfig.connectionString = '';
|
||||
});
|
||||
@@ -76,3 +96,56 @@ describe('VectorStoreMongoDBAtlas -> getMongoClient', () => {
|
||||
expect(client2).toBe(mockClient2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('.getCollectionName', () => {
|
||||
beforeEach(() => {
|
||||
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
|
||||
if (paramName === MONGODB_COLLECTION_NAME) return 'testCollection';
|
||||
return '';
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the collection name from the context', () => {
|
||||
expect(getCollectionName(executeFunctions, 0)).toEqual('testCollection');
|
||||
});
|
||||
});
|
||||
|
||||
describe('.getVectorIndexName', () => {
|
||||
beforeEach(() => {
|
||||
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
|
||||
if (paramName === VECTOR_INDEX_NAME) return 'testIndex';
|
||||
return '';
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the index name from the context', () => {
|
||||
expect(getVectorIndexName(executeFunctions, 0)).toEqual('testIndex');
|
||||
});
|
||||
});
|
||||
|
||||
describe('.getEmbeddingFieldName', () => {
|
||||
beforeEach(() => {
|
||||
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
|
||||
if (paramName === EMBEDDING_NAME) return 'testEmbedding';
|
||||
return '';
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the embedding name from the context', () => {
|
||||
expect(getEmbeddingFieldName(executeFunctions, 0)).toEqual('testEmbedding');
|
||||
});
|
||||
});
|
||||
|
||||
describe('.getMetadataFieldName', () => {
|
||||
beforeEach(() => {
|
||||
executeFunctions.getNodeParameter.mockImplementation((paramName: string) => {
|
||||
if (paramName === METADATA_FIELD_NAME) return 'testMetadata';
|
||||
return '';
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the metadata field name from the context', () => {
|
||||
expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import { MongoDBAtlasVectorSearch } from '@langchain/mongodb';
|
||||
import { MongoClient } from 'mongodb';
|
||||
import { type ILoadOptionsFunctions, NodeOperationError, type INodeProperties } from 'n8n-workflow';
|
||||
|
||||
import {
|
||||
type ILoadOptionsFunctions,
|
||||
NodeOperationError,
|
||||
type INodeProperties,
|
||||
type IExecuteFunctions,
|
||||
type ISupplyDataFunctions,
|
||||
} from 'n8n-workflow';
|
||||
import { metadataFilterField } from '@utils/sharedFields';
|
||||
|
||||
import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode';
|
||||
@@ -108,8 +113,27 @@ export const mongoConfig = {
|
||||
connectionString: '',
|
||||
};
|
||||
|
||||
/**
|
||||
* Constants for the name of the credentials and Node parameters.
|
||||
*/
|
||||
export const MONGODB_CREDENTIALS = 'mongoDb';
|
||||
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
|
||||
export const VECTOR_INDEX_NAME = 'vectorIndexName';
|
||||
export const EMBEDDING_NAME = 'embedding';
|
||||
export const METADATA_FIELD_NAME = 'metadata_field';
|
||||
|
||||
/**
|
||||
* Type used for cleaner, more intentional typing.
|
||||
*/
|
||||
type IFunctionsContext = IExecuteFunctions | ISupplyDataFunctions | ILoadOptionsFunctions;
|
||||
|
||||
/**
|
||||
* Get the mongo client.
|
||||
* @param context - The context.
|
||||
* @returns the MongoClient for the node.
|
||||
*/
|
||||
export async function getMongoClient(context: any) {
|
||||
const credentials = await context.getCredentials('mongoDb');
|
||||
const credentials = await context.getCredentials(MONGODB_CREDENTIALS);
|
||||
const connectionString = credentials.connectionString as string;
|
||||
if (!mongoConfig.client || mongoConfig.connectionString !== connectionString) {
|
||||
if (mongoConfig.client) {
|
||||
@@ -125,16 +149,25 @@ export async function getMongoClient(context: any) {
|
||||
return mongoConfig.client;
|
||||
}
|
||||
|
||||
async function mongoClientAndDatabase(context: any) {
|
||||
const client = await getMongoClient(context);
|
||||
const credentials = await context.getCredentials('mongoDb');
|
||||
const db = client.db(credentials.database as string);
|
||||
return { client, db };
|
||||
/**
|
||||
* Get the database object from the MongoClient by the configured name.
|
||||
* @param context - The context.
|
||||
* @returns the Db object.
|
||||
*/
|
||||
export async function getDatabase(context: IFunctionsContext, client: MongoClient) {
|
||||
const credentials = await context.getCredentials(MONGODB_CREDENTIALS);
|
||||
return client.db(credentials.database as string);
|
||||
}
|
||||
|
||||
async function mongoCollectionSearch(this: ILoadOptionsFunctions) {
|
||||
const { db } = await mongoClientAndDatabase(this);
|
||||
/**
|
||||
* Get all the collection in the database.
|
||||
* @param this The load options context.
|
||||
* @returns The list of collections.
|
||||
*/
|
||||
export async function getCollections(this: ILoadOptionsFunctions) {
|
||||
try {
|
||||
const client = await getMongoClient(this);
|
||||
const db = await getDatabase(this, client);
|
||||
const collections = await db.listCollections().toArray();
|
||||
const results = collections.map((collection) => ({
|
||||
name: collection.name,
|
||||
@@ -146,6 +179,29 @@ async function mongoCollectionSearch(this: ILoadOptionsFunctions) {
|
||||
throw new NodeOperationError(this.getNode(), `Error: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a parameter from the context.
|
||||
* @param key - The key of the parameter.
|
||||
* @param context - The context.
|
||||
* @param itemIndex - The index.
|
||||
* @returns The value.
|
||||
*/
|
||||
export function getParameter(key: string, context: IFunctionsContext, itemIndex: number): string {
|
||||
const value = context.getNodeParameter(key, itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
if (typeof value !== 'string') {
|
||||
throw new NodeOperationError(context.getNode(), `Parameter ${key} must be a string`);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
export const getCollectionName = getParameter.bind(null, MONGODB_COLLECTION_NAME);
|
||||
export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME);
|
||||
export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME);
|
||||
export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME);
|
||||
|
||||
export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
meta: {
|
||||
displayName: 'MongoDB Atlas Vector Store',
|
||||
@@ -162,30 +218,19 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
],
|
||||
operationModes: ['load', 'insert', 'retrieve', 'update', 'retrieve-as-tool'],
|
||||
},
|
||||
methods: { listSearch: { mongoCollectionSearch } },
|
||||
methods: { listSearch: { mongoCollectionSearch: getCollections } },
|
||||
retrieveFields,
|
||||
loadFields: retrieveFields,
|
||||
insertFields,
|
||||
sharedFields,
|
||||
async getVectorStoreClient(context, _filter, embeddings, itemIndex) {
|
||||
try {
|
||||
const { db } = await mongoClientAndDatabase(context);
|
||||
try {
|
||||
const collectionName = context.getNodeParameter('mongoCollection', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
|
||||
const mongoVectorIndexName = context.getNodeParameter('vectorIndexName', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
|
||||
const embeddingFieldName = context.getNodeParameter('embedding', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
|
||||
const metadataFieldName = context.getNodeParameter('metadata_field', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
const client = await getMongoClient(context);
|
||||
const db = await getDatabase(context, client);
|
||||
const collectionName = getCollectionName(context, itemIndex);
|
||||
const mongoVectorIndexName = getVectorIndexName(context, itemIndex);
|
||||
const embeddingFieldName = getEmbeddingFieldName(context, itemIndex);
|
||||
const metadataFieldName = getMetadataFieldName(context, itemIndex);
|
||||
|
||||
const collection = db.collection(collectionName);
|
||||
|
||||
@@ -195,14 +240,10 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
const indexExists = indexes.some((index) => index.name === mongoVectorIndexName);
|
||||
|
||||
if (!indexExists) {
|
||||
throw new NodeOperationError(
|
||||
context.getNode(),
|
||||
`Index ${mongoVectorIndexName} not found`,
|
||||
{
|
||||
throw new NodeOperationError(context.getNode(), `Index ${mongoVectorIndexName} not found`, {
|
||||
itemIndex,
|
||||
description: 'Please check that the index exists in your collection',
|
||||
},
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return new MongoDBAtlasVectorSearch(embeddings, {
|
||||
@@ -211,65 +252,42 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
textKey: metadataFieldName, // Field containing raw text
|
||||
embeddingKey: embeddingFieldName, // Field containing embeddings
|
||||
});
|
||||
} catch (error) {
|
||||
if (error instanceof NodeOperationError) {
|
||||
throw error;
|
||||
}
|
||||
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||
itemIndex,
|
||||
description: 'Please check your MongoDB Atlas connection details',
|
||||
});
|
||||
}
|
||||
},
|
||||
async populateVectorStore(context, embeddings, documents, itemIndex) {
|
||||
try {
|
||||
const client = await getMongoClient(context);
|
||||
const db = await getDatabase(context, client);
|
||||
const collectionName = getCollectionName(context, itemIndex);
|
||||
const mongoVectorIndexName = getVectorIndexName(context, itemIndex);
|
||||
const embeddingFieldName = getEmbeddingFieldName(context, itemIndex);
|
||||
const metadataFieldName = getMetadataFieldName(context, itemIndex);
|
||||
|
||||
// Check if collection exists
|
||||
const collections = await db.listCollections({ name: collectionName }).toArray();
|
||||
if (collections.length === 0) {
|
||||
await db.createCollection(collectionName);
|
||||
}
|
||||
const collection = db.collection(collectionName);
|
||||
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
|
||||
collection,
|
||||
indexName: mongoVectorIndexName, // Default index name
|
||||
textKey: metadataFieldName, // Field containing raw text
|
||||
embeddingKey: embeddingFieldName, // Field containing embeddings
|
||||
});
|
||||
} catch (error) {
|
||||
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||
itemIndex,
|
||||
description: 'Please check your MongoDB Atlas connection details',
|
||||
});
|
||||
} finally {
|
||||
// Don't close the client here to maintain connection pooling
|
||||
}
|
||||
} catch (error) {
|
||||
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||
itemIndex,
|
||||
description: 'Please check your MongoDB Atlas connection details',
|
||||
});
|
||||
}
|
||||
},
|
||||
async populateVectorStore(context, embeddings, documents, itemIndex) {
|
||||
try {
|
||||
const { db } = await mongoClientAndDatabase(context);
|
||||
try {
|
||||
const mongoCollectionName = context.getNodeParameter('mongoCollection', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
const embeddingFieldName = context.getNodeParameter('embedding', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
|
||||
const metadataFieldName = context.getNodeParameter('metadata_field', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
|
||||
const mongoDBAtlasVectorIndex = context.getNodeParameter('vectorIndexName', itemIndex, '', {
|
||||
extractValue: true,
|
||||
}) as string;
|
||||
|
||||
// Check if collection exists
|
||||
const collections = await db.listCollections({ name: mongoCollectionName }).toArray();
|
||||
if (collections.length === 0) {
|
||||
await db.createCollection(mongoCollectionName);
|
||||
}
|
||||
const collection = db.collection(mongoCollectionName);
|
||||
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
|
||||
collection,
|
||||
indexName: mongoDBAtlasVectorIndex, // Default index name
|
||||
textKey: metadataFieldName, // Field containing raw text
|
||||
embeddingKey: embeddingFieldName, // Field containing embeddings
|
||||
});
|
||||
} catch (error) {
|
||||
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||
itemIndex,
|
||||
description: 'Please check your MongoDB Atlas connection details',
|
||||
});
|
||||
} finally {
|
||||
// Don't close the client here to maintain connection pooling
|
||||
}
|
||||
} catch (error) {
|
||||
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||
itemIndex,
|
||||
description: 'Please check your MongoDB Atlas connection details',
|
||||
});
|
||||
}
|
||||
},
|
||||
}) {}
|
||||
|
||||
Reference in New Issue
Block a user