feat(Token Splitter Node): Replace remote tiktoken encoding with local implementation (#16548)

This commit is contained in:
oleg
2025-06-20 16:08:16 +02:00
committed by GitHub
parent 79650ea55a
commit 2d638023be
11 changed files with 427 additions and 32 deletions

View File

@@ -8,12 +8,12 @@ import type {
} from '@langchain/core/load/serializable';
import type { BaseMessage } from '@langchain/core/messages';
import type { LLMResult } from '@langchain/core/outputs';
import { encodingForModel } from '@langchain/core/utils/tiktoken';
import pick from 'lodash/pick';
import type { IDataObject, ISupplyDataFunctions, JsonObject } from 'n8n-workflow';
import { NodeConnectionTypes, NodeError, NodeOperationError } from 'n8n-workflow';
import { logAiEvent } from '@utils/helpers';
import { encodingForModel } from '@utils/tokenizer/tiktoken';
type TokensUsageParser = (llmOutput: LLMResult['llmOutput']) => {
completionTokens: number;

View File

@@ -1,5 +1,4 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import { TokenTextSplitter } from '@langchain/textsplitters';
import {
NodeConnectionTypes,
type INodeType,
@@ -11,6 +10,8 @@ import {
import { logWrapper } from '@utils/logWrapper';
import { getConnectionHintNoticeField } from '@utils/sharedFields';
import { TokenTextSplitter } from './TokenTextSplitter';
export class TextSplitterTokenSplitter implements INodeType {
description: INodeTypeDescription = {
displayName: 'Token Splitter',
@@ -71,9 +72,6 @@ export class TextSplitterTokenSplitter implements INodeType {
disallowedSpecial: 'all',
encodingName: 'cl100k_base',
keepSeparator: false,
// allowedSpecial: 'all',
// disallowedSpecial: 'all',
// encodingName: 'cl100k_base',
});
return {

View File

@@ -0,0 +1,57 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import type { TokenTextSplitterParams } from '@langchain/textsplitters';
import { TextSplitter } from '@langchain/textsplitters';
import type * as tiktoken from 'js-tiktoken';
import { getEncoding } from '@utils/tokenizer/tiktoken';
/**
* Implementation of splitter which looks at tokens.
* This is override of the LangChain TokenTextSplitter
* to use the n8n tokenizer utility which uses local JSON encodings
*/
export class TokenTextSplitter extends TextSplitter implements TokenTextSplitterParams {
static lc_name() {
return 'TokenTextSplitter';
}
encodingName: tiktoken.TiktokenEncoding;
allowedSpecial: 'all' | string[];
disallowedSpecial: 'all' | string[];
private tokenizer: tiktoken.Tiktoken | undefined;
constructor(fields?: Partial<TokenTextSplitterParams>) {
super(fields);
this.encodingName = fields?.encodingName ?? 'cl100k_base';
this.allowedSpecial = fields?.allowedSpecial ?? [];
this.disallowedSpecial = fields?.disallowedSpecial ?? 'all';
}
async splitText(text: string): Promise<string[]> {
if (!this.tokenizer) {
this.tokenizer = await getEncoding(this.encodingName);
}
const splits: string[] = [];
const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial);
let start_idx = 0;
while (start_idx < input_ids.length) {
if (start_idx > 0) {
start_idx -= this.chunkOverlap;
}
const end_idx = Math.min(start_idx + this.chunkSize, input_ids.length);
const chunk_ids = input_ids.slice(start_idx, end_idx);
splits.push(this.tokenizer.decode(chunk_ids));
start_idx = end_idx;
}
return splits;
}
}

View File

@@ -0,0 +1,165 @@
import * as tiktokenUtils from '../../../../utils/tokenizer/tiktoken';
import { TokenTextSplitter } from '../TokenTextSplitter';
jest.mock('../../../../utils/tokenizer/tiktoken');
describe('TokenTextSplitter', () => {
let mockTokenizer: jest.Mocked<{
encode: jest.Mock;
decode: jest.Mock;
}>;
beforeEach(() => {
mockTokenizer = {
encode: jest.fn(),
decode: jest.fn(),
};
(tiktokenUtils.getEncoding as jest.Mock).mockResolvedValue(mockTokenizer);
});
afterEach(() => {
jest.clearAllMocks();
});
describe('constructor', () => {
it('should initialize with default parameters', () => {
const splitter = new TokenTextSplitter();
expect(splitter.encodingName).toBe('cl100k_base');
expect(splitter.allowedSpecial).toEqual([]);
expect(splitter.disallowedSpecial).toBe('all');
});
it('should initialize with custom parameters', () => {
const splitter = new TokenTextSplitter({
encodingName: 'o200k_base',
allowedSpecial: ['<|special|>'],
disallowedSpecial: ['<|bad|>'],
chunkSize: 500,
chunkOverlap: 50,
});
expect(splitter.encodingName).toBe('o200k_base');
expect(splitter.allowedSpecial).toEqual(['<|special|>']);
expect(splitter.disallowedSpecial).toEqual(['<|bad|>']);
expect(splitter.chunkSize).toBe(500);
expect(splitter.chunkOverlap).toBe(50);
});
it('should have correct lc_name', () => {
expect(TokenTextSplitter.lc_name()).toBe('TokenTextSplitter');
});
});
describe('splitText', () => {
it('should split text into chunks based on token count', async () => {
const splitter = new TokenTextSplitter({
chunkSize: 3,
chunkOverlap: 0,
});
const inputText = 'Hello world, this is a test';
const mockTokenIds = [1, 2, 3, 4, 5, 6, 7, 8];
mockTokenizer.encode.mockReturnValue(mockTokenIds);
mockTokenizer.decode.mockImplementation((tokens: number[]) => {
const chunks = [
[1, 2, 3],
[4, 5, 6],
[7, 8],
];
const chunkTexts = ['Hello world,', ' this is', ' a test'];
const index = chunks.findIndex(
(chunk) => chunk.length === tokens.length && chunk.every((val, i) => val === tokens[i]),
);
return chunkTexts[index] || '';
});
const result = await splitter.splitText(inputText);
expect(tiktokenUtils.getEncoding).toHaveBeenCalledWith('cl100k_base');
expect(mockTokenizer.encode).toHaveBeenCalledWith(inputText, [], 'all');
expect(result).toEqual(['Hello world,', ' this is', ' a test']);
});
it('should handle empty text', async () => {
const splitter = new TokenTextSplitter();
mockTokenizer.encode.mockReturnValue([]);
const result = await splitter.splitText('');
expect(result).toEqual([]);
});
it('should handle text shorter than chunk size', async () => {
const splitter = new TokenTextSplitter({
chunkSize: 10,
chunkOverlap: 0,
});
const inputText = 'Short text';
const mockTokenIds = [1, 2];
mockTokenizer.encode.mockReturnValue(mockTokenIds);
mockTokenizer.decode.mockReturnValue('Short text');
const result = await splitter.splitText(inputText);
expect(result).toEqual(['Short text']);
});
it('should use custom encoding and special tokens', async () => {
const splitter = new TokenTextSplitter({
encodingName: 'o200k_base',
allowedSpecial: ['<|special|>'],
disallowedSpecial: ['<|bad|>'],
});
const inputText = 'Text with <|special|> tokens';
mockTokenizer.encode.mockReturnValue([1, 2, 3]);
mockTokenizer.decode.mockReturnValue('Text with <|special|> tokens');
await splitter.splitText(inputText);
expect(tiktokenUtils.getEncoding).toHaveBeenCalledWith('o200k_base');
expect(mockTokenizer.encode).toHaveBeenCalledWith(inputText, ['<|special|>'], ['<|bad|>']);
});
it('should reuse tokenizer on subsequent calls', async () => {
const splitter = new TokenTextSplitter();
mockTokenizer.encode.mockReturnValue([1, 2, 3]);
mockTokenizer.decode.mockReturnValue('test');
await splitter.splitText('first call');
await splitter.splitText('second call');
expect(tiktokenUtils.getEncoding).toHaveBeenCalledTimes(1);
});
it('should handle large text with multiple chunks and overlap', async () => {
const splitter = new TokenTextSplitter({
chunkSize: 2,
chunkOverlap: 1,
});
const inputText = 'One two three four five six';
const mockTokenIds = [1, 2, 3, 4, 5, 6];
mockTokenizer.encode.mockReturnValue(mockTokenIds);
mockTokenizer.decode.mockImplementation((tokens: number[]) => {
const chunkMap: Record<string, string> = {
'1,2': 'One two',
'2,3': 'two three',
'3,4': 'three four',
'4,5': 'four five',
'5,6': 'five six',
};
return chunkMap[tokens.join(',')] || '';
});
const result = await splitter.splitText(inputText);
expect(result).toEqual(['One two', 'two three', 'three four', 'four five', 'five six']);
});
});
});

View File

@@ -8,7 +8,7 @@
"dev": "pnpm run watch",
"typecheck": "tsc --noEmit",
"copy-nodes-json": "node ../../nodes-base/scripts/copy-nodes-json.js .",
"build": "tsup --tsconfig tsconfig.build.json && pnpm copy-nodes-json && tsc-alias -p tsconfig.build.json && pnpm n8n-copy-static-files && pnpm n8n-generate-metadata",
"build": "tsup --tsconfig tsconfig.build.json && pnpm copy-nodes-json && tsc-alias -p tsconfig.build.json && cp utils/tokenizer/*.json dist/utils/tokenizer/ && pnpm n8n-copy-static-files && pnpm n8n-generate-metadata",
"format": "biome format --write .",
"format:check": "biome ci .",
"lint": "eslint nodes credentials utils --quiet",
@@ -198,6 +198,7 @@
"html-to-text": "9.0.5",
"https-proxy-agent": "catalog:",
"jsdom": "23.0.1",
"js-tiktoken": "^1.0.12",
"langchain": "0.3.28",
"lodash": "catalog:",
"mammoth": "1.7.2",

View File

@@ -0,0 +1,125 @@
/* eslint-disable @typescript-eslint/no-unsafe-call */
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
/* eslint-disable @typescript-eslint/no-unsafe-member-access */
/* eslint-disable @typescript-eslint/no-var-requires */
import type { TiktokenEncoding } from 'js-tiktoken/lite';
import { Tiktoken } from 'js-tiktoken/lite';
import { getEncoding, encodingForModel } from '../tokenizer/tiktoken';
jest.mock('js-tiktoken/lite', () => ({
Tiktoken: jest.fn(),
getEncodingNameForModel: jest.fn(),
}));
jest.mock('../tokenizer/cl100k_base.json', () => ({ mockCl100kBase: 'data' }), { virtual: true });
jest.mock('../tokenizer/o200k_base.json', () => ({ mockO200kBase: 'data' }), { virtual: true });
describe('tiktoken utils', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('getEncoding', () => {
it('should return Tiktoken instance for cl100k_base encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('cl100k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should return Tiktoken instance for o200k_base encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('o200k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should map p50k_base to cl100k_base encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('p50k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should map r50k_base to cl100k_base encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('r50k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should map gpt2 to cl100k_base encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('gpt2');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should map p50k_edit to cl100k_base encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('p50k_edit');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should return cl100k_base for unknown encoding', async () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
const result = await getEncoding('unknown_encoding' as unknown as TiktokenEncoding);
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
});
describe('encodingForModel', () => {
it('should call getEncodingNameForModel and return encoding for cl100k_base', async () => {
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
const mockTiktoken = {};
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await encodingForModel('gpt-3.5-turbo');
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
it('should handle gpt-4 model with cl100k_base', async () => {
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
const mockTiktoken = {};
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await encodingForModel('gpt-4');
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken);
});
});
});

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,30 @@
import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite';
import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite';
import cl100k_base from './cl100k_base.json';
import o200k_base from './o200k_base.json';
export async function getEncoding(encoding: TiktokenEncoding) {
const encodings = {
cl100k_base: cl100k_base as TiktokenBPE,
o200k_base: o200k_base as TiktokenBPE,
};
const encodingsMap: Record<TiktokenEncoding, TiktokenBPE> = {
cl100k_base: encodings.cl100k_base,
p50k_base: encodings.cl100k_base,
r50k_base: encodings.cl100k_base,
gpt2: encodings.cl100k_base,
p50k_edit: encodings.cl100k_base,
o200k_base: encodings.o200k_base,
};
if (!(encoding in encodingsMap)) {
return new Tiktoken(cl100k_base);
}
return new Tiktoken(encodingsMap[encoding]);
}
export async function encodingForModel(model: TiktokenModel) {
return await getEncoding(getEncodingNameForModel(model));
}