From 40850c95b680a54f16fe8133ff7b801008879df2 Mon Sep 17 00:00:00 2001 From: Nikhil Kuriakose Date: Tue, 3 Jun 2025 17:14:26 +0200 Subject: [PATCH] feat(Default Data Loader Node): Add default text splitter (#15786) Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> --- .../DocumentDefaultDataLoader.node.ts | 84 +++++++++++++--- .../DocumentDefaultDataLoader.node.test.ts | 72 ++++++++++++++ .../DocumentGithubLoader.node.ts | 81 ++++++++++++--- .../test/DocumentGithubLoader.node.test.ts | 99 +++++++++++++++++++ 4 files changed, 309 insertions(+), 27 deletions(-) create mode 100644 packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/test/DocumentDefaultDataLoader.node.test.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/test/DocumentGithubLoader.node.test.ts diff --git a/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.ts b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.ts index ca9b1e5053..1b5740c56e 100644 --- a/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/DocumentDefaultDataLoader.node.ts @@ -1,11 +1,13 @@ /* eslint-disable n8n-nodes-base/node-dirname-against-convention */ -import type { TextSplitter } from '@langchain/textsplitters'; +import { RecursiveCharacterTextSplitter, type TextSplitter } from '@langchain/textsplitters'; import { NodeConnectionTypes, type INodeType, type INodeTypeDescription, type ISupplyDataFunctions, type SupplyData, + type IDataObject, + type INodeInputConfiguration, } from 'n8n-workflow'; import { logWrapper } from '@utils/logWrapper'; @@ -20,13 +22,31 @@ import 'mammoth'; // for docx import 'epub2'; // for epub import 'pdf-parse'; // for pdf +function getInputs(parameters: IDataObject) { + const inputs: INodeInputConfiguration[] = []; + + const textSplittingMode = parameters?.textSplittingMode; + // If text splitting mode is 'custom' or does not exist (v1), we need to add an input for the text splitter + if (!textSplittingMode || textSplittingMode === 'custom') { + inputs.push({ + displayName: 'Text Splitter', + maxConnections: 1, + type: 'ai_textSplitter', + required: true, + }); + } + + return inputs; +} + export class DocumentDefaultDataLoader implements INodeType { description: INodeTypeDescription = { displayName: 'Default Data Loader', name: 'documentDefaultDataLoader', icon: 'file:binary.svg', group: ['transform'], - version: 1, + version: [1, 1.1], + defaultVersion: 1.1, description: 'Load data from previous step in the workflow', defaults: { name: 'Default Data Loader', @@ -45,14 +65,7 @@ export class DocumentDefaultDataLoader implements INodeType { }, }, // eslint-disable-next-line n8n-nodes-base/node-class-description-inputs-wrong-regular-node - inputs: [ - { - displayName: 'Text Splitter', - maxConnections: 1, - type: NodeConnectionTypes.AiTextSplitter, - required: true, - }, - ], + inputs: `={{ ((parameter) => { ${getInputs.toString()}; return getInputs(parameter) })($parameter) }}`, // eslint-disable-next-line n8n-nodes-base/node-class-description-outputs-wrong outputs: [NodeConnectionTypes.AiDocument], outputNames: ['Document'], @@ -64,6 +77,31 @@ export class DocumentDefaultDataLoader implements INodeType { type: 'notice', default: '', }, + { + displayName: 'Text Splitting', + name: 'textSplittingMode', + type: 'options', + default: 'simple', + required: true, + noDataExpression: true, + displayOptions: { + show: { + '@version': [1.1], + }, + }, + options: [ + { + name: 'Simple', + value: 'simple', + description: 'Uses the Recursive Character Text Splitter with default options', + }, + { + name: 'Custom', + value: 'custom', + description: 'Connect a text splitter of your choice', + }, + ], + }, { displayName: 'Type of Data', name: 'dataType', @@ -284,11 +322,29 @@ export class DocumentDefaultDataLoader implements INodeType { }; async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise { + const node = this.getNode(); const dataType = this.getNodeParameter('dataType', itemIndex, 'json') as 'json' | 'binary'; - const textSplitter = (await this.getInputConnectionData( - NodeConnectionTypes.AiTextSplitter, - 0, - )) as TextSplitter | undefined; + + let textSplitter: TextSplitter | undefined; + + if (node.typeVersion === 1.1) { + const textSplittingMode = this.getNodeParameter('textSplittingMode', itemIndex, 'simple') as + | 'simple' + | 'custom'; + + if (textSplittingMode === 'simple') { + textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000, chunkOverlap: 200 }); + } else if (textSplittingMode === 'custom') { + textSplitter = (await this.getInputConnectionData(NodeConnectionTypes.AiTextSplitter, 0)) as + | TextSplitter + | undefined; + } + } else { + textSplitter = (await this.getInputConnectionData(NodeConnectionTypes.AiTextSplitter, 0)) as + | TextSplitter + | undefined; + } + const binaryDataKey = this.getNodeParameter('binaryDataKey', itemIndex, '') as string; const processor = diff --git a/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/test/DocumentDefaultDataLoader.node.test.ts b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/test/DocumentDefaultDataLoader.node.test.ts new file mode 100644 index 0000000000..3515ef5163 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentDefaultDataLoader/test/DocumentDefaultDataLoader.node.test.ts @@ -0,0 +1,72 @@ +import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'; +import type { ISupplyDataFunctions } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; + +import { DocumentDefaultDataLoader } from '../DocumentDefaultDataLoader.node'; + +jest.mock('@langchain/textsplitters', () => ({ + RecursiveCharacterTextSplitter: jest.fn().mockImplementation(() => ({ + splitDocuments: jest.fn( + async (docs: Array>): Promise>> => + docs.map((doc) => ({ ...doc, split: true })), + ), + })), +})); + +describe('DocumentDefaultDataLoader', () => { + let loader: DocumentDefaultDataLoader; + + beforeEach(() => { + loader = new DocumentDefaultDataLoader(); + jest.clearAllMocks(); + }); + + it('should supply data with recursive char text splitter', async () => { + const context = { + getNode: jest.fn(() => ({ typeVersion: 1.1 })), + getNodeParameter: jest.fn().mockImplementation((paramName, _itemIndex) => { + switch (paramName) { + case 'dataType': + return 'json'; + case 'textSplittingMode': + return 'simple'; + case 'binaryDataKey': + return 'data'; + default: + return; + } + }), + } as unknown as ISupplyDataFunctions; + + await loader.supplyData.call(context, 0); + expect(RecursiveCharacterTextSplitter).toHaveBeenCalledWith({ + chunkSize: 1000, + chunkOverlap: 200, + }); + }); + + it('should supply data with custom text splitter', async () => { + const customSplitter = { splitDocuments: jest.fn(async (docs) => docs) }; + const context = { + getNode: jest.fn(() => ({ typeVersion: 1.1 })), + getNodeParameter: jest.fn().mockImplementation((paramName, _itemIndex) => { + switch (paramName) { + case 'dataType': + return 'json'; + case 'textSplittingMode': + return 'custom'; + case 'binaryDataKey': + return 'data'; + default: + return; + } + }), + getInputConnectionData: jest.fn(async () => customSplitter), + } as unknown as ISupplyDataFunctions; + await loader.supplyData.call(context, 0); + expect(context.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiTextSplitter, + 0, + ); + }); +}); diff --git a/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/DocumentGithubLoader.node.ts b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/DocumentGithubLoader.node.ts index 49efeb0cb1..fa30b91d7d 100644 --- a/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/DocumentGithubLoader.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/DocumentGithubLoader.node.ts @@ -1,24 +1,45 @@ /* eslint-disable n8n-nodes-base/node-dirname-against-convention */ import { GithubRepoLoader } from '@langchain/community/document_loaders/web/github'; -import type { CharacterTextSplitter } from '@langchain/textsplitters'; +import type { TextSplitter } from '@langchain/textsplitters'; +import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'; import { NodeConnectionTypes, type INodeType, type INodeTypeDescription, type ISupplyDataFunctions, type SupplyData, + type IDataObject, + type INodeInputConfiguration, } from 'n8n-workflow'; import { logWrapper } from '@utils/logWrapper'; import { getConnectionHintNoticeField } from '@utils/sharedFields'; +function getInputs(parameters: IDataObject) { + const inputs: INodeInputConfiguration[] = []; + + const textSplittingMode = parameters?.textSplittingMode; + // If text splitting mode is 'custom' or does not exist (v1), we need to add an input for the text splitter + if (!textSplittingMode || textSplittingMode === 'custom') { + inputs.push({ + displayName: 'Text Splitter', + maxConnections: 1, + type: 'ai_textSplitter', + required: true, + }); + } + + return inputs; +} + export class DocumentGithubLoader implements INodeType { description: INodeTypeDescription = { displayName: 'GitHub Document Loader', name: 'documentGithubLoader', icon: 'file:github.svg', group: ['transform'], - version: 1, + version: [1, 1.1], + defaultVersion: 1.1, description: 'Use GitHub data as input to this chain', defaults: { name: 'GitHub Document Loader', @@ -43,19 +64,38 @@ export class DocumentGithubLoader implements INodeType { }, ], // eslint-disable-next-line n8n-nodes-base/node-class-description-inputs-wrong-regular-node - inputs: [ - { - displayName: 'Text Splitter', - maxConnections: 1, - type: NodeConnectionTypes.AiTextSplitter, - }, - ], + inputs: `={{ ((parameter) => { ${getInputs.toString()}; return getInputs(parameter) })($parameter) }}`, inputNames: ['Text Splitter'], // eslint-disable-next-line n8n-nodes-base/node-class-description-outputs-wrong outputs: [NodeConnectionTypes.AiDocument], outputNames: ['Document'], properties: [ getConnectionHintNoticeField([NodeConnectionTypes.AiVectorStore]), + { + displayName: 'Text Splitting', + name: 'textSplittingMode', + type: 'options', + default: 'simple', + required: true, + noDataExpression: true, + displayOptions: { + show: { + '@version': [1.1], + }, + }, + options: [ + { + name: 'Simple', + value: 'simple', + description: 'Uses Recursive Character Text Splitter with default options', + }, + { + name: 'Custom', + value: 'custom', + description: 'Connect a text splitter of your choice', + }, + ], + }, { displayName: 'Repository Link', name: 'repository', @@ -96,6 +136,7 @@ export class DocumentGithubLoader implements INodeType { async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise { this.logger.debug('Supplying data for Github Document Loader'); + const node = this.getNode(); const repository = this.getNodeParameter('repository', itemIndex) as string; const branch = this.getNodeParameter('branch', itemIndex) as string; @@ -104,11 +145,25 @@ export class DocumentGithubLoader implements INodeType { recursive: boolean; ignorePaths: string; }; + let textSplitter: TextSplitter | undefined; - const textSplitter = (await this.getInputConnectionData( - NodeConnectionTypes.AiTextSplitter, - 0, - )) as CharacterTextSplitter | undefined; + if (node.typeVersion === 1.1) { + const textSplittingMode = this.getNodeParameter('textSplittingMode', itemIndex, 'simple') as + | 'simple' + | 'custom'; + + if (textSplittingMode === 'simple') { + textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000, chunkOverlap: 200 }); + } else if (textSplittingMode === 'custom') { + textSplitter = (await this.getInputConnectionData(NodeConnectionTypes.AiTextSplitter, 0)) as + | TextSplitter + | undefined; + } + } else { + textSplitter = (await this.getInputConnectionData(NodeConnectionTypes.AiTextSplitter, 0)) as + | TextSplitter + | undefined; + } const { index } = this.addInputData(NodeConnectionTypes.AiDocument, [ [{ json: { repository, branch, ignorePaths, recursive } }], diff --git a/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/test/DocumentGithubLoader.node.test.ts b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/test/DocumentGithubLoader.node.test.ts new file mode 100644 index 0000000000..9b8f16b682 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/document_loaders/DocumentGithubLoader/test/DocumentGithubLoader.node.test.ts @@ -0,0 +1,99 @@ +import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'; +import type { ISupplyDataFunctions } from 'n8n-workflow'; +import { NodeConnectionTypes } from 'n8n-workflow'; + +import { DocumentGithubLoader } from '../DocumentGithubLoader.node'; + +jest.mock('@langchain/textsplitters', () => ({ + RecursiveCharacterTextSplitter: jest.fn().mockImplementation(() => ({ + splitDocuments: jest.fn( + async (docs: Array<{ [key: string]: unknown }>): Promise> => + docs.map((doc) => ({ ...doc, split: true })), + ), + })), +})); +jest.mock('@langchain/community/document_loaders/web/github', () => ({ + GithubRepoLoader: jest.fn().mockImplementation(() => ({ + load: jest.fn(async () => [{ pageContent: 'doc1' }, { pageContent: 'doc2' }]), + })), +})); + +const mockLogger = { debug: jest.fn() }; + +describe('DocumentGithubLoader', () => { + let loader: DocumentGithubLoader; + + beforeEach(() => { + loader = new DocumentGithubLoader(); + jest.clearAllMocks(); + }); + + it('should supply data with recursive char text splitter', async () => { + const context = { + logger: mockLogger, + getNode: jest.fn(() => ({ typeVersion: 1.1 })), + getNodeParameter: jest.fn().mockImplementation((paramName, _itemIndex) => { + switch (paramName) { + case 'repository': + return 'owner/repo'; + case 'branch': + return 'main'; + case 'textSplittingMode': + return 'simple'; + case 'additionalOptions': + return { recursive: true, ignorePaths: 'docs,tests' }; + default: + return; + } + }), + getCredentials: jest.fn().mockResolvedValue({ + accessToken: 'token', + server: 'https://api.github.com', + }), + addInputData: jest.fn(() => ({ index: 0 })), + addOutputData: jest.fn(), + } as unknown as ISupplyDataFunctions; + await loader.supplyData.call(context, 0); + + expect(RecursiveCharacterTextSplitter).toHaveBeenCalledWith({ + chunkSize: 1000, + chunkOverlap: 200, + }); + }); + + it('should use custom text splitter when textSplittingMode is custom', async () => { + const customSplitter = { splitDocuments: jest.fn(async (docs) => docs) }; + const context = { + logger: mockLogger, + getNode: jest.fn(() => ({ typeVersion: 1.1 })), + getNodeParameter: jest.fn().mockImplementation((paramName, _itemIndex) => { + switch (paramName) { + case 'repository': + return 'owner/repo'; + case 'branch': + return 'main'; + case 'textSplittingMode': + return 'custom'; + case 'additionalOptions': + return { recursive: true, ignorePaths: 'docs,tests' }; + default: + return; + } + }), + getCredentials: jest.fn().mockResolvedValue({ + accessToken: 'token', + server: 'https://api.github.com', + }), + getInputConnectionData: jest.fn(async () => customSplitter), + addInputData: jest.fn(() => ({ index: 0 })), + addOutputData: jest.fn(), + } as unknown as ISupplyDataFunctions; + await loader.supplyData.call(context, 0); + + expect(context.getInputConnectionData).toHaveBeenCalledWith( + NodeConnectionTypes.AiTextSplitter, + 0, + ); + expect(customSplitter.splitDocuments).toHaveBeenCalled(); + }); +});