mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: AI performance improvements
This commit is contained in:
parent
09f3078f2b
commit
66e317e0e8
2 changed files with 124 additions and 57 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Optional, Dict, Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from langchain.agents import AgentExecutor
|
from langchain.agents import AgentExecutor
|
||||||
from langchain_text_splitters import CharacterTextSplitter
|
from langchain_text_splitters import CharacterTextSplitter
|
||||||
|
|
@ -10,8 +10,6 @@ from langchain_core.messages import SystemMessage
|
||||||
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||||
AgentTokenBufferMemory,
|
AgentTokenBufferMemory,
|
||||||
)
|
)
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain.agents.agent_toolkits import (
|
from langchain.agents.agent_toolkits import (
|
||||||
create_retriever_tool,
|
create_retriever_tool,
|
||||||
)
|
)
|
||||||
|
|
@ -19,6 +17,7 @@ from langchain.agents.agent_toolkits import (
|
||||||
import chromadb
|
import chromadb
|
||||||
|
|
||||||
from config.config import get_learnhouse_config
|
from config.config import get_learnhouse_config
|
||||||
|
from src.services.ai.init import get_chromadb_client, get_embedding_function, get_llm
|
||||||
|
|
||||||
LH_CONFIG = get_learnhouse_config()
|
LH_CONFIG = get_learnhouse_config()
|
||||||
client = (
|
client = (
|
||||||
|
|
@ -27,98 +26,112 @@ client = (
|
||||||
else chromadb.Client()
|
else chromadb.Client()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use efficient text splitter settings
|
||||||
chat_history = []
|
TEXT_SPLITTER = CharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=100,
|
||||||
|
separator="\n",
|
||||||
|
length_function=len,
|
||||||
|
)
|
||||||
|
|
||||||
def ask_ai(
|
def ask_ai(
|
||||||
question: str,
|
question: str,
|
||||||
message_history,
|
message_history: Any,
|
||||||
text_reference: str,
|
text_reference: str,
|
||||||
message_for_the_prompt: str,
|
message_for_the_prompt: str,
|
||||||
embedding_model_name: str,
|
embedding_model_name: str,
|
||||||
openai_model_name: str,
|
openai_model_name: str,
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
# Get API Keys
|
"""
|
||||||
LH_CONFIG = get_learnhouse_config()
|
Process an AI query with improved performance using cached components
|
||||||
openai_api_key = LH_CONFIG.ai_config.openai_api_key
|
"""
|
||||||
|
# Get embedding function
|
||||||
|
embedding_function = get_embedding_function(embedding_model_name)
|
||||||
|
if not embedding_function:
|
||||||
|
raise Exception(f"Embedding model {embedding_model_name} not found or API key not configured")
|
||||||
|
|
||||||
# split it into chunks
|
# Split text into chunks efficiently
|
||||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
documents = TEXT_SPLITTER.create_documents([text_reference])
|
||||||
documents = text_splitter.create_documents([text_reference])
|
|
||||||
texts = text_splitter.split_documents(documents)
|
|
||||||
|
|
||||||
embedding_models = {
|
# Create vector store
|
||||||
"text-embedding-ada-002": OpenAIEmbeddings,
|
db = Chroma.from_documents(
|
||||||
}
|
documents,
|
||||||
|
embedding_function,
|
||||||
embedding_function = None
|
client=get_chromadb_client()
|
||||||
|
|
||||||
if embedding_model_name in embedding_models:
|
|
||||||
if embedding_model_name == "text-embedding-ada-002":
|
|
||||||
embedding_function = embedding_models[embedding_model_name](
|
|
||||||
model=embedding_model_name, api_key=openai_api_key
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise Exception("Embedding model not found")
|
|
||||||
|
|
||||||
# load it into Chroma and use it as a retriever
|
# Create retriever tool
|
||||||
db = Chroma.from_documents(texts, embedding_function)
|
retriever_tool = create_retriever_tool(
|
||||||
tool = create_retriever_tool(
|
db.as_retriever(search_kwargs={"k": 3}),
|
||||||
db.as_retriever(),
|
|
||||||
"find_context_text",
|
"find_context_text",
|
||||||
"Find associated text to get context about a course or a lecture",
|
"Find associated text to get context about a course or a lecture",
|
||||||
)
|
)
|
||||||
tools = [tool]
|
|
||||||
|
|
||||||
llm = ChatOpenAI(
|
# Get LLM
|
||||||
temperature=0, api_key=openai_api_key, model_name=openai_model_name
|
llm = get_llm(openai_model_name)
|
||||||
)
|
if not llm:
|
||||||
|
raise Exception(f"LLM model {openai_model_name} not found or API key not configured")
|
||||||
memory_key = "history"
|
|
||||||
|
|
||||||
|
# Setup memory with optimized token limit
|
||||||
memory = AgentTokenBufferMemory(
|
memory = AgentTokenBufferMemory(
|
||||||
memory_key=memory_key,
|
memory_key="history",
|
||||||
llm=llm,
|
llm=llm,
|
||||||
chat_memory=message_history,
|
chat_memory=message_history,
|
||||||
max_token_limit=1000,
|
max_token_limit=2000, # Increased for better context retention
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = SystemMessage(content=(message_for_the_prompt))
|
# Create agent with system message
|
||||||
|
system_message = SystemMessage(content=message_for_the_prompt)
|
||||||
prompt = OpenAIFunctionsAgent.create_prompt(
|
prompt = OpenAIFunctionsAgent.create_prompt(
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
|
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
agent = OpenAIFunctionsAgent(
|
||||||
|
llm=llm,
|
||||||
|
tools=[retriever_tool],
|
||||||
|
prompt=prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and execute agent
|
||||||
agent_executor = AgentExecutor(
|
agent_executor = AgentExecutor(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tools=tools,
|
tools=[retriever_tool],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
return_intermediate_steps=True,
|
return_intermediate_steps=True,
|
||||||
handle_parsing_errors=True,
|
handle_parsing_errors=True,
|
||||||
|
max_iterations=3, # Limit maximum iterations for better performance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
return agent_executor({"input": question})
|
return agent_executor({"input": question})
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error processing AI request: {str(e)}")
|
||||||
|
|
||||||
|
def get_chat_session_history(aichat_uuid: Optional[str] = None) -> Dict[str, Any]:
|
||||||
def get_chat_session_history(aichat_uuid: Optional[str] = None):
|
"""Get or create a new chat session history"""
|
||||||
# Init Message History
|
|
||||||
session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}"
|
session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}"
|
||||||
|
|
||||||
LH_CONFIG = get_learnhouse_config()
|
LH_CONFIG = get_learnhouse_config()
|
||||||
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
|
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
|
||||||
|
|
||||||
if redis_conn_string:
|
if redis_conn_string:
|
||||||
|
try:
|
||||||
message_history = RedisChatMessageHistory(
|
message_history = RedisChatMessageHistory(
|
||||||
url=redis_conn_string, ttl=2160000, session_id=session_id
|
url=redis_conn_string,
|
||||||
|
ttl=2160000, # 25 days
|
||||||
|
session_id=session_id
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
print("Failed to connect to Redis, falling back to local memory")
|
||||||
|
message_history = []
|
||||||
else:
|
else:
|
||||||
print("Redis connection string not found, using local memory")
|
print("Redis connection string not found, using local memory")
|
||||||
message_history = []
|
message_history = []
|
||||||
|
|
||||||
return {"message_history": message_history, "aichat_uuid": session_id}
|
return {
|
||||||
|
"message_history": message_history,
|
||||||
|
"aichat_uuid": session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
||||||
54
apps/api/src/services/ai/init.py
Normal file
54
apps/api/src/services/ai/init.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
import chromadb
|
||||||
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||||
|
from config.config import get_learnhouse_config
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_chromadb_client():
|
||||||
|
"""Get cached ChromaDB client instance"""
|
||||||
|
LH_CONFIG = get_learnhouse_config()
|
||||||
|
chromadb_config = getattr(LH_CONFIG.ai_config, 'chromadb_config', None)
|
||||||
|
|
||||||
|
if (
|
||||||
|
chromadb_config
|
||||||
|
and isinstance(chromadb_config.db_host, str)
|
||||||
|
and chromadb_config.db_host
|
||||||
|
and getattr(chromadb_config, 'isSeparateDatabaseEnabled', False)
|
||||||
|
):
|
||||||
|
return chromadb.HttpClient(
|
||||||
|
host=chromadb_config.db_host,
|
||||||
|
port=8000
|
||||||
|
)
|
||||||
|
return chromadb.Client()
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_embedding_function(model_name: str) -> Optional[OpenAIEmbeddings]:
|
||||||
|
"""Get cached embedding function"""
|
||||||
|
LH_CONFIG = get_learnhouse_config()
|
||||||
|
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if model_name == "text-embedding-ada-002":
|
||||||
|
return OpenAIEmbeddings(
|
||||||
|
model=model_name,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_llm(model_name: str, temperature: float = 0) -> Optional[ChatOpenAI]:
|
||||||
|
"""Get cached LLM instance"""
|
||||||
|
LH_CONFIG = get_learnhouse_config()
|
||||||
|
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=api_key,
|
||||||
|
model=model_name
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue