feat: Session is selector for memory nodes (#8736)

This commit is contained in:
Michael Kret
2024-02-27 15:01:15 +02:00
committed by GitHub
parent 5f6da7b84e
commit 2aaf211dfc
7 changed files with 188 additions and 22 deletions

View File

@@ -10,6 +10,8 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
import { BufferWindowMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
class MemoryChatBufferSingleton {
private static instance: MemoryChatBufferSingleton;
@@ -70,7 +72,7 @@ export class MemoryBufferWindow implements INodeType {
name: 'memoryBufferWindow',
icon: 'fa:database',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Stores in n8n memory, so no credentials required',
defaults: {
name: 'Window Buffer Memory',
@@ -119,6 +121,15 @@ export class MemoryBufferWindow implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
{
displayName: 'Context Window Length',
name: 'contextWindowLength',
@@ -130,12 +141,21 @@ export class MemoryBufferWindow implements INodeType {
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
const contextWindowLength = this.getNodeParameter('contextWindowLength', itemIndex) as number;
const workflowId = this.getWorkflow().id;
const memoryInstance = MemoryChatBufferSingleton.getInstance();
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionKey}`, {
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
}
const memory = await memoryInstance.getMemory(`${workflowId}__${sessionId}`, {
k: contextWindowLength,
inputKey: 'input',
memoryKey: 'chat_history',

View File

@@ -10,6 +10,8 @@ import {
import { MotorheadMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryMotorhead implements INodeType {
description: INodeTypeDescription = {
@@ -17,7 +19,7 @@ export class MemoryMotorhead implements INodeType {
name: 'memoryMotorhead',
icon: 'fa:file-export',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Use Motorhead Memory',
defaults: {
name: 'Motorhead',
@@ -72,13 +74,29 @@ export class MemoryMotorhead implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
],
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('motorheadApi');
const nodeVersion = this.getNode().typeVersion;
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const memory = new MotorheadMemory({
sessionId,

View File

@@ -14,6 +14,8 @@ import type { RedisClientOptions } from 'redis';
import { createClient } from 'redis';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryRedisChat implements INodeType {
description: INodeTypeDescription = {
@@ -21,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
name: 'memoryRedisChat',
icon: 'file:redis.svg',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Stores the chat history in Redis.',
defaults: {
name: 'Redis Chat Memory',
@@ -76,6 +78,15 @@ export class MemoryRedisChat implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
{
displayName: 'Session Time To Live',
name: 'sessionTTL',
@@ -89,9 +100,18 @@ export class MemoryRedisChat implements INodeType {
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('redis');
const sessionKey = this.getNodeParameter('sessionKey', itemIndex) as string;
const nodeVersion = this.getNode().typeVersion;
const sessionTTL = this.getNodeParameter('sessionTTL', itemIndex, 0) as number;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionKey', itemIndex) as string;
}
const redisOptions: RedisClientOptions = {
socket: {
host: credentials.host as string,
@@ -115,7 +135,7 @@ export class MemoryRedisChat implements INodeType {
const redisChatConfig: RedisChatMessageHistoryInput = {
client,
sessionId: sessionKey,
sessionId,
};
if (sessionTTL > 0) {

View File

@@ -6,13 +6,16 @@ import { BufferMemory } from 'langchain/memory';
import { BaseClient } from '@xata.io/client';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryXata implements INodeType {
description: INodeTypeDescription = {
displayName: 'Xata',
name: 'memoryXata',
icon: 'file:xata.svg',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Use Xata Memory',
defaults: {
name: 'Xata',
@@ -69,11 +72,29 @@ export class MemoryXata implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
],
};
async supplyData(this: IExecuteFunctions, itemIndex: number): Promise<SupplyData> {
const credentials = await this.getCredentials('xataApi');
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const xataClient = new BaseClient({
apiKey: credentials.apiKey as string,
@@ -81,8 +102,6 @@ export class MemoryXata implements INodeType {
databaseURL: credentials.databaseEndpoint as string,
});
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
const table = (credentials.databaseEndpoint as string).match(
/https:\/\/[^.]+\.[^.]+\.xata\.sh\/db\/([^\/:]+)/,
);
@@ -94,18 +113,21 @@ export class MemoryXata implements INodeType {
);
}
const chatHistory = new XataChatMessageHistory({
table: table[1],
sessionId,
client: xataClient,
apiKey: credentials.apiKey as string,
});
const memory = new BufferMemory({
chatHistory: new XataChatMessageHistory({
table: table[1],
sessionId,
client: xataClient,
apiKey: credentials.apiKey as string,
}),
chatHistory,
memoryKey: 'chat_history',
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
});
return {
response: logWrapper(memory, this),
};

View File

@@ -9,6 +9,8 @@ import {
import { ZepMemory } from 'langchain/memory/zep';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';
export class MemoryZep implements INodeType {
description: INodeTypeDescription = {
@@ -17,7 +19,7 @@ export class MemoryZep implements INodeType {
// eslint-disable-next-line n8n-nodes-base/node-class-description-icon-not-svg
icon: 'file:zep.png',
group: ['transform'],
version: [1, 1.1],
version: [1, 1.1, 1.2],
description: 'Use Zep Memory',
defaults: {
name: 'Zep',
@@ -72,6 +74,15 @@ export class MemoryZep implements INodeType {
},
},
},
{
...sessionIdOption,
displayOptions: {
show: {
'@version': [{ _cnd: { gte: 1.2 } }],
},
},
},
sessionKeyProperty,
],
};
@@ -81,8 +92,15 @@ export class MemoryZep implements INodeType {
apiUrl: string;
};
// TODO: Should it get executed once per item or not?
const sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
const nodeVersion = this.getNode().typeVersion;
let sessionId;
if (nodeVersion >= 1.2) {
sessionId = getSessionId(this, itemIndex);
} else {
sessionId = this.getNodeParameter('sessionId', itemIndex) as string;
}
const memory = new ZepMemory({
sessionId,

View File

@@ -0,0 +1,35 @@
import type { INodeProperties } from 'n8n-workflow';
export const sessionIdOption: INodeProperties = {
displayName: 'Session ID',
name: 'sessionIdType',
type: 'options',
options: [
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Take from previous node automatically',
value: 'fromInput',
description: 'Looks for an input field called sessionId',
},
{
// eslint-disable-next-line n8n-nodes-base/node-param-display-name-miscased
name: 'Define below',
value: 'customKey',
description: 'Use an expression to reference data in previous nodes or enter static text',
},
],
default: 'fromInput',
};
export const sessionKeyProperty: INodeProperties = {
displayName: 'Key',
name: 'sessionKey',
type: 'string',
default: '',
description: 'The key to use to store session ID in the memory',
displayOptions: {
show: {
sessionIdType: ['customKey'],
},
},
};