mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 17:46:45 +00:00
feat(MongoDB Vector Store Node): Allow pre and post filtering (#18506)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
import { mock } from 'jest-mock-extended';
|
import { mock } from 'jest-mock-extended';
|
||||||
import { MongoClient } from 'mongodb';
|
import { MongoClient } from 'mongodb';
|
||||||
import type { ILoadOptionsFunctions } from 'n8n-workflow';
|
import type { ILoadOptionsFunctions, ISupplyDataFunctions } from 'n8n-workflow';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
EMBEDDING_NAME,
|
EMBEDDING_NAME,
|
||||||
getCollectionName,
|
getCollectionName,
|
||||||
getEmbeddingFieldName,
|
getEmbeddingFieldName,
|
||||||
|
getFilterValue,
|
||||||
getMetadataFieldName,
|
getMetadataFieldName,
|
||||||
getMongoClient,
|
getMongoClient,
|
||||||
getVectorIndexName,
|
getVectorIndexName,
|
||||||
@@ -22,6 +23,8 @@ jest.mock('mongodb', () => ({
|
|||||||
describe('VectorStoreMongoDBAtlas', () => {
|
describe('VectorStoreMongoDBAtlas', () => {
|
||||||
const helpers = mock<ILoadOptionsFunctions['helpers']>();
|
const helpers = mock<ILoadOptionsFunctions['helpers']>();
|
||||||
const executeFunctions = mock<ILoadOptionsFunctions>({ helpers });
|
const executeFunctions = mock<ILoadOptionsFunctions>({ helpers });
|
||||||
|
const dataHelpers = mock<ISupplyDataFunctions['helpers']>();
|
||||||
|
const dataFunctions = mock<ISupplyDataFunctions>({ helpers: dataHelpers });
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.resetAllMocks();
|
jest.resetAllMocks();
|
||||||
@@ -148,4 +151,88 @@ describe('VectorStoreMongoDBAtlas', () => {
|
|||||||
expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata');
|
expect(getMetadataFieldName(executeFunctions, 0)).toEqual('testMetadata');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('.getFilterValue', () => {
|
||||||
|
describe('when no post filter is present', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||||
|
return {};
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns undefined', () => {
|
||||||
|
expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual(undefined);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('when a post filter is present', () => {
|
||||||
|
describe('when the JSON is valid', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||||
|
return { postFilterPipeline: '[{ "$match": { "name": "value" }}]' };
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns the post filter pipeline', () => {
|
||||||
|
expect(getFilterValue('postFilterPipeline', dataFunctions, 0)).toEqual([
|
||||||
|
{ $match: { name: 'value' } },
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('when the JSON is invalid', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||||
|
return { postFilterPipeline: '[{ "$match": { "name":}}]' };
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws an error', () => {
|
||||||
|
expect(() => {
|
||||||
|
getFilterValue('postFilterPipeline', dataFunctions, 0);
|
||||||
|
}).toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('when no pre filter is present', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||||
|
return {};
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns undefined', () => {
|
||||||
|
expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual(undefined);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('when a pre filter is present', () => {
|
||||||
|
describe('when the JSON is valid', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||||
|
return { preFilter: '{ "name": "value" }' };
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns the pre filter', () => {
|
||||||
|
expect(getFilterValue('preFilter', dataFunctions, 0)).toEqual({ name: 'value' });
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('when the JSON is invalid', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
dataFunctions.getNodeParameter.mockImplementation(() => {
|
||||||
|
return { preFilter: '"name":}}]' };
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws an error', () => {
|
||||||
|
expect(() => {
|
||||||
|
getFilterValue('preFilter', dataFunctions, 0);
|
||||||
|
}).toThrow();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import { MongoDBAtlasVectorSearch } from '@langchain/mongodb';
|
import type { EmbeddingsInterface } from '@langchain/core/embeddings';
|
||||||
|
import { MongoDBAtlasVectorSearch, type MongoDBAtlasVectorSearchLibArgs } from '@langchain/mongodb';
|
||||||
import { MongoClient } from 'mongodb';
|
import { MongoClient } from 'mongodb';
|
||||||
import {
|
import {
|
||||||
|
type IDataObject,
|
||||||
type ILoadOptionsFunctions,
|
type ILoadOptionsFunctions,
|
||||||
NodeOperationError,
|
NodeOperationError,
|
||||||
type INodeProperties,
|
type INodeProperties,
|
||||||
@@ -11,9 +13,20 @@ import { metadataFilterField } from '@utils/sharedFields';
|
|||||||
|
|
||||||
import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode';
|
import { createVectorStoreNode } from '../shared/createVectorStoreNode/createVectorStoreNode';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constants for the name of the credentials and Node parameters.
|
||||||
|
*/
|
||||||
|
export const MONGODB_CREDENTIALS = 'mongoDb';
|
||||||
|
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
|
||||||
|
export const VECTOR_INDEX_NAME = 'vectorIndexName';
|
||||||
|
export const EMBEDDING_NAME = 'embedding';
|
||||||
|
export const METADATA_FIELD_NAME = 'metadata_field';
|
||||||
|
export const PRE_FILTER_NAME = 'preFilter';
|
||||||
|
export const POST_FILTER_NAME = 'postFilterPipeline';
|
||||||
|
|
||||||
const mongoCollectionRLC: INodeProperties = {
|
const mongoCollectionRLC: INodeProperties = {
|
||||||
displayName: 'MongoDB Collection',
|
displayName: 'MongoDB Collection',
|
||||||
name: 'mongoCollection',
|
name: MONGODB_COLLECTION_NAME,
|
||||||
type: 'resourceLocator',
|
type: 'resourceLocator',
|
||||||
default: { mode: 'list', value: '' },
|
default: { mode: 'list', value: '' },
|
||||||
required: true,
|
required: true,
|
||||||
@@ -37,7 +50,7 @@ const mongoCollectionRLC: INodeProperties = {
|
|||||||
|
|
||||||
const vectorIndexName: INodeProperties = {
|
const vectorIndexName: INodeProperties = {
|
||||||
displayName: 'Vector Index Name',
|
displayName: 'Vector Index Name',
|
||||||
name: 'vectorIndexName',
|
name: VECTOR_INDEX_NAME,
|
||||||
type: 'string',
|
type: 'string',
|
||||||
default: '',
|
default: '',
|
||||||
description: 'The name of the vector index',
|
description: 'The name of the vector index',
|
||||||
@@ -46,7 +59,7 @@ const vectorIndexName: INodeProperties = {
|
|||||||
|
|
||||||
const embeddingField: INodeProperties = {
|
const embeddingField: INodeProperties = {
|
||||||
displayName: 'Embedding',
|
displayName: 'Embedding',
|
||||||
name: 'embedding',
|
name: EMBEDDING_NAME,
|
||||||
type: 'string',
|
type: 'string',
|
||||||
default: 'embedding',
|
default: 'embedding',
|
||||||
description: 'The field with the embedding array',
|
description: 'The field with the embedding array',
|
||||||
@@ -55,7 +68,7 @@ const embeddingField: INodeProperties = {
|
|||||||
|
|
||||||
const metadataField: INodeProperties = {
|
const metadataField: INodeProperties = {
|
||||||
displayName: 'Metadata Field',
|
displayName: 'Metadata Field',
|
||||||
name: 'metadata_field',
|
name: METADATA_FIELD_NAME,
|
||||||
type: 'string',
|
type: 'string',
|
||||||
default: 'text',
|
default: 'text',
|
||||||
description: 'The text field of the raw data',
|
description: 'The text field of the raw data',
|
||||||
@@ -77,6 +90,34 @@ const mongoNamespaceField: INodeProperties = {
|
|||||||
default: '',
|
default: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const preFilterField: INodeProperties = {
|
||||||
|
displayName: 'Pre Filter',
|
||||||
|
name: PRE_FILTER_NAME,
|
||||||
|
type: 'json',
|
||||||
|
typeOptions: {
|
||||||
|
alwaysOpenEditWindow: true,
|
||||||
|
},
|
||||||
|
default: '',
|
||||||
|
placeholder: '{ "key": "value" }',
|
||||||
|
hint: 'This is a filter applied in the $vectorSearch stage <a href="https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter">here</a>',
|
||||||
|
required: true,
|
||||||
|
description: 'MongoDB Atlas Vector Search pre-filter',
|
||||||
|
};
|
||||||
|
|
||||||
|
const postFilterField: INodeProperties = {
|
||||||
|
displayName: 'Post Filter Pipeline',
|
||||||
|
name: POST_FILTER_NAME,
|
||||||
|
type: 'json',
|
||||||
|
typeOptions: {
|
||||||
|
alwaysOpenEditWindow: true,
|
||||||
|
},
|
||||||
|
default: '',
|
||||||
|
placeholder: '[{ "$match": { "$gt": "1950-01-01" }, ... }]',
|
||||||
|
hint: 'Learn more about aggregation pipeline <a href="https://docs.mongodb.com/manual/core/aggregation-pipeline/">here</a>',
|
||||||
|
required: true,
|
||||||
|
description: 'MongoDB aggregation pipeline in JSON format',
|
||||||
|
};
|
||||||
|
|
||||||
const retrieveFields: INodeProperties[] = [
|
const retrieveFields: INodeProperties[] = [
|
||||||
{
|
{
|
||||||
displayName: 'Options',
|
displayName: 'Options',
|
||||||
@@ -84,7 +125,7 @@ const retrieveFields: INodeProperties[] = [
|
|||||||
type: 'collection',
|
type: 'collection',
|
||||||
placeholder: 'Add Option',
|
placeholder: 'Add Option',
|
||||||
default: {},
|
default: {},
|
||||||
options: [mongoNamespaceField, metadataFilterField],
|
options: [mongoNamespaceField, metadataFilterField, preFilterField, postFilterField],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
@@ -113,15 +154,6 @@ export const mongoConfig = {
|
|||||||
connectionString: '',
|
connectionString: '',
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
|
||||||
* Constants for the name of the credentials and Node parameters.
|
|
||||||
*/
|
|
||||||
export const MONGODB_CREDENTIALS = 'mongoDb';
|
|
||||||
export const MONGODB_COLLECTION_NAME = 'mongoCollection';
|
|
||||||
export const VECTOR_INDEX_NAME = 'vectorIndexName';
|
|
||||||
export const EMBEDDING_NAME = 'embedding';
|
|
||||||
export const METADATA_FIELD_NAME = 'metadata_field';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Type used for cleaner, more intentional typing.
|
* Type used for cleaner, more intentional typing.
|
||||||
*/
|
*/
|
||||||
@@ -202,6 +234,57 @@ export const getVectorIndexName = getParameter.bind(null, VECTOR_INDEX_NAME);
|
|||||||
export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME);
|
export const getEmbeddingFieldName = getParameter.bind(null, EMBEDDING_NAME);
|
||||||
export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME);
|
export const getMetadataFieldName = getParameter.bind(null, METADATA_FIELD_NAME);
|
||||||
|
|
||||||
|
export function getFilterValue<T>(
|
||||||
|
name: string,
|
||||||
|
context: IExecuteFunctions | ISupplyDataFunctions,
|
||||||
|
itemIndex: number,
|
||||||
|
): T | undefined {
|
||||||
|
const options: IDataObject = context.getNodeParameter('options', itemIndex, {});
|
||||||
|
|
||||||
|
if (options[name]) {
|
||||||
|
if (typeof options[name] === 'string') {
|
||||||
|
try {
|
||||||
|
return JSON.parse(options[name]);
|
||||||
|
} catch (error) {
|
||||||
|
throw new NodeOperationError(context.getNode(), `Error: ${error.message}`, {
|
||||||
|
itemIndex,
|
||||||
|
description: `Could not parse JSON for ${name}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new NodeOperationError(context.getNode(), 'Error: No JSON string provided.', {
|
||||||
|
itemIndex,
|
||||||
|
description: `Could not parse JSON for ${name}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
class ExtendedMongoDBAtlasVectorSearch extends MongoDBAtlasVectorSearch {
|
||||||
|
preFilter: IDataObject;
|
||||||
|
postFilterPipeline?: IDataObject[];
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
embeddings: EmbeddingsInterface,
|
||||||
|
options: MongoDBAtlasVectorSearchLibArgs,
|
||||||
|
preFilter: IDataObject,
|
||||||
|
postFilterPipeline?: IDataObject[],
|
||||||
|
) {
|
||||||
|
super(embeddings, options);
|
||||||
|
this.preFilter = preFilter;
|
||||||
|
this.postFilterPipeline = postFilterPipeline;
|
||||||
|
}
|
||||||
|
|
||||||
|
async similaritySearchVectorWithScore(query: number[], k: number) {
|
||||||
|
const mergedFilter: MongoDBAtlasVectorSearch['FilterType'] = {
|
||||||
|
preFilter: this.preFilter,
|
||||||
|
postFilterPipeline: this.postFilterPipeline,
|
||||||
|
};
|
||||||
|
return await super.similaritySearchVectorWithScore(query, k, mergedFilter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
||||||
meta: {
|
meta: {
|
||||||
displayName: 'MongoDB Atlas Vector Store',
|
displayName: 'MongoDB Atlas Vector Store',
|
||||||
@@ -245,13 +328,24 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
|||||||
description: 'Please check that the index exists in your collection',
|
description: 'Please check that the index exists in your collection',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
const preFilter = getFilterValue<IDataObject>(PRE_FILTER_NAME, context, itemIndex);
|
||||||
|
const postFilterPipeline = getFilterValue<IDataObject[]>(
|
||||||
|
POST_FILTER_NAME,
|
||||||
|
context,
|
||||||
|
itemIndex,
|
||||||
|
);
|
||||||
|
|
||||||
return new MongoDBAtlasVectorSearch(embeddings, {
|
return new ExtendedMongoDBAtlasVectorSearch(
|
||||||
collection,
|
embeddings,
|
||||||
indexName: mongoVectorIndexName, // Default index name
|
{
|
||||||
textKey: metadataFieldName, // Field containing raw text
|
collection,
|
||||||
embeddingKey: embeddingFieldName, // Field containing embeddings
|
indexName: mongoVectorIndexName, // Default index name
|
||||||
});
|
textKey: metadataFieldName, // Field containing raw text
|
||||||
|
embeddingKey: embeddingFieldName, // Field containing embeddings
|
||||||
|
},
|
||||||
|
preFilter ?? {},
|
||||||
|
postFilterPipeline,
|
||||||
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof NodeOperationError) {
|
if (error instanceof NodeOperationError) {
|
||||||
throw error;
|
throw error;
|
||||||
@@ -277,7 +371,7 @@ export class VectorStoreMongoDBAtlas extends createVectorStoreNode({
|
|||||||
await db.createCollection(collectionName);
|
await db.createCollection(collectionName);
|
||||||
}
|
}
|
||||||
const collection = db.collection(collectionName);
|
const collection = db.collection(collectionName);
|
||||||
await MongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
|
await ExtendedMongoDBAtlasVectorSearch.fromDocuments(documents, embeddings, {
|
||||||
collection,
|
collection,
|
||||||
indexName: mongoVectorIndexName, // Default index name
|
indexName: mongoVectorIndexName, // Default index name
|
||||||
textKey: metadataFieldName, // Field containing raw text
|
textKey: metadataFieldName, // Field containing raw text
|
||||||
|
|||||||
Reference in New Issue
Block a user