Merge pull request #435 from learnhouse/feat/ai-improvements

AI performance improvements
This commit is contained in:
Badr B. 2025-02-19 17:00:22 +01:00 committed by GitHub
commit 000f1031e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 133 additions and 65 deletions

View file

@ -13,9 +13,9 @@ dependencies = [
"fastapi>=0.115.0", "fastapi>=0.115.0",
"fastapi-jwt-auth>=0.5.0", "fastapi-jwt-auth>=0.5.0",
"httpx>=0.27.0", "httpx>=0.27.0",
"langchain>=0.2.0", "langchain>=0.1.7",
"langchain-community>=0.2.0", "langchain-community>=0.0.20",
"langchain-openai>=0.1.7", "langchain-openai>=0.0.6",
"openai>=1.50.2", "openai>=1.50.2",
"passlib>=1.7.4", "passlib>=1.7.4",
"psycopg2-binary>=2.9.9", "psycopg2-binary>=2.9.9",

View file

@ -1,17 +1,15 @@
from typing import Optional from typing import Optional, Dict, Any
from uuid import uuid4 from uuid import uuid4
from langchain.agents import AgentExecutor from langchain.agents import AgentExecutor
from langchain_text_splitters import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma from langchain_community.vectorstores import Chroma
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain_core.prompts import MessagesPlaceholder from langchain.prompts import MessagesPlaceholder
from langchain_community.chat_message_histories import RedisChatMessageHistory from langchain_community.chat_message_histories import RedisChatMessageHistory
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory, AgentTokenBufferMemory,
) )
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.agents.agent_toolkits import ( from langchain.agents.agent_toolkits import (
create_retriever_tool, create_retriever_tool,
) )
@ -19,6 +17,7 @@ from langchain.agents.agent_toolkits import (
import chromadb import chromadb
from config.config import get_learnhouse_config from config.config import get_learnhouse_config
from src.services.ai.init import get_chromadb_client, get_embedding_function, get_llm
LH_CONFIG = get_learnhouse_config() LH_CONFIG = get_learnhouse_config()
client = ( client = (
@ -27,98 +26,112 @@ client = (
else chromadb.Client() else chromadb.Client()
) )
# Use efficient text splitter settings
chat_history = [] TEXT_SPLITTER = CharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
separator="\n",
length_function=len,
)
def ask_ai( def ask_ai(
question: str, question: str,
message_history, message_history: Any,
text_reference: str, text_reference: str,
message_for_the_prompt: str, message_for_the_prompt: str,
embedding_model_name: str, embedding_model_name: str,
openai_model_name: str, openai_model_name: str,
): ) -> Dict[str, Any]:
# Get API Keys """
LH_CONFIG = get_learnhouse_config() Process an AI query with improved performance using cached components
openai_api_key = LH_CONFIG.ai_config.openai_api_key """
# Get embedding function
embedding_function = get_embedding_function(embedding_model_name)
if not embedding_function:
raise Exception(f"Embedding model {embedding_model_name} not found or API key not configured")
# split it into chunks # Split text into chunks efficiently
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) documents = TEXT_SPLITTER.create_documents([text_reference])
documents = text_splitter.create_documents([text_reference])
texts = text_splitter.split_documents(documents)
embedding_models = { # Create vector store
"text-embedding-ada-002": OpenAIEmbeddings, db = Chroma.from_documents(
} documents,
embedding_function,
embedding_function = None client=get_chromadb_client()
if embedding_model_name in embedding_models:
if embedding_model_name == "text-embedding-ada-002":
embedding_function = embedding_models[embedding_model_name](
model=embedding_model_name, api_key=openai_api_key
) )
else:
raise Exception("Embedding model not found")
# load it into Chroma and use it as a retriever # Create retriever tool
db = Chroma.from_documents(texts, embedding_function) retriever_tool = create_retriever_tool(
tool = create_retriever_tool( db.as_retriever(search_kwargs={"k": 3}),
db.as_retriever(),
"find_context_text", "find_context_text",
"Find associated text to get context about a course or a lecture", "Find associated text to get context about a course or a lecture",
) )
tools = [tool]
llm = ChatOpenAI( # Get LLM
temperature=0, api_key=openai_api_key, model_name=openai_model_name llm = get_llm(openai_model_name)
) if not llm:
raise Exception(f"LLM model {openai_model_name} not found or API key not configured")
memory_key = "history"
# Setup memory with optimized token limit
memory = AgentTokenBufferMemory( memory = AgentTokenBufferMemory(
memory_key=memory_key, memory_key="history",
llm=llm, llm=llm,
chat_memory=message_history, chat_memory=message_history,
max_token_limit=1000, max_token_limit=2000, # Increased for better context retention
) )
system_message = SystemMessage(content=(message_for_the_prompt)) # Create agent with system message
system_message = SystemMessage(content=message_for_the_prompt)
prompt = OpenAIFunctionsAgent.create_prompt( prompt = OpenAIFunctionsAgent.create_prompt(
system_message=system_message, system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
) )
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) agent = OpenAIFunctionsAgent(
llm=llm,
tools=[retriever_tool],
prompt=prompt
)
# Create and execute agent
agent_executor = AgentExecutor( agent_executor = AgentExecutor(
agent=agent, agent=agent,
tools=tools, tools=[retriever_tool],
memory=memory, memory=memory,
verbose=True, verbose=True,
return_intermediate_steps=True, return_intermediate_steps=True,
handle_parsing_errors=True, handle_parsing_errors=True,
max_iterations=3, # Limit maximum iterations for better performance
) )
try:
return agent_executor({"input": question}) return agent_executor({"input": question})
except Exception as e:
raise Exception(f"Error processing AI request: {str(e)}")
def get_chat_session_history(aichat_uuid: Optional[str] = None) -> Dict[str, Any]:
def get_chat_session_history(aichat_uuid: Optional[str] = None): """Get or create a new chat session history"""
# Init Message History
session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}" session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}"
LH_CONFIG = get_learnhouse_config() LH_CONFIG = get_learnhouse_config()
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
if redis_conn_string: if redis_conn_string:
try:
message_history = RedisChatMessageHistory( message_history = RedisChatMessageHistory(
url=redis_conn_string, ttl=2160000, session_id=session_id url=redis_conn_string,
ttl=2160000, # 25 days
session_id=session_id
) )
except Exception:
print("Failed to connect to Redis, falling back to local memory")
message_history = []
else: else:
print("Redis connection string not found, using local memory") print("Redis connection string not found, using local memory")
message_history = [] message_history = []
return {"message_history": message_history, "aichat_uuid": session_id} return {
"message_history": message_history,
"aichat_uuid": session_id
}

View file

@ -0,0 +1,55 @@
from typing import Optional
from functools import lru_cache
import chromadb
from langchain_openai import OpenAIEmbeddings
from langchain_community.chat_models import ChatOpenAI
from config.config import get_learnhouse_config
@lru_cache()
def get_chromadb_client():
"""Get cached ChromaDB client instance"""
LH_CONFIG = get_learnhouse_config()
chromadb_config = getattr(LH_CONFIG.ai_config, 'chromadb_config', None)
if (
chromadb_config
and isinstance(chromadb_config.db_host, str)
and chromadb_config.db_host
and getattr(chromadb_config, 'isSeparateDatabaseEnabled', False)
):
return chromadb.HttpClient(
host=chromadb_config.db_host,
port=8000
)
return chromadb.Client()
@lru_cache()
def get_embedding_function(model_name: str) -> Optional[OpenAIEmbeddings]:
"""Get cached embedding function"""
LH_CONFIG = get_learnhouse_config()
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
if not api_key:
return None
if model_name == "text-embedding-ada-002":
return OpenAIEmbeddings(
model=model_name,
api_key=api_key
)
return None
@lru_cache()
def get_llm(model_name: str, temperature: float = 0) -> Optional[ChatOpenAI]:
"""Get cached LLM instance"""
LH_CONFIG = get_learnhouse_config()
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
if not api_key:
return None
return ChatOpenAI(
temperature=temperature,
api_key=api_key,
model=model_name
)

6
apps/api/uv.lock generated
View file

@ -1035,9 +1035,9 @@ requires-dist = [
{ name = "fastapi", specifier = ">=0.115.0" }, { name = "fastapi", specifier = ">=0.115.0" },
{ name = "fastapi-jwt-auth", specifier = ">=0.5.0" }, { name = "fastapi-jwt-auth", specifier = ">=0.5.0" },
{ name = "httpx", specifier = ">=0.27.0" }, { name = "httpx", specifier = ">=0.27.0" },
{ name = "langchain", specifier = ">=0.2.0" }, { name = "langchain", specifier = ">=0.1.7" },
{ name = "langchain-community", specifier = ">=0.2.0" }, { name = "langchain-community", specifier = ">=0.0.20" },
{ name = "langchain-openai", specifier = ">=0.1.7" }, { name = "langchain-openai", specifier = ">=0.0.6" },
{ name = "openai", specifier = ">=1.50.2" }, { name = "openai", specifier = ">=1.50.2" },
{ name = "passlib", specifier = ">=1.7.4" }, { name = "passlib", specifier = ">=1.7.4" },
{ name = "psycopg2-binary", specifier = ">=2.9.9" }, { name = "psycopg2-binary", specifier = ">=2.9.9" },