From b9030d45dead945342be88726cde2d0043ff4ab7 Mon Sep 17 00:00:00 2001 From: oleg Date: Wed, 26 Mar 2025 14:26:09 +0100 Subject: [PATCH] fix(Basic LLM Chain Node): Prevent incorrect wrapping of output (#14183) --- .../nodes/chains/ChainLLM/ChainLlm.node.ts | 4 +- .../chains/ChainLLM/methods/chainExecutor.ts | 52 +++++++- .../ChainLLM/methods/responseFormatter.ts | 14 ++- .../ChainLLM/test/chainExecutor.test.ts | 111 +++++++++++++++++- .../ChainLLM/test/responseFormatter.test.ts | 18 ++- 5 files changed, 174 insertions(+), 25 deletions(-) diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts index 1419e5823c..24ae67b20b 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/ChainLlm.node.ts @@ -34,7 +34,7 @@ export class ChainLlm implements INodeType { icon: 'fa:link', iconColor: 'black', group: ['transform'], - version: [1, 1.1, 1.2, 1.3, 1.4, 1.5], + version: [1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6], description: 'A simple chain to prompt a large language model', defaults: { name: 'Basic LLM Chain', @@ -119,7 +119,7 @@ export class ChainLlm implements INodeType { // Process each response and add to return data responses.forEach((response) => { returnData.push({ - json: formatResponse(response), + json: formatResponse(response, this.getNode().typeVersion), }); }); } catch (error) { diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts index 60f2c3fd00..0f3b5dc120 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/chainExecutor.ts @@ -1,5 +1,6 @@ import type { BaseLanguageModel } from '@langchain/core/language_models/base'; -import { StringOutputParser } from '@langchain/core/output_parsers'; +import type { BaseLLMOutputParser } from '@langchain/core/output_parsers'; +import { JsonOutputParser, StringOutputParser } from '@langchain/core/output_parsers'; import type { ChatPromptTemplate, PromptTemplate } from '@langchain/core/prompts'; import type { IExecuteFunctions } from 'n8n-workflow'; @@ -8,6 +9,46 @@ import { getTracingConfig } from '@utils/tracing'; import { createPromptTemplate } from './promptUtils'; import type { ChainExecutionParams } from './types'; +/** + * Type guard to check if the LLM has a modelKwargs property(OpenAI) + */ +export function isModelWithResponseFormat( + llm: BaseLanguageModel, +): llm is BaseLanguageModel & { modelKwargs: { response_format: { type: string } } } { + return ( + 'modelKwargs' in llm && + !!llm.modelKwargs && + typeof llm.modelKwargs === 'object' && + 'response_format' in llm.modelKwargs + ); +} + +/** + * Type guard to check if the LLM has a format property(Ollama) + */ +export function isModelWithFormat( + llm: BaseLanguageModel, +): llm is BaseLanguageModel & { format: string } { + return 'format' in llm && typeof llm.format !== 'undefined'; +} + +/** + * Determines if an LLM is configured to output JSON and returns the appropriate output parser + */ +export function getOutputParserForLLM( + llm: BaseLanguageModel, +): BaseLLMOutputParser> { + if (isModelWithResponseFormat(llm) && llm.modelKwargs?.response_format?.type === 'json_object') { + return new JsonOutputParser(); + } + + if (isModelWithFormat(llm) && llm.format === 'json') { + return new JsonOutputParser(); + } + + return new StringOutputParser(); +} + /** * Creates a simple chain for LLMs without output parsers */ @@ -21,11 +62,10 @@ async function executeSimpleChain({ llm: BaseLanguageModel; query: string; prompt: ChatPromptTemplate | PromptTemplate; -}): Promise { - const chain = prompt - .pipe(llm) - .pipe(new StringOutputParser()) - .withConfig(getTracingConfig(context)); +}) { + const outputParser = getOutputParserForLLM(llm); + + const chain = prompt.pipe(llm).pipe(outputParser).withConfig(getTracingConfig(context)); // Execute the chain const response = await chain.invoke({ diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/responseFormatter.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/responseFormatter.ts index e045f5fab9..6f1c65b79e 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/responseFormatter.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/methods/responseFormatter.ts @@ -3,12 +3,10 @@ import type { IDataObject } from 'n8n-workflow'; /** * Formats the response from the LLM chain into a consistent structure */ -export function formatResponse(response: unknown): IDataObject { +export function formatResponse(response: unknown, version: number): IDataObject { if (typeof response === 'string') { return { - response: { - text: response.trim(), - }, + text: response.trim(), }; } @@ -19,7 +17,13 @@ export function formatResponse(response: unknown): IDataObject { } if (response instanceof Object) { - return response as IDataObject; + if (version >= 1.6) { + return response as IDataObject; + } + + return { + text: JSON.stringify(response), + }; } return { diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/chainExecutor.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/chainExecutor.test.ts index 7cb3456d89..4ca92c7222 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/chainExecutor.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/chainExecutor.test.ts @@ -1,4 +1,5 @@ -import { StringOutputParser } from '@langchain/core/output_parsers'; +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { JsonOutputParser, StringOutputParser } from '@langchain/core/output_parsers'; import { ChatPromptTemplate, PromptTemplate } from '@langchain/core/prompts'; import { FakeLLM, FakeChatModel } from '@langchain/core/utils/testing'; import { mock } from 'jest-mock-extended'; @@ -8,6 +9,7 @@ import type { N8nOutputParser } from '@utils/output_parsers/N8nOutputParser'; import * as tracing from '@utils/tracing'; import { executeChain } from '../methods/chainExecutor'; +import * as chainExecutor from '../methods/chainExecutor'; import * as promptUtils from '../methods/promptUtils'; jest.mock('@utils/tracing', () => ({ @@ -27,6 +29,41 @@ describe('chainExecutor', () => { jest.clearAllMocks(); }); + describe('getOutputParserForLLM', () => { + it('should return JsonOutputParser for OpenAI-like models with json_object response format', () => { + const openAILikeModel = { + modelKwargs: { + response_format: { + type: 'json_object', + }, + }, + }; + + const parser = chainExecutor.getOutputParserForLLM( + openAILikeModel as unknown as BaseChatModel, + ); + expect(parser).toBeInstanceOf(JsonOutputParser); + }); + + it('should return JsonOutputParser for Ollama models with json format', () => { + const ollamaLikeModel = { + format: 'json', + }; + + const parser = chainExecutor.getOutputParserForLLM( + ollamaLikeModel as unknown as BaseChatModel, + ); + expect(parser).toBeInstanceOf(JsonOutputParser); + }); + + it('should return StringOutputParser for models without JSON format settings', () => { + const regularModel = new FakeLLM({}); + + const parser = chainExecutor.getOutputParserForLLM(regularModel); + expect(parser).toBeInstanceOf(StringOutputParser); + }); + }); + describe('executeChain', () => { it('should execute a simple chain without output parsers', async () => { const fakeLLM = new FakeLLM({ response: 'Test response' }); @@ -219,5 +256,77 @@ describe('chainExecutor', () => { expect(result).toEqual(['Test chat response']); }); + + it('should use JsonOutputParser for OpenAI models with json_object response format', async () => { + const fakeOpenAIModel = new FakeChatModel({}); + ( + fakeOpenAIModel as unknown as { modelKwargs: { response_format: { type: string } } } + ).modelKwargs = { + response_format: { type: 'json_object' }, + }; + + const mockPromptTemplate = new PromptTemplate({ + template: '{query}', + inputVariables: ['query'], + }); + + const mockChain = { + invoke: jest.fn().mockResolvedValue('{"result": "json data"}'), + }; + + const withConfigMock = jest.fn().mockReturnValue(mockChain); + const pipeOutputParserMock = jest.fn().mockReturnValue({ + withConfig: withConfigMock, + }); + + mockPromptTemplate.pipe = jest.fn().mockReturnValue({ + pipe: pipeOutputParserMock, + }); + + (promptUtils.createPromptTemplate as jest.Mock).mockResolvedValue(mockPromptTemplate); + + await executeChain({ + context: mockContext, + itemIndex: 0, + query: 'Hello', + llm: fakeOpenAIModel, + }); + + expect(pipeOutputParserMock).toHaveBeenCalledWith(expect.any(JsonOutputParser)); + }); + + it('should use JsonOutputParser for Ollama models with json format', async () => { + const fakeOllamaModel = new FakeChatModel({}); + (fakeOllamaModel as unknown as { format: string }).format = 'json'; + + const mockPromptTemplate = new PromptTemplate({ + template: '{query}', + inputVariables: ['query'], + }); + + const mockChain = { + invoke: jest.fn().mockResolvedValue('{"result": "json data"}'), + }; + + const withConfigMock = jest.fn().mockReturnValue(mockChain); + const pipeOutputParserMock = jest.fn().mockReturnValue({ + withConfig: withConfigMock, + }); + + mockPromptTemplate.pipe = jest.fn().mockReturnValue({ + pipe: pipeOutputParserMock, + }); + + (promptUtils.createPromptTemplate as jest.Mock).mockResolvedValue(mockPromptTemplate); + + await executeChain({ + context: mockContext, + itemIndex: 0, + query: 'Hello', + llm: fakeOllamaModel, + }); + + expect(pipeOutputParserMock).toHaveBeenCalledWith(expect.any(JsonOutputParser)); + }); }); }); diff --git a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/responseFormatter.test.ts b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/responseFormatter.test.ts index ec8cf59858..08bc587d05 100644 --- a/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/responseFormatter.test.ts +++ b/packages/@n8n/nodes-langchain/nodes/chains/ChainLLM/test/responseFormatter.test.ts @@ -3,38 +3,34 @@ import { formatResponse } from '../methods/responseFormatter'; describe('responseFormatter', () => { describe('formatResponse', () => { it('should format string responses', () => { - const result = formatResponse('Test response'); + const result = formatResponse('Test response', 1.6); expect(result).toEqual({ - response: { - text: 'Test response', - }, + text: 'Test response', }); }); it('should trim string responses', () => { - const result = formatResponse(' Test response with whitespace '); + const result = formatResponse(' Test response with whitespace ', 1.6); expect(result).toEqual({ - response: { - text: 'Test response with whitespace', - }, + text: 'Test response with whitespace', }); }); it('should handle array responses', () => { const testArray = [{ item: 1 }, { item: 2 }]; - const result = formatResponse(testArray); + const result = formatResponse(testArray, 1.6); expect(result).toEqual({ data: testArray }); }); it('should handle object responses', () => { const testObject = { key: 'value', nested: { key: 'value' } }; - const result = formatResponse(testObject); + const result = formatResponse(testObject, 1.6); expect(result).toEqual(testObject); }); it('should handle primitive non-string responses', () => { const testNumber = 42; - const result = formatResponse(testNumber); + const result = formatResponse(testNumber, 1.6); expect(result).toEqual({ response: { text: 42,