refactor: Implement LLM tracing callback to improve parsing of tokens usage stats (#9311)

Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
This commit is contained in:
oleg
2024-05-12 21:12:07 +02:00
committed by GitHub
parent 244520547b
commit 359ade45bc
19 changed files with 282 additions and 111 deletions

View File

@@ -4,20 +4,13 @@ import type { ConnectionTypes, IExecuteFunctions, INodeExecutionData } from 'n8n
import { Tool } from '@langchain/core/tools';
import type { BaseMessage } from '@langchain/core/messages';
import type { InputValues, MemoryVariables, OutputValues } from '@langchain/core/memory';
import type { ChatResult } from '@langchain/core/outputs';
import { BaseChatMessageHistory } from '@langchain/core/chat_history';
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type {
CallbackManagerForLLMRun,
BaseCallbackConfig,
Callbacks,
} from '@langchain/core/callbacks/manager';
import type { BaseCallbackConfig, Callbacks } from '@langchain/core/callbacks/manager';
import { Embeddings } from '@langchain/core/embeddings';
import { VectorStore } from '@langchain/core/vectorstores';
import type { Document } from '@langchain/core/documents';
import { TextSplitter } from 'langchain/text_splitter';
import { BaseLLM } from '@langchain/core/language_models/llms';
import { BaseChatMemory } from '@langchain/community/memory/chat_memory';
import { BaseRetriever } from '@langchain/core/retrievers';
import type { FormatInstructionsOptions } from '@langchain/core/output_parsers';
@@ -26,7 +19,7 @@ import { isObject } from 'lodash';
import type { BaseDocumentLoader } from 'langchain/dist/document_loaders/base';
import { N8nJsonLoader } from './N8nJsonLoader';
import { N8nBinaryLoader } from './N8nBinaryLoader';
import { isChatInstance, logAiEvent } from './helpers';
import { logAiEvent } from './helpers';
const errorsMap: { [key: string]: { message: string; description: string } } = {
'You exceeded your current quota, please check your plan and billing details.': {
@@ -115,9 +108,7 @@ export function callMethodSync<T>(
export function logWrapper(
originalInstance:
| Tool
| BaseChatModel
| BaseChatMemory
| BaseLLM
| BaseChatMessageHistory
| BaseOutputParser
| BaseRetriever
@@ -229,56 +220,6 @@ export function logWrapper(
}
}
// ========== BaseChatModel ==========
if (originalInstance instanceof BaseLLM || isChatInstance(originalInstance)) {
if (prop === '_generate' && '_generate' in target) {
return async (
messages: BaseMessage[] & string[],
options: any,
runManager?: CallbackManagerForLLMRun,
): Promise<ChatResult> => {
connectionType = NodeConnectionType.AiLanguageModel;
const { index } = executeFunctions.addInputData(connectionType, [
[{ json: { messages, options } }],
]);
try {
const response = (await callMethodAsync.call(target, {
executeFunctions,
connectionType,
currentNodeRunIndex: index,
method: target[prop],
arguments: [
messages,
{ ...options, signal: executeFunctions.getExecutionCancelSignal() },
runManager,
],
})) as ChatResult;
const parsedMessages =
typeof messages === 'string'
? messages
: messages.map((message) => {
if (typeof message === 'string') return message;
if (typeof message?.toJSON === 'function') return message.toJSON();
return message;
});
void logAiEvent(executeFunctions, 'n8n.ai.llm.generated', {
messages: parsedMessages,
options,
response,
});
executeFunctions.addOutputData(connectionType, index, [[{ json: { response } }]]);
return response;
} catch (error) {
// Mute AbortError as they are expected
if (error?.name === 'AbortError') return { generations: [] };
throw error;
}
};
}
}
// ========== BaseOutputParser ==========
if (originalInstance instanceof BaseOutputParser) {
if (prop === 'getFormatInstructions' && 'getFormatInstructions' in target) {