From c37397cb2b7fb8b14de613ca6fae24c8db7c380c Mon Sep 17 00:00:00 2001 From: oleg Date: Wed, 9 Jul 2025 13:20:25 +0200 Subject: [PATCH] feat(Cohere Chat Model Node): Add Cohere Chat Model node (#16888) --- .../credentials/CohereApi.credentials.ts | 8 +- .../Agent/agents/ToolsAgent/V2/execute.ts | 5 +- .../LMChatAnthropic/LmChatAnthropic.node.ts | 7 +- .../llms/LmChatCohere/LmChatCohere.node.ts | 177 ++++++++ .../nodes/llms/LmChatCohere/cohere.dark.svg | 5 + .../nodes/llms/LmChatCohere/cohere.svg | 5 + .../nodes/llms/N8nLlmTracing.ts | 10 +- .../nodes/llms/test/N8nLlmTracing.test.ts | 390 ++++++++++++++++++ packages/@n8n/nodes-langchain/package.json | 4 +- pnpm-lock.yaml | 3 + 10 files changed, 604 insertions(+), 10 deletions(-) create mode 100644 packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/LmChatCohere.node.ts create mode 100644 packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.dark.svg create mode 100644 packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.svg create mode 100644 packages/@n8n/nodes-langchain/nodes/llms/test/N8nLlmTracing.test.ts diff --git a/packages/@n8n/nodes-langchain/credentials/CohereApi.credentials.ts b/packages/@n8n/nodes-langchain/credentials/CohereApi.credentials.ts index f243593e7b..9a1117af9e 100644 --- a/packages/@n8n/nodes-langchain/credentials/CohereApi.credentials.ts +++ b/packages/@n8n/nodes-langchain/credentials/CohereApi.credentials.ts @@ -21,6 +21,12 @@ export class CohereApi implements ICredentialType { required: true, default: '', }, + { + displayName: 'Base URL', + name: 'url', + type: 'hidden', + default: 'https://api.cohere.ai', + }, ]; authenticate: IAuthenticateGeneric = { @@ -34,7 +40,7 @@ export class CohereApi implements ICredentialType { test: ICredentialTestRequest = { request: { - baseURL: 'https://api.cohere.ai', + baseURL: '={{ $credentials.url }}', url: '/v1/models?page_size=1', }, }; diff --git a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts index 55dbf859c8..f770e09d17 100644 --- a/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts +++ b/packages/@n8n/nodes-langchain/nodes/agents/Agent/agents/ToolsAgent/V2/execute.ts @@ -65,7 +65,10 @@ function createAgentExecutor( fallbackAgent ? agent.withFallbacks([fallbackAgent]) : agent, getAgentStepsParser(outputParser, memory), fixEmptyContentMessage, - ]); + ]) as AgentRunnableSequence; + + runnableAgent.singleAction = false; + runnableAgent.streamRunnable = false; return AgentExecutor.fromAgentAndTools({ agent: runnableAgent, diff --git a/packages/@n8n/nodes-langchain/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.ts b/packages/@n8n/nodes-langchain/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.ts index f5053fda35..f8e506118b 100644 --- a/packages/@n8n/nodes-langchain/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.ts @@ -285,8 +285,11 @@ export class LmChatAnthropic implements INodeType { }; let invocationKwargs = {}; - const tokensUsageParser = (llmOutput: LLMResult['llmOutput']) => { - const usage = (llmOutput?.usage as { input_tokens: number; output_tokens: number }) ?? { + const tokensUsageParser = (result: LLMResult) => { + const usage = (result?.llmOutput?.usage as { + input_tokens: number; + output_tokens: number; + }) ?? { input_tokens: 0, output_tokens: 0, }; diff --git a/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/LmChatCohere.node.ts b/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/LmChatCohere.node.ts new file mode 100644 index 0000000000..3d8c10a188 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/LmChatCohere.node.ts @@ -0,0 +1,177 @@ +import { ChatCohere } from '@langchain/cohere'; +import type { LLMResult } from '@langchain/core/outputs'; +import type { + INodeType, + INodeTypeDescription, + ISupplyDataFunctions, + SupplyData, +} from 'n8n-workflow'; + +import { getConnectionHintNoticeField } from '@utils/sharedFields'; + +import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler'; +import { N8nLlmTracing } from '../N8nLlmTracing'; + +export function tokensUsageParser(result: LLMResult): { + completionTokens: number; + promptTokens: number; + totalTokens: number; +} { + let totalInputTokens = 0; + let totalOutputTokens = 0; + + result.generations?.forEach((generationArray) => { + generationArray.forEach((gen) => { + const inputTokens = gen.generationInfo?.meta?.tokens?.inputTokens ?? 0; + const outputTokens = gen.generationInfo?.meta?.tokens?.outputTokens ?? 0; + + totalInputTokens += inputTokens; + totalOutputTokens += outputTokens; + }); + }); + + return { + completionTokens: totalOutputTokens, + promptTokens: totalInputTokens, + totalTokens: totalInputTokens + totalOutputTokens, + }; +} + +export class LmChatCohere implements INodeType { + description: INodeTypeDescription = { + displayName: 'Cohere Chat Model', + name: 'lmChatCohere', + icon: { light: 'file:cohere.svg', dark: 'file:cohere.dark.svg' }, + group: ['transform'], + version: [1], + description: 'For advanced usage with an AI chain', + defaults: { + name: 'Cohere Chat Model', + }, + codex: { + categories: ['AI'], + subcategories: { + AI: ['Language Models', 'Root Nodes'], + 'Language Models': ['Chat Models (Recommended)'], + }, + resources: { + primaryDocumentation: [ + { + url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.lmchatcohere/', + }, + ], + }, + }, + inputs: [], + outputs: ['ai_languageModel'], + outputNames: ['Model'], + credentials: [ + { + name: 'cohereApi', + required: true, + }, + ], + requestDefaults: { + baseURL: '={{$credentials?.url}}', + headers: { + accept: 'application/json', + authorization: '=Bearer {{$credentials?.apiKey}}', + }, + }, + properties: [ + getConnectionHintNoticeField(['ai_chain', 'ai_agent']), + { + displayName: 'Model', + name: 'model', + type: 'options', + description: + 'The model which will generate the completion. Learn more.', + typeOptions: { + loadOptions: { + routing: { + request: { + method: 'GET', + url: '/v1/models?page_size=100&endpoint=chat', + }, + output: { + postReceive: [ + { + type: 'rootProperty', + properties: { + property: 'models', + }, + }, + { + type: 'setKeyValue', + properties: { + name: '={{$responseItem.name}}', + value: '={{$responseItem.name}}', + description: '={{$responseItem.description}}', + }, + }, + { + type: 'sort', + properties: { + key: 'name', + }, + }, + ], + }, + }, + }, + }, + default: 'command-a-03-2025', + }, + { + displayName: 'Options', + name: 'options', + placeholder: 'Add Option', + description: 'Additional options to add', + type: 'collection', + default: {}, + options: [ + { + displayName: 'Sampling Temperature', + name: 'temperature', + default: 0.7, + typeOptions: { maxValue: 2, minValue: 0, numberPrecision: 1 }, + description: + 'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.', + type: 'number', + }, + { + displayName: 'Max Retries', + name: 'maxRetries', + default: 2, + description: 'Maximum number of retries to attempt', + type: 'number', + }, + ], + }, + ], + }; + + async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise { + const credentials = await this.getCredentials<{ url?: string; apiKey?: string }>('cohereApi'); + + const modelName = this.getNodeParameter('model', itemIndex) as string; + + const options = this.getNodeParameter('options', itemIndex, {}) as { + maxRetries: number; + temperature?: number; + }; + + const model = new ChatCohere({ + apiKey: credentials.apiKey, + model: modelName, + temperature: options.temperature, + maxRetries: options.maxRetries ?? 2, + callbacks: [new N8nLlmTracing(this, { tokensUsageParser })], + onFailedAttempt: makeN8nLlmFailedAttemptHandler(this), + }); + + return { + response: model, + }; + } +} diff --git a/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.dark.svg b/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.dark.svg new file mode 100644 index 0000000000..796fe1bcbc --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.dark.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.svg b/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.svg new file mode 100644 index 0000000000..c54ba34ee8 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/llms/LmChatCohere/cohere.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/packages/@n8n/nodes-langchain/nodes/llms/N8nLlmTracing.ts b/packages/@n8n/nodes-langchain/nodes/llms/N8nLlmTracing.ts index f3ecfa9f8f..6b6e48cf46 100644 --- a/packages/@n8n/nodes-langchain/nodes/llms/N8nLlmTracing.ts +++ b/packages/@n8n/nodes-langchain/nodes/llms/N8nLlmTracing.ts @@ -15,7 +15,7 @@ import { NodeConnectionTypes, NodeError, NodeOperationError } from 'n8n-workflow import { logAiEvent } from '@utils/helpers'; import { estimateTokensFromStringList } from '@utils/tokenizer/token-estimator'; -type TokensUsageParser = (llmOutput: LLMResult['llmOutput']) => { +type TokensUsageParser = (result: LLMResult) => { completionTokens: number; promptTokens: number; totalTokens: number; @@ -53,9 +53,9 @@ export class N8nLlmTracing extends BaseCallbackHandler { options = { // Default(OpenAI format) parser - tokensUsageParser: (llmOutput: LLMResult['llmOutput']) => { - const completionTokens = (llmOutput?.tokenUsage?.completionTokens as number) ?? 0; - const promptTokens = (llmOutput?.tokenUsage?.promptTokens as number) ?? 0; + tokensUsageParser: (result: LLMResult) => { + const completionTokens = (result?.llmOutput?.tokenUsage?.completionTokens as number) ?? 0; + const promptTokens = (result?.llmOutput?.tokenUsage?.promptTokens as number) ?? 0; return { completionTokens, @@ -101,7 +101,7 @@ export class N8nLlmTracing extends BaseCallbackHandler { promptTokens: 0, totalTokens: 0, }; - const tokenUsage = this.options.tokensUsageParser(output.llmOutput); + const tokenUsage = this.options.tokensUsageParser(output); if (output.generations.length > 0) { tokenUsageEstimate.completionTokens = await this.estimateTokensFromGeneration( diff --git a/packages/@n8n/nodes-langchain/nodes/llms/test/N8nLlmTracing.test.ts b/packages/@n8n/nodes-langchain/nodes/llms/test/N8nLlmTracing.test.ts new file mode 100644 index 0000000000..889a0fa649 --- /dev/null +++ b/packages/@n8n/nodes-langchain/nodes/llms/test/N8nLlmTracing.test.ts @@ -0,0 +1,390 @@ +/* eslint-disable @typescript-eslint/no-unsafe-member-access */ + +/* eslint-disable @typescript-eslint/unbound-method */ +/* eslint-disable @typescript-eslint/no-unsafe-assignment */ +import type { Serialized } from '@langchain/core/load/serializable'; +import type { LLMResult } from '@langchain/core/outputs'; +import { mock } from 'jest-mock-extended'; +import type { IDataObject, ISupplyDataFunctions } from 'n8n-workflow'; +import { NodeOperationError, NodeApiError } from 'n8n-workflow'; + +import { N8nLlmTracing } from '../N8nLlmTracing'; + +describe('N8nLlmTracing', () => { + const executionFunctions = mock({ + addInputData: jest.fn().mockReturnValue({ index: 0 }), + addOutputData: jest.fn(), + getNode: jest.fn().mockReturnValue({ name: 'TestNode' }), + getNextRunIndex: jest.fn().mockReturnValue(1), + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('tokensUsageParser', () => { + it('should parse OpenAI format tokens correctly', () => { + const tracer = new N8nLlmTracing(executionFunctions); + const llmResult: LLMResult = { + generations: [], + llmOutput: { + tokenUsage: { + completionTokens: 100, + promptTokens: 50, + }, + }, + }; + + const result = tracer.options.tokensUsageParser(llmResult); + + expect(result).toEqual({ + completionTokens: 100, + promptTokens: 50, + totalTokens: 150, + }); + }); + + it('should handle missing token data', () => { + const tracer = new N8nLlmTracing(executionFunctions); + const llmResult: LLMResult = { + generations: [], + }; + + const result = tracer.options.tokensUsageParser(llmResult); + + expect(result).toEqual({ + completionTokens: 0, + promptTokens: 0, + totalTokens: 0, + }); + }); + + it('should handle undefined llmOutput', () => { + const tracer = new N8nLlmTracing(executionFunctions); + const llmResult: LLMResult = { + generations: [], + llmOutput: undefined, + }; + + const result = tracer.options.tokensUsageParser(llmResult); + + expect(result).toEqual({ + completionTokens: 0, + promptTokens: 0, + totalTokens: 0, + }); + }); + + it('should use custom tokensUsageParser when provided', () => { + // Custom parser for Cohere format + const customParser = (result: LLMResult) => { + let totalInputTokens = 0; + let totalOutputTokens = 0; + + result.generations?.forEach((generationArray) => { + generationArray.forEach((gen) => { + const inputTokens = gen.generationInfo?.meta?.tokens?.inputTokens ?? 0; + const outputTokens = gen.generationInfo?.meta?.tokens?.outputTokens ?? 0; + + totalInputTokens += inputTokens; + totalOutputTokens += outputTokens; + }); + }); + + return { + completionTokens: totalOutputTokens, + promptTokens: totalInputTokens, + totalTokens: totalInputTokens + totalOutputTokens, + }; + }; + + const tracer = new N8nLlmTracing(executionFunctions, { + tokensUsageParser: customParser, + }); + + const llmResult: LLMResult = { + generations: [ + [ + { + text: 'Response 1', + generationInfo: { + meta: { + tokens: { + inputTokens: 30, + outputTokens: 40, + }, + }, + }, + }, + ], + [ + { + text: 'Response 2', + generationInfo: { + meta: { + tokens: { + inputTokens: 20, + outputTokens: 60, + }, + }, + }, + }, + ], + ], + }; + + const result = tracer.options.tokensUsageParser(llmResult); + + expect(result).toEqual({ + completionTokens: 100, // 40 + 60 + promptTokens: 50, // 30 + 20 + totalTokens: 150, + }); + }); + + it('should handle Anthropic format with custom parser', () => { + const anthropicParser = (result: LLMResult) => { + const usage = (result?.llmOutput?.usage as { + input_tokens: number; + output_tokens: number; + }) ?? { + input_tokens: 0, + output_tokens: 0, + }; + return { + completionTokens: usage.output_tokens, + promptTokens: usage.input_tokens, + totalTokens: usage.input_tokens + usage.output_tokens, + }; + }; + + const tracer = new N8nLlmTracing(executionFunctions, { + tokensUsageParser: anthropicParser, + }); + + const llmResult: LLMResult = { + generations: [], + llmOutput: { + usage: { + input_tokens: 75, + output_tokens: 125, + }, + }, + }; + + const result = tracer.options.tokensUsageParser(llmResult); + + expect(result).toEqual({ + completionTokens: 125, + promptTokens: 75, + totalTokens: 200, + }); + }); + }); + + describe('handleLLMEnd', () => { + it('should process LLM output and use token usage when available', async () => { + const tracer = new N8nLlmTracing(executionFunctions); + const runId = 'test-run-id'; + + // Set up run details + tracer.runsMap[runId] = { + index: 0, + messages: ['Test prompt'], + options: { model: 'test-model' }, + }; + + const output: LLMResult = { + generations: [ + [ + { + text: 'Test response', + generationInfo: { meta: {} }, + }, + ], + ], + llmOutput: { + tokenUsage: { + completionTokens: 50, + promptTokens: 25, + }, + }, + }; + + await tracer.handleLLMEnd(output, runId); + + expect(executionFunctions.addOutputData).toHaveBeenCalledWith( + 'ai_languageModel', + 0, + [ + [ + { + json: expect.objectContaining({ + response: { generations: output.generations }, + tokenUsage: { + completionTokens: 50, + promptTokens: 25, + totalTokens: 75, + }, + }), + }, + ], + ], + undefined, + undefined, + ); + }); + + it('should use token estimates when actual usage is not available', async () => { + const tracer = new N8nLlmTracing(executionFunctions); + const runId = 'test-run-id'; + + // Set up run details and prompt estimate + tracer.runsMap[runId] = { + index: 0, + messages: ['Test prompt'], + options: { model: 'test-model' }, + }; + tracer.promptTokensEstimate = 30; + + const output: LLMResult = { + generations: [ + [ + { + text: 'Test response', + generationInfo: { meta: {} }, + }, + ], + ], + llmOutput: {}, + }; + + jest.spyOn(tracer, 'estimateTokensFromGeneration').mockResolvedValue(45); + + await tracer.handleLLMEnd(output, runId); + + expect(executionFunctions.addOutputData).toHaveBeenCalledWith( + 'ai_languageModel', + 0, + [ + [ + { + json: expect.objectContaining({ + response: { generations: output.generations }, + tokenUsageEstimate: { + completionTokens: 45, + promptTokens: 30, + totalTokens: 75, + }, + }), + }, + ], + ], + undefined, + undefined, + ); + }); + }); + + describe('handleLLMError', () => { + it('should handle NodeError with custom error description mapper', async () => { + const customMapper = jest.fn().mockReturnValue('Mapped error description'); + const tracer = new N8nLlmTracing(executionFunctions, { + errorDescriptionMapper: customMapper, + }); + + const runId = 'test-run-id'; + tracer.runsMap[runId] = { index: 0, messages: [], options: {} }; + + const error = new NodeApiError(executionFunctions.getNode(), { + message: 'Test error', + description: 'Original description', + }); + + await tracer.handleLLMError(error, runId); + + expect(customMapper).toHaveBeenCalledWith(error); + expect(error.description).toBe('Mapped error description'); + expect(executionFunctions.addOutputData).toHaveBeenCalledWith('ai_languageModel', 0, error); + }); + + it('should wrap non-NodeError in NodeOperationError', async () => { + const tracer = new N8nLlmTracing(executionFunctions); + const runId = 'test-run-id'; + tracer.runsMap[runId] = { index: 0, messages: [], options: {} }; + + const error = new Error('Regular error'); + + await tracer.handleLLMError(error, runId); + + expect(executionFunctions.addOutputData).toHaveBeenCalledWith( + 'ai_languageModel', + 0, + expect.any(NodeOperationError), + ); + }); + + it('should filter out non-x- headers from error objects', async () => { + const tracer = new N8nLlmTracing(executionFunctions); + const runId = 'test-run-id'; + tracer.runsMap[runId] = { index: 0, messages: [], options: {} }; + + const error = { + message: 'API Error', + headers: { + 'x-request-id': 'keep-this', + authorization: 'remove-this', + 'x-rate-limit': 'keep-this-too', + 'content-type': 'remove-this-too', + }, + }; + + await tracer.handleLLMError(error as IDataObject, runId); + + expect(error.headers).toEqual({ + 'x-request-id': 'keep-this', + 'x-rate-limit': 'keep-this-too', + }); + }); + }); + + describe('handleLLMStart', () => { + it('should estimate tokens and create run details', async () => { + const tracer = new N8nLlmTracing(executionFunctions); + const runId = 'test-run-id'; + const prompts = ['Prompt 1', 'Prompt 2']; + + jest.spyOn(tracer, 'estimateTokensFromStringList').mockResolvedValue(100); + + const llm = { + type: 'constructor', + kwargs: { model: 'test-model' }, + }; + + await tracer.handleLLMStart(llm as unknown as Serialized, prompts, runId); + + expect(tracer.estimateTokensFromStringList).toHaveBeenCalledWith(prompts); + expect(tracer.promptTokensEstimate).toBe(100); + expect(tracer.runsMap[runId]).toEqual({ + index: 0, + options: { model: 'test-model' }, + messages: prompts, + }); + expect(executionFunctions.addInputData).toHaveBeenCalledWith( + 'ai_languageModel', + [ + [ + { + json: { + messages: prompts, + estimatedTokens: 100, + options: { model: 'test-model' }, + }, + }, + ], + ], + undefined, + ); + }); + }); +}); diff --git a/packages/@n8n/nodes-langchain/package.json b/packages/@n8n/nodes-langchain/package.json index 61face1cab..1bc148c10a 100644 --- a/packages/@n8n/nodes-langchain/package.json +++ b/packages/@n8n/nodes-langchain/package.json @@ -76,6 +76,7 @@ "dist/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.js", "dist/nodes/llms/LmChatAzureOpenAi/LmChatAzureOpenAi.node.js", "dist/nodes/llms/LmChatAwsBedrock/LmChatAwsBedrock.node.js", + "dist/nodes/llms/LmChatCohere/LmChatCohere.node.js", "dist/nodes/llms/LmChatDeepSeek/LmChatDeepSeek.node.js", "dist/nodes/llms/LmChatGoogleGemini/LmChatGoogleGemini.node.js", "dist/nodes/llms/LmChatGoogleVertex/LmChatGoogleVertex.node.js", @@ -155,7 +156,8 @@ "@types/temp": "^0.9.1", "fast-glob": "catalog:", "n8n-core": "workspace:*", - "tsup": "catalog:" + "tsup": "catalog:", + "jest-mock-extended": "^3.0.4" }, "dependencies": { "@aws-sdk/client-sso-oidc": "3.808.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9a9283007b..b6428fb68b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1095,6 +1095,9 @@ importers: fast-glob: specifier: 'catalog:' version: 3.2.12 + jest-mock-extended: + specifier: ^3.0.4 + version: 3.0.4(jest@29.6.2(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)))(typescript@5.8.3) n8n-core: specifier: workspace:* version: link:../../core