fix(Token Splitter Node): Cache tokenizer JSONs in memory (#17201)

This commit is contained in:
oleg
2025-07-10 19:08:29 +02:00
committed by GitHub
parent 36b410abdb
commit 2402926573
5 changed files with 115 additions and 54 deletions

View File

@@ -1,10 +1,9 @@
import type { TokenTextSplitterParams } from '@langchain/textsplitters'; import type { TokenTextSplitterParams } from '@langchain/textsplitters';
import { TextSplitter } from '@langchain/textsplitters'; import { TextSplitter } from '@langchain/textsplitters';
import type * as tiktoken from 'js-tiktoken';
import { hasLongSequentialRepeat } from '@utils/helpers'; import { hasLongSequentialRepeat } from '@utils/helpers';
import { getEncoding } from '@utils/tokenizer/tiktoken'; import { getEncoding } from '@utils/tokenizer/tiktoken';
import { estimateTextSplitsByTokens } from '@utils/tokenizer/token-estimator'; import { estimateTextSplitsByTokens } from '@utils/tokenizer/token-estimator';
import type * as tiktoken from 'js-tiktoken';
/** /**
* Implementation of splitter which looks at tokens. * Implementation of splitter which looks at tokens.
@@ -52,9 +51,7 @@ export class TokenTextSplitter extends TextSplitter implements TokenTextSplitter
// Use tiktoken for normal text // Use tiktoken for normal text
try { try {
if (!this.tokenizer) { this.tokenizer ??= getEncoding(this.encodingName);
this.tokenizer = await getEncoding(this.encodingName);
}
const splits: string[] = []; const splits: string[] = [];
const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial); const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial);

View File

@@ -20,7 +20,7 @@ describe('TokenTextSplitter', () => {
encode: jest.fn(), encode: jest.fn(),
decode: jest.fn(), decode: jest.fn(),
}; };
(tiktokenUtils.getEncoding as jest.Mock).mockResolvedValue(mockTokenizer); (tiktokenUtils.getEncoding as jest.Mock).mockReturnValue(mockTokenizer);
// Default mock for hasLongSequentialRepeat - no repetition // Default mock for hasLongSequentialRepeat - no repetition
(helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false); (helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false);
}); });
@@ -306,7 +306,9 @@ describe('TokenTextSplitter', () => {
const text = 'This will cause tiktoken to fail'; const text = 'This will cause tiktoken to fail';
(helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false); (helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false);
(tiktokenUtils.getEncoding as jest.Mock).mockRejectedValue(new Error('Tiktoken error')); (tiktokenUtils.getEncoding as jest.Mock).mockImplementation(() => {
throw new Error('Tiktoken error');
});
(tokenEstimator.estimateTextSplitsByTokens as jest.Mock).mockReturnValue([ (tokenEstimator.estimateTextSplitsByTokens as jest.Mock).mockReturnValue([
'fallback chunk', 'fallback chunk',
]); ]);

View File

@@ -13,113 +13,165 @@ jest.mock('js-tiktoken/lite', () => ({
getEncodingNameForModel: jest.fn(), getEncodingNameForModel: jest.fn(),
})); }));
jest.mock('../tokenizer/cl100k_base.json', () => ({ mockCl100kBase: 'data' }), { virtual: true }); jest.mock('fs', () => ({
jest.mock('../tokenizer/o200k_base.json', () => ({ mockO200kBase: 'data' }), { virtual: true }); readFileSync: jest.fn(),
}));
jest.mock('n8n-workflow', () => ({
jsonParse: jest.fn(),
}));
describe('tiktoken utils', () => { describe('tiktoken utils', () => {
const mockReadFileSync = require('fs').readFileSync;
const mockJsonParse = require('n8n-workflow').jsonParse;
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
// Set up mock implementations
mockReadFileSync.mockImplementation((path: string) => {
if (path.includes('cl100k_base.json')) {
return JSON.stringify({ mockCl100kBase: 'data' });
}
if (path.includes('o200k_base.json')) {
return JSON.stringify({ mockO200kBase: 'data' });
}
throw new Error(`Unexpected file path: ${path}`);
});
mockJsonParse.mockImplementation((content: string) => JSON.parse(content));
}); });
describe('getEncoding', () => { describe('getEncoding', () => {
it('should return Tiktoken instance for cl100k_base encoding', async () => { it('should return Tiktoken instance for cl100k_base encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('cl100k_base'); const result = getEncoding('cl100k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should return Tiktoken instance for o200k_base encoding', async () => { it('should return Tiktoken instance for o200k_base encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('o200k_base'); const result = getEncoding('o200k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map p50k_base to cl100k_base encoding', async () => { it('should map p50k_base to cl100k_base encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('p50k_base'); const result = getEncoding('p50k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map r50k_base to cl100k_base encoding', async () => { it('should map r50k_base to cl100k_base encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('r50k_base'); const result = getEncoding('r50k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map gpt2 to cl100k_base encoding', async () => { it('should map gpt2 to cl100k_base encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('gpt2'); const result = getEncoding('gpt2');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map p50k_edit to cl100k_base encoding', async () => { it('should map p50k_edit to cl100k_base encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('p50k_edit'); const result = getEncoding('p50k_edit');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should return cl100k_base for unknown encoding', async () => { it('should return cl100k_base for unknown encoding', () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await getEncoding('unknown_encoding' as unknown as TiktokenEncoding); const result = getEncoding('unknown_encoding' as unknown as TiktokenEncoding);
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should use cache for repeated calls with same encoding', () => {
const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
// Clear any previous calls to isolate this test
jest.clearAllMocks();
// Use a unique encoding that hasn't been cached yet
const uniqueEncoding = 'test_encoding' as TiktokenEncoding;
// First call
const result1 = getEncoding(uniqueEncoding);
expect(Tiktoken).toHaveBeenCalledTimes(1);
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); // Falls back to cl100k_base
// Second call - should use cache
const result2 = getEncoding(uniqueEncoding);
expect(Tiktoken).toHaveBeenCalledTimes(1); // Still only called once
expect(result1).toBe(result2);
});
}); });
describe('encodingForModel', () => { describe('encodingForModel', () => {
it('should call getEncodingNameForModel and return encoding for cl100k_base', async () => { it('should call getEncodingNameForModel and return encoding for cl100k_base', () => {
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel; const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
const mockTiktoken = {}; const mockTiktoken = {};
mockGetEncodingNameForModel.mockReturnValue('cl100k_base'); mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await encodingForModel('gpt-3.5-turbo'); // Clear previous calls since cl100k_base might be cached from previous tests
jest.clearAllMocks();
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
const result = encodingForModel('gpt-3.5-turbo');
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo'); expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); // Since cl100k_base was already loaded in previous tests, Tiktoken constructor
expect(result).toBe(mockTiktoken); // won't be called again due to caching
expect(result).toBeTruthy();
}); });
it('should handle gpt-4 model with cl100k_base', async () => { it('should handle gpt-4 model with o200k_base', () => {
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel; const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
const mockTiktoken = {}; const mockTiktoken = { isO200k: true };
mockGetEncodingNameForModel.mockReturnValue('cl100k_base'); // Use o200k_base to test a different encoding
mockGetEncodingNameForModel.mockReturnValue('o200k_base');
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = await encodingForModel('gpt-4'); // Clear mocks and set up for this test
jest.clearAllMocks();
mockGetEncodingNameForModel.mockReturnValue('o200k_base');
const result = encodingForModel('gpt-4');
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4'); expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); // Since o200k_base was already loaded in previous tests, we just verify the result
expect(result).toBe(mockTiktoken); expect(result).toBeTruthy();
}); });
}); });
}); });

View File

@@ -1,30 +1,40 @@
import { readFileSync } from 'fs';
import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite'; import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite';
import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite'; import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite';
import { jsonParse } from 'n8n-workflow';
import { join } from 'path';
import cl100k_base from './cl100k_base.json'; const cache: Record<string, Tiktoken> = {};
import o200k_base from './o200k_base.json';
export async function getEncoding(encoding: TiktokenEncoding) { const loadJSONFile = (filename: string): TiktokenBPE => {
const encodings = { const filePath = join(__dirname, filename);
cl100k_base: cl100k_base as TiktokenBPE, const content = readFileSync(filePath, 'utf-8');
o200k_base: o200k_base as TiktokenBPE, return jsonParse(content);
}; };
const encodingsMap: Record<TiktokenEncoding, TiktokenBPE> = {
cl100k_base: encodings.cl100k_base,
p50k_base: encodings.cl100k_base,
r50k_base: encodings.cl100k_base,
gpt2: encodings.cl100k_base,
p50k_edit: encodings.cl100k_base,
o200k_base: encodings.o200k_base,
};
if (!(encoding in encodingsMap)) { export function getEncoding(encoding: TiktokenEncoding): Tiktoken {
return new Tiktoken(cl100k_base); if (cache[encoding]) {
return cache[encoding];
} }
return new Tiktoken(encodingsMap[encoding]); let jsonData: TiktokenBPE;
switch (encoding) {
case 'o200k_base':
jsonData = loadJSONFile('./o200k_base.json');
break;
case 'cl100k_base':
jsonData = loadJSONFile('./cl100k_base.json');
break;
default:
// Fall back to cl100k_base for unsupported encodings
jsonData = loadJSONFile('./cl100k_base.json');
}
cache[encoding] = new Tiktoken(jsonData);
return cache[encoding];
} }
export async function encodingForModel(model: TiktokenModel) { export function encodingForModel(model: TiktokenModel): Tiktoken {
return await getEncoding(getEncodingNameForModel(model)); return getEncoding(getEncodingNameForModel(model));
} }

View File

@@ -136,7 +136,7 @@ export async function estimateTokensFromStringList(
return 0; return 0;
} }
const encoder = await encodingForModel(model); const encoder = encodingForModel(model);
const encodedListLength = await Promise.all( const encodedListLength = await Promise.all(
list.map(async (text) => { list.map(async (text) => {
try { try {