mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: init ai activity chat session
This commit is contained in:
parent
ddab6d6483
commit
f7d76eea1e
10 changed files with 305 additions and 8 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import os
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -23,6 +24,11 @@ class SecurityConfig(BaseModel):
|
||||||
auth_jwt_secret_key: str
|
auth_jwt_secret_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class AIConfig(BaseModel):
|
||||||
|
openai_api_key: str | None
|
||||||
|
is_ai_enabled: bool | None
|
||||||
|
|
||||||
|
|
||||||
class S3ApiConfig(BaseModel):
|
class S3ApiConfig(BaseModel):
|
||||||
bucket_name: str | None
|
bucket_name: str | None
|
||||||
endpoint_url: str | None
|
endpoint_url: str | None
|
||||||
|
|
@ -58,9 +64,13 @@ class LearnHouseConfig(BaseModel):
|
||||||
hosting_config: HostingConfig
|
hosting_config: HostingConfig
|
||||||
database_config: DatabaseConfig
|
database_config: DatabaseConfig
|
||||||
security_config: SecurityConfig
|
security_config: SecurityConfig
|
||||||
|
ai_config: AIConfig
|
||||||
|
|
||||||
|
|
||||||
def get_learnhouse_config() -> LearnHouseConfig:
|
def get_learnhouse_config() -> LearnHouseConfig:
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
# Get the YAML file
|
# Get the YAML file
|
||||||
yaml_path = os.path.join(os.path.dirname(__file__), "config.yaml")
|
yaml_path = os.path.join(os.path.dirname(__file__), "config.yaml")
|
||||||
|
|
||||||
|
|
@ -173,6 +183,16 @@ def get_learnhouse_config() -> LearnHouseConfig:
|
||||||
"mongo_connection_string"
|
"mongo_connection_string"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# AI Config
|
||||||
|
env_openai_api_key = os.environ.get("LEARNHOUSE_OPENAI_API_KEY")
|
||||||
|
env_is_ai_enabled = os.environ.get("LEARNHOUSE_IS_AI_ENABLED")
|
||||||
|
openai_api_key = env_openai_api_key or yaml_config.get("ai_config", {}).get(
|
||||||
|
"openai_api_key"
|
||||||
|
)
|
||||||
|
is_ai_enabled = env_is_ai_enabled or yaml_config.get("ai_config", {}).get(
|
||||||
|
"is_ai_enabled"
|
||||||
|
)
|
||||||
|
|
||||||
# Sentry config
|
# Sentry config
|
||||||
# check if the sentry config is provided in the YAML file
|
# check if the sentry config is provided in the YAML file
|
||||||
sentry_config_verif = (
|
sentry_config_verif = (
|
||||||
|
|
@ -217,6 +237,12 @@ def get_learnhouse_config() -> LearnHouseConfig:
|
||||||
mongo_connection_string=mongo_connection_string,
|
mongo_connection_string=mongo_connection_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# AI Config
|
||||||
|
ai_config = AIConfig(
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
is_ai_enabled=bool(is_ai_enabled),
|
||||||
|
)
|
||||||
|
|
||||||
# Create LearnHouseConfig object
|
# Create LearnHouseConfig object
|
||||||
config = LearnHouseConfig(
|
config = LearnHouseConfig(
|
||||||
site_name=site_name,
|
site_name=site_name,
|
||||||
|
|
@ -228,6 +254,7 @@ def get_learnhouse_config() -> LearnHouseConfig:
|
||||||
hosting_config=hosting_config,
|
hosting_config=hosting_config,
|
||||||
database_config=database_config,
|
database_config=database_config,
|
||||||
security_config=SecurityConfig(auth_jwt_secret_key=auth_jwt_secret_key),
|
security_config=SecurityConfig(auth_jwt_secret_key=auth_jwt_secret_key),
|
||||||
|
ai_config=ai_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
|
||||||
|
|
@ -17,4 +17,10 @@ faker
|
||||||
requests
|
requests
|
||||||
pyyaml
|
pyyaml
|
||||||
sentry-sdk[fastapi]
|
sentry-sdk[fastapi]
|
||||||
pydantic[email]>=1.8.0,<2.0.0
|
pydantic[email]>=1.8.0,<2.0.0
|
||||||
|
langchain
|
||||||
|
tiktoken
|
||||||
|
openai
|
||||||
|
chromadb
|
||||||
|
sentence-transformers
|
||||||
|
python-dotenv
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class CourseUpdate(CourseBase):
|
||||||
class CourseRead(CourseBase):
|
class CourseRead(CourseBase):
|
||||||
id: int
|
id: int
|
||||||
org_id: int = Field(default=None, foreign_key="organization.id")
|
org_id: int = Field(default=None, foreign_key="organization.id")
|
||||||
authors: List[UserRead]
|
authors: Optional[List[UserRead]]
|
||||||
course_uuid: str
|
course_uuid: str
|
||||||
creation_date: str
|
creation_date: str
|
||||||
update_date: str
|
update_date: str
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from src.routers import blocks, dev, trail, users, auth, orgs, roles
|
from src.routers import blocks, dev, trail, users, auth, orgs, roles
|
||||||
|
from src.routers.ai import ai
|
||||||
from src.routers.courses import chapters, collections, courses, activities
|
from src.routers.courses import chapters, collections, courses, activities
|
||||||
from src.routers.install import install
|
from src.routers.install import install
|
||||||
from src.services.dev.dev import isDevModeEnabledOrRaise
|
from src.services.dev.dev import isDevModeEnabledOrRaise
|
||||||
|
|
@ -18,14 +19,16 @@ v1_router.include_router(blocks.router, prefix="/blocks", tags=["blocks"])
|
||||||
v1_router.include_router(courses.router, prefix="/courses", tags=["courses"])
|
v1_router.include_router(courses.router, prefix="/courses", tags=["courses"])
|
||||||
v1_router.include_router(chapters.router, prefix="/chapters", tags=["chapters"])
|
v1_router.include_router(chapters.router, prefix="/chapters", tags=["chapters"])
|
||||||
v1_router.include_router(activities.router, prefix="/activities", tags=["activities"])
|
v1_router.include_router(activities.router, prefix="/activities", tags=["activities"])
|
||||||
v1_router.include_router(
|
v1_router.include_router(collections.router, prefix="/collections", tags=["collections"])
|
||||||
collections.router, prefix="/collections", tags=["collections"]
|
|
||||||
)
|
|
||||||
v1_router.include_router(trail.router, prefix="/trail", tags=["trail"])
|
v1_router.include_router(trail.router, prefix="/trail", tags=["trail"])
|
||||||
|
v1_router.include_router(ai.router, prefix="/ai", tags=["ai"])
|
||||||
|
|
||||||
# Dev Routes
|
# Dev Routes
|
||||||
v1_router.include_router(
|
v1_router.include_router(
|
||||||
dev.router, prefix="/dev", tags=["dev"], dependencies=[Depends(isDevModeEnabledOrRaise)]
|
dev.router,
|
||||||
|
prefix="/dev",
|
||||||
|
tags=["dev"],
|
||||||
|
dependencies=[Depends(isDevModeEnabledOrRaise)],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Install Routes
|
# Install Routes
|
||||||
|
|
@ -35,4 +38,3 @@ v1_router.include_router(
|
||||||
tags=["install"],
|
tags=["install"],
|
||||||
dependencies=[Depends(isInstallModeEnabled)],
|
dependencies=[Depends(isInstallModeEnabled)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
25
apps/api/src/routers/ai/ai.py
Normal file
25
apps/api/src/routers/ai/ai.py
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from sqlmodel import Session
|
||||||
|
from src.services.ai.ai import ai_start_activity_chat_session
|
||||||
|
from src.services.ai.schemas.ai import StartActivityAIChatSession
|
||||||
|
from src.core.events.database import get_db_session
|
||||||
|
from src.db.users import PublicUser
|
||||||
|
from src.security.auth import get_current_user
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start/activity_chat_session")
|
||||||
|
async def api_ai_start_activity_chat_session(
|
||||||
|
request: Request,
|
||||||
|
chat_session_object: StartActivityAIChatSession,
|
||||||
|
current_user: PublicUser = Depends(get_current_user),
|
||||||
|
db_session: Session = Depends(get_db_session),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Start a new AI Chat session with a Course Activity
|
||||||
|
"""
|
||||||
|
return ai_start_activity_chat_session(
|
||||||
|
request, chat_session_object, current_user, db_session
|
||||||
|
)
|
||||||
65
apps/api/src/services/ai/ai.py
Normal file
65
apps/api/src/services/ai/ai.py
Normal file
|
|
@ -0,0 +1,65 @@
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.db.courses import Course, CourseRead
|
||||||
|
from src.core.events.database import get_db_session
|
||||||
|
from src.db.users import PublicUser
|
||||||
|
from src.db.activities import Activity, ActivityRead
|
||||||
|
from src.security.auth import get_current_user
|
||||||
|
from src.services.ai.base import ask_ai
|
||||||
|
|
||||||
|
from src.services.ai.schemas.ai import StartActivityAIChatSession
|
||||||
|
from src.services.courses.activities.utils import (
|
||||||
|
serialize_activity_text_to_ai_comprehensible_text,
|
||||||
|
structure_activity_content_by_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ai_start_activity_chat_session(
|
||||||
|
request: Request,
|
||||||
|
chat_session_object: StartActivityAIChatSession,
|
||||||
|
current_user: PublicUser = Depends(get_current_user),
|
||||||
|
db_session: Session = Depends(get_db_session),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Start a new AI Chat session with a Course Activity
|
||||||
|
"""
|
||||||
|
# Get the Activity
|
||||||
|
statement = select(Activity).where(
|
||||||
|
Activity.activity_uuid == chat_session_object.activity_uuid
|
||||||
|
)
|
||||||
|
activity = db_session.exec(statement).first()
|
||||||
|
|
||||||
|
activity = ActivityRead.from_orm(activity)
|
||||||
|
|
||||||
|
# Get the Course
|
||||||
|
statement = select(Course).join(Activity).where(
|
||||||
|
Activity.activity_uuid == chat_session_object.activity_uuid
|
||||||
|
)
|
||||||
|
course = db_session.exec(statement).first()
|
||||||
|
course = CourseRead.from_orm(course)
|
||||||
|
|
||||||
|
|
||||||
|
if not activity:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Activity not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get Activity Content Blocks
|
||||||
|
content = activity.content
|
||||||
|
|
||||||
|
# Serialize Activity Content Blocks to a text comprehensible by the AI
|
||||||
|
structured = structure_activity_content_by_type(content)
|
||||||
|
ai_friendly_text = serialize_activity_text_to_ai_comprehensible_text(structured,course,activity)
|
||||||
|
|
||||||
|
response = ask_ai(
|
||||||
|
chat_session_object.message,
|
||||||
|
[],
|
||||||
|
ai_friendly_text,
|
||||||
|
"You are a helpful Education Assistant, and you are helping a student with the associated Course. "
|
||||||
|
"Use the available tools to get context about this question even if the question is not specific enough."
|
||||||
|
"For context, this is the Course name :" + course.name + " and this is the Lecture name :" + activity.name + "."
|
||||||
|
"Use your knowledge to help the student."
|
||||||
|
)
|
||||||
|
|
||||||
|
return response['output']
|
||||||
82
apps/api/src/services/ai/base.py
Normal file
82
apps/api/src/services/ai/base.py
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
from langchain.agents import AgentExecutor
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
|
from langchain.prompts import MessagesPlaceholder
|
||||||
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
from langchain_core.messages import SystemMessage
|
||||||
|
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
|
||||||
|
AgentTokenBufferMemory,
|
||||||
|
)
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.agents.agent_toolkits import (
|
||||||
|
create_retriever_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
from config.config import get_learnhouse_config
|
||||||
|
|
||||||
|
client = chromadb.Client()
|
||||||
|
|
||||||
|
|
||||||
|
chat_history = []
|
||||||
|
|
||||||
|
|
||||||
|
def ask_ai(
|
||||||
|
question: str,
|
||||||
|
chat_history: list[BaseMessage],
|
||||||
|
text_reference: str,
|
||||||
|
message_for_the_prompt: str,
|
||||||
|
):
|
||||||
|
# Get API Keys
|
||||||
|
LH_CONFIG = get_learnhouse_config()
|
||||||
|
openai_api_key = LH_CONFIG.ai_config.openai_api_key
|
||||||
|
|
||||||
|
# split it into chunks
|
||||||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
documents = text_splitter.create_documents([text_reference])
|
||||||
|
texts = text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
# create the open-source embedding function
|
||||||
|
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
# load it into Chroma and use it as a retriever
|
||||||
|
db = Chroma.from_documents(texts, embedding_function)
|
||||||
|
tool = create_retriever_tool(
|
||||||
|
db.as_retriever(),
|
||||||
|
"find_context_text",
|
||||||
|
"Find associated text to get context about a course or a lecture",
|
||||||
|
)
|
||||||
|
tools = [tool]
|
||||||
|
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
temperature=0, api_key=openai_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_key = "history"
|
||||||
|
|
||||||
|
memory = AgentTokenBufferMemory(memory_key=memory_key, llm=llm)
|
||||||
|
|
||||||
|
|
||||||
|
system_message = SystemMessage(content=(message_for_the_prompt))
|
||||||
|
|
||||||
|
prompt = OpenAIFunctionsAgent.create_prompt(
|
||||||
|
system_message=system_message,
|
||||||
|
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
||||||
|
|
||||||
|
|
||||||
|
agent_executor = AgentExecutor(
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
memory=memory,
|
||||||
|
verbose=True,
|
||||||
|
return_intermediate_steps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return agent_executor({"input": question})
|
||||||
12
apps/api/src/services/ai/schemas/ai.py
Normal file
12
apps/api/src/services/ai/schemas/ai.py
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class StartActivityAIChatSession(BaseModel):
|
||||||
|
activity_uuid: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class SendActivityAIChatMessage(BaseModel):
|
||||||
|
aichat_uuid: str
|
||||||
|
activity_uuid: str
|
||||||
|
message: str
|
||||||
78
apps/api/src/services/courses/activities/utils.py
Normal file
78
apps/api/src/services/courses/activities/utils.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
from src.db.activities import ActivityRead
|
||||||
|
from src.db.courses import CourseRead
|
||||||
|
|
||||||
|
def structure_activity_content_by_type(activity):
|
||||||
|
### Get Headings, Texts, Callouts, Answers and Paragraphs from the activity as a big list of strings (text only) and return it
|
||||||
|
|
||||||
|
# Get Headings
|
||||||
|
headings = []
|
||||||
|
for item in activity["content"]:
|
||||||
|
if item["type"] == "heading":
|
||||||
|
headings.append(item["content"][0]["text"])
|
||||||
|
|
||||||
|
# Get Callouts
|
||||||
|
callouts = []
|
||||||
|
for item in activity["content"]:
|
||||||
|
if item["type"] == "calloutInfo":
|
||||||
|
# Get every type of text in the callout
|
||||||
|
text = ""
|
||||||
|
for text_item in item["content"]:
|
||||||
|
text += text_item["text"]
|
||||||
|
callouts.append(text)
|
||||||
|
|
||||||
|
# Get Paragraphs
|
||||||
|
paragraphs = []
|
||||||
|
for item in activity["content"]:
|
||||||
|
if item["type"] == "paragraph":
|
||||||
|
paragraphs.append(item["content"][0]["text"])
|
||||||
|
|
||||||
|
# TODO: Get Questions and Answers (if any)
|
||||||
|
|
||||||
|
data_array = []
|
||||||
|
|
||||||
|
# Add Headings
|
||||||
|
data_array.append({"Headings": headings})
|
||||||
|
|
||||||
|
# Add Callouts
|
||||||
|
data_array.append({"Callouts": callouts})
|
||||||
|
|
||||||
|
# Add Paragraphs
|
||||||
|
data_array.append({"Paragraphs": paragraphs})
|
||||||
|
|
||||||
|
return data_array
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_activity_text_to_ai_comprehensible_text(data_array, course: CourseRead, activity: ActivityRead):
|
||||||
|
### Serialize the text to a format that is comprehensible by the AI
|
||||||
|
|
||||||
|
# Serialize Headings
|
||||||
|
serialized_headings = ""
|
||||||
|
for heading in data_array[0]["Headings"]:
|
||||||
|
serialized_headings += heading + " "
|
||||||
|
|
||||||
|
# Serialize Callouts
|
||||||
|
serialized_callouts = ""
|
||||||
|
|
||||||
|
for callout in data_array[1]["Callouts"]:
|
||||||
|
serialized_callouts += callout + " "
|
||||||
|
|
||||||
|
# Serialize Paragraphs
|
||||||
|
serialized_paragraphs = ""
|
||||||
|
for paragraph in data_array[2]["Paragraphs"]:
|
||||||
|
serialized_paragraphs += paragraph + " "
|
||||||
|
|
||||||
|
# Get a text that is comprehensible by the AI
|
||||||
|
text = (
|
||||||
|
'Use this as a context ' +
|
||||||
|
'This is a course about "' + course.name + '". '
|
||||||
|
+ 'This is a lecture about "' + activity.name + '". '
|
||||||
|
'These are the headings: "'
|
||||||
|
+ serialized_headings
|
||||||
|
+ '" These are the callouts: "'
|
||||||
|
+ serialized_callouts
|
||||||
|
+ '" These are the paragraphs: "'
|
||||||
|
+ serialized_paragraphs
|
||||||
|
+ '"'
|
||||||
|
)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
@ -7,7 +7,7 @@ from src.db.roles import Role, RoleRead
|
||||||
from src.security.rbac.rbac import (
|
from src.security.rbac.rbac import (
|
||||||
authorization_verify_based_on_roles_and_authorship,
|
authorization_verify_based_on_roles_and_authorship,
|
||||||
authorization_verify_if_user_is_anon,
|
authorization_verify_if_user_is_anon,
|
||||||
)
|
)
|
||||||
from src.db.organizations import Organization, OrganizationRead
|
from src.db.organizations import Organization, OrganizationRead
|
||||||
from src.db.users import (
|
from src.db.users import (
|
||||||
AnonymousUser,
|
AnonymousUser,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue