mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-18 11:59:26 +00:00
Merge pull request #435 from learnhouse/feat/ai-improvements
AI performance improvements
This commit is contained in:
commit
000f1031e7
4 changed files with 133 additions and 65 deletions
|
|
@ -13,9 +13,9 @@ dependencies = [
|
|||
"fastapi>=0.115.0",
|
||||
"fastapi-jwt-auth>=0.5.0",
|
||||
"httpx>=0.27.0",
|
||||
"langchain>=0.2.0",
|
||||
"langchain-community>=0.2.0",
|
||||
"langchain-openai>=0.1.7",
|
||||
"langchain>=0.1.7",
|
||||
"langchain-community>=0.0.20",
|
||||
"langchain-openai>=0.0.6",
|
||||
"openai>=1.50.2",
|
||||
"passlib>=1.7.4",
|
||||
"psycopg2-binary>=2.9.9",
|
||||
|
|
|
|||
|
|
@ -1,17 +1,15 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import uuid4
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain_core.prompts import MessagesPlaceholder
|
||||
from langchain.prompts import MessagesPlaceholder
|
||||
from langchain_community.chat_message_histories import RedisChatMessageHistory
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||
AgentTokenBufferMemory,
|
||||
)
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.agents.agent_toolkits import (
|
||||
create_retriever_tool,
|
||||
)
|
||||
|
|
@ -19,6 +17,7 @@ from langchain.agents.agent_toolkits import (
|
|||
import chromadb
|
||||
|
||||
from config.config import get_learnhouse_config
|
||||
from src.services.ai.init import get_chromadb_client, get_embedding_function, get_llm
|
||||
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
client = (
|
||||
|
|
@ -27,98 +26,112 @@ client = (
|
|||
else chromadb.Client()
|
||||
)
|
||||
|
||||
|
||||
chat_history = []
|
||||
|
||||
# Use efficient text splitter settings
|
||||
TEXT_SPLITTER = CharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=100,
|
||||
separator="\n",
|
||||
length_function=len,
|
||||
)
|
||||
|
||||
def ask_ai(
|
||||
question: str,
|
||||
message_history,
|
||||
message_history: Any,
|
||||
text_reference: str,
|
||||
message_for_the_prompt: str,
|
||||
embedding_model_name: str,
|
||||
openai_model_name: str,
|
||||
):
|
||||
# Get API Keys
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
openai_api_key = LH_CONFIG.ai_config.openai_api_key
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process an AI query with improved performance using cached components
|
||||
"""
|
||||
# Get embedding function
|
||||
embedding_function = get_embedding_function(embedding_model_name)
|
||||
if not embedding_function:
|
||||
raise Exception(f"Embedding model {embedding_model_name} not found or API key not configured")
|
||||
|
||||
# split it into chunks
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
documents = text_splitter.create_documents([text_reference])
|
||||
texts = text_splitter.split_documents(documents)
|
||||
|
||||
embedding_models = {
|
||||
"text-embedding-ada-002": OpenAIEmbeddings,
|
||||
}
|
||||
|
||||
embedding_function = None
|
||||
|
||||
if embedding_model_name in embedding_models:
|
||||
if embedding_model_name == "text-embedding-ada-002":
|
||||
embedding_function = embedding_models[embedding_model_name](
|
||||
model=embedding_model_name, api_key=openai_api_key
|
||||
)
|
||||
else:
|
||||
raise Exception("Embedding model not found")
|
||||
|
||||
# load it into Chroma and use it as a retriever
|
||||
db = Chroma.from_documents(texts, embedding_function)
|
||||
tool = create_retriever_tool(
|
||||
db.as_retriever(),
|
||||
# Split text into chunks efficiently
|
||||
documents = TEXT_SPLITTER.create_documents([text_reference])
|
||||
|
||||
# Create vector store
|
||||
db = Chroma.from_documents(
|
||||
documents,
|
||||
embedding_function,
|
||||
client=get_chromadb_client()
|
||||
)
|
||||
|
||||
# Create retriever tool
|
||||
retriever_tool = create_retriever_tool(
|
||||
db.as_retriever(search_kwargs={"k": 3}),
|
||||
"find_context_text",
|
||||
"Find associated text to get context about a course or a lecture",
|
||||
)
|
||||
tools = [tool]
|
||||
|
||||
llm = ChatOpenAI(
|
||||
temperature=0, api_key=openai_api_key, model_name=openai_model_name
|
||||
)
|
||||
|
||||
memory_key = "history"
|
||||
|
||||
# Get LLM
|
||||
llm = get_llm(openai_model_name)
|
||||
if not llm:
|
||||
raise Exception(f"LLM model {openai_model_name} not found or API key not configured")
|
||||
|
||||
# Setup memory with optimized token limit
|
||||
memory = AgentTokenBufferMemory(
|
||||
memory_key=memory_key,
|
||||
memory_key="history",
|
||||
llm=llm,
|
||||
chat_memory=message_history,
|
||||
max_token_limit=1000,
|
||||
max_token_limit=2000, # Increased for better context retention
|
||||
)
|
||||
|
||||
system_message = SystemMessage(content=(message_for_the_prompt))
|
||||
|
||||
# Create agent with system message
|
||||
system_message = SystemMessage(content=message_for_the_prompt)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(
|
||||
system_message=system_message,
|
||||
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
|
||||
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
|
||||
)
|
||||
|
||||
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
tools=[retriever_tool],
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Create and execute agent
|
||||
agent_executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
tools=[retriever_tool],
|
||||
memory=memory,
|
||||
verbose=True,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
max_iterations=3, # Limit maximum iterations for better performance
|
||||
)
|
||||
|
||||
return agent_executor({"input": question})
|
||||
try:
|
||||
return agent_executor({"input": question})
|
||||
except Exception as e:
|
||||
raise Exception(f"Error processing AI request: {str(e)}")
|
||||
|
||||
|
||||
def get_chat_session_history(aichat_uuid: Optional[str] = None):
|
||||
# Init Message History
|
||||
def get_chat_session_history(aichat_uuid: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get or create a new chat session history"""
|
||||
session_id = aichat_uuid if aichat_uuid else f"aichat_{uuid4()}"
|
||||
|
||||
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
|
||||
|
||||
if redis_conn_string:
|
||||
message_history = RedisChatMessageHistory(
|
||||
url=redis_conn_string, ttl=2160000, session_id=session_id
|
||||
)
|
||||
try:
|
||||
message_history = RedisChatMessageHistory(
|
||||
url=redis_conn_string,
|
||||
ttl=2160000, # 25 days
|
||||
session_id=session_id
|
||||
)
|
||||
except Exception:
|
||||
print("Failed to connect to Redis, falling back to local memory")
|
||||
message_history = []
|
||||
else:
|
||||
print("Redis connection string not found, using local memory")
|
||||
message_history = []
|
||||
|
||||
return {"message_history": message_history, "aichat_uuid": session_id}
|
||||
return {
|
||||
"message_history": message_history,
|
||||
"aichat_uuid": session_id
|
||||
}
|
||||
|
||||
|
|
|
|||
55
apps/api/src/services/ai/init.py
Normal file
55
apps/api/src/services/ai/init.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
import chromadb
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from config.config import get_learnhouse_config
|
||||
|
||||
@lru_cache()
|
||||
def get_chromadb_client():
|
||||
"""Get cached ChromaDB client instance"""
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
chromadb_config = getattr(LH_CONFIG.ai_config, 'chromadb_config', None)
|
||||
|
||||
if (
|
||||
chromadb_config
|
||||
and isinstance(chromadb_config.db_host, str)
|
||||
and chromadb_config.db_host
|
||||
and getattr(chromadb_config, 'isSeparateDatabaseEnabled', False)
|
||||
):
|
||||
return chromadb.HttpClient(
|
||||
host=chromadb_config.db_host,
|
||||
port=8000
|
||||
)
|
||||
return chromadb.Client()
|
||||
|
||||
@lru_cache()
|
||||
def get_embedding_function(model_name: str) -> Optional[OpenAIEmbeddings]:
|
||||
"""Get cached embedding function"""
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
if model_name == "text-embedding-ada-002":
|
||||
return OpenAIEmbeddings(
|
||||
model=model_name,
|
||||
api_key=api_key
|
||||
)
|
||||
return None
|
||||
|
||||
@lru_cache()
|
||||
def get_llm(model_name: str, temperature: float = 0) -> Optional[ChatOpenAI]:
|
||||
"""Get cached LLM instance"""
|
||||
LH_CONFIG = get_learnhouse_config()
|
||||
api_key = getattr(LH_CONFIG.ai_config, 'openai_api_key', None)
|
||||
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
model=model_name
|
||||
)
|
||||
6
apps/api/uv.lock
generated
6
apps/api/uv.lock
generated
|
|
@ -1035,9 +1035,9 @@ requires-dist = [
|
|||
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||
{ name = "fastapi-jwt-auth", specifier = ">=0.5.0" },
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "langchain", specifier = ">=0.2.0" },
|
||||
{ name = "langchain-community", specifier = ">=0.2.0" },
|
||||
{ name = "langchain-openai", specifier = ">=0.1.7" },
|
||||
{ name = "langchain", specifier = ">=0.1.7" },
|
||||
{ name = "langchain-community", specifier = ">=0.0.20" },
|
||||
{ name = "langchain-openai", specifier = ">=0.0.6" },
|
||||
{ name = "openai", specifier = ">=1.50.2" },
|
||||
{ name = "passlib", specifier = ">=1.7.4" },
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.9" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue