mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-16 09:36:44 +00:00
feat: Add Cohere reranking capability to vector stores (#16014)
Co-authored-by: Yiorgis Gozadinos <yiorgis@n8n.io> Co-authored-by: Mutasem Aldmour <mutasem@n8n.io>
This commit is contained in:
@@ -0,0 +1,90 @@
|
|||||||
|
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
|
||||||
|
import { CohereRerank } from '@langchain/cohere';
|
||||||
|
import {
|
||||||
|
NodeConnectionTypes,
|
||||||
|
type INodeType,
|
||||||
|
type INodeTypeDescription,
|
||||||
|
type ISupplyDataFunctions,
|
||||||
|
type SupplyData,
|
||||||
|
} from 'n8n-workflow';
|
||||||
|
|
||||||
|
import { logWrapper } from '@utils/logWrapper';
|
||||||
|
|
||||||
|
export class RerankerCohere implements INodeType {
|
||||||
|
description: INodeTypeDescription = {
|
||||||
|
displayName: 'Reranker Cohere',
|
||||||
|
name: 'rerankerCohere',
|
||||||
|
icon: { light: 'file:cohere.svg', dark: 'file:cohere.dark.svg' },
|
||||||
|
group: ['transform'],
|
||||||
|
version: 1,
|
||||||
|
description:
|
||||||
|
'Use Cohere Reranker to reorder documents after retrieval from a vector store by relevance to the given query.',
|
||||||
|
defaults: {
|
||||||
|
name: 'Reranker Cohere',
|
||||||
|
},
|
||||||
|
requestDefaults: {
|
||||||
|
ignoreHttpStatusErrors: true,
|
||||||
|
baseURL: '={{ $credentials.host }}',
|
||||||
|
},
|
||||||
|
credentials: [
|
||||||
|
{
|
||||||
|
name: 'cohereApi',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
codex: {
|
||||||
|
categories: ['AI'],
|
||||||
|
subcategories: {
|
||||||
|
AI: ['Rerankers'],
|
||||||
|
},
|
||||||
|
resources: {
|
||||||
|
primaryDocumentation: [
|
||||||
|
{
|
||||||
|
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.rerankercohere/',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputs: [],
|
||||||
|
outputs: [NodeConnectionTypes.AiReranker],
|
||||||
|
outputNames: ['Reranker'],
|
||||||
|
properties: [
|
||||||
|
{
|
||||||
|
displayName: 'Model',
|
||||||
|
name: 'modelName',
|
||||||
|
type: 'options',
|
||||||
|
description:
|
||||||
|
'The model that should be used to rerank the documents. <a href="https://docs.cohere.com/docs/models">Learn more</a>.',
|
||||||
|
default: 'rerank-v3.5',
|
||||||
|
options: [
|
||||||
|
{
|
||||||
|
name: 'rerank-v3.5',
|
||||||
|
value: 'rerank-v3.5',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'rerank-english-v3.0',
|
||||||
|
value: 'rerank-english-v3.0',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'rerank-multilingual-v3.0',
|
||||||
|
value: 'rerank-multilingual-v3.0',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
|
||||||
|
this.logger.debug('Supply data for reranking Cohere');
|
||||||
|
const modelName = this.getNodeParameter('modelName', itemIndex, 'rerank-v3.5') as string;
|
||||||
|
const credentials = await this.getCredentials<{ apiKey: string }>('cohereApi');
|
||||||
|
const reranker = new CohereRerank({
|
||||||
|
apiKey: credentials.apiKey,
|
||||||
|
model: modelName,
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
response: logWrapper(reranker, this),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.96 23.84C14.0267 23.84 16.16 23.7867 19.1467 22.56C22.6133 21.12 29.44 18.56 34.4 15.8933C37.8667 14.0267 39.36 11.5733 39.36 8.26667C39.36 3.73333 35.68 0 31.0933 0H11.8933C5.33333 0 0 5.33333 0 11.8933C0 18.4533 5.01333 23.84 12.96 23.84Z" fill="white"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M16.2134 31.9999C16.2134 28.7999 18.1334 25.8666 21.12 24.6399L27.1467 22.1333C33.28 19.6266 40 24.1066 40 30.7199C40 35.8399 35.84 39.9999 30.72 39.9999H24.16C19.7867 39.9999 16.2134 36.4266 16.2134 31.9999Z" fill="white"/>
|
||||||
|
<path d="M6.88 25.3867C3.09333 25.3867 0 28.4801 0 32.2667V33.1734C0 36.9067 3.09333 40.0001 6.88 40.0001C10.6667 40.0001 13.76 36.9067 13.76 33.1201V32.2134C13.7067 28.4801 10.6667 25.3867 6.88 25.3867Z" fill="white"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 907 B |
@@ -0,0 +1,5 @@
|
|||||||
|
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.96 23.84C14.0267 23.84 16.16 23.7867 19.1467 22.56C22.6133 21.12 29.44 18.56 34.4 15.8933C37.8667 14.0267 39.36 11.5733 39.36 8.26667C39.36 3.73333 35.68 0 31.0933 0H11.8933C5.33333 0 0 5.33333 0 11.8933C0 18.4533 5.01333 23.84 12.96 23.84Z" fill="#39594D"/>
|
||||||
|
<path fill-rule="evenodd" clip-rule="evenodd" d="M16.2134 31.9999C16.2134 28.7999 18.1334 25.8666 21.12 24.6399L27.1467 22.1333C33.28 19.6266 40 24.1066 40 30.7199C40 35.8399 35.84 39.9999 30.72 39.9999H24.16C19.7867 39.9999 16.2134 36.4266 16.2134 31.9999Z" fill="#D18EE2"/>
|
||||||
|
<path d="M6.88 25.3867C3.09333 25.3867 0 28.4801 0 32.2667V33.1734C0 36.9067 3.09333 40.0001 6.88 40.0001C10.6667 40.0001 13.76 36.9067 13.76 33.1201V32.2134C13.7067 28.4801 10.6667 25.3867 6.88 25.3867Z" fill="#FF7759"/>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 913 B |
@@ -0,0 +1,141 @@
|
|||||||
|
import { CohereRerank } from '@langchain/cohere';
|
||||||
|
import { mock } from 'jest-mock-extended';
|
||||||
|
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||||
|
|
||||||
|
import { logWrapper } from '@utils/logWrapper';
|
||||||
|
|
||||||
|
import { RerankerCohere } from '../RerankerCohere.node';
|
||||||
|
|
||||||
|
// Mock the CohereRerank class
|
||||||
|
jest.mock('@langchain/cohere', () => ({
|
||||||
|
CohereRerank: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock the logWrapper utility
|
||||||
|
jest.mock('@utils/logWrapper', () => ({
|
||||||
|
logWrapper: jest.fn().mockImplementation((obj) => ({ logWrapped: obj })),
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe('RerankerCohere', () => {
|
||||||
|
let rerankerCohere: RerankerCohere;
|
||||||
|
let mockSupplyDataFunctions: ISupplyDataFunctions;
|
||||||
|
let mockCohereRerank: jest.Mocked<CohereRerank>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
rerankerCohere = new RerankerCohere();
|
||||||
|
|
||||||
|
// Reset the mock
|
||||||
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
// Create a mock CohereRerank instance
|
||||||
|
mockCohereRerank = {
|
||||||
|
compressDocuments: jest.fn(),
|
||||||
|
} as unknown as jest.Mocked<CohereRerank>;
|
||||||
|
|
||||||
|
// Make the CohereRerank constructor return our mock instance
|
||||||
|
(CohereRerank as jest.MockedClass<typeof CohereRerank>).mockImplementation(
|
||||||
|
() => mockCohereRerank,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create mock supply data functions
|
||||||
|
mockSupplyDataFunctions = mock<ISupplyDataFunctions>({
|
||||||
|
logger: {
|
||||||
|
debug: jest.fn(),
|
||||||
|
error: jest.fn(),
|
||||||
|
info: jest.fn(),
|
||||||
|
warn: jest.fn(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock specific methods with proper jest functions
|
||||||
|
mockSupplyDataFunctions.getNodeParameter = jest.fn();
|
||||||
|
mockSupplyDataFunctions.getCredentials = jest.fn();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create CohereRerank with default model and return wrapped instance', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
const result = await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||||
|
|
||||||
|
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith(
|
||||||
|
'modelName',
|
||||||
|
0,
|
||||||
|
'rerank-v3.5',
|
||||||
|
);
|
||||||
|
expect(mockSupplyDataFunctions.getCredentials).toHaveBeenCalledWith('cohereApi');
|
||||||
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
model: 'rerank-v3.5',
|
||||||
|
});
|
||||||
|
expect(logWrapper).toHaveBeenCalledWith(mockCohereRerank, mockSupplyDataFunctions);
|
||||||
|
expect(result.response).toEqual({ logWrapped: mockCohereRerank });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create CohereRerank with custom model', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockCredentials = { apiKey: 'custom-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue(
|
||||||
|
'rerank-multilingual-v3.0',
|
||||||
|
);
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
|
apiKey: 'custom-api-key',
|
||||||
|
model: 'rerank-multilingual-v3.0',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle different item indices', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-english-v3.0');
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute with different item index
|
||||||
|
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 2);
|
||||||
|
|
||||||
|
// Verify the correct item index is passed
|
||||||
|
expect(mockSupplyDataFunctions.getNodeParameter).toHaveBeenCalledWith(
|
||||||
|
'modelName',
|
||||||
|
2,
|
||||||
|
'rerank-v3.5',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw error when credentials are missing', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5');
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockRejectedValue(
|
||||||
|
new Error('Missing credentials'),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Execute and verify error
|
||||||
|
await expect(rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0)).rejects.toThrow(
|
||||||
|
'Missing credentials',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use fallback model when parameter is not provided', async () => {
|
||||||
|
// Setup mocks - getNodeParameter returns the fallback value
|
||||||
|
const mockCredentials = { apiKey: 'test-api-key' };
|
||||||
|
(mockSupplyDataFunctions.getNodeParameter as jest.Mock).mockReturnValue('rerank-v3.5'); // fallback value
|
||||||
|
(mockSupplyDataFunctions.getCredentials as jest.Mock).mockResolvedValue(mockCredentials);
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
await rerankerCohere.supplyData.call(mockSupplyDataFunctions, 0);
|
||||||
|
|
||||||
|
// Verify fallback is used
|
||||||
|
expect(CohereRerank).toHaveBeenCalledWith({
|
||||||
|
apiKey: 'test-api-key',
|
||||||
|
model: 'rerank-v3.5',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -41,8 +41,13 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
|||||||
"inputs": "={{
|
"inputs": "={{
|
||||||
((parameters) => {
|
((parameters) => {
|
||||||
const mode = parameters?.mode;
|
const mode = parameters?.mode;
|
||||||
|
const useReranker = parameters?.useReranker;
|
||||||
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
|
const inputs = [{ displayName: "Embedding", type: "ai_embedding", required: true, maxConnections: 1}]
|
||||||
|
|
||||||
|
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||||
|
inputs.push({ displayName: "Reranker", type: "ai_reranker", required: true, maxConnections: 1})
|
||||||
|
}
|
||||||
|
|
||||||
if (mode === 'retrieve-as-tool') {
|
if (mode === 'retrieve-as-tool') {
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
@@ -233,6 +238,21 @@ exports[`createVectorStoreNode retrieve mode supplies vector store as data 1`] =
|
|||||||
"name": "includeDocumentMetadata",
|
"name": "includeDocumentMetadata",
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"default": false,
|
||||||
|
"description": "Whether or not to rerank results",
|
||||||
|
"displayName": "Rerank Results",
|
||||||
|
"displayOptions": {
|
||||||
|
"show": {
|
||||||
|
"mode": [
|
||||||
|
"load",
|
||||||
|
"retrieve-as-tool",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"name": "useReranker",
|
||||||
|
"type": "boolean",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"default": "",
|
"default": "",
|
||||||
"description": "ID of an embedding entry",
|
"description": "ID of an embedding entry",
|
||||||
|
|||||||
@@ -69,8 +69,13 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
|||||||
inputs: `={{
|
inputs: `={{
|
||||||
((parameters) => {
|
((parameters) => {
|
||||||
const mode = parameters?.mode;
|
const mode = parameters?.mode;
|
||||||
|
const useReranker = parameters?.useReranker;
|
||||||
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
|
const inputs = [{ displayName: "Embedding", type: "${NodeConnectionTypes.AiEmbedding}", required: true, maxConnections: 1}]
|
||||||
|
|
||||||
|
if (['load', 'retrieve-as-tool'].includes(mode) && useReranker) {
|
||||||
|
inputs.push({ displayName: "Reranker", type: "${NodeConnectionTypes.AiReranker}", required: true, maxConnections: 1})
|
||||||
|
}
|
||||||
|
|
||||||
if (mode === 'retrieve-as-tool') {
|
if (mode === 'retrieve-as-tool') {
|
||||||
return inputs;
|
return inputs;
|
||||||
}
|
}
|
||||||
@@ -202,6 +207,18 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
displayName: 'Rerank Results',
|
||||||
|
name: 'useReranker',
|
||||||
|
type: 'boolean',
|
||||||
|
default: false,
|
||||||
|
description: 'Whether or not to rerank results',
|
||||||
|
displayOptions: {
|
||||||
|
show: {
|
||||||
|
mode: ['load', 'retrieve-as-tool'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
// ID is always used for update operation
|
// ID is always used for update operation
|
||||||
{
|
{
|
||||||
displayName: 'ID',
|
displayName: 'ID',
|
||||||
@@ -233,7 +250,6 @@ export const createVectorStoreNode = <T extends VectorStore = VectorStore>(
|
|||||||
*/
|
*/
|
||||||
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
|
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> {
|
||||||
const mode = this.getNodeParameter('mode', 0) as NodeOperationMode;
|
const mode = this.getNodeParameter('mode', 0) as NodeOperationMode;
|
||||||
|
|
||||||
// Get the embeddings model connected to this node
|
// Get the embeddings model connected to this node
|
||||||
const embeddings = (await this.getInputConnectionData(
|
const embeddings = (await this.getInputConnectionData(
|
||||||
NodeConnectionTypes.AiEmbedding,
|
NodeConnectionTypes.AiEmbedding,
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
/* eslint-disable @typescript-eslint/unbound-method */
|
/* eslint-disable @typescript-eslint/unbound-method */
|
||||||
import type { Document } from '@langchain/core/documents';
|
import type { Document } from '@langchain/core/documents';
|
||||||
import type { Embeddings } from '@langchain/core/embeddings';
|
import type { Embeddings } from '@langchain/core/embeddings';
|
||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import type { MockProxy } from 'jest-mock-extended';
|
import type { MockProxy } from 'jest-mock-extended';
|
||||||
import { mock } from 'jest-mock-extended';
|
import { mock } from 'jest-mock-extended';
|
||||||
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
|
import type { IDataObject, IExecuteFunctions } from 'n8n-workflow';
|
||||||
|
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||||
|
|
||||||
import { logAiEvent } from '@utils/helpers';
|
import { logAiEvent } from '@utils/helpers';
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ describe('handleLoadOperation', () => {
|
|||||||
let mockContext: MockProxy<IExecuteFunctions>;
|
let mockContext: MockProxy<IExecuteFunctions>;
|
||||||
let mockEmbeddings: MockProxy<Embeddings>;
|
let mockEmbeddings: MockProxy<Embeddings>;
|
||||||
let mockVectorStore: MockProxy<VectorStore>;
|
let mockVectorStore: MockProxy<VectorStore>;
|
||||||
|
let mockReranker: MockProxy<BaseDocumentCompressor>;
|
||||||
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
||||||
let nodeParameters: Record<string, any>;
|
let nodeParameters: Record<string, any>;
|
||||||
|
|
||||||
@@ -30,6 +33,7 @@ describe('handleLoadOperation', () => {
|
|||||||
prompt: 'test search query',
|
prompt: 'test search query',
|
||||||
topK: 3,
|
topK: 3,
|
||||||
includeDocumentMetadata: true,
|
includeDocumentMetadata: true,
|
||||||
|
useReranker: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
mockContext = mock<IExecuteFunctions>();
|
mockContext = mock<IExecuteFunctions>();
|
||||||
@@ -48,6 +52,24 @@ describe('handleLoadOperation', () => {
|
|||||||
[{ pageContent: 'test content 3', metadata: { test: 'metadata 3' } } as Document, 0.75],
|
[{ pageContent: 'test content 3', metadata: { test: 'metadata 3' } } as Document, 0.75],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
mockReranker = mock<BaseDocumentCompressor>();
|
||||||
|
mockReranker.compressDocuments.mockResolvedValue([
|
||||||
|
{
|
||||||
|
pageContent: 'test content 2',
|
||||||
|
metadata: { test: 'metadata 2', relevanceScore: 0.98 },
|
||||||
|
} as Document,
|
||||||
|
{
|
||||||
|
pageContent: 'test content 1',
|
||||||
|
metadata: { test: 'metadata 1', relevanceScore: 0.92 },
|
||||||
|
} as Document,
|
||||||
|
{
|
||||||
|
pageContent: 'test content 3',
|
||||||
|
metadata: { test: 'metadata 3', relevanceScore: 0.88 },
|
||||||
|
} as Document,
|
||||||
|
]);
|
||||||
|
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
|
||||||
|
|
||||||
mockArgs = {
|
mockArgs = {
|
||||||
meta: {
|
meta: {
|
||||||
displayName: 'Test Vector Store',
|
displayName: 'Test Vector Store',
|
||||||
@@ -142,4 +164,82 @@ describe('handleLoadOperation', () => {
|
|||||||
|
|
||||||
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('reranking functionality', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
nodeParameters.useReranker = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use reranker when useReranker is true', async () => {
|
||||||
|
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||||
|
NodeConnectionTypes.AiReranker,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
expect(mockReranker.compressDocuments).toHaveBeenCalledWith(
|
||||||
|
[
|
||||||
|
{ pageContent: 'test content 1', metadata: { test: 'metadata 1' } },
|
||||||
|
{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } },
|
||||||
|
{ pageContent: 'test content 3', metadata: { test: 'metadata 3' } },
|
||||||
|
],
|
||||||
|
'test search query',
|
||||||
|
);
|
||||||
|
expect(result).toHaveLength(3);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return reranked documents with relevance scores', async () => {
|
||||||
|
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
// First result should be the reranked first document (was second in original order)
|
||||||
|
expect((result[0].json?.document as IDataObject)?.pageContent).toEqual('test content 2');
|
||||||
|
expect(result[0].json?.score).toEqual(0.98);
|
||||||
|
|
||||||
|
// Second result should be the reranked second document (was first in original order)
|
||||||
|
expect((result[1].json?.document as IDataObject)?.pageContent).toEqual('test content 1');
|
||||||
|
expect(result[1].json?.score).toEqual(0.92);
|
||||||
|
|
||||||
|
// Third result should be the reranked third document
|
||||||
|
expect((result[2].json?.document as IDataObject)?.pageContent).toEqual('test content 3');
|
||||||
|
expect(result[2].json?.score).toEqual(0.88);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should remove relevanceScore from metadata after reranking', async () => {
|
||||||
|
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
// Check that relevanceScore is not included in the metadata
|
||||||
|
expect((result[0].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 2' });
|
||||||
|
expect((result[1].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 1' });
|
||||||
|
expect((result[2].json?.document as IDataObject)?.metadata).toEqual({ test: 'metadata 3' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle reranking with includeDocumentMetadata false', async () => {
|
||||||
|
nodeParameters.includeDocumentMetadata = false;
|
||||||
|
|
||||||
|
const result = await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
expect(result[0].json?.document).not.toHaveProperty('metadata');
|
||||||
|
expect((result[0].json?.document as IDataObject)?.pageContent).toEqual('test content 2');
|
||||||
|
expect(result[0].json?.score).toEqual(0.98);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not call reranker when useReranker is false', async () => {
|
||||||
|
nodeParameters.useReranker = false;
|
||||||
|
|
||||||
|
await handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
|
||||||
|
expect(mockReranker.compressDocuments).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should release vector store client even if reranking fails', async () => {
|
||||||
|
mockReranker.compressDocuments.mockRejectedValue(new Error('Reranking failed'));
|
||||||
|
|
||||||
|
await expect(handleLoadOperation(mockContext, mockArgs, mockEmbeddings, 0)).rejects.toThrow(
|
||||||
|
'Reranking failed',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
/* eslint-disable @typescript-eslint/unbound-method */
|
/* eslint-disable @typescript-eslint/unbound-method */
|
||||||
import type { Document } from '@langchain/core/documents';
|
import type { Document } from '@langchain/core/documents';
|
||||||
import type { Embeddings } from '@langchain/core/embeddings';
|
import type { Embeddings } from '@langchain/core/embeddings';
|
||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import type { MockProxy } from 'jest-mock-extended';
|
import type { MockProxy } from 'jest-mock-extended';
|
||||||
import { mock } from 'jest-mock-extended';
|
import { mock } from 'jest-mock-extended';
|
||||||
import { DynamicTool } from 'langchain/tools';
|
import { DynamicTool } from 'langchain/tools';
|
||||||
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
import type { ISupplyDataFunctions } from 'n8n-workflow';
|
||||||
|
import { NodeConnectionTypes } from 'n8n-workflow';
|
||||||
|
|
||||||
import { logWrapper } from '@utils/logWrapper';
|
import { logWrapper } from '@utils/logWrapper';
|
||||||
|
|
||||||
@@ -26,6 +28,7 @@ describe('handleRetrieveAsToolOperation', () => {
|
|||||||
let mockContext: MockProxy<ISupplyDataFunctions>;
|
let mockContext: MockProxy<ISupplyDataFunctions>;
|
||||||
let mockEmbeddings: MockProxy<Embeddings>;
|
let mockEmbeddings: MockProxy<Embeddings>;
|
||||||
let mockVectorStore: MockProxy<VectorStore>;
|
let mockVectorStore: MockProxy<VectorStore>;
|
||||||
|
let mockReranker: MockProxy<BaseDocumentCompressor>;
|
||||||
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
let mockArgs: VectorStoreNodeConstructorArgs<VectorStore>;
|
||||||
let nodeParameters: Record<string, any>;
|
let nodeParameters: Record<string, any>;
|
||||||
|
|
||||||
@@ -35,6 +38,7 @@ describe('handleRetrieveAsToolOperation', () => {
|
|||||||
toolDescription: 'Search the test knowledge base',
|
toolDescription: 'Search the test knowledge base',
|
||||||
topK: 3,
|
topK: 3,
|
||||||
includeDocumentMetadata: true,
|
includeDocumentMetadata: true,
|
||||||
|
useReranker: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
mockContext = mock<ISupplyDataFunctions>();
|
mockContext = mock<ISupplyDataFunctions>();
|
||||||
@@ -60,6 +64,20 @@ describe('handleRetrieveAsToolOperation', () => {
|
|||||||
[{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } } as Document, 0.85],
|
[{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } } as Document, 0.85],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
mockReranker = mock<BaseDocumentCompressor>();
|
||||||
|
mockReranker.compressDocuments.mockResolvedValue([
|
||||||
|
{
|
||||||
|
pageContent: 'test content 2',
|
||||||
|
metadata: { test: 'metadata 2', relevanceScore: 0.98 },
|
||||||
|
} as Document,
|
||||||
|
{
|
||||||
|
pageContent: 'test content 1',
|
||||||
|
metadata: { test: 'metadata 1', relevanceScore: 0.92 },
|
||||||
|
} as Document,
|
||||||
|
]);
|
||||||
|
|
||||||
|
mockContext.getInputConnectionData.mockResolvedValue(mockReranker);
|
||||||
|
|
||||||
mockArgs = {
|
mockArgs = {
|
||||||
meta: {
|
meta: {
|
||||||
displayName: 'Test Vector Store',
|
displayName: 'Test Vector Store',
|
||||||
@@ -215,4 +233,115 @@ describe('handleRetrieveAsToolOperation', () => {
|
|||||||
// Should still release the client
|
// Should still release the client
|
||||||
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('reranking functionality', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
nodeParameters.useReranker = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use reranker when useReranker is true', async () => {
|
||||||
|
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
const tool = result.response as DynamicTool;
|
||||||
|
|
||||||
|
await tool.func('test query');
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).toHaveBeenCalledWith(
|
||||||
|
NodeConnectionTypes.AiReranker,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
expect(mockReranker.compressDocuments).toHaveBeenCalledWith(
|
||||||
|
[
|
||||||
|
{ pageContent: 'test content 1', metadata: { test: 'metadata 1' } },
|
||||||
|
{ pageContent: 'test content 2', metadata: { test: 'metadata 2' } },
|
||||||
|
],
|
||||||
|
'test query',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return reranked documents in the correct order', async () => {
|
||||||
|
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
const tool = result.response as DynamicTool;
|
||||||
|
|
||||||
|
const toolResult = await tool.func('test query');
|
||||||
|
|
||||||
|
expect(toolResult).toHaveLength(2);
|
||||||
|
|
||||||
|
// First result should be the reranked first document (was second in original order)
|
||||||
|
const parsedFirst = JSON.parse(toolResult[0].text);
|
||||||
|
expect(parsedFirst.pageContent).toEqual('test content 2');
|
||||||
|
expect(parsedFirst.metadata).toEqual({ test: 'metadata 2' });
|
||||||
|
|
||||||
|
// Second result should be the reranked second document (was first in original order)
|
||||||
|
const parsedSecond = JSON.parse(toolResult[1].text);
|
||||||
|
expect(parsedSecond.pageContent).toEqual('test content 1');
|
||||||
|
expect(parsedSecond.metadata).toEqual({ test: 'metadata 1' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle reranking with includeDocumentMetadata false', async () => {
|
||||||
|
nodeParameters.includeDocumentMetadata = false;
|
||||||
|
|
||||||
|
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
const tool = result.response as DynamicTool;
|
||||||
|
|
||||||
|
const toolResult = await tool.func('test query');
|
||||||
|
|
||||||
|
// Parse the JSON text to verify it excludes metadata but maintains reranked order
|
||||||
|
const parsedFirst = JSON.parse(toolResult[0].text);
|
||||||
|
expect(parsedFirst).toHaveProperty('pageContent', 'test content 2');
|
||||||
|
expect(parsedFirst).not.toHaveProperty('metadata');
|
||||||
|
|
||||||
|
const parsedSecond = JSON.parse(toolResult[1].text);
|
||||||
|
expect(parsedSecond).toHaveProperty('pageContent', 'test content 1');
|
||||||
|
expect(parsedSecond).not.toHaveProperty('metadata');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not call reranker when useReranker is false', async () => {
|
||||||
|
nodeParameters.useReranker = false;
|
||||||
|
|
||||||
|
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
const tool = result.response as DynamicTool;
|
||||||
|
|
||||||
|
await tool.func('test query');
|
||||||
|
|
||||||
|
expect(mockContext.getInputConnectionData).not.toHaveBeenCalled();
|
||||||
|
expect(mockReranker.compressDocuments).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should release vector store client even if reranking fails', async () => {
|
||||||
|
mockReranker.compressDocuments.mockRejectedValueOnce(new Error('Reranking failed'));
|
||||||
|
|
||||||
|
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
const tool = result.response as DynamicTool;
|
||||||
|
|
||||||
|
await expect(tool.func('test query')).rejects.toThrow('Reranking failed');
|
||||||
|
|
||||||
|
// Should still release the client
|
||||||
|
expect(mockArgs.releaseVectorStoreClient).toHaveBeenCalledWith(mockVectorStore);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should properly handle relevanceScore from reranker metadata', async () => {
|
||||||
|
// Mock reranker to return documents with relevanceScore in different metadata structure
|
||||||
|
mockReranker.compressDocuments.mockResolvedValueOnce([
|
||||||
|
{
|
||||||
|
pageContent: 'test content 2',
|
||||||
|
metadata: { test: 'metadata 2', relevanceScore: 0.98, otherField: 'value' },
|
||||||
|
} as Document,
|
||||||
|
{
|
||||||
|
pageContent: 'test content 1',
|
||||||
|
metadata: { test: 'metadata 1', relevanceScore: 0.92 },
|
||||||
|
} as Document,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const result = await handleRetrieveAsToolOperation(mockContext, mockArgs, mockEmbeddings, 0);
|
||||||
|
const tool = result.response as DynamicTool;
|
||||||
|
|
||||||
|
const toolResult = await tool.invoke('test query');
|
||||||
|
|
||||||
|
// Check that relevanceScore is used but not included in the final metadata
|
||||||
|
const parsedFirst = JSON.parse(toolResult[0].text);
|
||||||
|
expect(parsedFirst.pageContent).toEqual('test content 2');
|
||||||
|
expect(parsedFirst.metadata).toEqual({ test: 'metadata 2', otherField: 'value' });
|
||||||
|
expect(parsedFirst.metadata).not.toHaveProperty('relevanceScore');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import type { Embeddings } from '@langchain/core/embeddings';
|
import type { Embeddings } from '@langchain/core/embeddings';
|
||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
|
import { NodeConnectionTypes, type IExecuteFunctions, type INodeExecutionData } from 'n8n-workflow';
|
||||||
|
|
||||||
import { getMetadataFiltersValues, logAiEvent } from '@utils/helpers';
|
import { getMetadataFiltersValues, logAiEvent } from '@utils/helpers';
|
||||||
|
|
||||||
@@ -29,6 +30,8 @@ export async function handleLoadOperation<T extends VectorStore = VectorStore>(
|
|||||||
// Get the search parameters from the node
|
// Get the search parameters from the node
|
||||||
const prompt = context.getNodeParameter('prompt', itemIndex) as string;
|
const prompt = context.getNodeParameter('prompt', itemIndex) as string;
|
||||||
const topK = context.getNodeParameter('topK', itemIndex, 4) as number;
|
const topK = context.getNodeParameter('topK', itemIndex, 4) as number;
|
||||||
|
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
|
||||||
|
|
||||||
const includeDocumentMetadata = context.getNodeParameter(
|
const includeDocumentMetadata = context.getNodeParameter(
|
||||||
'includeDocumentMetadata',
|
'includeDocumentMetadata',
|
||||||
itemIndex,
|
itemIndex,
|
||||||
@@ -39,7 +42,22 @@ export async function handleLoadOperation<T extends VectorStore = VectorStore>(
|
|||||||
const embeddedPrompt = await embeddings.embedQuery(prompt);
|
const embeddedPrompt = await embeddings.embedQuery(prompt);
|
||||||
|
|
||||||
// Get the most similar documents to the embedded prompt
|
// Get the most similar documents to the embedded prompt
|
||||||
const docs = await vectorStore.similaritySearchVectorWithScore(embeddedPrompt, topK, filter);
|
let docs = await vectorStore.similaritySearchVectorWithScore(embeddedPrompt, topK, filter);
|
||||||
|
|
||||||
|
// If reranker is used, rerank the documents
|
||||||
|
if (useReranker && docs.length > 0) {
|
||||||
|
const reranker = (await context.getInputConnectionData(
|
||||||
|
NodeConnectionTypes.AiReranker,
|
||||||
|
0,
|
||||||
|
)) as BaseDocumentCompressor;
|
||||||
|
const documents = docs.map(([doc]) => doc);
|
||||||
|
|
||||||
|
const rerankedDocuments = await reranker.compressDocuments(documents, prompt);
|
||||||
|
docs = rerankedDocuments.map((doc) => {
|
||||||
|
const { relevanceScore, ...metadata } = doc.metadata || {};
|
||||||
|
return [{ ...doc, metadata }, relevanceScore];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Format the documents for the output
|
// Format the documents for the output
|
||||||
const serializedDocs = docs.map(([doc, score]) => {
|
const serializedDocs = docs.map(([doc, score]) => {
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import type { Embeddings } from '@langchain/core/embeddings';
|
import type { Embeddings } from '@langchain/core/embeddings';
|
||||||
|
import type { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { VectorStore } from '@langchain/core/vectorstores';
|
import type { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import { DynamicTool } from 'langchain/tools';
|
import { DynamicTool } from 'langchain/tools';
|
||||||
import type { ISupplyDataFunctions, SupplyData } from 'n8n-workflow';
|
import { NodeConnectionTypes, type ISupplyDataFunctions, type SupplyData } from 'n8n-workflow';
|
||||||
|
|
||||||
import { getMetadataFiltersValues, nodeNameToToolName } from '@utils/helpers';
|
import { getMetadataFiltersValues, nodeNameToToolName } from '@utils/helpers';
|
||||||
import { logWrapper } from '@utils/logWrapper';
|
import { logWrapper } from '@utils/logWrapper';
|
||||||
@@ -29,6 +30,7 @@ export async function handleRetrieveAsToolOperation<T extends VectorStore = Vect
|
|||||||
: nodeNameToToolName(node);
|
: nodeNameToToolName(node);
|
||||||
|
|
||||||
const topK = context.getNodeParameter('topK', itemIndex, 4) as number;
|
const topK = context.getNodeParameter('topK', itemIndex, 4) as number;
|
||||||
|
const useReranker = context.getNodeParameter('useReranker', itemIndex, false) as boolean;
|
||||||
const includeDocumentMetadata = context.getNodeParameter(
|
const includeDocumentMetadata = context.getNodeParameter(
|
||||||
'includeDocumentMetadata',
|
'includeDocumentMetadata',
|
||||||
itemIndex,
|
itemIndex,
|
||||||
@@ -58,12 +60,27 @@ export async function handleRetrieveAsToolOperation<T extends VectorStore = Vect
|
|||||||
const embeddedPrompt = await embeddings.embedQuery(input);
|
const embeddedPrompt = await embeddings.embedQuery(input);
|
||||||
|
|
||||||
// Search for similar documents
|
// Search for similar documents
|
||||||
const documents = await vectorStore.similaritySearchVectorWithScore(
|
let documents = await vectorStore.similaritySearchVectorWithScore(
|
||||||
embeddedPrompt,
|
embeddedPrompt,
|
||||||
topK,
|
topK,
|
||||||
filter,
|
filter,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// If reranker is used, rerank the documents
|
||||||
|
if (useReranker && documents.length > 0) {
|
||||||
|
const reranker = (await context.getInputConnectionData(
|
||||||
|
NodeConnectionTypes.AiReranker,
|
||||||
|
0,
|
||||||
|
)) as BaseDocumentCompressor;
|
||||||
|
|
||||||
|
const docs = documents.map(([doc]) => doc);
|
||||||
|
const rerankedDocuments = await reranker.compressDocuments(docs, input);
|
||||||
|
documents = rerankedDocuments.map((doc) => {
|
||||||
|
const { relevanceScore, ...metadata } = doc.metadata;
|
||||||
|
return [{ ...doc, metadata }, relevanceScore];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Format the documents for the tool output
|
// Format the documents for the tool output
|
||||||
return documents
|
return documents
|
||||||
.map((document) => {
|
.map((document) => {
|
||||||
|
|||||||
@@ -99,6 +99,7 @@
|
|||||||
"dist/nodes/output_parser/OutputParserAutofixing/OutputParserAutofixing.node.js",
|
"dist/nodes/output_parser/OutputParserAutofixing/OutputParserAutofixing.node.js",
|
||||||
"dist/nodes/output_parser/OutputParserItemList/OutputParserItemList.node.js",
|
"dist/nodes/output_parser/OutputParserItemList/OutputParserItemList.node.js",
|
||||||
"dist/nodes/output_parser/OutputParserStructured/OutputParserStructured.node.js",
|
"dist/nodes/output_parser/OutputParserStructured/OutputParserStructured.node.js",
|
||||||
|
"dist/nodes/rerankers/RerankerCohere/RerankerCohere.node.js",
|
||||||
"dist/nodes/retrievers/RetrieverContextualCompression/RetrieverContextualCompression.node.js",
|
"dist/nodes/retrievers/RetrieverContextualCompression/RetrieverContextualCompression.node.js",
|
||||||
"dist/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.js",
|
"dist/nodes/retrievers/RetrieverVectorStore/RetrieverVectorStore.node.js",
|
||||||
"dist/nodes/retrievers/RetrieverMultiQuery/RetrieverMultiQuery.node.js",
|
"dist/nodes/retrievers/RetrieverMultiQuery/RetrieverMultiQuery.node.js",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { Embeddings } from '@langchain/core/embeddings';
|
|||||||
import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory';
|
import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory';
|
||||||
import type { BaseMessage } from '@langchain/core/messages';
|
import type { BaseMessage } from '@langchain/core/messages';
|
||||||
import { BaseRetriever } from '@langchain/core/retrievers';
|
import { BaseRetriever } from '@langchain/core/retrievers';
|
||||||
|
import { BaseDocumentCompressor } from '@langchain/core/retrievers/document_compressors';
|
||||||
import type { StructuredTool, Tool } from '@langchain/core/tools';
|
import type { StructuredTool, Tool } from '@langchain/core/tools';
|
||||||
import { VectorStore } from '@langchain/core/vectorstores';
|
import { VectorStore } from '@langchain/core/vectorstores';
|
||||||
import { TextSplitter } from '@langchain/textsplitters';
|
import { TextSplitter } from '@langchain/textsplitters';
|
||||||
@@ -18,7 +19,12 @@ import type {
|
|||||||
ITaskMetadata,
|
ITaskMetadata,
|
||||||
NodeConnectionType,
|
NodeConnectionType,
|
||||||
} from 'n8n-workflow';
|
} from 'n8n-workflow';
|
||||||
import { NodeOperationError, NodeConnectionTypes, parseErrorMetadata } from 'n8n-workflow';
|
import {
|
||||||
|
NodeOperationError,
|
||||||
|
NodeConnectionTypes,
|
||||||
|
parseErrorMetadata,
|
||||||
|
deepCopy,
|
||||||
|
} from 'n8n-workflow';
|
||||||
|
|
||||||
import { logAiEvent, isToolsInstance, isBaseChatMemory, isBaseChatMessageHistory } from './helpers';
|
import { logAiEvent, isToolsInstance, isBaseChatMemory, isBaseChatMessageHistory } from './helpers';
|
||||||
import { N8nBinaryLoader } from './N8nBinaryLoader';
|
import { N8nBinaryLoader } from './N8nBinaryLoader';
|
||||||
@@ -102,6 +108,7 @@ export function logWrapper<
|
|||||||
| BaseChatMemory
|
| BaseChatMemory
|
||||||
| BaseChatMessageHistory
|
| BaseChatMessageHistory
|
||||||
| BaseRetriever
|
| BaseRetriever
|
||||||
|
| BaseDocumentCompressor
|
||||||
| Embeddings
|
| Embeddings
|
||||||
| Document[]
|
| Document[]
|
||||||
| Document
|
| Document
|
||||||
@@ -297,6 +304,32 @@ export function logWrapper<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== Rerankers ==========
|
||||||
|
if (originalInstance instanceof BaseDocumentCompressor) {
|
||||||
|
if (prop === 'compressDocuments' && 'compressDocuments' in target) {
|
||||||
|
return async (documents: Document[], query: string): Promise<Document[]> => {
|
||||||
|
connectionType = NodeConnectionTypes.AiReranker;
|
||||||
|
const { index } = executeFunctions.addInputData(connectionType, [
|
||||||
|
[{ json: { query, documents } }],
|
||||||
|
]);
|
||||||
|
|
||||||
|
const response = (await callMethodAsync.call(target, {
|
||||||
|
executeFunctions,
|
||||||
|
connectionType,
|
||||||
|
currentNodeRunIndex: index,
|
||||||
|
method: target[prop],
|
||||||
|
// compressDocuments mutates the original object
|
||||||
|
// messing up the input data logging
|
||||||
|
arguments: [deepCopy(documents), query],
|
||||||
|
})) as Document[];
|
||||||
|
|
||||||
|
logAiEvent(executeFunctions, 'ai-document-reranked', { query });
|
||||||
|
executeFunctions.addOutputData(connectionType, index, [[{ json: { response } }]]);
|
||||||
|
return response;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ========== N8n Loaders Process All ==========
|
// ========== N8n Loaders Process All ==========
|
||||||
if (
|
if (
|
||||||
originalInstance instanceof N8nJsonLoader ||
|
originalInstance instanceof N8nJsonLoader ||
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ const outputTypeParsers: {
|
|||||||
},
|
},
|
||||||
[NodeConnectionTypes.AiOutputParser]: fallbackParser,
|
[NodeConnectionTypes.AiOutputParser]: fallbackParser,
|
||||||
[NodeConnectionTypes.AiRetriever]: fallbackParser,
|
[NodeConnectionTypes.AiRetriever]: fallbackParser,
|
||||||
|
[NodeConnectionTypes.AiReranker]: fallbackParser,
|
||||||
[NodeConnectionTypes.AiVectorStore](execData: IDataObject) {
|
[NodeConnectionTypes.AiVectorStore](execData: IDataObject) {
|
||||||
if (execData.documents) {
|
if (execData.documents) {
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1871,6 +1871,7 @@ export const NodeConnectionTypes = {
|
|||||||
AiMemory: 'ai_memory',
|
AiMemory: 'ai_memory',
|
||||||
AiOutputParser: 'ai_outputParser',
|
AiOutputParser: 'ai_outputParser',
|
||||||
AiRetriever: 'ai_retriever',
|
AiRetriever: 'ai_retriever',
|
||||||
|
AiReranker: 'ai_reranker',
|
||||||
AiTextSplitter: 'ai_textSplitter',
|
AiTextSplitter: 'ai_textSplitter',
|
||||||
AiTool: 'ai_tool',
|
AiTool: 'ai_tool',
|
||||||
AiVectorStore: 'ai_vectorStore',
|
AiVectorStore: 'ai_vectorStore',
|
||||||
@@ -1881,20 +1882,7 @@ export type NodeConnectionType = (typeof NodeConnectionTypes)[keyof typeof NodeC
|
|||||||
|
|
||||||
export type AINodeConnectionType = Exclude<NodeConnectionType, typeof NodeConnectionTypes.Main>;
|
export type AINodeConnectionType = Exclude<NodeConnectionType, typeof NodeConnectionTypes.Main>;
|
||||||
|
|
||||||
export const nodeConnectionTypes: NodeConnectionType[] = [
|
export const nodeConnectionTypes: NodeConnectionType[] = Object.values(NodeConnectionTypes);
|
||||||
NodeConnectionTypes.AiAgent,
|
|
||||||
NodeConnectionTypes.AiChain,
|
|
||||||
NodeConnectionTypes.AiDocument,
|
|
||||||
NodeConnectionTypes.AiEmbedding,
|
|
||||||
NodeConnectionTypes.AiLanguageModel,
|
|
||||||
NodeConnectionTypes.AiMemory,
|
|
||||||
NodeConnectionTypes.AiOutputParser,
|
|
||||||
NodeConnectionTypes.AiRetriever,
|
|
||||||
NodeConnectionTypes.AiTextSplitter,
|
|
||||||
NodeConnectionTypes.AiTool,
|
|
||||||
NodeConnectionTypes.AiVectorStore,
|
|
||||||
NodeConnectionTypes.Main,
|
|
||||||
];
|
|
||||||
|
|
||||||
export interface INodeInputFilter {
|
export interface INodeInputFilter {
|
||||||
// TODO: Later add more filter options like categories, subcatogries,
|
// TODO: Later add more filter options like categories, subcatogries,
|
||||||
@@ -2333,6 +2321,7 @@ export type AiEvent =
|
|||||||
| 'ai-message-added-to-memory'
|
| 'ai-message-added-to-memory'
|
||||||
| 'ai-output-parsed'
|
| 'ai-output-parsed'
|
||||||
| 'ai-documents-retrieved'
|
| 'ai-documents-retrieved'
|
||||||
|
| 'ai-document-reranked'
|
||||||
| 'ai-document-embedded'
|
| 'ai-document-embedded'
|
||||||
| 'ai-query-embedded'
|
| 'ai-query-embedded'
|
||||||
| 'ai-document-processed'
|
| 'ai-document-processed'
|
||||||
|
|||||||
Reference in New Issue
Block a user