mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(MongoDB Vector Store Node): Allow pre and post filtering (#18506)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import { mock } from 'jest-mock-extended';
|
||||
import { MongoClient } from 'mongodb';
|
||||
import type { ILoadOptionsFunctions } from 'n8n-workflow';
|
||||
import type { ILoadOptionsFunctions, ISupplyDataFunctions } from 'n8n-workflow';
|
||||
|
||||
import {
|
||||
EMBEDDING_NAME,
|
||||
getCollectionName,
|
||||
getEmbeddingFieldName,
|
||||
getFilterValue,
|
||||
getMetadataFieldName,
|
||||
getMongoClient,
|
||||
getVectorIndexName,
|
||||
@@ -22,6 +23,8 @@ jest.mock('mongodb', () => ({
|
||||
describe('VectorStoreMongoDBAtlas', () => {
|
||||
const helpers = mock<ILoadOptionsFunctions['helpers']>();
|
||||
const executeFunctions = mock<ILoadOptionsFunctions>({ helpers });
|
||||
const dataHelpers = mock<ISupplyDataFunctions['helpers']>();
|
||||
const dataFunctions = mock<ISupplyDataFunctions>({ helpers: dataHelpers });
|
||||
|
||||
beforeEach(() => {
|
||||
jest.resetAllMocks();
|
||||
@@ -148,4 +151,88 @@ describe('VectorStoreMongoDBAtlas', () => {
|
||||
expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata');
|
||||
});
|
||||
});
|
||||
|
||||
describe('.getFilterValue', () => {
|
||||
describe('when no post filter is present', () => {
|
||||
beforeEach(() => {
|
||||
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||
return {};
|
||||
});
|
||||
});
|
||||
|
||||
it('returns undefined', () => {
|
||||
expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual(undefined);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when a post filter is present', () => {
|
||||
describe('when the JSON is valid', () => {
|
||||
beforeEach(() => {
|
||||
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||
return { postFilterPipeline: '[{ "$match": { "name": "value" }}]' };
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the post filter pipeline', () => {
|
||||
expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual([
|
||||
{ $match: { name: 'value' } },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when the JSON is invalid', () => {
|
||||
beforeEach(() => {
|
||||
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||
return { postFilterPipeline: '[{ "$match": { "name":}}]' };
|
||||
});
|
||||
});
|
||||
|
||||
it('throws an error', () => {
|
||||
expect(() => {
|
||||
getFilterValue('postFilterPipeline', dataFunctions, 0);
|
||||
}).toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('when no pre filter is present', () => {
|
||||
beforeEach(() => {
|
||||
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||
return {};
|
||||
});
|
||||
});
|
||||
|
||||
it('returns undefined', () => {
|
||||
expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual(undefined);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when a pre filter is present', () => {
|
||||
describe('when the JSON is valid', () => {
|
||||
beforeEach(() => {
|
||||
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||
return { preFilter: '{ "name": "value" }' };
|
||||
});
|
||||
});
|
||||
|
||||
it('returns the pre filter', () => {
|
||||
expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual({ name: 'value' });
|
||||
});
|
||||
});
|
||||
|
||||
describe('when the JSON is invalid', () => {
|
||||
beforeEach(() => {
|
||||
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||
return { preFilter: '"name":}}]' };
|
||||
});
|
||||
});
|
||||
|
||||
it('throws an error', () => {
|
||||
expect(() => {
|
||||
getFilterValue('preFilter', dataFunctions, 0);
|
||||
}).toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { MongoDBAtlasVectorSearch } from '@langchain/mongodb';
|
||||
import type { EmbeddingsInterface } from '@langchain/core/embeddings';
|
||||
import { MongoDBAtlasVectorSearch, type MongoDBAtlasVectorSearchLibArgs } from '@langchain/mongodb';
|
||||
import { MongoClient } from 'mongodb';
|
||||
import {
|
||||
type IDataObject,
|
||||
type ILoadOptionsFunctions,
|
||||
NodeOperationError,
|
||||
type INodeProperties,
|
||||
@@ -11,9 +13,20 @@ import { metadataFilterField } from '@utils/sharedFields';
|
||||
|
||||
import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode';
|
||||
|
||||
/**
|
||||
* Constants for the name of the credentials and Node parameters.
|
||||
*/
|
||||
export const MONGODB_CREDENTIALS = 'mongoDb';
|
||||
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
|
||||
export const VECTOR_INDEX_NAME = 'vectorIndexName';
|
||||
export const EMBEDDING_NAME = 'embedding';
|
||||
export const METADATA_FIELD_NAME = 'metadata_field';
|
||||
export const PRE_FILTER_NAME = 'preFilter';
|
||||
export const POST_FILTER_NAME = 'postFilterPipeline';
|
||||
|
||||
const mongoCollectionRLC: INodeProperties = {
|
||||
displayName: 'MongoDB Collection',
|
||||
name: 'mongoCollection',
|
||||
name: MONGODB_COLLECTION_NAME,
|
||||
type: 'resourceLocator',
|
||||
default: { mode: 'list', value: '' },
|
||||
required: true,
|
||||
@@ -37,7 +50,7 @@ const mongoCollectionRLC: INodeProperties = {
|
||||
|
||||
const vectorIndexName: INodeProperties = {
|
||||
displayName: 'Vector Index Name',
|
||||
name: 'vectorIndexName',
|
||||
name: VECTOR_INDEX_NAME,
|
||||
type: 'string',
|
||||
default: '',
|
||||
description: 'The name of the vector index',
|
||||
@@ -46,7 +59,7 @@ const vectorIndexName: INodeProperties = {
|
||||
|
||||
const embeddingField: INodeProperties = {
|
||||
displayName: 'Embedding',
|
||||
name: 'embedding',
|
||||
name: EMBEDDING_NAME,
|
||||
type: 'string',
|
||||
default: 'embedding',
|
||||
description: 'The field with the embedding array',
|
||||
@@ -55,7 +68,7 @@ const embeddingField: INodeProperties = {
|
||||
|
||||
const metadataField: INodeProperties = {
|
||||
displayName: 'Metadata Field',
|
||||
name: 'metadata_field',
|
||||
name: METADATA_FIELD_NAME,
|
||||
type: 'string',
|
||||
default: 'text',
|
||||
description: 'The text field of the raw data',
|
||||
@@ -77,6 +90,34 @@ const mongoNamespaceField: INodeProperties = {
|
||||
default: '',
|
||||
};
|
||||
|
||||
const preFilterField: INodeProperties = {
|
||||
displayName: 'Pre Filter',
|
||||
name: PRE_FILTER_NAME,
|
||||
type: 'json',
|
||||
typeOptions: {
|
||||
alwaysOpenEditWindow: true,
|
||||
},
|
||||
default: '',
|
||||
placeholder: '{ "key": "value" }',
|
||||
hint: 'This is a filter applied in the $vectorSearch stage <a href="https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter">here</a>',
|
||||
required: true,
|
||||
description: 'MongoDB Atlas Vector Search pre-filter',
|
||||
};
|
||||
|
||||
const postFilterField: INodeProperties = {
|
||||
displayName: 'Post Filter Pipeline',
|
||||
name: POST_FILTER_NAME,
|
||||
type: 'json',
|
||||
typeOptions: {
|
||||
alwaysOpenEditWindow: true,
|
||||
},
|
||||
default: '',
|
||||
placeholder: '[{ "$match": { "$gt": "1950-01-01" }, ... }]',
|
||||
hint: 'Learn more about aggregation pipeline <a href="https://docs.mongodb.com/manual/core/aggregation-pipeline/">here</a>',
|
||||
required: true,
|
||||
description: 'MongoDB aggregation pipeline in JSON format',
|
||||
};
|
||||
|
||||
const retrieveFields: INodeProperties[] = [
|
||||
{
|
||||
displayName: 'Options',
|
||||
@@ -84,7 +125,7 @@ const retrieveFields: INodeProperties[] = [
|
||||
type: 'collection',
|
||||
placeholder: 'Add Option',
|
||||
default: {},
|
||||
options: [mongoNamespaceField, metadataFilterField],
|
||||
options: [mongoNamespaceField, metadataFilterField, preFilterField, postFilterField],
|
||||
},
|
||||
];
|
||||
|
||||
@@ -113,15 +154,6 @@ export const mongoConfig = {
|
||||
connectionString: '',
|
||||
};
|
||||
|
||||
/**
|
||||
* Constants for the name of the credentials and Node parameters.
|
||||
*/
|
||||
export const MONGODB_CREDENTIALS = 'mongoDb';
|
||||
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
|
||||
export const VECTOR_INDEX_NAME = 'vectorIndexName';
|
||||
export const EMBEDDING_NAME = 'embedding';
|
||||
export const METADATA_FIELD_NAME = 'metadata_field';
|
||||
|
||||
/**
|
||||
* Type used for cleaner, more intentional typing.
|
||||
*/
|
||||
@@ -202,6 +234,57 @@ export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME);
|
||||
export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME);
|
||||
export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME);
|
||||
|
||||
export function getFilterValue<T>(
|
||||
name: string,
|
||||
context: IExecuteFunctions | ISupplyDataFunctions,
|
||||
itemIndex: number,
|
||||
): T | undefined {
|
||||
const options: IDataObject = context.getNodeParameter('options', itemIndex, {});
|
||||
|
||||
if (options[name]) {
|
||||
if (typeof options[name] === 'string') {
|
||||
try {
|
||||
return JSON.parse(options[name]);
|
||||
} catch (error) {
|
||||
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||
itemIndex,
|
||||
description: `Could not parse JSON for ${name}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
throw new NodeOperationError(context.getNode(), 'Error: No JSON string provided.', {
|
||||
itemIndex,
|
||||
description: `Could not parse JSON for ${name}`,
|
||||
});
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
class ExtendedMongoDBAtlasVectorSearch extends MongoDBAtlasVectorSearch {
|
||||
preFilter: IDataObject;
|
||||
postFilterPipeline?: IDataObject[];
|
||||
|
||||
constructor(
|
||||
embeddings: EmbeddingsInterface,
|
||||
options: MongoDBAtlasVectorSearchLibArgs,
|
||||
preFilter: IDataObject,
|
||||
postFilterPipeline?: IDataObject[],
|
||||
) {
|
||||
super(embeddings, options);
|
||||
this.preFilter = preFilter;
|
||||
this.postFilterPipeline = postFilterPipeline;
|
||||
}
|
||||
|
||||
async similaritySearchVectorWithScore(query: number[], k: number) {
|
||||
const mergedFilter: MongoDBAtlasVectorSearch['FilterType'] = {
|
||||
preFilter: this.preFilter,
|
||||
postFilterPipeline: this.postFilterPipeline,
|
||||
};
|
||||
return await super.similaritySearchVectorWithScore(query, k, mergedFilter);
|
||||
}
|
||||
}
|
||||
|
||||
export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
meta: {
|
||||
displayName: 'MongoDB Atlas Vector Store',
|
||||
@@ -245,13 +328,24 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
description: 'Please check that the index exists in your collection',
|
||||
});
|
||||
}
|
||||
const preFilter = getFilterValue<IDataObject>(PRE_FILTER_NAME, context, itemIndex);
|
||||
const postFilterPipeline = getFilterValue<IDataObject[]>(
|
||||
POST_FILTER_NAME,
|
||||
context,
|
||||
itemIndex,
|
||||
);
|
||||
|
||||
return new MongoDBAtlasVectorSearch(embeddings, {
|
||||
collection,
|
||||
indexName: mongoVectorIndexName, // Default index name
|
||||
textKey: metadataFieldName, // Field containing raw text
|
||||
embeddingKey: embeddingFieldName, // Field containing embeddings
|
||||
});
|
||||
return new ExtendedMongoDBAtlasVectorSearch(
|
||||
embeddings,
|
||||
{
|
||||
collection,
|
||||
indexName: mongoVectorIndexName, // Default index name
|
||||
textKey: metadataFieldName, // Field containing raw text
|
||||
embeddingKey: embeddingFieldName, // Field containing embeddings
|
||||
},
|
||||
preFilter ?? {},
|
||||
postFilterPipeline,
|
||||
);
|
||||
} catch (error) {
|
||||
if (error instanceof NodeOperationError) {
|
||||
throw error;
|
||||
@@ -277,7 +371,7 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||
await db.createCollection(collectionName);
|
||||
}
|
||||
const collection = db.collection(collectionName);
|
||||
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
|
||||
await ExtendedMongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
|
||||
collection,
|
||||
indexName: mongoVectorIndexName, // Default index name
|
||||
textKey: metadataFieldName, // Field containing raw text
|
||||
|
||||
Reference in New Issue
Block a user