Merge pull request #435 from learnhouse/feat/ai-improvements

AI performance improvements
This commit is contained in:
Badr B. 2025-02-19 17:00:22 +01:00 committed by GitHub
commit 000f1031e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 133 additions and 65 deletions

View file

@ -13,9 +13,9 @@ dependencies = [
"fastapi>=0.115.0",
"fastapi-jwt-auth>=0.5.0",
"httpx>=0.27.0",
"langchain>=0.2.0",
"langchain-community>=0.2.0",
"langchain-openai>=0.1.7",
"langchain>=0.1.7",
"langchain-community>=0.0.20",
"langchain-openai>=0.0.6",
"openai>=1.50.2",
"passlib>=1.7.4",
"psycopg2-binary>=2.9.9",

View file

@ -1,17 +1,15 @@
from typing import Optional
from typing import Optional, Dict, Any
from uuid import uuid4
from langchain.agents import AgentExecutor
from langchain_text_splitters import CharacterTextSplitter
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain_core.prompts import MessagesPlaceholder
from langchain.prompts import MessagesPlaceholder
from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import SystemMessage
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.agents.agent_toolkits import (
create_retriever_tool,
)
@ -19,6 +17,7 @@ from langchain.agents.agent_toolkits import (
import chromadb
from config.config import get_learnhouse_config
from src.services.ai.init import get_chromadb_client, get_embedding_function, get_llm
LH_CONFIG = get_learnhouse_config()
client = (
@ -27,98 +26,112 @@ client = (
else chromadb.Client()
)
chat_history = []
# Use efficient text splitter settings
TEXT_SPLITTER = CharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
separator="\n",
length_function=len,
)
def ask_ai(
question: str,
message_history,
message_history: Any,
text_reference: str,
message_for_the_prompt: str,
embedding_model_name: str,
openai_model_name: str,
):
# Get API Keys
LH_CONFIG = get_learnhouse_config()
openai_api_key = LH_CONFIG.ai_config.openai_api_key
) -> Dict[str, Any]:
"""
Process an AI query with improved performance using cached components
"""
# Get embedding function
embedding_function = get_embedding_function(embedding_model_name)
if not embedding_function:
raise Exception(f"Embedding model {embedding_model_name} not found or API key not configured")
# split it into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.create_documents([text_reference])
texts = text_splitter.split_documents(documents)
embedding_models = {
"text-embedding-ada-002": OpenAIEmbeddings,
}
embedding_function = None
if embedding_model_name in embedding_models:
if embedding_model_name == "text-embedding-ada-002":
embedding_function = embedding_models[embedding_model_name](
model=embedding_model_name, api_key=openai_api_key
)
else:
raise Exception("Embedding model not found")
# load it into Chroma and use it as a retriever
db = Chroma.from_documents(texts, embedding_function)
tool = create_retriever_tool(
db.as_retriever(),
# Split text into chunks efficiently
documents = TEXT_SPLITTER.create_documents([text_reference])
# Create vector store
db = Chroma.from_documents(
documents,
embedding_function,
client=get_chromadb_client()
)
# Create retriever tool
retriever_tool = create_retriever_tool(
db.as_retriever(search_kwargs={"k": 3}),
"find_context_text",
"Find associated text to get context about a course or a lecture",
)
tools = [tool]
llm = ChatOpenAI(
temperature=0, api_key=openai_api_key, model_name=openai_model_name
)
memory_key = "history"
# Get LLM
llm = get_llm(openai_model_name)
if not llm:
raise Exception(f"LLM model {openai_model_name} not found or API key not configured")
# Setup memory with optimized token limit
memory = AgentTokenBufferMemory(
memory_key=memory_key,
memory_key="history",
llm=llm,
chat_memory=message_history,
max_token_limit=1000,
max_token_limit=2000, # Increased for better context retention
)
system_message = SystemMessage(content=(message_for_the_prompt))
# Create agent with system message
system_message = SystemMessage(content=message_for_the_prompt)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
)
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
agent = OpenAIFunctionsAgent(
llm=llm,
tools=[retriever_tool],
prompt=prompt
)
# Create and execute agent
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
tools=[retriever_tool],
memory=memory,
verbose=True,
return_intermediate_steps=True,
handle_parsing_errors=True,
max_iterations=3, # Limit maximum iterations for better performance
)
return agent_executor({"input": question})
try:
return agent_executor({"input": question})
except Exception as e:
raise Exception(f"Error processing AI request: {str(e)}")
def get_chat_session_history(aichat_uuid: Optional[str] = None):
# Init Message History
def get_chat_session_history(aichat_uuid: Optional[str] = None) -> Dict[str, Any]:
"""Get or create a new chat session history"""
session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}"
LH_CONFIG = get_learnhouse_config()
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
if redis_conn_string:
message_history = RedisChatMessageHistory(
url=redis_conn_string, ttl=2160000, session_id=session_id
)
try:
message_history = RedisChatMessageHistory(
url=redis_conn_string,
ttl=2160000, # 25 days
session_id=session_id
)
except Exception:
print("Failed to connect to Redis, falling back to local memory")
message_history = []
else:
print("Redis connection string not found, using local memory")
message_history = []
return {"message_history": message_history, "aichat_uuid": session_id}
return {
"message_history": message_history,
"aichat_uuid": session_id
}

View file

@ -0,0 +1,55 @@
from typing import Optional
from functools import lru_cache
import chromadb
from langchain_openai import OpenAIEmbeddings
from langchain_community.chat_models import ChatOpenAI
from config.config import get_learnhouse_config
@lru_cache()
def get_chromadb_client():
"""Get cached ChromaDB client instance"""
LH_CONFIG = get_learnhouse_config()
chromadb_config = getattr(LH_CONFIG.ai_config, 'chromadb_config', None)
if (
chromadb_config
and isinstance(chromadb_config.db_host, str)
and chromadb_config.db_host
and getattr(chromadb_config, 'isSeparateDatabaseEnabled', False)
):
return chromadb.HttpClient(
host=chromadb_config.db_host,
port=8000
)
return chromadb.Client()
@lru_cache()
def get_embedding_function(model_name: str) -> Optional[OpenAIEmbeddings]:
"""Get cached embedding function"""
LH_CONFIG = get_learnhouse_config()
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
if not api_key:
return None
if model_name == "text-embedding-ada-002":
return OpenAIEmbeddings(
model=model_name,
api_key=api_key
)
return None
@lru_cache()
def get_llm(model_name: str, temperature: float = 0) -> Optional[ChatOpenAI]:
"""Get cached LLM instance"""
LH_CONFIG = get_learnhouse_config()
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
if not api_key:
return None
return ChatOpenAI(
temperature=temperature,
api_key=api_key,
model=model_name
)

6
apps/api/uv.lock generated
View file

@ -1035,9 +1035,9 @@ requires-dist = [
{ name = "fastapi", specifier = ">=0.115.0" },
{ name = "fastapi-jwt-auth", specifier = ">=0.5.0" },
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "langchain", specifier = ">=0.2.0" },
{ name = "langchain-community", specifier = ">=0.2.0" },
{ name = "langchain-openai", specifier = ">=0.1.7" },
{ name = "langchain", specifier = ">=0.1.7" },
{ name = "langchain-community", specifier = ">=0.0.20" },
{ name = "langchain-openai", specifier = ">=0.0.6" },
{ name = "openai", specifier = ">=1.50.2" },
{ name = "passlib", specifier = ">=1.7.4" },
{ name = "psycopg2-binary", specifier = ">=2.9.9" },