feat: Support thinking settings for Gemini models (#19591)

This commit is contained in:
Michael Drury
2025-09-18 10:21:07 +01:00
committed by GitHub
parent a0452f02dd
commit bb0cd86b28
12 changed files with 2279 additions and 934 deletions

View File

@@ -11,7 +11,7 @@ import type {
import { getConnectionHintNoticeField } from '@utils/sharedFields';
import { additionalOptions } from '../gemini-common/additional-options';
import { getAdditionalOptions } from '../gemini-common/additional-options';
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
import { N8nLlmTracing } from '../N8nLlmTracing';
@@ -119,7 +119,9 @@ export class LmChatGoogleGemini implements INodeType {
},
default: 'models/gemini-2.5-flash',
},
additionalOptions,
// thinking budget not supported in @langchain/google-genai
// as it utilises the old google generative ai SDK
getAdditionalOptions({ supportsThinkingBudget: false }),
],
};

View File

@@ -1,6 +1,6 @@
import type { SafetySetting } from '@google/generative-ai';
import { ProjectsClient } from '@google-cloud/resource-manager';
import { ChatVertexAI } from '@langchain/google-vertexai';
import type { GoogleAISafetySetting } from '@langchain/google-common';
import { ChatVertexAI, type ChatVertexAIInput } from '@langchain/google-vertexai';
import { formatPrivateKey } from 'n8n-nodes-base/dist/utils/utilities';
import {
NodeConnectionTypes,
@@ -11,12 +11,13 @@ import {
type ILoadOptionsFunctions,
type JsonObject,
NodeOperationError,
validateNodeParameters,
} from 'n8n-workflow';
import { getConnectionHintNoticeField } from '@utils/sharedFields';
import { makeErrorFromStatus } from './error-handling';
import { additionalOptions } from '../gemini-common/additional-options';
import { getAdditionalOptions } from '../gemini-common/additional-options';
import { makeN8nLlmFailedAttemptHandler } from '../n8nLlmFailedAttemptHandler';
import { N8nLlmTracing } from '../N8nLlmTracing';
@@ -90,7 +91,7 @@ export class LmChatGoogleVertex implements INodeType {
'The model which will generate the completion. <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models">Learn more</a>.',
default: 'gemini-2.5-flash',
},
additionalOptions,
getAdditionalOptions({ supportsThinkingBudget: true }),
],
};
@@ -143,21 +144,29 @@ export class LmChatGoogleVertex implements INodeType {
temperature: 0.4,
topK: 40,
topP: 0.9,
}) as {
maxOutputTokens: number;
temperature: number;
topK: number;
topP: number;
};
});
// Validate options parameter
validateNodeParameters(
options,
{
maxOutputTokens: { type: 'number', required: false },
temperature: { type: 'number', required: false },
topK: { type: 'number', required: false },
topP: { type: 'number', required: false },
thinkingBudget: { type: 'number', required: false },
},
this.getNode(),
);
const safetySettings = this.getNodeParameter(
'options.safetySettings.values',
itemIndex,
null,
) as SafetySetting[];
) as GoogleAISafetySetting[];
try {
const model = new ChatVertexAI({
const modelConfig: ChatVertexAIInput = {
authOptions: {
projectId,
credentials: {
@@ -186,7 +195,14 @@ export class LmChatGoogleVertex implements INodeType {
throw error;
}),
});
};
// Add thinkingBudget if specified
if (options.thinkingBudget !== undefined) {
modelConfig.thinkingBudget = options.thinkingBudget;
}
const model = new ChatVertexAI(modelConfig);
return {
response: model,

View File

@@ -0,0 +1,149 @@
import { ChatVertexAI } from '@langchain/google-vertexai';
import { createMockExecuteFunction } from 'n8n-nodes-base/test/nodes/Helpers';
import type { INode, ISupplyDataFunctions } from 'n8n-workflow';
import { makeN8nLlmFailedAttemptHandler } from '../../n8nLlmFailedAttemptHandler';
import { N8nLlmTracing } from '../../N8nLlmTracing';
import { LmChatGoogleVertex } from '../LmChatGoogleVertex.node';
jest.mock('@langchain/google-vertexai');
jest.mock('../../N8nLlmTracing');
jest.mock('../../n8nLlmFailedAttemptHandler');
jest.mock('n8n-nodes-base/dist/utils/utilities', () => ({
formatPrivateKey: jest.fn().mockImplementation((key: string) => key),
}));
const MockedChatVertexAI = jest.mocked(ChatVertexAI);
const MockedN8nLlmTracing = jest.mocked(N8nLlmTracing);
const mockedMakeN8nLlmFailedAttemptHandler = jest.mocked(makeN8nLlmFailedAttemptHandler);
describe('LmChatGoogleVertex - Thinking Budget', () => {
let lmChatGoogleVertex: LmChatGoogleVertex;
let mockContext: jest.Mocked<ISupplyDataFunctions>;
const mockNode: INode = {
id: '1',
name: 'Google Vertex Chat Model',
typeVersion: 1,
type: 'n8n-nodes-langchain.lmChatGoogleVertex',
position: [0, 0],
parameters: {},
};
const setupMockContext = () => {
mockContext = createMockExecuteFunction<ISupplyDataFunctions>(
{},
mockNode,
) as jest.Mocked<ISupplyDataFunctions>;
mockContext.getCredentials = jest.fn().mockResolvedValue({
privateKey: 'test-private-key',
email: 'test@n8n.io',
region: 'us-central1',
});
mockContext.getNode = jest.fn().mockReturnValue(mockNode);
mockContext.getNodeParameter = jest.fn();
MockedN8nLlmTracing.mockImplementation(() => ({}) as unknown as N8nLlmTracing);
mockedMakeN8nLlmFailedAttemptHandler.mockReturnValue(jest.fn());
return mockContext;
};
beforeEach(() => {
lmChatGoogleVertex = new LmChatGoogleVertex();
jest.clearAllMocks();
});
afterEach(() => {
jest.clearAllMocks();
});
describe('supplyData - thinking budget parameter passing', () => {
it('should not include thinkingBudget in model config when not specified', async () => {
const mockContext = setupMockContext();
mockContext.getNodeParameter = jest.fn().mockImplementation((paramName: string) => {
if (paramName === 'modelName') return 'gemini-2.5-flash';
if (paramName === 'projectId') return 'test-project';
if (paramName === 'options') {
// Return options without thinkingBudget
return {
maxOutputTokens: 2048,
temperature: 0.4,
topK: 40,
topP: 0.9,
};
}
if (paramName === 'options.safetySettings.values') return null;
return undefined;
});
await lmChatGoogleVertex.supplyData.call(mockContext, 0);
expect(MockedChatVertexAI).toHaveBeenCalledTimes(1);
const callArgs = MockedChatVertexAI.mock.calls[0][0];
expect(callArgs).not.toHaveProperty('thinkingBudget');
expect(callArgs).toMatchObject({
authOptions: {
projectId: 'test-project',
credentials: {
client_email: 'test@n8n.io',
private_key: 'test-private-key',
},
},
location: 'us-central1',
model: 'gemini-2.5-flash',
topK: 40,
topP: 0.9,
temperature: 0.4,
maxOutputTokens: 2048,
});
});
it('should include thinkingBudget in model config when specified', async () => {
const mockContext = setupMockContext();
const expectedThinkingBudget = 1024;
mockContext.getNodeParameter = jest.fn().mockImplementation((paramName: string) => {
if (paramName === 'modelName') return 'gemini-2.5-flash';
if (paramName === 'projectId') return 'test-project';
if (paramName === 'options') {
// Return options with thinkingBudget
return {
maxOutputTokens: 2048,
temperature: 0.4,
topK: 40,
topP: 0.9,
thinkingBudget: expectedThinkingBudget,
};
}
if (paramName === 'options.safetySettings.values') return null;
return undefined;
});
await lmChatGoogleVertex.supplyData.call(mockContext, 0);
expect(MockedChatVertexAI).toHaveBeenCalledWith(
expect.objectContaining({
authOptions: {
projectId: 'test-project',
credentials: {
client_email: 'test@n8n.io',
private_key: 'test-private-key',
},
},
location: 'us-central1',
model: 'gemini-2.5-flash',
topK: 40,
topP: 0.9,
temperature: 0.4,
maxOutputTokens: 2048,
thinkingBudget: expectedThinkingBudget,
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
callbacks: expect.arrayContaining([expect.any(Object)]),
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
onFailedAttempt: expect.any(Function),
}),
);
});
});
});

View File

@@ -1,88 +1,107 @@
import type { HarmBlockThreshold, HarmCategory } from '@google/generative-ai';
import type { HarmBlockThreshold, HarmCategory } from '@google/genai';
import type { INodeProperties } from 'n8n-workflow';
import { harmCategories, harmThresholds } from './safety-options';
export const additionalOptions: INodeProperties = {
displayName: 'Options',
name: 'options',
placeholder: 'Add Option',
description: 'Additional options to add',
type: 'collection',
default: {},
options: [
{
displayName: 'Maximum Number of Tokens',
name: 'maxOutputTokens',
default: 2048,
description: 'The maximum number of tokens to generate in the completion',
type: 'number',
},
{
displayName: 'Sampling Temperature',
name: 'temperature',
default: 0.4,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.',
type: 'number',
},
{
displayName: 'Top K',
name: 'topK',
default: 32,
typeOptions: { maxValue: 40, minValue: -1, numberPrecision: 1 },
description:
'Used to remove "long tail" low probability responses. Defaults to -1, which disables it.',
type: 'number',
},
{
displayName: 'Top P',
name: 'topP',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered. We generally recommend altering this or temperature but not both.',
type: 'number',
},
// Safety Settings
{
displayName: 'Safety Settings',
name: 'safetySettings',
type: 'fixedCollection',
typeOptions: { multipleValues: true },
default: {
values: {
category: harmCategories[0].name as HarmCategory,
threshold: harmThresholds[0].name as HarmBlockThreshold,
},
export function getAdditionalOptions({
supportsThinkingBudget,
}: { supportsThinkingBudget: boolean }) {
const baseOptions: INodeProperties = {
displayName: 'Options',
name: 'options',
placeholder: 'Add Option',
description: 'Additional options to add',
type: 'collection',
default: {},
options: [
{
displayName: 'Maximum Number of Tokens',
name: 'maxOutputTokens',
default: 2048,
description: 'The maximum number of tokens to generate in the completion',
type: 'number',
},
placeholder: 'Add Option',
options: [
{
name: 'values',
displayName: 'Values',
values: [
{
displayName: 'Safety Category',
name: 'category',
type: 'options',
description: 'The category of harmful content to block',
default: 'HARM_CATEGORY_UNSPECIFIED',
options: harmCategories,
},
{
displayName: 'Safety Threshold',
name: 'threshold',
type: 'options',
description: 'The threshold of harmful content to block',
default: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
options: harmThresholds,
},
],
{
displayName: 'Sampling Temperature',
name: 'temperature',
default: 0.4,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.',
type: 'number',
},
{
displayName: 'Top K',
name: 'topK',
default: 32,
typeOptions: { maxValue: 40, minValue: -1, numberPrecision: 1 },
description:
'Used to remove "long tail" low probability responses. Defaults to -1, which disables it.',
type: 'number',
},
{
displayName: 'Top P',
name: 'topP',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered. We generally recommend altering this or temperature but not both.',
type: 'number',
},
// Safety Settings
{
displayName: 'Safety Settings',
name: 'safetySettings',
type: 'fixedCollection',
typeOptions: { multipleValues: true },
default: {
values: {
category: harmCategories[0].name as HarmCategory,
threshold: harmThresholds[0].name as HarmBlockThreshold,
},
},
],
},
],
};
placeholder: 'Add Option',
options: [
{
name: 'values',
displayName: 'Values',
values: [
{
displayName: 'Safety Category',
name: 'category',
type: 'options',
description: 'The category of harmful content to block',
default: 'HARM_CATEGORY_UNSPECIFIED',
options: harmCategories,
},
{
displayName: 'Safety Threshold',
name: 'threshold',
type: 'options',
description: 'The threshold of harmful content to block',
default: 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
options: harmThresholds,
},
],
},
],
},
],
};
// only supported in the new google genai SDK
if (supportsThinkingBudget) {
baseOptions.options?.push({
displayName: 'Thinking Budget',
name: 'thinkingBudget',
default: undefined,
description:
'Controls reasoning tokens for thinking models. Set to 0 to disable automatic thinking. Set to -1 for dynamic thinking. Leave empty for auto mode.',
type: 'number',
typeOptions: {
minValue: -1,
numberPrecision: 0,
},
});
}
return baseOptions;
}

View File

@@ -80,7 +80,7 @@ describe('GoogleGemini Node', () => {
expect(apiRequestMock).toHaveBeenCalledWith(
'POST',
'/v1beta/models/gemini-2.5-flash:generateContent',
{
expect.objectContaining({
body: {
contents: [
{
@@ -107,6 +107,76 @@ describe('GoogleGemini Node', () => {
parts: [{ text: 'You are a helpful assistant.' }],
},
},
}),
);
});
it('should include thinking options when the thinking budget is specified', async () => {
executeFunctionsMock.getNodeParameter.mockImplementation((parameter: string) => {
switch (parameter) {
case 'modelId':
return 'models/gemini-2.5-flash';
case 'messages.values':
return [{ role: 'user', content: 'Hello, world!' }];
case 'simplify':
return true;
case 'jsonOutput':
return false;
case 'options':
return {
thinkingBudget: 1024,
maxOutputTokens: 100,
temperature: 0.5,
};
default:
return undefined;
}
});
executeFunctionsMock.getNodeInputs.mockReturnValue([{ type: 'main' }]);
apiRequestMock.mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: 'Hello with thinking!' }],
role: 'model',
},
},
],
});
const result = await text.message.execute.call(executeFunctionsMock, 0);
expect(result).toEqual([
{
json: {
content: {
parts: [{ text: 'Hello with thinking!' }],
role: 'model',
},
},
pairedItem: { item: 0 },
},
]);
expect(apiRequestMock).toHaveBeenCalledWith(
'POST',
'/v1beta/models/gemini-2.5-flash:generateContent',
{
body: {
contents: [
{
parts: [{ text: 'Hello, world!' }],
role: 'user',
},
],
tools: [],
generationConfig: {
maxOutputTokens: 100,
temperature: 0.5,
thinkingConfig: {
thinkingBudget: 1024,
},
},
},
},
);
});

View File

@@ -1,7 +1,11 @@
import type { IExecuteFunctions, INodeExecutionData, INodeProperties } from 'n8n-workflow';
import { updateDisplayOptions } from 'n8n-workflow';
import type { Content, GenerateContentResponse } from '../../helpers/interfaces';
import type {
Content,
GenerateContentRequest,
GenerateContentResponse,
} from '../../helpers/interfaces';
import { downloadFile, uploadFile } from '../../helpers/utils';
import { apiRequest } from '../../transport';
import { modelRLC } from '../descriptions';
@@ -157,7 +161,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
}${options.endTime ? ` to ${options.endTime as string}` : ''}`;
contents[0].parts.push({ text });
const body = {
const body: GenerateContentRequest = {
contents,
};

View File

@@ -1,7 +1,12 @@
import type { IExecuteFunctions, INodeExecutionData, INodeProperties } from 'n8n-workflow';
import { NodeOperationError, updateDisplayOptions } from 'n8n-workflow';
import type { GenerateContentResponse, ImagenResponse } from '../../helpers/interfaces';
import {
type GenerateContentRequest,
type GenerateContentResponse,
type ImagenResponse,
Modality,
} from '../../helpers/interfaces';
import { apiRequest } from '../../transport';
import { modelRLC } from '../descriptions';
@@ -67,9 +72,9 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
if (model.includes('gemini')) {
const generationConfig = {
responseModalities: ['IMAGE', 'TEXT'],
responseModalities: [Modality.IMAGE, Modality.TEXT],
};
const body = {
const body: GenerateContentRequest = {
contents: [
{
role: 'user',

View File

@@ -1,15 +1,22 @@
import type {
IDataObject,
IExecuteFunctions,
INodeExecutionData,
INodeProperties,
import {
type IDataObject,
type IExecuteFunctions,
type INodeExecutionData,
type INodeProperties,
validateNodeParameters,
} from 'n8n-workflow';
import { updateDisplayOptions } from 'n8n-workflow';
import zodToJsonSchema from 'zod-to-json-schema';
import { getConnectedTools } from '@utils/helpers';
import type { GenerateContentResponse, Content, Tool } from '../../helpers/interfaces';
import type {
GenerateContentRequest,
GenerateContentResponse,
Content,
Tool,
GenerateContentGenerationConfig,
} from '../../helpers/interfaces';
import { apiRequest } from '../../transport';
import { modelRLC } from '../descriptions';
@@ -186,6 +193,18 @@ const properties: INodeProperties[] = [
numberPrecision: 0,
},
},
{
displayName: 'Thinking Budget',
name: 'thinkingBudget',
type: 'number',
default: undefined,
description:
'Controls reasoning tokens for thinking models. Set to 0 to disable automatic thinking. Set to -1 for dynamic thinking. Leave empty for auto mode.',
typeOptions: {
minValue: -1,
numberPrecision: 0,
},
},
{
displayName: 'Max Tool Calls Iterations',
name: 'maxToolsIterations',
@@ -224,8 +243,25 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
const simplify = this.getNodeParameter('simplify', i, true) as boolean;
const jsonOutput = this.getNodeParameter('jsonOutput', i, false) as boolean;
const options = this.getNodeParameter('options', i, {});
validateNodeParameters(
options,
{
systemMessage: { type: 'string', required: false },
codeExecution: { type: 'boolean', required: false },
frequencyPenalty: { type: 'number', required: false },
maxOutputTokens: { type: 'number', required: false },
candidateCount: { type: 'number', required: false },
presencePenalty: { type: 'number', required: false },
temperature: { type: 'number', required: false },
topP: { type: 'number', required: false },
topK: { type: 'number', required: false },
thinkingBudget: { type: 'number', required: false },
maxToolsIterations: { type: 'number', required: false },
},
this.getNode(),
);
const generationConfig = {
const generationConfig: GenerateContentGenerationConfig = {
frequencyPenalty: options.frequencyPenalty,
maxOutputTokens: options.maxOutputTokens,
candidateCount: options.candidateCount,
@@ -236,6 +272,13 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
responseMimeType: jsonOutput ? 'application/json' : undefined,
};
// Add thinkingConfig if thinkingBudget is specified
if (options.thinkingBudget !== undefined) {
generationConfig.thinkingConfig = {
thinkingBudget: options.thinkingBudget,
};
}
const nodeInputs = this.getNodeInputs();
const availableTools = nodeInputs.some((i) => i.type === 'ai_tool')
? await getConnectedTools(this, true)
@@ -267,7 +310,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
parts: [{ text: m.content }],
role: m.role,
}));
const body = {
const body: GenerateContentRequest = {
tools,
contents,
generationConfig,

View File

@@ -1,4 +1,8 @@
import type { IExecuteFunctions, INodeExecutionData } from 'n8n-workflow';
import {
validateNodeParameters,
type IExecuteFunctions,
type INodeExecutionData,
} from 'n8n-workflow';
import type { Content, GenerateContentResponse } from './interfaces';
import { downloadFile, uploadFile } from './utils';
@@ -15,7 +19,11 @@ export async function baseAnalyze(
const text = this.getNodeParameter('text', i, '') as string;
const simplify = this.getNodeParameter('simplify', i, true) as boolean;
const options = this.getNodeParameter('options', i, {});
validateNodeParameters(
options,
{ maxOutputTokens: { type: 'number', required: false } },
this.getNode(),
);
const generationConfig = {
maxOutputTokens: options.maxOutputTokens,
};

View File

@@ -1,4 +1,44 @@
import type {
GenerateContentConfig,
GenerationConfig,
GenerateContentParameters,
} from '@google/genai';
import type { IDataObject } from 'n8n-workflow';
export { Modality } from '@google/genai';
/* type created based on: https://ai.google.dev/api/generate-content#generationconfig */
export type GenerateContentGenerationConfig = Pick<
GenerationConfig,
| 'stopSequences'
| 'responseMimeType'
| 'responseSchema'
| 'responseJsonSchema'
| 'responseModalities'
| 'candidateCount'
| 'maxOutputTokens'
| 'temperature'
| 'topP'
| 'topK'
| 'seed'
| 'presencePenalty'
| 'frequencyPenalty'
| 'responseLogprobs'
| 'logprobs'
| 'speechConfig'
| 'thinkingConfig'
| 'mediaResolution'
>;
/* Type created based on: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent */
export interface GenerateContentRequest extends IDataObject {
contents: GenerateContentParameters['contents'];
tools?: GenerateContentConfig['tools'];
toolConfig?: GenerateContentConfig['toolConfig'];
systemInstruction?: GenerateContentConfig['systemInstruction'];
safetySettings?: GenerateContentConfig['safetySettings'];
generationConfig?: GenerateContentGenerationConfig;
cachedContent?: string;
}
export interface GenerateContentResponse {
candidates: Array<{

View File

@@ -168,9 +168,9 @@
"@azure/identity": "4.3.0",
"@getzep/zep-cloud": "1.0.12",
"@getzep/zep-js": "0.9.0",
"@google-ai/generativelanguage": "2.6.0",
"@google-cloud/resource-manager": "5.3.0",
"@google/generative-ai": "0.21.0",
"@google/genai": "1.19.0",
"@huggingface/inference": "4.0.5",
"@langchain/anthropic": "catalog:",
"@langchain/aws": "0.1.11",

2629
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff