From ee91aa00f116cec845a5355d051fd536964cb6e3 Mon Sep 17 00:00:00 2001 From: Durran Jordan Date: Fri, 5 Sep 2025 08:38:32 +0200 Subject: [PATCH] feat(MongoDB Vector Store Node): Allow pre and post filtering (#18506) Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> --- .../VectorStoreMongoDBAtlas.node.test.ts | 89 ++++++++++- .../VectorStoreMongoDBAtlas.node.ts | 138 +++++++++++++++--- 2 files changed, 204 insertions(+), 23 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts index 3596a10532..e9a06afdb7 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.test.ts @@ -1,11 +1,12 @@ import { mock } from 'jest-mock-extended'; import { MongoClient } from 'mongodb'; -import type { ILoadOptionsFunctions } from 'n8n-workflow'; +import type { ILoadOptionsFunctions, ISupplyDataFunctions } from 'n8n-workflow'; import { EMBEDDING_NAME, getCollectionName, getEmbeddingFieldName, + getFilterValue, getMetadataFieldName, getMongoClient, getVectorIndexName, @@ -22,6 +23,8 @@ jest.mock('mongodb', () => ({ describe('VectorStoreMongoDBAtlas', () => { const helpers = mock(); const executeFunctions = mock({ helpers }); + const dataHelpers = mock(); + const dataFunctions = mock({ helpers: dataHelpers }); beforeEach(() => { jest.resetAllMocks(); @@ -148,4 +151,88 @@ describe('VectorStoreMongoDBAtlas', () => { expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata'); }); }); + + describe('.getFilterValue', () => { + describe('when no post filter is present', () => { + beforeEach(() => { + dataFunctions.getNodeParameter.mockImplementation(() => { + return {}; + }); + }); + + it('returns undefined', () => { + expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual(undefined); + }); + }); + + describe('when a post filter is present', () => { + describe('when the JSON is valid', () => { + beforeEach(() => { + dataFunctions.getNodeParameter.mockImplementation(() => { + return { postFilterPipeline: '[{ "$match": { "name": "value" }}]' }; + }); + }); + + it('returns the post filter pipeline', () => { + expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual([ + { $match: { name: 'value' } }, + ]); + }); + }); + + describe('when the JSON is invalid', () => { + beforeEach(() => { + dataFunctions.getNodeParameter.mockImplementation(() => { + return { postFilterPipeline: '[{ "$match": { "name":}}]' }; + }); + }); + + it('throws an error', () => { + expect(() => { + getFilterValue('postFilterPipeline', dataFunctions, 0); + }).toThrow(); + }); + }); + }); + + describe('when no pre filter is present', () => { + beforeEach(() => { + dataFunctions.getNodeParameter.mockImplementation(() => { + return {}; + }); + }); + + it('returns undefined', () => { + expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual(undefined); + }); + }); + + describe('when a pre filter is present', () => { + describe('when the JSON is valid', () => { + beforeEach(() => { + dataFunctions.getNodeParameter.mockImplementation(() => { + return { preFilter: '{ "name": "value" }' }; + }); + }); + + it('returns the pre filter', () => { + expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual({ name: 'value' }); + }); + }); + + describe('when the JSON is invalid', () => { + beforeEach(() => { + dataFunctions.getNodeParameter.mockImplementation(() => { + return { preFilter: '"name":}}]' }; + }); + }); + + it('throws an error', () => { + expect(() => { + getFilterValue('preFilter', dataFunctions, 0); + }).toThrow(); + }); + }); + }); + }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts index 48d1e27cdc..7e2a9ddd92 100644 --- a/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/vector_store/VectorStoreMongoDBAtlas/VectorStoreMongoDBAtlas.node.ts @@ -1,6 +1,8 @@ -import { MongoDBAtlasVectorSearch } from '@langchain/mongodb'; +import type { EmbeddingsInterface } from '@langchain/core/embeddings'; +import { MongoDBAtlasVectorSearch, type MongoDBAtlasVectorSearchLibArgs } from '@langchain/mongodb'; import { MongoClient } from 'mongodb'; import { + type IDataObject, type ILoadOptionsFunctions, NodeOperationError, type INodeProperties, @@ -11,9 +13,20 @@ import { metadataFilterField } from '@utils/sharedFields'; import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode'; +/** + * Constants for the name of the credentials and Node parameters. + */ +export const MONGODB_CREDENTIALS = 'mongoDb'; +export const MONGODB_COLLECTION_NAME = 'mongoCollection'; +export const VECTOR_INDEX_NAME = 'vectorIndexName'; +export const EMBEDDING_NAME = 'embedding'; +export const METADATA_FIELD_NAME = 'metadata_field'; +export const PRE_FILTER_NAME = 'preFilter'; +export const POST_FILTER_NAME = 'postFilterPipeline'; + const mongoCollectionRLC: INodeProperties = { displayName: 'MongoDB Collection', - name: 'mongoCollection', + name: MONGODB_COLLECTION_NAME, type: 'resourceLocator', default: { mode: 'list', value: '' }, required: true, @@ -37,7 +50,7 @@ const mongoCollectionRLC: INodeProperties = { const vectorIndexName: INodeProperties = { displayName: 'Vector Index Name', - name: 'vectorIndexName', + name: VECTOR_INDEX_NAME, type: 'string', default: '', description: 'The name of the vector index', @@ -46,7 +59,7 @@ const vectorIndexName: INodeProperties = { const embeddingField: INodeProperties = { displayName: 'Embedding', - name: 'embedding', + name: EMBEDDING_NAME, type: 'string', default: 'embedding', description: 'The field with the embedding array', @@ -55,7 +68,7 @@ const embeddingField: INodeProperties = { const metadataField: INodeProperties = { displayName: 'Metadata Field', - name: 'metadata_field', + name: METADATA_FIELD_NAME, type: 'string', default: 'text', description: 'The text field of the raw data', @@ -77,6 +90,34 @@ const mongoNamespaceField: INodeProperties = { default: '', }; +const preFilterField: INodeProperties = { + displayName: 'Pre Filter', + name: PRE_FILTER_NAME, + type: 'json', + typeOptions: { + alwaysOpenEditWindow: true, + }, + default: '', + placeholder: '{ "key": "value" }', + hint: 'This is a filter applied in the $vectorSearch stage here', + required: true, + description: 'MongoDB Atlas Vector Search pre-filter', +}; + +const postFilterField: INodeProperties = { + displayName: 'Post Filter Pipeline', + name: POST_FILTER_NAME, + type: 'json', + typeOptions: { + alwaysOpenEditWindow: true, + }, + default: '', + placeholder: '[{ "$match": { "$gt": "1950-01-01" }, ... }]', + hint: 'Learn more about aggregation pipeline here', + required: true, + description: 'MongoDB aggregation pipeline in JSON format', +}; + const retrieveFields: INodeProperties[] = [ { displayName: 'Options', @@ -84,7 +125,7 @@ const retrieveFields: INodeProperties[] = [ type: 'collection', placeholder: 'Add Option', default: {}, - options: [mongoNamespaceField, metadataFilterField], + options: [mongoNamespaceField, metadataFilterField, preFilterField, postFilterField], }, ]; @@ -113,15 +154,6 @@ export const mongoConfig = { connectionString: '', }; -/** - * Constants for the name of the credentials and Node parameters. - */ -export const MONGODB_CREDENTIALS = 'mongoDb'; -export const MONGODB_COLLECTION_NAME = 'mongoCollection'; -export const VECTOR_INDEX_NAME = 'vectorIndexName'; -export const EMBEDDING_NAME = 'embedding'; -export const METADATA_FIELD_NAME = 'metadata_field'; - /** * Type used for cleaner, more intentional typing. */ @@ -202,6 +234,57 @@ export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME); export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME); export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME); +export function getFilterValue( + name: string, + context: IExecuteFunctions | ISupplyDataFunctions, + itemIndex: number, +): T | undefined { + const options: IDataObject = context.getNodeParameter('options', itemIndex, {}); + + if (options[name]) { + if (typeof options[name] === 'string') { + try { + return JSON.parse(options[name]); + } catch (error) { + throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, { + itemIndex, + description: `Could not parse JSON for ${name}`, + }); + } + } + throw new NodeOperationError(context.getNode(), 'Error: No JSON string provided.', { + itemIndex, + description: `Could not parse JSON for ${name}`, + }); + } + + return undefined; +} + +class ExtendedMongoDBAtlasVectorSearch extends MongoDBAtlasVectorSearch { + preFilter: IDataObject; + postFilterPipeline?: IDataObject[]; + + constructor( + embeddings: EmbeddingsInterface, + options: MongoDBAtlasVectorSearchLibArgs, + preFilter: IDataObject, + postFilterPipeline?: IDataObject[], + ) { + super(embeddings, options); + this.preFilter = preFilter; + this.postFilterPipeline = postFilterPipeline; + } + + async similaritySearchVectorWithScore(query: number[], k: number) { + const mergedFilter: MongoDBAtlasVectorSearch['FilterType'] = { + preFilter: this.preFilter, + postFilterPipeline: this.postFilterPipeline, + }; + return await super.similaritySearchVectorWithScore(query, k, mergedFilter); + } +} + export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ meta: { displayName: 'MongoDB Atlas Vector Store', @@ -245,13 +328,24 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ description: 'Please check that the index exists in your collection', }); } + const preFilter = getFilterValue(PRE_FILTER_NAME, context, itemIndex); + const postFilterPipeline = getFilterValue( + POST_FILTER_NAME, + context, + itemIndex, + ); - return new MongoDBAtlasVectorSearch(embeddings, { - collection, - indexName: mongoVectorIndexName, // Default index name - textKey: metadataFieldName, // Field containing raw text - embeddingKey: embeddingFieldName, // Field containing embeddings - }); + return new ExtendedMongoDBAtlasVectorSearch( + embeddings, + { + collection, + indexName: mongoVectorIndexName, // Default index name + textKey: metadataFieldName, // Field containing raw text + embeddingKey: embeddingFieldName, // Field containing embeddings + }, + preFilter ?? {}, + postFilterPipeline, + ); } catch (error) { if (error instanceof NodeOperationError) { throw error; @@ -277,7 +371,7 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ await db.createCollection(collectionName); } const collection = db.collection(collectionName); - await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, { + await ExtendedMongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, { collection, indexName: mongoVectorIndexName, // Default index name textKey: metadataFieldName, // Field containing raw text