feat: init memory into ai chat messaging

This commit is contained in:
swve 2023-12-31 16:31:43 +00:00
parent f7d76eea1e
commit cf681b2260
9 changed files with 163 additions and 22 deletions

View file

@ -1,3 +1,5 @@
from typing import Optional
from uuid import uuid4
from langchain.agents import AgentExecutor
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
@ -5,7 +7,7 @@ from langchain.vectorstores import Chroma
from langchain_core.messages import BaseMessage
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.prompts import MessagesPlaceholder
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import SystemMessage
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
@ -27,7 +29,7 @@ chat_history = []
def ask_ai(
question: str,
chat_history: list[BaseMessage],
message_history,
text_reference: str,
message_for_the_prompt: str,
):
@ -52,14 +54,13 @@ def ask_ai(
)
tools = [tool]
llm = ChatOpenAI(
temperature=0, api_key=openai_api_key
)
llm = ChatOpenAI(temperature=0, api_key=openai_api_key)
memory_key = "history"
memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
memory = AgentTokenBufferMemory(
memory_key=memory_key, llm=llm, chat_memory=message_history
)
system_message = SystemMessage(content=(message_for_the_prompt))
@ -70,7 +71,6 @@ def ask_ai(
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
@ -80,3 +80,21 @@ def ask_ai(
)
return agent_executor({"input": question})
def get_chat_session_history(aichat_uuid: Optional[str] = None):
# Init Message History
session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}"
LH_CONFIG = get_learnhouse_config()
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
if redis_conn_string:
message_history = RedisChatMessageHistory(
url=redis_conn_string, ttl=2160000, session_id=session_id
)
else:
print("Redis connection string not found, using local memory")
message_history = []
return {"message_history": message_history, "aichat_uuid": session_id}