feat(Cohere Chat Model Node): Add Cohere Chat Model node (#16888)

This commit is contained in:
oleg
2025-07-09 13:20:25 +02:00
committed by GitHub
parent 59704b4010
commit c37397cb2b
10 changed files with 604 additions and 10 deletions

View File

@@ -21,6 +21,12 @@ export class CohereApi implements ICredentialType {
required: true, required: true,
default: '', default: '',
}, },
{
displayName: 'Base URL',
name: 'url',
type: 'hidden',
default: 'https://api.cohere.ai',
},
]; ];
authenticate: IAuthenticateGeneric = { authenticate: IAuthenticateGeneric = {
@@ -34,7 +40,7 @@ export class CohereApi implements ICredentialType {
test: ICredentialTestRequest = { test: ICredentialTestRequest = {
request: { request: {
baseURL: 'https://api.cohere.ai', baseURL: '={{ $credentials.url }}',
url: '/v1/models?page_size=1', url: '/v1/models?page_size=1',
}, },
}; };

View File

@@ -65,7 +65,10 @@ function createAgentExecutor(
fallbackAgent ? agent.withFallbacks([fallbackAgent]) : agent, fallbackAgent ? agent.withFallbacks([fallbackAgent]) : agent,
getAgentStepsParser(outputParser, memory), getAgentStepsParser(outputParser, memory),
fixEmptyContentMessage, fixEmptyContentMessage,
]); ]) as AgentRunnableSequence;
runnableAgent.singleAction = false;
runnableAgent.streamRunnable = false;
return AgentExecutor.fromAgentAndTools({ return AgentExecutor.fromAgentAndTools({
agent: runnableAgent, agent: runnableAgent,

View File

@@ -285,8 +285,11 @@ export class LmChatAnthropic implements INodeType {
}; };
let invocationKwargs = {}; let invocationKwargs = {};
const tokensUsageParser = (llmOutput: LLMResult['llmOutput']) => { const tokensUsageParser = (result: LLMResult) => {
const usage = (llmOutput?.usage as { input_tokens: number; output_tokens: number }) ?? { const usage = (result?.llmOutput?.usage as {
input_tokens: number;
output_tokens: number;
}) ?? {
input_tokens: 0, input_tokens: 0,
output_tokens: 0, output_tokens: 0,
}; };

View File

@@ -0,0 +1,177 @@
import { ChatCohere } from '@langchain/cohere';
import type { LLMResult } from '@langchain/core/outputs';
import type {
INodeType,
INodeTypeDescription,
ISupplyDataFunctions,
SupplyData,
} from 'n8n-workflow';
import { getConnectionHintNoticeField } from '@utils/sharedFields';
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
import { N8nLlmTracing } from '../N8nLlmTracing';
export function tokensUsageParser(result: LLMResult): {
completionTokens: number;
promptTokens: number;
totalTokens: number;
} {
let totalInputTokens = 0;
let totalOutputTokens = 0;
result.generations?.forEach((generationArray) => {
generationArray.forEach((gen) => {
const inputTokens = gen.generationInfo?.meta?.tokens?.inputTokens ?? 0;
const outputTokens = gen.generationInfo?.meta?.tokens?.outputTokens ?? 0;
totalInputTokens += inputTokens;
totalOutputTokens += outputTokens;
});
});
return {
completionTokens: totalOutputTokens,
promptTokens: totalInputTokens,
totalTokens: totalInputTokens + totalOutputTokens,
};
}
export class LmChatCohere implements INodeType {
description: INodeTypeDescription = {
displayName: 'Cohere Chat Model',
name: 'lmChatCohere',
icon: { light: 'file:cohere.svg', dark: 'file:cohere.dark.svg' },
group: ['transform'],
version: [1],
description: 'For advanced usage with an AI chain',
defaults: {
name: 'Cohere Chat Model',
},
codex: {
categories: ['AI'],
subcategories: {
AI: ['Language Models', 'Root Nodes'],
'Language Models': ['Chat Models (Recommended)'],
},
resources: {
primaryDocumentation: [
{
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/sub-nodes/n8n-nodes-langchain.lmchatcohere/',
},
],
},
},
inputs: [],
outputs: ['ai_languageModel'],
outputNames: ['Model'],
credentials: [
{
name: 'cohereApi',
required: true,
},
],
requestDefaults: {
baseURL: '={{$credentials?.url}}',
headers: {
accept: 'application/json',
authorization: '=Bearer {{$credentials?.apiKey}}',
},
},
properties: [
getConnectionHintNoticeField(['ai_chain', 'ai_agent']),
{
displayName: 'Model',
name: 'model',
type: 'options',
description:
'The model which will generate the completion. <a href="https://docs.cohere.com/docs/models">Learn more</a>.',
typeOptions: {
loadOptions: {
routing: {
request: {
method: 'GET',
url: '/v1/models?page_size=100&endpoint=chat',
},
output: {
postReceive: [
{
type: 'rootProperty',
properties: {
property: 'models',
},
},
{
type: 'setKeyValue',
properties: {
name: '={{$responseItem.name}}',
value: '={{$responseItem.name}}',
description: '={{$responseItem.description}}',
},
},
{
type: 'sort',
properties: {
key: 'name',
},
},
],
},
},
},
},
default: 'command-a-03-2025',
},
{
displayName: 'Options',
name: 'options',
placeholder: 'Add Option',
description: 'Additional options to add',
type: 'collection',
default: {},
options: [
{
displayName: 'Sampling Temperature',
name: 'temperature',
default: 0.7,
typeOptions: { maxValue: 2, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.',
type: 'number',
},
{
displayName: 'Max Retries',
name: 'maxRetries',
default: 2,
description: 'Maximum number of retries to attempt',
type: 'number',
},
],
},
],
};
async supplyData(this: ISupplyDataFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials<{ url?: string; apiKey?: string }>('cohereApi');
const modelName = this.getNodeParameter('model', itemIndex) as string;
const options = this.getNodeParameter('options', itemIndex, {}) as {
maxRetries: number;
temperature?: number;
};
const model = new ChatCohere({
apiKey: credentials.apiKey,
model: modelName,
temperature: options.temperature,
maxRetries: options.maxRetries ?? 2,
callbacks: [new N8nLlmTracing(this, { tokensUsageParser })],
onFailedAttempt: makeN8nLlmFailedAttemptHandler(this),
});
return {
response: model,
};
}
}

View File

@@ -0,0 +1,5 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.96 23.84C14.0267 23.84 16.16 23.7867 19.1467 22.56C22.6133 21.12 29.44 18.56 34.4 15.8933C37.8667 14.0267 39.36 11.5733 39.36 8.26667C39.36 3.73333 35.68 0 31.0933 0H11.8933C5.33333 0 0 5.33333 0 11.8933C0 18.4533 5.01333 23.84 12.96 23.84Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M16.2134 31.9999C16.2134 28.7999 18.1334 25.8666 21.12 24.6399L27.1467 22.1333C33.28 19.6266 40 24.1066 40 30.7199C40 35.8399 35.84 39.9999 30.72 39.9999H24.16C19.7867 39.9999 16.2134 36.4266 16.2134 31.9999Z" fill="white"/>
<path d="M6.88 25.3867C3.09333 25.3867 0 28.4801 0 32.2667V33.1734C0 36.9067 3.09333 40.0001 6.88 40.0001C10.6667 40.0001 13.76 36.9067 13.76 33.1201V32.2134C13.7067 28.4801 10.6667 25.3867 6.88 25.3867Z" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 907 B

View File

@@ -0,0 +1,5 @@
<svg width="40" height="40" viewBox="0 0 40 40" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M12.96 23.84C14.0267 23.84 16.16 23.7867 19.1467 22.56C22.6133 21.12 29.44 18.56 34.4 15.8933C37.8667 14.0267 39.36 11.5733 39.36 8.26667C39.36 3.73333 35.68 0 31.0933 0H11.8933C5.33333 0 0 5.33333 0 11.8933C0 18.4533 5.01333 23.84 12.96 23.84Z" fill="#39594D"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M16.2134 31.9999C16.2134 28.7999 18.1334 25.8666 21.12 24.6399L27.1467 22.1333C33.28 19.6266 40 24.1066 40 30.7199C40 35.8399 35.84 39.9999 30.72 39.9999H24.16C19.7867 39.9999 16.2134 36.4266 16.2134 31.9999Z" fill="#D18EE2"/>
<path d="M6.88 25.3867C3.09333 25.3867 0 28.4801 0 32.2667V33.1734C0 36.9067 3.09333 40.0001 6.88 40.0001C10.6667 40.0001 13.76 36.9067 13.76 33.1201V32.2134C13.7067 28.4801 10.6667 25.3867 6.88 25.3867Z" fill="#FF7759"/>
</svg>

After

Width:  |  Height:  |  Size: 913 B

View File

@@ -15,7 +15,7 @@ import { NodeConnectionTypes, NodeError, NodeOperationError } from 'n8n-workflow
import { logAiEvent } from '@utils/helpers'; import { logAiEvent } from '@utils/helpers';
import { estimateTokensFromStringList } from '@utils/tokenizer/token-estimator'; import { estimateTokensFromStringList } from '@utils/tokenizer/token-estimator';
type TokensUsageParser = (llmOutput: LLMResult['llmOutput']) => { type TokensUsageParser = (result: LLMResult) => {
completionTokens: number; completionTokens: number;
promptTokens: number; promptTokens: number;
totalTokens: number; totalTokens: number;
@@ -53,9 +53,9 @@ export class N8nLlmTracing extends BaseCallbackHandler {
options = { options = {
// Default(OpenAI format) parser // Default(OpenAI format) parser
tokensUsageParser: (llmOutput: LLMResult['llmOutput']) => { tokensUsageParser: (result: LLMResult) => {
const completionTokens = (llmOutput?.tokenUsage?.completionTokens as number) ?? 0; const completionTokens = (result?.llmOutput?.tokenUsage?.completionTokens as number) ?? 0;
const promptTokens = (llmOutput?.tokenUsage?.promptTokens as number) ?? 0; const promptTokens = (result?.llmOutput?.tokenUsage?.promptTokens as number) ?? 0;
return { return {
completionTokens, completionTokens,
@@ -101,7 +101,7 @@ export class N8nLlmTracing extends BaseCallbackHandler {
promptTokens: 0, promptTokens: 0,
totalTokens: 0, totalTokens: 0,
}; };
const tokenUsage = this.options.tokensUsageParser(output.llmOutput); const tokenUsage = this.options.tokensUsageParser(output);
if (output.generations.length > 0) { if (output.generations.length > 0) {
tokenUsageEstimate.completionTokens = await this.estimateTokensFromGeneration( tokenUsageEstimate.completionTokens = await this.estimateTokensFromGeneration(

View File

@@ -0,0 +1,390 @@
/* eslint-disable @typescript-eslint/no-unsafe-member-access */
/* eslint-disable @typescript-eslint/unbound-method */
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
import type { Serialized } from '@langchain/core/load/serializable';
import type { LLMResult } from '@langchain/core/outputs';
import { mock } from 'jest-mock-extended';
import type { IDataObject, ISupplyDataFunctions } from 'n8n-workflow';
import { NodeOperationError, NodeApiError } from 'n8n-workflow';
import { N8nLlmTracing } from '../N8nLlmTracing';
describe('N8nLlmTracing', () => {
const executionFunctions = mock<ISupplyDataFunctions>({
addInputData: jest.fn().mockReturnValue({ index: 0 }),
addOutputData: jest.fn(),
getNode: jest.fn().mockReturnValue({ name: 'TestNode' }),
getNextRunIndex: jest.fn().mockReturnValue(1),
});
beforeEach(() => {
jest.clearAllMocks();
});
describe('tokensUsageParser', () => {
it('should parse OpenAI format tokens correctly', () => {
const tracer = new N8nLlmTracing(executionFunctions);
const llmResult: LLMResult = {
generations: [],
llmOutput: {
tokenUsage: {
completionTokens: 100,
promptTokens: 50,
},
},
};
const result = tracer.options.tokensUsageParser(llmResult);
expect(result).toEqual({
completionTokens: 100,
promptTokens: 50,
totalTokens: 150,
});
});
it('should handle missing token data', () => {
const tracer = new N8nLlmTracing(executionFunctions);
const llmResult: LLMResult = {
generations: [],
};
const result = tracer.options.tokensUsageParser(llmResult);
expect(result).toEqual({
completionTokens: 0,
promptTokens: 0,
totalTokens: 0,
});
});
it('should handle undefined llmOutput', () => {
const tracer = new N8nLlmTracing(executionFunctions);
const llmResult: LLMResult = {
generations: [],
llmOutput: undefined,
};
const result = tracer.options.tokensUsageParser(llmResult);
expect(result).toEqual({
completionTokens: 0,
promptTokens: 0,
totalTokens: 0,
});
});
it('should use custom tokensUsageParser when provided', () => {
// Custom parser for Cohere format
const customParser = (result: LLMResult) => {
let totalInputTokens = 0;
let totalOutputTokens = 0;
result.generations?.forEach((generationArray) => {
generationArray.forEach((gen) => {
const inputTokens = gen.generationInfo?.meta?.tokens?.inputTokens ?? 0;
const outputTokens = gen.generationInfo?.meta?.tokens?.outputTokens ?? 0;
totalInputTokens += inputTokens;
totalOutputTokens += outputTokens;
});
});
return {
completionTokens: totalOutputTokens,
promptTokens: totalInputTokens,
totalTokens: totalInputTokens + totalOutputTokens,
};
};
const tracer = new N8nLlmTracing(executionFunctions, {
tokensUsageParser: customParser,
});
const llmResult: LLMResult = {
generations: [
[
{
text: 'Response 1',
generationInfo: {
meta: {
tokens: {
inputTokens: 30,
outputTokens: 40,
},
},
},
},
],
[
{
text: 'Response 2',
generationInfo: {
meta: {
tokens: {
inputTokens: 20,
outputTokens: 60,
},
},
},
},
],
],
};
const result = tracer.options.tokensUsageParser(llmResult);
expect(result).toEqual({
completionTokens: 100, // 40 + 60
promptTokens: 50, // 30 + 20
totalTokens: 150,
});
});
it('should handle Anthropic format with custom parser', () => {
const anthropicParser = (result: LLMResult) => {
const usage = (result?.llmOutput?.usage as {
input_tokens: number;
output_tokens: number;
}) ?? {
input_tokens: 0,
output_tokens: 0,
};
return {
completionTokens: usage.output_tokens,
promptTokens: usage.input_tokens,
totalTokens: usage.input_tokens + usage.output_tokens,
};
};
const tracer = new N8nLlmTracing(executionFunctions, {
tokensUsageParser: anthropicParser,
});
const llmResult: LLMResult = {
generations: [],
llmOutput: {
usage: {
input_tokens: 75,
output_tokens: 125,
},
},
};
const result = tracer.options.tokensUsageParser(llmResult);
expect(result).toEqual({
completionTokens: 125,
promptTokens: 75,
totalTokens: 200,
});
});
});
describe('handleLLMEnd', () => {
it('should process LLM output and use token usage when available', async () => {
const tracer = new N8nLlmTracing(executionFunctions);
const runId = 'test-run-id';
// Set up run details
tracer.runsMap[runId] = {
index: 0,
messages: ['Test prompt'],
options: { model: 'test-model' },
};
const output: LLMResult = {
generations: [
[
{
text: 'Test response',
generationInfo: { meta: {} },
},
],
],
llmOutput: {
tokenUsage: {
completionTokens: 50,
promptTokens: 25,
},
},
};
await tracer.handleLLMEnd(output, runId);
expect(executionFunctions.addOutputData).toHaveBeenCalledWith(
'ai_languageModel',
0,
[
[
{
json: expect.objectContaining({
response: { generations: output.generations },
tokenUsage: {
completionTokens: 50,
promptTokens: 25,
totalTokens: 75,
},
}),
},
],
],
undefined,
undefined,
);
});
it('should use token estimates when actual usage is not available', async () => {
const tracer = new N8nLlmTracing(executionFunctions);
const runId = 'test-run-id';
// Set up run details and prompt estimate
tracer.runsMap[runId] = {
index: 0,
messages: ['Test prompt'],
options: { model: 'test-model' },
};
tracer.promptTokensEstimate = 30;
const output: LLMResult = {
generations: [
[
{
text: 'Test response',
generationInfo: { meta: {} },
},
],
],
llmOutput: {},
};
jest.spyOn(tracer, 'estimateTokensFromGeneration').mockResolvedValue(45);
await tracer.handleLLMEnd(output, runId);
expect(executionFunctions.addOutputData).toHaveBeenCalledWith(
'ai_languageModel',
0,
[
[
{
json: expect.objectContaining({
response: { generations: output.generations },
tokenUsageEstimate: {
completionTokens: 45,
promptTokens: 30,
totalTokens: 75,
},
}),
},
],
],
undefined,
undefined,
);
});
});
describe('handleLLMError', () => {
it('should handle NodeError with custom error description mapper', async () => {
const customMapper = jest.fn().mockReturnValue('Mapped error description');
const tracer = new N8nLlmTracing(executionFunctions, {
errorDescriptionMapper: customMapper,
});
const runId = 'test-run-id';
tracer.runsMap[runId] = { index: 0, messages: [], options: {} };
const error = new NodeApiError(executionFunctions.getNode(), {
message: 'Test error',
description: 'Original description',
});
await tracer.handleLLMError(error, runId);
expect(customMapper).toHaveBeenCalledWith(error);
expect(error.description).toBe('Mapped error description');
expect(executionFunctions.addOutputData).toHaveBeenCalledWith('ai_languageModel', 0, error);
});
it('should wrap non-NodeError in NodeOperationError', async () => {
const tracer = new N8nLlmTracing(executionFunctions);
const runId = 'test-run-id';
tracer.runsMap[runId] = { index: 0, messages: [], options: {} };
const error = new Error('Regular error');
await tracer.handleLLMError(error, runId);
expect(executionFunctions.addOutputData).toHaveBeenCalledWith(
'ai_languageModel',
0,
expect.any(NodeOperationError),
);
});
it('should filter out non-x- headers from error objects', async () => {
const tracer = new N8nLlmTracing(executionFunctions);
const runId = 'test-run-id';
tracer.runsMap[runId] = { index: 0, messages: [], options: {} };
const error = {
message: 'API Error',
headers: {
'x-request-id': 'keep-this',
authorization: 'remove-this',
'x-rate-limit': 'keep-this-too',
'content-type': 'remove-this-too',
},
};
await tracer.handleLLMError(error as IDataObject, runId);
expect(error.headers).toEqual({
'x-request-id': 'keep-this',
'x-rate-limit': 'keep-this-too',
});
});
});
describe('handleLLMStart', () => {
it('should estimate tokens and create run details', async () => {
const tracer = new N8nLlmTracing(executionFunctions);
const runId = 'test-run-id';
const prompts = ['Prompt 1', 'Prompt 2'];
jest.spyOn(tracer, 'estimateTokensFromStringList').mockResolvedValue(100);
const llm = {
type: 'constructor',
kwargs: { model: 'test-model' },
};
await tracer.handleLLMStart(llm as unknown as Serialized, prompts, runId);
expect(tracer.estimateTokensFromStringList).toHaveBeenCalledWith(prompts);
expect(tracer.promptTokensEstimate).toBe(100);
expect(tracer.runsMap[runId]).toEqual({
index: 0,
options: { model: 'test-model' },
messages: prompts,
});
expect(executionFunctions.addInputData).toHaveBeenCalledWith(
'ai_languageModel',
[
[
{
json: {
messages: prompts,
estimatedTokens: 100,
options: { model: 'test-model' },
},
},
],
],
undefined,
);
});
});
});

View File

@@ -76,6 +76,7 @@
"dist/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.js", "dist/nodes/llms/LMChatAnthropic/LmChatAnthropic.node.js",
"dist/nodes/llms/LmChatAzureOpenAi/LmChatAzureOpenAi.node.js", "dist/nodes/llms/LmChatAzureOpenAi/LmChatAzureOpenAi.node.js",
"dist/nodes/llms/LmChatAwsBedrock/LmChatAwsBedrock.node.js", "dist/nodes/llms/LmChatAwsBedrock/LmChatAwsBedrock.node.js",
"dist/nodes/llms/LmChatCohere/LmChatCohere.node.js",
"dist/nodes/llms/LmChatDeepSeek/LmChatDeepSeek.node.js", "dist/nodes/llms/LmChatDeepSeek/LmChatDeepSeek.node.js",
"dist/nodes/llms/LmChatGoogleGemini/LmChatGoogleGemini.node.js", "dist/nodes/llms/LmChatGoogleGemini/LmChatGoogleGemini.node.js",
"dist/nodes/llms/LmChatGoogleVertex/LmChatGoogleVertex.node.js", "dist/nodes/llms/LmChatGoogleVertex/LmChatGoogleVertex.node.js",
@@ -155,7 +156,8 @@
"@types/temp": "^0.9.1", "@types/temp": "^0.9.1",
"fast-glob": "catalog:", "fast-glob": "catalog:",
"n8n-core": "workspace:*", "n8n-core": "workspace:*",
"tsup": "catalog:" "tsup": "catalog:",
"jest-mock-extended": "^3.0.4"
}, },
"dependencies": { "dependencies": {
"@aws-sdk/client-sso-oidc": "3.808.0", "@aws-sdk/client-sso-oidc": "3.808.0",

3
pnpm-lock.yaml generated
View File

@@ -1095,6 +1095,9 @@ importers:
fast-glob: fast-glob:
specifier: 'catalog:' specifier: 'catalog:'
version: 3.2.12 version: 3.2.12
jest-mock-extended:
specifier: ^3.0.4
version: 3.0.4(jest@29.6.2(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)))(typescript@5.8.3)
n8n-core: n8n-core:
specifier: workspace:* specifier: workspace:*
version: link:../../core version: link:../../core