refactor(core): Convert tiktoken file loading to async (#18756)

This commit is contained in:
jeanpaul
2025-08-25 20:33:38 +02:00
committed by GitHub
parent fd12b3d5ce
commit e233dfa4a2
4 changed files with 37 additions and 37 deletions

View File

@@ -51,7 +51,7 @@ export class TokenTextSplitter extends TextSplitter implements TokenTextSplitter
// Use tiktoken for normal text // Use tiktoken for normal text
try { try {
this.tokenizer ??= getEncoding(this.encodingName); this.tokenizer ??= await getEncoding(this.encodingName);
const splits: string[] = []; const splits: string[] = [];
const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial); const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial);

View File

@@ -13,8 +13,8 @@ jest.mock('js-tiktoken/lite', () => ({
getEncodingNameForModel: jest.fn(), getEncodingNameForModel: jest.fn(),
})); }));
jest.mock('fs', () => ({ jest.mock('fs/promises', () => ({
readFileSync: jest.fn(), readFile: jest.fn(),
})); }));
jest.mock('n8n-workflow', () => ({ jest.mock('n8n-workflow', () => ({
@@ -22,14 +22,14 @@ jest.mock('n8n-workflow', () => ({
})); }));
describe('tiktoken utils', () => { describe('tiktoken utils', () => {
const mockReadFileSync = require('fs').readFileSync; const mockReadFile = require('fs/promises').readFile;
const mockJsonParse = require('n8n-workflow').jsonParse; const mockJsonParse = require('n8n-workflow').jsonParse;
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
// Set up mock implementations // Set up mock implementations
mockReadFileSync.mockImplementation((path: string) => { mockReadFile.mockImplementation(async (path: string) => {
if (path.includes('cl100k_base.json')) { if (path.includes('cl100k_base.json')) {
return JSON.stringify({ mockCl100kBase: 'data' }); return JSON.stringify({ mockCl100kBase: 'data' });
} }
@@ -43,77 +43,77 @@ describe('tiktoken utils', () => {
}); });
describe('getEncoding', () => { describe('getEncoding', () => {
it('should return Tiktoken instance for cl100k_base encoding', () => { it('should return Tiktoken instance for cl100k_base encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('cl100k_base'); const result = await getEncoding('cl100k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should return Tiktoken instance for o200k_base encoding', () => { it('should return Tiktoken instance for o200k_base encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('o200k_base'); const result = await getEncoding('o200k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map p50k_base to cl100k_base encoding', () => { it('should map p50k_base to cl100k_base encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('p50k_base'); const result = await getEncoding('p50k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map r50k_base to cl100k_base encoding', () => { it('should map r50k_base to cl100k_base encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('r50k_base'); const result = await getEncoding('r50k_base');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map gpt2 to cl100k_base encoding', () => { it('should map gpt2 to cl100k_base encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('gpt2'); const result = await getEncoding('gpt2');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should map p50k_edit to cl100k_base encoding', () => { it('should map p50k_edit to cl100k_base encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('p50k_edit'); const result = await getEncoding('p50k_edit');
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should return cl100k_base for unknown encoding', () => { it('should return cl100k_base for unknown encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
const result = getEncoding('unknown_encoding' as unknown as TiktokenEncoding); const result = await getEncoding('unknown_encoding' as unknown as TiktokenEncoding);
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
expect(result).toBe(mockTiktoken); expect(result).toBe(mockTiktoken);
}); });
it('should use cache for repeated calls with same encoding', () => { it('should use cache for repeated calls with same encoding', async () => {
const mockTiktoken = {}; const mockTiktoken = {};
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken); (Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
@@ -124,19 +124,19 @@ describe('tiktoken utils', () => {
const uniqueEncoding = 'test_encoding' as TiktokenEncoding; const uniqueEncoding = 'test_encoding' as TiktokenEncoding;
// First call // First call
const result1 = getEncoding(uniqueEncoding); const result1 = await getEncoding(uniqueEncoding);
expect(Tiktoken).toHaveBeenCalledTimes(1); expect(Tiktoken).toHaveBeenCalledTimes(1);
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); // Falls back to cl100k_base expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); // Falls back to cl100k_base
// Second call - should use cache // Second call - should use cache
const result2 = getEncoding(uniqueEncoding); const result2 = await getEncoding(uniqueEncoding);
expect(Tiktoken).toHaveBeenCalledTimes(1); // Still only called once expect(Tiktoken).toHaveBeenCalledTimes(1); // Still only called once
expect(result1).toBe(result2); expect(result1).toBe(result2);
}); });
}); });
describe('encodingForModel', () => { describe('encodingForModel', () => {
it('should call getEncodingNameForModel and return encoding for cl100k_base', () => { it('should call getEncodingNameForModel and return encoding for cl100k_base', async () => {
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel; const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
const mockTiktoken = {}; const mockTiktoken = {};
@@ -147,7 +147,7 @@ describe('tiktoken utils', () => {
jest.clearAllMocks(); jest.clearAllMocks();
mockGetEncodingNameForModel.mockReturnValue('cl100k_base'); mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
const result = encodingForModel('gpt-3.5-turbo'); const result = await encodingForModel('gpt-3.5-turbo');
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo'); expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo');
// Since cl100k_base was already loaded in previous tests, Tiktoken constructor // Since cl100k_base was already loaded in previous tests, Tiktoken constructor
@@ -155,7 +155,7 @@ describe('tiktoken utils', () => {
expect(result).toBeTruthy(); expect(result).toBeTruthy();
}); });
it('should handle gpt-4 model with o200k_base', () => { it('should handle gpt-4 model with o200k_base', async () => {
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel; const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
const mockTiktoken = { isO200k: true }; const mockTiktoken = { isO200k: true };
@@ -167,7 +167,7 @@ describe('tiktoken utils', () => {
jest.clearAllMocks(); jest.clearAllMocks();
mockGetEncodingNameForModel.mockReturnValue('o200k_base'); mockGetEncodingNameForModel.mockReturnValue('o200k_base');
const result = encodingForModel('gpt-4'); const result = await encodingForModel('gpt-4');
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4'); expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4');
// Since o200k_base was already loaded in previous tests, we just verify the result // Since o200k_base was already loaded in previous tests, we just verify the result

View File

@@ -1,4 +1,4 @@
import { readFileSync } from 'fs'; import { readFile } from 'fs/promises';
import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite'; import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite';
import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite'; import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite';
import { jsonParse } from 'n8n-workflow'; import { jsonParse } from 'n8n-workflow';
@@ -6,13 +6,13 @@ import { join } from 'path';
const cache: Record<string, Tiktoken> = {}; const cache: Record<string, Tiktoken> = {};
const loadJSONFile = (filename: string): TiktokenBPE => { const loadJSONFile = async (filename: string): Promise<TiktokenBPE> => {
const filePath = join(__dirname, filename); const filePath = join(__dirname, filename);
const content = readFileSync(filePath, 'utf-8'); const content = await readFile(filePath, 'utf-8');
return jsonParse(content); return await jsonParse(content);
}; };
export function getEncoding(encoding: TiktokenEncoding): Tiktoken { export async function getEncoding(encoding: TiktokenEncoding): Promise<Tiktoken> {
if (cache[encoding]) { if (cache[encoding]) {
return cache[encoding]; return cache[encoding];
} }
@@ -21,20 +21,20 @@ export function getEncoding(encoding: TiktokenEncoding): Tiktoken {
switch (encoding) { switch (encoding) {
case 'o200k_base': case 'o200k_base':
jsonData = loadJSONFile('./o200k_base.json'); jsonData = await loadJSONFile('./o200k_base.json');
break; break;
case 'cl100k_base': case 'cl100k_base':
jsonData = loadJSONFile('./cl100k_base.json'); jsonData = await loadJSONFile('./cl100k_base.json');
break; break;
default: default:
// Fall back to cl100k_base for unsupported encodings // Fall back to cl100k_base for unsupported encodings
jsonData = loadJSONFile('./cl100k_base.json'); jsonData = await loadJSONFile('./cl100k_base.json');
} }
cache[encoding] = new Tiktoken(jsonData); cache[encoding] = new Tiktoken(jsonData);
return cache[encoding]; return cache[encoding];
} }
export function encodingForModel(model: TiktokenModel): Tiktoken { export async function encodingForModel(model: TiktokenModel): Promise<Tiktoken> {
return getEncoding(getEncodingNameForModel(model)); return await getEncoding(getEncodingNameForModel(model));
} }

View File

@@ -136,7 +136,7 @@ export async function estimateTokensFromStringList(
return 0; return 0;
} }
const encoder = encodingForModel(model); const encoder = await encodingForModel(model);
const encodedListLength = await Promise.all( const encodedListLength = await Promise.all(
list.map(async (text) => { list.map(async (text) => {
try { try {