mirror of
https://github.com/Abdulazizzn/n8n-enterprise-unlocked.git
synced 2025-12-17 18:12:04 +00:00
fix(Token Splitter Node): Cache tokenizer JSONs in memory (#17201)
This commit is contained in:
@@ -1,10 +1,9 @@
|
|||||||
import type { TokenTextSplitterParams } from '@langchain/textsplitters';
|
import type { TokenTextSplitterParams } from '@langchain/textsplitters';
|
||||||
import { TextSplitter } from '@langchain/textsplitters';
|
import { TextSplitter } from '@langchain/textsplitters';
|
||||||
import type * as tiktoken from 'js-tiktoken';
|
|
||||||
|
|
||||||
import { hasLongSequentialRepeat } from '@utils/helpers';
|
import { hasLongSequentialRepeat } from '@utils/helpers';
|
||||||
import { getEncoding } from '@utils/tokenizer/tiktoken';
|
import { getEncoding } from '@utils/tokenizer/tiktoken';
|
||||||
import { estimateTextSplitsByTokens } from '@utils/tokenizer/token-estimator';
|
import { estimateTextSplitsByTokens } from '@utils/tokenizer/token-estimator';
|
||||||
|
import type * as tiktoken from 'js-tiktoken';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of splitter which looks at tokens.
|
* Implementation of splitter which looks at tokens.
|
||||||
@@ -52,9 +51,7 @@ export class TokenTextSplitter extends TextSplitter implements TokenTextSplitter
|
|||||||
|
|
||||||
// Use tiktoken for normal text
|
// Use tiktoken for normal text
|
||||||
try {
|
try {
|
||||||
if (!this.tokenizer) {
|
this.tokenizer ??= getEncoding(this.encodingName);
|
||||||
this.tokenizer = await getEncoding(this.encodingName);
|
|
||||||
}
|
|
||||||
|
|
||||||
const splits: string[] = [];
|
const splits: string[] = [];
|
||||||
const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial);
|
const input_ids = this.tokenizer.encode(text, this.allowedSpecial, this.disallowedSpecial);
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ describe('TokenTextSplitter', () => {
|
|||||||
encode: jest.fn(),
|
encode: jest.fn(),
|
||||||
decode: jest.fn(),
|
decode: jest.fn(),
|
||||||
};
|
};
|
||||||
(tiktokenUtils.getEncoding as jest.Mock).mockResolvedValue(mockTokenizer);
|
(tiktokenUtils.getEncoding as jest.Mock).mockReturnValue(mockTokenizer);
|
||||||
// Default mock for hasLongSequentialRepeat - no repetition
|
// Default mock for hasLongSequentialRepeat - no repetition
|
||||||
(helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false);
|
(helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false);
|
||||||
});
|
});
|
||||||
@@ -306,7 +306,9 @@ describe('TokenTextSplitter', () => {
|
|||||||
const text = 'This will cause tiktoken to fail';
|
const text = 'This will cause tiktoken to fail';
|
||||||
|
|
||||||
(helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false);
|
(helpers.hasLongSequentialRepeat as jest.Mock).mockReturnValue(false);
|
||||||
(tiktokenUtils.getEncoding as jest.Mock).mockRejectedValue(new Error('Tiktoken error'));
|
(tiktokenUtils.getEncoding as jest.Mock).mockImplementation(() => {
|
||||||
|
throw new Error('Tiktoken error');
|
||||||
|
});
|
||||||
(tokenEstimator.estimateTextSplitsByTokens as jest.Mock).mockReturnValue([
|
(tokenEstimator.estimateTextSplitsByTokens as jest.Mock).mockReturnValue([
|
||||||
'fallback chunk',
|
'fallback chunk',
|
||||||
]);
|
]);
|
||||||
|
|||||||
@@ -13,113 +13,165 @@ jest.mock('js-tiktoken/lite', () => ({
|
|||||||
getEncodingNameForModel: jest.fn(),
|
getEncodingNameForModel: jest.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
jest.mock('../tokenizer/cl100k_base.json', () => ({ mockCl100kBase: 'data' }), { virtual: true });
|
jest.mock('fs', () => ({
|
||||||
jest.mock('../tokenizer/o200k_base.json', () => ({ mockO200kBase: 'data' }), { virtual: true });
|
readFileSync: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('n8n-workflow', () => ({
|
||||||
|
jsonParse: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
describe('tiktoken utils', () => {
|
describe('tiktoken utils', () => {
|
||||||
|
const mockReadFileSync = require('fs').readFileSync;
|
||||||
|
const mockJsonParse = require('n8n-workflow').jsonParse;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks();
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
// Set up mock implementations
|
||||||
|
mockReadFileSync.mockImplementation((path: string) => {
|
||||||
|
if (path.includes('cl100k_base.json')) {
|
||||||
|
return JSON.stringify({ mockCl100kBase: 'data' });
|
||||||
|
}
|
||||||
|
if (path.includes('o200k_base.json')) {
|
||||||
|
return JSON.stringify({ mockO200kBase: 'data' });
|
||||||
|
}
|
||||||
|
throw new Error(`Unexpected file path: ${path}`);
|
||||||
|
});
|
||||||
|
|
||||||
|
mockJsonParse.mockImplementation((content: string) => JSON.parse(content));
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getEncoding', () => {
|
describe('getEncoding', () => {
|
||||||
it('should return Tiktoken instance for cl100k_base encoding', async () => {
|
it('should return Tiktoken instance for cl100k_base encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('cl100k_base');
|
const result = getEncoding('cl100k_base');
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return Tiktoken instance for o200k_base encoding', async () => {
|
it('should return Tiktoken instance for o200k_base encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('o200k_base');
|
const result = getEncoding('o200k_base');
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockO200kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should map p50k_base to cl100k_base encoding', async () => {
|
it('should map p50k_base to cl100k_base encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('p50k_base');
|
const result = getEncoding('p50k_base');
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should map r50k_base to cl100k_base encoding', async () => {
|
it('should map r50k_base to cl100k_base encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('r50k_base');
|
const result = getEncoding('r50k_base');
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should map gpt2 to cl100k_base encoding', async () => {
|
it('should map gpt2 to cl100k_base encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('gpt2');
|
const result = getEncoding('gpt2');
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should map p50k_edit to cl100k_base encoding', async () => {
|
it('should map p50k_edit to cl100k_base encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('p50k_edit');
|
const result = getEncoding('p50k_edit');
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return cl100k_base for unknown encoding', async () => {
|
it('should return cl100k_base for unknown encoding', () => {
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await getEncoding('unknown_encoding' as unknown as TiktokenEncoding);
|
const result = getEncoding('unknown_encoding' as unknown as TiktokenEncoding);
|
||||||
|
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBe(mockTiktoken);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should use cache for repeated calls with same encoding', () => {
|
||||||
|
const mockTiktoken = {};
|
||||||
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
|
// Clear any previous calls to isolate this test
|
||||||
|
jest.clearAllMocks();
|
||||||
|
|
||||||
|
// Use a unique encoding that hasn't been cached yet
|
||||||
|
const uniqueEncoding = 'test_encoding' as TiktokenEncoding;
|
||||||
|
|
||||||
|
// First call
|
||||||
|
const result1 = getEncoding(uniqueEncoding);
|
||||||
|
expect(Tiktoken).toHaveBeenCalledTimes(1);
|
||||||
|
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' }); // Falls back to cl100k_base
|
||||||
|
|
||||||
|
// Second call - should use cache
|
||||||
|
const result2 = getEncoding(uniqueEncoding);
|
||||||
|
expect(Tiktoken).toHaveBeenCalledTimes(1); // Still only called once
|
||||||
|
expect(result1).toBe(result2);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('encodingForModel', () => {
|
describe('encodingForModel', () => {
|
||||||
it('should call getEncodingNameForModel and return encoding for cl100k_base', async () => {
|
it('should call getEncodingNameForModel and return encoding for cl100k_base', () => {
|
||||||
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
|
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = {};
|
||||||
|
|
||||||
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
|
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await encodingForModel('gpt-3.5-turbo');
|
// Clear previous calls since cl100k_base might be cached from previous tests
|
||||||
|
jest.clearAllMocks();
|
||||||
|
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
|
||||||
|
|
||||||
|
const result = encodingForModel('gpt-3.5-turbo');
|
||||||
|
|
||||||
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo');
|
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-3.5-turbo');
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
// Since cl100k_base was already loaded in previous tests, Tiktoken constructor
|
||||||
expect(result).toBe(mockTiktoken);
|
// won't be called again due to caching
|
||||||
|
expect(result).toBeTruthy();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle gpt-4 model with cl100k_base', async () => {
|
it('should handle gpt-4 model with o200k_base', () => {
|
||||||
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
|
const mockGetEncodingNameForModel = require('js-tiktoken/lite').getEncodingNameForModel;
|
||||||
const mockTiktoken = {};
|
const mockTiktoken = { isO200k: true };
|
||||||
|
|
||||||
mockGetEncodingNameForModel.mockReturnValue('cl100k_base');
|
// Use o200k_base to test a different encoding
|
||||||
|
mockGetEncodingNameForModel.mockReturnValue('o200k_base');
|
||||||
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
(Tiktoken as unknown as jest.Mock).mockReturnValue(mockTiktoken);
|
||||||
|
|
||||||
const result = await encodingForModel('gpt-4');
|
// Clear mocks and set up for this test
|
||||||
|
jest.clearAllMocks();
|
||||||
|
mockGetEncodingNameForModel.mockReturnValue('o200k_base');
|
||||||
|
|
||||||
|
const result = encodingForModel('gpt-4');
|
||||||
|
|
||||||
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4');
|
expect(mockGetEncodingNameForModel).toHaveBeenCalledWith('gpt-4');
|
||||||
expect(Tiktoken).toHaveBeenCalledWith({ mockCl100kBase: 'data' });
|
// Since o200k_base was already loaded in previous tests, we just verify the result
|
||||||
expect(result).toBe(mockTiktoken);
|
expect(result).toBeTruthy();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,30 +1,40 @@
|
|||||||
|
import { readFileSync } from 'fs';
|
||||||
import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite';
|
import type { TiktokenBPE, TiktokenEncoding, TiktokenModel } from 'js-tiktoken/lite';
|
||||||
import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite';
|
import { Tiktoken, getEncodingNameForModel } from 'js-tiktoken/lite';
|
||||||
|
import { jsonParse } from 'n8n-workflow';
|
||||||
|
import { join } from 'path';
|
||||||
|
|
||||||
import cl100k_base from './cl100k_base.json';
|
const cache: Record<string, Tiktoken> = {};
|
||||||
import o200k_base from './o200k_base.json';
|
|
||||||
|
|
||||||
export async function getEncoding(encoding: TiktokenEncoding) {
|
const loadJSONFile = (filename: string): TiktokenBPE => {
|
||||||
const encodings = {
|
const filePath = join(__dirname, filename);
|
||||||
cl100k_base: cl100k_base as TiktokenBPE,
|
const content = readFileSync(filePath, 'utf-8');
|
||||||
o200k_base: o200k_base as TiktokenBPE,
|
return jsonParse(content);
|
||||||
};
|
};
|
||||||
const encodingsMap: Record<TiktokenEncoding, TiktokenBPE> = {
|
|
||||||
cl100k_base: encodings.cl100k_base,
|
|
||||||
p50k_base: encodings.cl100k_base,
|
|
||||||
r50k_base: encodings.cl100k_base,
|
|
||||||
gpt2: encodings.cl100k_base,
|
|
||||||
p50k_edit: encodings.cl100k_base,
|
|
||||||
o200k_base: encodings.o200k_base,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!(encoding in encodingsMap)) {
|
export function getEncoding(encoding: TiktokenEncoding): Tiktoken {
|
||||||
return new Tiktoken(cl100k_base);
|
if (cache[encoding]) {
|
||||||
|
return cache[encoding];
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Tiktoken(encodingsMap[encoding]);
|
let jsonData: TiktokenBPE;
|
||||||
|
|
||||||
|
switch (encoding) {
|
||||||
|
case 'o200k_base':
|
||||||
|
jsonData = loadJSONFile('./o200k_base.json');
|
||||||
|
break;
|
||||||
|
case 'cl100k_base':
|
||||||
|
jsonData = loadJSONFile('./cl100k_base.json');
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// Fall back to cl100k_base for unsupported encodings
|
||||||
|
jsonData = loadJSONFile('./cl100k_base.json');
|
||||||
|
}
|
||||||
|
|
||||||
|
cache[encoding] = new Tiktoken(jsonData);
|
||||||
|
return cache[encoding];
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function encodingForModel(model: TiktokenModel) {
|
export function encodingForModel(model: TiktokenModel): Tiktoken {
|
||||||
return await getEncoding(getEncodingNameForModel(model));
|
return getEncoding(getEncodingNameForModel(model));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ export async function estimateTokensFromStringList(
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const encoder = await encodingForModel(model);
|
const encoder = encodingForModel(model);
|
||||||
const encodedListLength = await Promise.all(
|
const encodedListLength = await Promise.all(
|
||||||
list.map(async (text) => {
|
list.map(async (text) => {
|
||||||
try {
|
try {
|
||||||
|
|||||||
Reference in New Issue
Block a user