feat(MongoDB Vector Store Node): Allow pre and post filtering (#18506)

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
This commit is contained in:
Durran Jordan
2025-09-05 08:38:32 +02:00
committed by GitHub
parent 6456b7c07d
commit ee91aa00f1
2 changed files with 204 additions and 23 deletions

View File

@@ -1,11 +1,12 @@
import { mock } from 'jest-mock-extended'; import { mock } from 'jest-mock-extended';
import { MongoClient } from 'mongodb'; import { MongoClient } from 'mongodb';
import type { ILoadOptionsFunctions } from 'n8n-workflow'; import type { ILoadOptionsFunctions, ISupplyDataFunctions } from 'n8n-workflow';
import { import {
EMBEDDING_NAME, EMBEDDING_NAME,
getCollectionName, getCollectionName,
getEmbeddingFieldName, getEmbeddingFieldName,
getFilterValue,
getMetadataFieldName, getMetadataFieldName,
getMongoClient, getMongoClient,
getVectorIndexName, getVectorIndexName,
@@ -22,6 +23,8 @@ jest.mock('mongodb', () => ({
describe('VectorStoreMongoDBAtlas', () => { describe('VectorStoreMongoDBAtlas', () => {
const helpers = mock<ILoadOptionsFunctions['helpers']>(); const helpers = mock<ILoadOptionsFunctions['helpers']>();
const executeFunctions = mock<ILoadOptionsFunctions>({ helpers }); const executeFunctions = mock<ILoadOptionsFunctions>({ helpers });
const dataHelpers = mock<ISupplyDataFunctions['helpers']>();
const dataFunctions = mock<ISupplyDataFunctions>({ helpers: dataHelpers });
beforeEach(() => { beforeEach(() => {
jest.resetAllMocks(); jest.resetAllMocks();
@@ -148,4 +151,88 @@ describe('VectorStoreMongoDBAtlas', () => {
expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata'); expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata');
}); });
}); });
describe('.getFilterValue', () => {
describe('when no post filter is present', () => {
beforeEach(() => {
dataFunctions.getNodeParameter.mockImplementation(() => {
return {};
});
});
it('returns undefined', () => {
expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual(undefined);
});
});
describe('when a post filter is present', () => {
describe('when the JSON is valid', () => {
beforeEach(() => {
dataFunctions.getNodeParameter.mockImplementation(() => {
return { postFilterPipeline: '[{ "$match": { "name": "value" }}]' };
});
});
it('returns the post filter pipeline', () => {
expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual([
{ $match: { name: 'value' } },
]);
});
});
describe('when the JSON is invalid', () => {
beforeEach(() => {
dataFunctions.getNodeParameter.mockImplementation(() => {
return { postFilterPipeline: '[{ "$match": { "name":}}]' };
});
});
it('throws an error', () => {
expect(() => {
getFilterValue('postFilterPipeline', dataFunctions, 0);
}).toThrow();
});
});
});
describe('when no pre filter is present', () => {
beforeEach(() => {
dataFunctions.getNodeParameter.mockImplementation(() => {
return {};
});
});
it('returns undefined', () => {
expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual(undefined);
});
});
describe('when a pre filter is present', () => {
describe('when the JSON is valid', () => {
beforeEach(() => {
dataFunctions.getNodeParameter.mockImplementation(() => {
return { preFilter: '{ "name": "value" }' };
});
});
it('returns the pre filter', () => {
expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual({ name: 'value' });
});
});
describe('when the JSON is invalid', () => {
beforeEach(() => {
dataFunctions.getNodeParameter.mockImplementation(() => {
return { preFilter: '"name":}}]' };
});
});
it('throws an error', () => {
expect(() => {
getFilterValue('preFilter', dataFunctions, 0);
}).toThrow();
});
});
});
});
}); });

View File

@@ -1,6 +1,8 @@
import { MongoDBAtlasVectorSearch } from '@langchain/mongodb'; import type { EmbeddingsInterface } from '@langchain/core/embeddings';
import { MongoDBAtlasVectorSearch, type MongoDBAtlasVectorSearchLibArgs } from '@langchain/mongodb';
import { MongoClient } from 'mongodb'; import { MongoClient } from 'mongodb';
import { import {
type IDataObject,
type ILoadOptionsFunctions, type ILoadOptionsFunctions,
NodeOperationError, NodeOperationError,
type INodeProperties, type INodeProperties,
@@ -11,9 +13,20 @@ import { metadataFilterField } from '@utils/sharedFields';
import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode'; import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode';
/**
* Constants for the name of the credentials and Node parameters.
*/
export const MONGODB_CREDENTIALS = 'mongoDb';
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
export const VECTOR_INDEX_NAME = 'vectorIndexName';
export const EMBEDDING_NAME = 'embedding';
export const METADATA_FIELD_NAME = 'metadata_field';
export const PRE_FILTER_NAME = 'preFilter';
export const POST_FILTER_NAME = 'postFilterPipeline';
const mongoCollectionRLC: INodeProperties = { const mongoCollectionRLC: INodeProperties = {
displayName: 'MongoDB Collection', displayName: 'MongoDB Collection',
name: 'mongoCollection', name: MONGODB_COLLECTION_NAME,
type: 'resourceLocator', type: 'resourceLocator',
default: { mode: 'list', value: '' }, default: { mode: 'list', value: '' },
required: true, required: true,
@@ -37,7 +50,7 @@ const mongoCollectionRLC: INodeProperties = {
const vectorIndexName: INodeProperties = { const vectorIndexName: INodeProperties = {
displayName: 'Vector Index Name', displayName: 'Vector Index Name',
name: 'vectorIndexName', name: VECTOR_INDEX_NAME,
type: 'string', type: 'string',
default: '', default: '',
description: 'The name of the vector index', description: 'The name of the vector index',
@@ -46,7 +59,7 @@ const vectorIndexName: INodeProperties = {
const embeddingField: INodeProperties = { const embeddingField: INodeProperties = {
displayName: 'Embedding', displayName: 'Embedding',
name: 'embedding', name: EMBEDDING_NAME,
type: 'string', type: 'string',
default: 'embedding', default: 'embedding',
description: 'The field with the embedding array', description: 'The field with the embedding array',
@@ -55,7 +68,7 @@ const embeddingField: INodeProperties = {
const metadataField: INodeProperties = { const metadataField: INodeProperties = {
displayName: 'Metadata Field', displayName: 'Metadata Field',
name: 'metadata_field', name: METADATA_FIELD_NAME,
type: 'string', type: 'string',
default: 'text', default: 'text',
description: 'The text field of the raw data', description: 'The text field of the raw data',
@@ -77,6 +90,34 @@ const mongoNamespaceField: INodeProperties = {
default: '', default: '',
}; };
const preFilterField: INodeProperties = {
displayName: 'Pre Filter',
name: PRE_FILTER_NAME,
type: 'json',
typeOptions: {
alwaysOpenEditWindow: true,
},
default: '',
placeholder: '{ "key": "value" }',
hint: 'This is a filter applied in the $vectorSearch stage <a href="https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter">here</a>',
required: true,
description: 'MongoDB Atlas Vector Search pre-filter',
};
const postFilterField: INodeProperties = {
displayName: 'Post Filter Pipeline',
name: POST_FILTER_NAME,
type: 'json',
typeOptions: {
alwaysOpenEditWindow: true,
},
default: '',
placeholder: '[{ "$match": { "$gt": "1950-01-01" }, ... }]',
hint: 'Learn more about aggregation pipeline <a href="https://docs.mongodb.com/manual/core/aggregation-pipeline/">here</a>',
required: true,
description: 'MongoDB aggregation pipeline in JSON format',
};
const retrieveFields: INodeProperties[] = [ const retrieveFields: INodeProperties[] = [
{ {
displayName: 'Options', displayName: 'Options',
@@ -84,7 +125,7 @@ const retrieveFields: INodeProperties[] = [
type: 'collection', type: 'collection',
placeholder: 'Add Option', placeholder: 'Add Option',
default: {}, default: {},
options: [mongoNamespaceField, metadataFilterField], options: [mongoNamespaceField, metadataFilterField, preFilterField, postFilterField],
}, },
]; ];
@@ -113,15 +154,6 @@ export const mongoConfig = {
connectionString: '', connectionString: '',
}; };
/**
* Constants for the name of the credentials and Node parameters.
*/
export const MONGODB_CREDENTIALS = 'mongoDb';
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
export const VECTOR_INDEX_NAME = 'vectorIndexName';
export const EMBEDDING_NAME = 'embedding';
export const METADATA_FIELD_NAME = 'metadata_field';
/** /**
* Type used for cleaner, more intentional typing. * Type used for cleaner, more intentional typing.
*/ */
@@ -202,6 +234,57 @@ export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME);
export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME); export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME);
export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME); export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME);
export function getFilterValue<T>(
name: string,
context: IExecuteFunctions | ISupplyDataFunctions,
itemIndex: number,
): T | undefined {
const options: IDataObject = context.getNodeParameter('options', itemIndex, {});
if (options[name]) {
if (typeof options[name] === 'string') {
try {
return JSON.parse(options[name]);
} catch (error) {
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
itemIndex,
description: `Could not parse JSON for ${name}`,
});
}
}
throw new NodeOperationError(context.getNode(), 'Error: No JSON string provided.', {
itemIndex,
description: `Could not parse JSON for ${name}`,
});
}
return undefined;
}
class ExtendedMongoDBAtlasVectorSearch extends MongoDBAtlasVectorSearch {
preFilter: IDataObject;
postFilterPipeline?: IDataObject[];
constructor(
embeddings: EmbeddingsInterface,
options: MongoDBAtlasVectorSearchLibArgs,
preFilter: IDataObject,
postFilterPipeline?: IDataObject[],
) {
super(embeddings, options);
this.preFilter = preFilter;
this.postFilterPipeline = postFilterPipeline;
}
async similaritySearchVectorWithScore(query: number[], k: number) {
const mergedFilter: MongoDBAtlasVectorSearch['FilterType'] = {
preFilter: this.preFilter,
postFilterPipeline: this.postFilterPipeline,
};
return await super.similaritySearchVectorWithScore(query, k, mergedFilter);
}
}
export class VectorStoreMongoDBAtlas extends createVectorStoreNode({ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
meta: { meta: {
displayName: 'MongoDB Atlas Vector Store', displayName: 'MongoDB Atlas Vector Store',
@@ -245,13 +328,24 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
description: 'Please check that the index exists in your collection', description: 'Please check that the index exists in your collection',
}); });
} }
const preFilter = getFilterValue<IDataObject>(PRE_FILTER_NAME, context, itemIndex);
const postFilterPipeline = getFilterValue<IDataObject[]>(
POST_FILTER_NAME,
context,
itemIndex,
);
return new MongoDBAtlasVectorSearch(embeddings, { return new ExtendedMongoDBAtlasVectorSearch(
collection, embeddings,
indexName: mongoVectorIndexName, // Default index name {
textKey: metadataFieldName, // Field containing raw text collection,
embeddingKey: embeddingFieldName, // Field containing embeddings indexName: mongoVectorIndexName, // Default index name
}); textKey: metadataFieldName, // Field containing raw text
embeddingKey: embeddingFieldName, // Field containing embeddings
},
preFilter ?? {},
postFilterPipeline,
);
} catch (error) { } catch (error) {
if (error instanceof NodeOperationError) { if (error instanceof NodeOperationError) {
throw error; throw error;
@@ -277,7 +371,7 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
await db.createCollection(collectionName); await db.createCollection(collectionName);
} }
const collection = db.collection(collectionName); const collection = db.collection(collectionName);
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, { await ExtendedMongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
collection, collection,
indexName: mongoVectorIndexName, // Default index name indexName: mongoVectorIndexName, // Default index name
textKey: metadataFieldName, // Field containing raw text textKey: metadataFieldName, // Field containing raw text