diff --git a/apps/api/src/services/ai/base.py b/apps/api/src/services/ai/base.py index cf5c029f..30965297 100644 --- a/apps/api/src/services/ai/base.py +++ b/apps/api/src/services/ai/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Dict, Any from uuid import uuid4 from langchain.agents import AgentExecutor from langchain_text_splitters import CharacterTextSplitter @@ -10,8 +10,6 @@ from langchain_core.messages import SystemMessage from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( AgentTokenBufferMemory, ) -from langchain_openai import OpenAIEmbeddings -from langchain_openai import ChatOpenAI from langchain.agents.agent_toolkits import ( create_retriever_tool, ) @@ -19,6 +17,7 @@ from langchain.agents.agent_toolkits import ( import chromadb from config.config import get_learnhouse_config +from src.services.ai.init import get_chromadb_client, get_embedding_function, get_llm LH_CONFIG = get_learnhouse_config() client = ( @@ -27,98 +26,112 @@ client = ( else chromadb.Client() ) - -chat_history = [] - +# Use efficient text splitter settings +TEXT_SPLITTER = CharacterTextSplitter( + chunk_size=1000, + chunk_overlap=100, + separator="\n", + length_function=len, +) def ask_ai( question: str, - message_history, + message_history: Any, text_reference: str, message_for_the_prompt: str, embedding_model_name: str, openai_model_name: str, -): - # Get API Keys - LH_CONFIG = get_learnhouse_config() - openai_api_key = LH_CONFIG.ai_config.openai_api_key +) -> Dict[str, Any]: + """ + Process an AI query with improved performance using cached components + """ + # Get embedding function + embedding_function = get_embedding_function(embedding_model_name) + if not embedding_function: + raise Exception(f"Embedding model {embedding_model_name} not found or API key not configured") - # split it into chunks - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - documents = text_splitter.create_documents([text_reference]) - texts = text_splitter.split_documents(documents) - - embedding_models = { - "text-embedding-ada-002": OpenAIEmbeddings, - } - - embedding_function = None - - if embedding_model_name in embedding_models: - if embedding_model_name == "text-embedding-ada-002": - embedding_function = embedding_models[embedding_model_name]( - model=embedding_model_name, api_key=openai_api_key - ) - else: - raise Exception("Embedding model not found") - - # load it into Chroma and use it as a retriever - db = Chroma.from_documents(texts, embedding_function) - tool = create_retriever_tool( - db.as_retriever(), + # Split text into chunks efficiently + documents = TEXT_SPLITTER.create_documents([text_reference]) + + # Create vector store + db = Chroma.from_documents( + documents, + embedding_function, + client=get_chromadb_client() + ) + + # Create retriever tool + retriever_tool = create_retriever_tool( + db.as_retriever(search_kwargs={"k": 3}), "find_context_text", "Find associated text to get context about a course or a lecture", ) - tools = [tool] - - llm = ChatOpenAI( - temperature=0, api_key=openai_api_key, model_name=openai_model_name - ) - - memory_key = "history" + + # Get LLM + llm = get_llm(openai_model_name) + if not llm: + raise Exception(f"LLM model {openai_model_name} not found or API key not configured") + # Setup memory with optimized token limit memory = AgentTokenBufferMemory( - memory_key=memory_key, + memory_key="history", llm=llm, chat_memory=message_history, - max_token_limit=1000, + max_token_limit=2000, # Increased for better context retention ) - system_message = SystemMessage(content=(message_for_the_prompt)) - + # Create agent with system message + system_message = SystemMessage(content=message_for_the_prompt) prompt = OpenAIFunctionsAgent.create_prompt( system_message=system_message, - extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)], + extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], ) - agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) + agent = OpenAIFunctionsAgent( + llm=llm, + tools=[retriever_tool], + prompt=prompt + ) + # Create and execute agent agent_executor = AgentExecutor( agent=agent, - tools=tools, + tools=[retriever_tool], memory=memory, verbose=True, return_intermediate_steps=True, handle_parsing_errors=True, + max_iterations=3, # Limit maximum iterations for better performance ) - return agent_executor({"input": question}) + try: + return agent_executor({"input": question}) + except Exception as e: + raise Exception(f"Error processing AI request: {str(e)}") - -def get_chat_session_history(aichat_uuid: Optional[str] = None): - # Init Message History +def get_chat_session_history(aichat_uuid: Optional[str] = None) -> Dict[str, Any]: + """Get or create a new chat session history""" session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}" - + LH_CONFIG = get_learnhouse_config() redis_conn_string = LH_CONFIG.redis_config.redis_connection_string if redis_conn_string: - message_history = RedisChatMessageHistory( - url=redis_conn_string, ttl=2160000, session_id=session_id - ) + try: + message_history = RedisChatMessageHistory( + url=redis_conn_string, + ttl=2160000, # 25 days + session_id=session_id + ) + except Exception: + print("Failed to connect to Redis, falling back to local memory") + message_history = [] else: print("Redis connection string not found, using local memory") message_history = [] - return {"message_history": message_history, "aichat_uuid": session_id} + return { + "message_history": message_history, + "aichat_uuid": session_id + } diff --git a/apps/api/src/services/ai/init.py b/apps/api/src/services/ai/init.py new file mode 100644 index 00000000..47f6733f --- /dev/null +++ b/apps/api/src/services/ai/init.py @@ -0,0 +1,54 @@ +from typing import Optional +from functools import lru_cache +import chromadb +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from config.config import get_learnhouse_config + +@lru_cache() +def get_chromadb_client(): + """Get cached ChromaDB client instance""" + LH_CONFIG = get_learnhouse_config() + chromadb_config = getattr(LH_CONFIG.ai_config, 'chromadb_config', None) + + if ( + chromadb_config + and isinstance(chromadb_config.db_host, str) + and chromadb_config.db_host + and getattr(chromadb_config, 'isSeparateDatabaseEnabled', False) + ): + return chromadb.HttpClient( + host=chromadb_config.db_host, + port=8000 + ) + return chromadb.Client() + +@lru_cache() +def get_embedding_function(model_name: str) -> Optional[OpenAIEmbeddings]: + """Get cached embedding function""" + LH_CONFIG = get_learnhouse_config() + api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None) + + if not api_key: + return None + + if model_name == "text-embedding-ada-002": + return OpenAIEmbeddings( + model=model_name, + api_key=api_key + ) + return None + +@lru_cache() +def get_llm(model_name: str, temperature: float = 0) -> Optional[ChatOpenAI]: + """Get cached LLM instance""" + LH_CONFIG = get_learnhouse_config() + api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None) + + if not api_key: + return None + + return ChatOpenAI( + temperature=temperature, + api_key=api_key, + model=model_name + ) \ No newline at end of file