refactor: optimize functions queries

This commit is contained in:
swve 2025-03-16 11:34:41 +01:00
parent 6ebac01c61
commit 5e7ae54215
3 changed files with 115 additions and 106 deletions

View file

@ -92,24 +92,21 @@ async def get_activity(
current_user: PublicUser, current_user: PublicUser,
db_session: Session, db_session: Session,
): ):
statement = select(Activity).where(Activity.activity_uuid == activity_uuid) # Optimize by joining Activity with Course in a single query
activity = db_session.exec(statement).first() statement = (
select(Activity, Course)
.join(Course)
.where(Activity.activity_uuid == activity_uuid)
)
result = db_session.exec(statement).first()
if not activity: if not result:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Activity not found", detail="Activity not found",
) )
# Get course from that activity activity, course = result
statement = select(Course).where(Course.id == activity.course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(
status_code=404,
detail="Course not found",
)
# RBAC check # RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session) await rbac_check(request, course.course_uuid, current_user, "read", db_session)
@ -124,9 +121,8 @@ async def get_activity(
activity_read = ActivityRead.model_validate(activity) activity_read = ActivityRead.model_validate(activity)
activity_read.content = activity_read.content if has_paid_access else { "paid_access": False } activity_read.content = activity_read.content if has_paid_access else { "paid_access": False }
activity = activity_read
return activity return activity_read
async def get_activityby_id( async def get_activityby_id(
request: Request, request: Request,
@ -134,31 +130,26 @@ async def get_activityby_id(
current_user: PublicUser, current_user: PublicUser,
db_session: Session, db_session: Session,
): ):
statement = select(Activity).where(Activity.id == activity_id) # Optimize by joining Activity with Course in a single query
activity = db_session.exec(statement).first() statement = (
select(Activity, Course)
.join(Course)
.where(Activity.id == activity_id)
)
result = db_session.exec(statement).first()
if not activity: if not result:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail="Activity not found", detail="Activity not found",
) )
# Get course from that activity activity, course = result
statement = select(Course).where(Course.id == activity.course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(
status_code=404,
detail="Course not found",
)
# RBAC check # RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session) await rbac_check(request, course.course_uuid, current_user, "read", db_session)
activity = ActivityRead.model_validate(activity) return ActivityRead.model_validate(activity)
return activity
async def update_activity( async def update_activity(

View file

@ -27,6 +27,7 @@ from src.security.rbac.rbac import (
from src.services.courses.thumbnails import upload_thumbnail from src.services.courses.thumbnails import upload_thumbnail
from fastapi import HTTPException, Request, UploadFile from fastapi import HTTPException, Request, UploadFile
from datetime import datetime from datetime import datetime
import asyncio
async def get_course( async def get_course(
@ -106,6 +107,7 @@ async def get_course_meta(
# Avoid circular import # Avoid circular import
from src.services.courses.chapters import get_course_chapters from src.services.courses.chapters import get_course_chapters
# Get course with a single query
course_statement = select(Course).where(Course.course_uuid == course_uuid) course_statement = select(Course).where(Course.course_uuid == course_uuid)
course = db_session.exec(course_statement).first() course = db_session.exec(course_statement).first()
@ -118,36 +120,51 @@ async def get_course_meta(
# RBAC check # RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session) await rbac_check(request, course.course_uuid, current_user, "read", db_session)
# Get course authors # Start async tasks concurrently
tasks = []
# Task 1: Get course authors
async def get_authors():
authors_statement = ( authors_statement = (
select(User) select(User)
.join(ResourceAuthor) .join(ResourceAuthor)
.where(ResourceAuthor.resource_uuid == course.course_uuid) .where(ResourceAuthor.resource_uuid == course.course_uuid)
) )
authors = db_session.exec(authors_statement).all() return db_session.exec(authors_statement).all()
# convert from User to UserRead # Task 2: Get course chapters
authors = [UserRead.model_validate(author) for author in authors] async def get_chapters():
# Ensure course.id is not None
course = CourseRead(**course.model_dump(), authors=authors) if course.id is None:
return []
# Get course chapters return await get_course_chapters(request, course.id, db_session, current_user)
chapters = await get_course_chapters(request, course.id, db_session, current_user)
# Trail
trail = None
# Task 3: Get user trail (only for authenticated users)
async def get_trail():
if isinstance(current_user, AnonymousUser): if isinstance(current_user, AnonymousUser):
trail = None return None
else: return await get_user_trail_with_orgid(
trail = await get_user_trail_with_orgid(
request, current_user, course.org_id, db_session request, current_user, course.org_id, db_session
) )
# Add tasks to the list
tasks.append(get_authors())
tasks.append(get_chapters())
tasks.append(get_trail())
# Run all tasks concurrently
authors_raw, chapters, trail = await asyncio.gather(*tasks)
# Convert authors from User to UserRead
authors = [UserRead.model_validate(author) for author in authors_raw]
# Create course read model
course_read = CourseRead(**course.model_dump(), authors=authors)
return FullCourseReadWithTrail( return FullCourseReadWithTrail(
**course.model_dump(), **course_read.model_dump(),
chapters=chapters, chapters=chapters,
trail=trail if trail else None, trail=trail,
) )
async def get_courses_orgslug( async def get_courses_orgslug(
@ -197,18 +214,33 @@ async def get_courses_orgslug(
courses = db_session.exec(query).all() courses = db_session.exec(query).all()
# Fetch authors for each course if not courses:
return []
# Get all course UUIDs
course_uuids = [course.course_uuid for course in courses]
# Fetch all authors for all courses in a single query
authors_query = (
select(ResourceAuthor, User)
.join(User, ResourceAuthor.user_id == User.id) # type: ignore
.where(ResourceAuthor.resource_uuid.in_(course_uuids)) # type: ignore
)
author_results = db_session.exec(authors_query).all()
# Create a dictionary mapping course_uuid to list of authors
course_authors = {}
for resource_author, user in author_results:
if resource_author.resource_uuid not in course_authors:
course_authors[resource_author.resource_uuid] = []
course_authors[resource_author.resource_uuid].append(UserRead.model_validate(user))
# Create CourseRead objects with authors
course_reads = [] course_reads = []
for course in courses: for course in courses:
authors_query = (
select(User)
.join(ResourceAuthor, ResourceAuthor.user_id == User.id) # type: ignore
.where(ResourceAuthor.resource_uuid == course.course_uuid)
)
authors = db_session.exec(authors_query).all()
course_read = CourseRead.model_validate(course) course_read = CourseRead.model_validate(course)
course_read.authors = [UserRead.model_validate(author) for author in authors] course_read.authors = course_authors.get(course.course_uuid, [])
course_reads.append(course_read) course_reads.append(course_read)
return course_reads return course_reads

View file

@ -529,39 +529,31 @@ async def get_orgs_by_user_admin(
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
) -> list[OrganizationRead]: ) -> list[OrganizationRead]:
# Join Organization, UserOrganization and OrganizationConfig in a single query
statement = ( statement = (
select(Organization) select(Organization, OrganizationConfig)
.join(UserOrganization) .join(UserOrganization)
.outerjoin(OrganizationConfig)
.where( .where(
UserOrganization.user_id == user_id, UserOrganization.user_id == user_id,
UserOrganization.role_id == 1, # Only where the user is admin UserOrganization.role_id == 1, # Only where the user is admin
UserOrganization.org_id == Organization.id,
OrganizationConfig.org_id == Organization.id
) )
.offset((page - 1) * limit) .offset((page - 1) * limit)
.limit(limit) .limit(limit)
) )
# Get organizations where the user is an admin # Execute single query to get all data
result = db_session.exec(statement) result = db_session.exec(statement)
orgs = result.all() org_data = result.all()
# Process results in memory
orgsWithConfig = [] orgsWithConfig = []
for org, org_config in org_data:
for org in orgs:
# Get org config
statement = select(OrganizationConfig).where(
OrganizationConfig.org_id == org.id
)
result = db_session.exec(statement)
org_config = result.first()
config = OrganizationConfig.model_validate(org_config) if org_config else {} config = OrganizationConfig.model_validate(org_config) if org_config else {}
org_read = OrganizationRead(**org.model_dump(), config=config)
org = OrganizationRead(**org.model_dump(), config=config) orgsWithConfig.append(org_read)
orgsWithConfig.append(org)
return orgsWithConfig return orgsWithConfig
@ -573,36 +565,30 @@ async def get_orgs_by_user(
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
) -> list[OrganizationRead]: ) -> list[OrganizationRead]:
# Join Organization, UserOrganization and OrganizationConfig in a single query
statement = ( statement = (
select(Organization) select(Organization, OrganizationConfig)
.join(UserOrganization) .join(UserOrganization)
.where(UserOrganization.user_id == user_id) .outerjoin(OrganizationConfig)
.where(
UserOrganization.user_id == user_id,
UserOrganization.org_id == Organization.id,
OrganizationConfig.org_id == Organization.id
)
.offset((page - 1) * limit) .offset((page - 1) * limit)
.limit(limit) .limit(limit)
) )
# Get organizations where the user is an admin # Execute single query to get all data
result = db_session.exec(statement) result = db_session.exec(statement)
orgs = result.all() org_data = result.all()
# Process results in memory
orgsWithConfig = [] orgsWithConfig = []
for org, org_config in org_data:
for org in orgs:
# Get org config
statement = select(OrganizationConfig).where(
OrganizationConfig.org_id == org.id
)
result = db_session.exec(statement)
org_config = result.first()
config = OrganizationConfig.model_validate(org_config) if org_config else {} config = OrganizationConfig.model_validate(org_config) if org_config else {}
org_read = OrganizationRead(**org.model_dump(), config=config)
org = OrganizationRead(**org.model_dump(), config=config) orgsWithConfig.append(org_read)
orgsWithConfig.append(org)
return orgsWithConfig return orgsWithConfig