feat: implement authorization with roles

This commit is contained in:
swve 2023-11-28 20:25:14 +01:00
parent 0595bfdb3f
commit 7738316200
19 changed files with 596 additions and 170 deletions

View file

@ -46,7 +46,6 @@ class Activity(ActivityBase, table=True):
class ActivityCreate(ActivityBase): class ActivityCreate(ActivityBase):
order: int
org_id: int = Field(default=None, foreign_key="organization.id") org_id: int = Field(default=None, foreign_key="organization.id")
course_id: int = Field(default=None, foreign_key="course.id") course_id: int = Field(default=None, foreign_key="course.id")
chapter_id: int chapter_id: int

View file

@ -17,7 +17,7 @@ class Collection(CollectionBase, table=True):
class CollectionCreate(CollectionBase): class CollectionCreate(CollectionBase):
courses: list courses: list[int]
org_id: int = Field(default=None, foreign_key="organization.id") org_id: int = Field(default=None, foreign_key="organization.id")
pass pass

View file

@ -1,6 +1,6 @@
from typing import List, Optional from typing import List, Optional
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
from src.db.trails import TrailRead
from src.db.chapters import ChapterRead from src.db.chapters import ChapterRead
@ -39,6 +39,7 @@ class CourseUpdate(CourseBase):
class CourseRead(CourseBase): class CourseRead(CourseBase):
id: int id: int
org_id: int = Field(default=None, foreign_key="organization.id")
course_uuid: str course_uuid: str
creation_date: str creation_date: str
update_date: str update_date: str
@ -53,3 +54,15 @@ class FullCourseRead(CourseBase):
# Chapters, Activities # Chapters, Activities
chapters: List[ChapterRead] chapters: List[ChapterRead]
pass pass
class FullCourseReadWithTrail(CourseBase):
id: int
course_uuid: str
creation_date: str
update_date: str
# Chapters, Activities
chapters: List[ChapterRead]
# Trail
trail: TrailRead
pass

View file

@ -32,12 +32,14 @@ class UserUpdatePassword(SQLModel):
class UserRead(UserBase): class UserRead(UserBase):
id: int id: int
user_uuid: str
class PublicUser(UserRead): class PublicUser(UserRead):
pass pass
class AnonymousUser(SQLModel): class AnonymousUser(SQLModel):
id: str = "anonymous" id: int = 0
user_uuid: str = "user_anonymous"
username: str = "anonymous" username: str = "anonymous"
class User(UserBase, table=True): class User(UserBase, table=True):

View file

@ -4,7 +4,6 @@ from src.core.events.database import get_db_session
from src.db.users import PublicUser from src.db.users import PublicUser
from src.db.courses import CourseCreate, CourseUpdate from src.db.courses import CourseCreate, CourseUpdate
from src.security.auth import get_current_user from src.security.auth import get_current_user
from src.services.courses.courses import ( from src.services.courses.courses import (
create_course, create_course,
get_course, get_course,
@ -46,9 +45,7 @@ async def api_create_course(
learnings=learnings, learnings=learnings,
tags=tags, tags=tags,
) )
return await create_course( return await create_course(request, course, current_user, db_session, thumbnail)
request, course, current_user, db_session, thumbnail
)
@router.put("/thumbnail/{course_id}") @router.put("/thumbnail/{course_id}")
@ -85,7 +82,7 @@ async def api_get_course(
@router.get("/meta/{course_id}") @router.get("/meta/{course_id}")
async def api_get_course_meta( async def api_get_course_meta(
request: Request, request: Request,
course_id: str, course_id: int,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user), current_user: PublicUser = Depends(get_current_user),
): ):
@ -109,7 +106,9 @@ async def api_get_course_by_orgslug(
""" """
Get houses by page and limit Get houses by page and limit
""" """
return await get_courses_orgslug(request, current_user, page, limit, org_slug) return await get_courses_orgslug(
request, current_user, org_slug, db_session, page, limit
)
@router.put("/") @router.put("/")

View file

@ -41,7 +41,7 @@ async def api_get_org(
""" """
Get single Org by ID Get single Org by ID
""" """
return await get_organization(request, org_id, db_session) return await get_organization(request, org_id, db_session, current_user)
@router.get("/slug/{org_slug}") @router.get("/slug/{org_slug}")
@ -54,7 +54,7 @@ async def api_get_org_by_slug(
""" """
Get single Org by Slug Get single Org by Slug
""" """
return await get_organization_by_slug(request, org_slug, db_session) return await get_organization_by_slug(request, org_slug, db_session, current_user)
@router.put("/{org_id}/logo") @router.put("/{org_id}/logo")
@ -109,7 +109,7 @@ async def api_update_org(
@router.delete("/{org_id}") @router.delete("/{org_id}")
async def api_delete_org( async def api_delete_org(
request: Request, request: Request,
org_id: str, org_id: int,
current_user: PublicUser = Depends(get_current_user), current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
): ):

View file

@ -1,9 +1,11 @@
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from sqlmodel import Session from sqlmodel import Session
from src.security.rbac.rbac import authorization_verify_based_on_roles, authorization_verify_if_element_is_public, authorization_verify_if_user_is_author
from src.security.auth import get_current_user from src.security.auth import get_current_user
from src.core.events.database import get_db_session from src.core.events.database import get_db_session
from src.db.users import ( from src.db.users import (
PublicUser,
User, User,
UserCreate, UserCreate,
UserRead, UserRead,
@ -37,13 +39,14 @@ async def api_create_user_with_orgid(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
user_object: UserCreate, user_object: UserCreate,
org_id: int, org_id: int,
) -> UserRead: ) -> UserRead:
""" """
Create User with Org ID Create User with Org ID
""" """
return await create_user(request, db_session, None, user_object, org_id) return await create_user(request, db_session, current_user, user_object, org_id)
@router.post("/", response_model=UserRead, tags=["users"]) @router.post("/", response_model=UserRead, tags=["users"])
@ -51,12 +54,13 @@ async def api_create_user_without_org(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
user_object: UserCreate, user_object: UserCreate,
) -> UserRead: ) -> UserRead:
""" """
Create User Create User
""" """
return await create_user_without_org(request, db_session, None, user_object) return await create_user_without_org(request, db_session, current_user, user_object)
@router.get("/user_id/{user_id}", response_model=UserRead, tags=["users"]) @router.get("/user_id/{user_id}", response_model=UserRead, tags=["users"])
@ -64,12 +68,13 @@ async def api_get_user_by_id(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
user_id: int, user_id: int,
) -> UserRead: ) -> UserRead:
""" """
Get User by ID Get User by ID
""" """
return await read_user_by_id(request, db_session, None, user_id) return await read_user_by_id(request, db_session, current_user, user_id)
@router.get("/user_uuid/{user_uuid}", response_model=UserRead, tags=["users"]) @router.get("/user_uuid/{user_uuid}", response_model=UserRead, tags=["users"])
@ -77,12 +82,13 @@ async def api_get_user_by_uuid(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
user_uuid: str, user_uuid: str,
) -> UserRead: ) -> UserRead:
""" """
Get User by UUID Get User by UUID
""" """
return await read_user_by_uuid(request, db_session, None, user_uuid) return await read_user_by_uuid(request, db_session, current_user, user_uuid)
@router.put("/", response_model=UserRead, tags=["users"]) @router.put("/", response_model=UserRead, tags=["users"])
@ -90,12 +96,13 @@ async def api_update_user(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
user_object: UserUpdate, user_object: UserUpdate,
) -> UserRead: ) -> UserRead:
""" """
Update User Update User
""" """
return await update_user(request, db_session, None, user_object) return await update_user(request, db_session, current_user, user_object)
@router.put("/change_password/", response_model=UserRead, tags=["users"]) @router.put("/change_password/", response_model=UserRead, tags=["users"])
@ -103,12 +110,13 @@ async def api_update_user_password(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
form: UserUpdatePassword, form: UserUpdatePassword,
) -> UserRead: ) -> UserRead:
""" """
Update User Password Update User Password
""" """
return await update_user_password(request, db_session, None, form) return await update_user_password(request, db_session, current_user, form)
@router.delete("/user_id/{user_id}", tags=["users"]) @router.delete("/user_id/{user_id}", tags=["users"])
@ -116,9 +124,10 @@ async def api_delete_user(
*, *,
request: Request, request: Request,
db_session: Session = Depends(get_db_session), db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
user_id: int, user_id: int,
): ):
""" """
Delete User Delete User
""" """
return await delete_user_by_id(request, db_session, None, user_id) return await delete_user_by_id(request, db_session, current_user, user_id)

View file

@ -1,6 +1,6 @@
from sqlmodel import Session from sqlmodel import Session
from src.core.events.database import get_db_session from src.core.events.database import get_db_session
from src.db.users import AnonymousUser, User, UserRead from src.db.users import AnonymousUser, PublicUser, User, UserRead
from src.services.users.users import security_get_user from src.services.users.users import security_get_user
from config.config import get_learnhouse_config from config.config import get_learnhouse_config
from pydantic import BaseModel from pydantic import BaseModel
@ -94,7 +94,7 @@ async def get_current_user(
user = await security_get_user(request, db_session, email=token_data.username) # type: ignore # treated as an email user = await security_get_user(request, db_session, email=token_data.username) # type: ignore # treated as an email
if user is None: if user is None:
raise credentials_exception raise credentials_exception
return UserRead(**user.dict()) return PublicUser(**user.dict())
else: else:
return AnonymousUser() return AnonymousUser()

View file

@ -11,27 +11,21 @@ from src.db.user_organizations import UserOrganization
from src.security.rbac.utils import check_element_type from src.security.rbac.utils import check_element_type
# Tested and working
async def authorization_verify_if_element_is_public( async def authorization_verify_if_element_is_public(
request, request,
element_uuid: str, element_uuid: str,
user_id: str,
action: Literal["read"], action: Literal["read"],
db_session: Session, db_session: Session,
): ):
element_nature = await check_element_type(element_uuid) element_nature = await check_element_type(element_uuid)
# Verifies if the element is public # Verifies if the element is public
if ( if element_nature == ("courses" or "collections") and action == "read":
element_nature == ("courses" or "collections")
and action == "read"
and user_id == "anonymous"
):
if element_nature == "courses": if element_nature == "courses":
statement = select(Course).where( statement = select(Course).where(
Course.public == True, Course.course_uuid == element_uuid Course.public == True, Course.course_uuid == element_uuid
) )
course = db_session.exec(statement).first() course = db_session.exec(statement).first()
if course: if course:
return True return True
else: else:
@ -60,9 +54,10 @@ async def authorization_verify_if_element_is_public(
) )
# Tested and working
async def authorization_verify_if_user_is_author( async def authorization_verify_if_user_is_author(
request, request,
user_id: str, user_id: int,
action: Literal["read", "update", "delete", "create"], action: Literal["read", "update", "delete", "create"],
element_uuid: str, element_uuid: str,
db_session: Session, db_session: Session,
@ -74,26 +69,23 @@ async def authorization_verify_if_user_is_author(
resource_author = db_session.exec(statement).first() resource_author = db_session.exec(statement).first()
if resource_author: if resource_author:
if resource_author.user_id == user_id: if resource_author.user_id == int(user_id):
if (resource_author.authorship == ResourceAuthorshipEnum.CREATOR) or ( if (resource_author.authorship == ResourceAuthorshipEnum.CREATOR) or (
resource_author.authorship == ResourceAuthorshipEnum.MAINTAINER resource_author.authorship == ResourceAuthorshipEnum.MAINTAINER
): ):
return True return True
else:
return False
else: else:
raise HTTPException( return False
status_code=status.HTTP_403_FORBIDDEN,
detail="User rights (authorship) : You don't have the right to perform this action",
)
else: else:
raise HTTPException( return False
status_code=status.HTTP_403_FORBIDDEN,
detail="Wrong action (create)",
)
# Tested and working
async def authorization_verify_based_on_roles( async def authorization_verify_based_on_roles(
request: Request, request: Request,
user_id: str, user_id: int,
action: Literal["read", "update", "delete", "create"], action: Literal["read", "update", "delete", "create"],
element_uuid: str, element_uuid: str,
db_session: Session, db_session: Session,
@ -104,8 +96,8 @@ async def authorization_verify_based_on_roles(
statement = ( statement = (
select(Role) select(Role)
.join(UserOrganization) .join(UserOrganization)
.where((UserOrganization.org_id == Role.org_id) | (Role.org_id == null()))
.where(UserOrganization.user_id == user_id) .where(UserOrganization.user_id == user_id)
.where((UserOrganization.id == Role.org_id) | (UserOrganization.id == null))
) )
user_roles_in_organization_and_standard_roles = db_session.exec(statement).all() user_roles_in_organization_and_standard_roles = db_session.exec(statement).all()
@ -120,15 +112,13 @@ async def authorization_verify_based_on_roles(
else: else:
return False return False
else: else:
raise HTTPException( return False
status_code=status.HTTP_403_FORBIDDEN,
detail="User rights (roles) : You don't have the right to perform this action",
)
# Tested and working
async def authorization_verify_based_on_roles_and_authorship( async def authorization_verify_based_on_roles_and_authorship(
request: Request, request: Request,
user_id: str, user_id: int,
action: Literal["read", "update", "delete", "create"], action: Literal["read", "update", "delete", "create"],
element_uuid: str, element_uuid: str,
db_session: Session, db_session: Session,
@ -150,8 +140,8 @@ async def authorization_verify_based_on_roles_and_authorship(
) )
async def authorization_verify_if_user_is_anon(user_id: str): async def authorization_verify_if_user_is_anon(user_id: int):
if user_id == "anonymous": if user_id == 0:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="You should be logged in to perform this action", detail="You should be logged in to perform this action",

View file

@ -1,8 +1,13 @@
from typing import Literal
from sqlmodel import Session, select from sqlmodel import Session, select
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.organizations import Organization from src.db.organizations import Organization
from src.db.activities import ActivityCreate, Activity, ActivityRead, ActivityUpdate from src.db.activities import ActivityCreate, Activity, ActivityRead, ActivityUpdate
from src.db.chapter_activities import ChapterActivity from src.db.chapter_activities import ChapterActivity
from src.db.users import PublicUser from src.db.users import AnonymousUser, PublicUser
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from uuid import uuid4 from uuid import uuid4
from datetime import datetime from datetime import datetime
@ -16,7 +21,7 @@ from datetime import datetime
async def create_activity( async def create_activity(
request: Request, request: Request,
activity_object: ActivityCreate, activity_object: ActivityCreate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
activity = Activity.from_orm(activity_object) activity = Activity.from_orm(activity_object)
@ -31,6 +36,9 @@ async def create_activity(
detail="Organization not found", detail="Organization not found",
) )
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
activity.activity_uuid = str(f"activity_{uuid4()}") activity.activity_uuid = str(f"activity_{uuid4()}")
activity.creation_date = str(datetime.now()) activity.creation_date = str(datetime.now())
activity.update_date = str(datetime.now()) activity.update_date = str(datetime.now())
@ -85,13 +93,16 @@ async def get_activity(
detail="Activity not found", detail="Activity not found",
) )
# RBAC check
await rbac_check(request, activity.activity_uuid, current_user, "read", db_session)
return activity return activity
async def update_activity( async def update_activity(
request: Request, request: Request,
activity_object: ActivityUpdate, activity_object: ActivityUpdate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Activity).where(Activity.id == activity_object.activity_id) statement = select(Activity).where(Activity.id == activity_object.activity_id)
@ -103,6 +114,11 @@ async def update_activity(
detail="Activity not found", detail="Activity not found",
) )
# RBAC check
await rbac_check(
request, activity.activity_uuid, current_user, "update", db_session
)
del activity_object.activity_id del activity_object.activity_id
# Update only the fields that were passed in # Update only the fields that were passed in
@ -120,7 +136,7 @@ async def update_activity(
async def delete_activity( async def delete_activity(
request: Request, request: Request,
activity_id: str, activity_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Activity).where(Activity.id == activity_id) statement = select(Activity).where(Activity.id == activity_id)
@ -132,6 +148,11 @@ async def delete_activity(
detail="Activity not found", detail="Activity not found",
) )
# RBAC check
await rbac_check(
request, activity.activity_uuid, current_user, "delete", db_session
)
# Delete activity from chapter # Delete activity from chapter
statement = select(ChapterActivity).where( statement = select(ChapterActivity).where(
ChapterActivity.activity_id == activity_id ChapterActivity.activity_id == activity_id
@ -159,7 +180,7 @@ async def delete_activity(
async def get_activities( async def get_activities(
request: Request, request: Request,
coursechapter_id: str, coursechapter_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(ChapterActivity).where( statement = select(ChapterActivity).where(
@ -173,4 +194,31 @@ async def get_activities(
detail="No activities found", detail="No activities found",
) )
# RBAC check
await rbac_check(request, "activity_x", current_user, "read", db_session)
return activities return activities
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
course_id,
db_session,
)
## 🔒 RBAC Utils ##

View file

@ -1,4 +1,9 @@
from typing import Literal
from sqlmodel import Session, select from sqlmodel import Session, select
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.chapters import Chapter from src.db.chapters import Chapter
from src.db.activities import ( from src.db.activities import (
Activity, Activity,
@ -8,7 +13,7 @@ from src.db.activities import (
) )
from src.db.chapter_activities import ChapterActivity from src.db.chapter_activities import ChapterActivity
from src.db.course_chapters import CourseChapter from src.db.course_chapters import CourseChapter
from src.db.users import PublicUser from src.db.users import AnonymousUser, PublicUser
from src.services.courses.activities.uploads.pdfs import upload_pdf from src.services.courses.activities.uploads.pdfs import upload_pdf
from fastapi import HTTPException, status, UploadFile, Request from fastapi import HTTPException, status, UploadFile, Request
from uuid import uuid4 from uuid import uuid4
@ -19,10 +24,13 @@ async def create_documentpdf_activity(
request: Request, request: Request,
name: str, name: str,
chapter_id: str, chapter_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
pdf_file: UploadFile | None = None, pdf_file: UploadFile | None = None,
): ):
# RBAC check
await rbac_check(request, "activity_x", current_user, "create", db_session)
# get chapter_id # get chapter_id
statement = select(Chapter).where(Chapter.id == chapter_id) statement = select(Chapter).where(Chapter.id == chapter_id)
chapter = db_session.exec(statement).first() chapter = db_session.exec(statement).first()
@ -94,7 +102,7 @@ async def create_documentpdf_activity(
# Add activity to chapter # Add activity to chapter
activity_chapter = ChapterActivity( activity_chapter = ChapterActivity(
chapter_id=(int(chapter_id)), chapter_id=(int(chapter_id)),
activity_id=activity.id is not None, activity_id=activity.id, # type: ignore
course_id=coursechapter.course_id, course_id=coursechapter.course_id,
org_id=coursechapter.org_id, org_id=coursechapter.org_id,
creation_date=str(datetime.now()), creation_date=str(datetime.now()),
@ -113,3 +121,27 @@ async def create_documentpdf_activity(
db_session.refresh(activity_chapter) db_session.refresh(activity_chapter)
return ActivityRead.from_orm(activity) return ActivityRead.from_orm(activity)
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
course_id,
db_session,
)
## 🔒 RBAC Utils ##

View file

@ -2,11 +2,20 @@ from typing import Literal
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import Session, select from sqlmodel import Session, select
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.chapters import Chapter from src.db.chapters import Chapter
from src.db.activities import Activity, ActivityRead, ActivitySubTypeEnum, ActivityTypeEnum from src.db.activities import (
Activity,
ActivityRead,
ActivitySubTypeEnum,
ActivityTypeEnum,
)
from src.db.chapter_activities import ChapterActivity from src.db.chapter_activities import ChapterActivity
from src.db.course_chapters import CourseChapter from src.db.course_chapters import CourseChapter
from src.db.users import PublicUser from src.db.users import AnonymousUser, PublicUser
from src.services.courses.activities.uploads.videos import upload_video from src.services.courses.activities.uploads.videos import upload_video
from fastapi import HTTPException, status, UploadFile, Request from fastapi import HTTPException, status, UploadFile, Request
from uuid import uuid4 from uuid import uuid4
@ -21,6 +30,9 @@ async def create_video_activity(
db_session: Session, db_session: Session,
video_file: UploadFile | None = None, video_file: UploadFile | None = None,
): ):
# RBAC check
await rbac_check(request, "activity_x", current_user, "create", db_session)
# get chapter_id # get chapter_id
statement = select(Chapter).where(Chapter.id == chapter_id) statement = select(Chapter).where(Chapter.id == chapter_id)
chapter = db_session.exec(statement).first() chapter = db_session.exec(statement).first()
@ -95,8 +107,8 @@ async def create_video_activity(
# update chapter # update chapter
chapter_activity_object = ChapterActivity( chapter_activity_object = ChapterActivity(
chapter_id=coursechapter.id is not None, chapter_id=chapter.id, # type: ignore
activity_id=activity.id is not None, activity_id=activity.id, # type: ignore
course_id=coursechapter.course_id, course_id=coursechapter.course_id,
org_id=coursechapter.org_id, org_id=coursechapter.org_id,
creation_date=str(datetime.now()), creation_date=str(datetime.now()),
@ -111,6 +123,7 @@ async def create_video_activity(
return ActivityRead.from_orm(activity) return ActivityRead.from_orm(activity)
class ExternalVideo(BaseModel): class ExternalVideo(BaseModel):
name: str name: str
uri: str uri: str
@ -124,10 +137,13 @@ class ExternalVideoInDB(BaseModel):
async def create_external_video_activity( async def create_external_video_activity(
request: Request, request: Request,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
data: ExternalVideo, data: ExternalVideo,
db_session: Session, db_session: Session,
): ):
# RBAC check
await rbac_check(request, "activity_x", current_user, "create", db_session)
# get chapter_id # get chapter_id
statement = select(Chapter).where(Chapter.id == data.chapter_id) statement = select(Chapter).where(Chapter.id == data.chapter_id)
chapter = db_session.exec(statement).first() chapter = db_session.exec(statement).first()
@ -174,8 +190,8 @@ async def create_external_video_activity(
# update chapter # update chapter
chapter_activity_object = ChapterActivity( chapter_activity_object = ChapterActivity(
chapter_id=coursechapter.id is not None, chapter_id=coursechapter.id, # type: ignore
activity_id=activity.id is not None, activity_id=activity.id, # type: ignore
creation_date=str(datetime.now()), creation_date=str(datetime.now()),
update_date=str(datetime.now()), update_date=str(datetime.now()),
order=1, order=1,
@ -186,3 +202,24 @@ async def create_external_video_activity(
db_session.commit() db_session.commit()
return ActivityRead.from_orm(activity) return ActivityRead.from_orm(activity)
async def rbac_check(
request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
course_id,
db_session,
)
## 🔒 RBAC Utils ##

View file

@ -1,7 +1,12 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List, Literal
from uuid import uuid4 from uuid import uuid4
from sqlmodel import Session, select from sqlmodel import Session, select
from src.db.users import AnonymousUser
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.course_chapters import CourseChapter from src.db.course_chapters import CourseChapter
from src.db.activities import Activity, ActivityRead from src.db.activities import Activity, ActivityRead
from src.db.chapter_activities import ChapterActivity from src.db.chapter_activities import ChapterActivity
@ -26,11 +31,14 @@ from fastapi import HTTPException, status, Request
async def create_chapter( async def create_chapter(
request: Request, request: Request,
chapter_object: ChapterCreate, chapter_object: ChapterCreate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
) -> ChapterRead: ) -> ChapterRead:
chapter = Chapter.from_orm(chapter_object) chapter = Chapter.from_orm(chapter_object)
# RBAC check
await rbac_check(request, "chapter_x", current_user, "create", db_session)
# complete chapter object # complete chapter object
chapter.course_id = chapter_object.course_id chapter.course_id = chapter_object.course_id
chapter.chapter_uuid = f"chapter_{uuid4()}" chapter.chapter_uuid = f"chapter_{uuid4()}"
@ -87,7 +95,7 @@ async def create_chapter(
async def get_chapter( async def get_chapter(
request: Request, request: Request,
chapter_id: int, chapter_id: int,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
) -> ChapterRead: ) -> ChapterRead:
statement = select(Chapter).where(Chapter.id == chapter_id) statement = select(Chapter).where(Chapter.id == chapter_id)
@ -98,6 +106,9 @@ async def get_chapter(
status_code=status.HTTP_409_CONFLICT, detail="Chapter does not exist" status_code=status.HTTP_409_CONFLICT, detail="Chapter does not exist"
) )
# RBAC check
await rbac_check(request, chapter.chapter_uuid, current_user, "read", db_session)
# Get activities for this chapter # Get activities for this chapter
statement = ( statement = (
select(Activity) select(Activity)
@ -119,7 +130,7 @@ async def get_chapter(
async def update_chapter( async def update_chapter(
request: Request, request: Request,
chapter_object: ChapterUpdate, chapter_object: ChapterUpdate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
) -> ChapterRead: ) -> ChapterRead:
statement = select(Chapter).where(Chapter.id == chapter_object.chapter_id) statement = select(Chapter).where(Chapter.id == chapter_object.chapter_id)
@ -130,6 +141,9 @@ async def update_chapter(
status_code=status.HTTP_409_CONFLICT, detail="Chapter does not exist" status_code=status.HTTP_409_CONFLICT, detail="Chapter does not exist"
) )
# RBAC check
await rbac_check(request, chapter.chapter_uuid, current_user, "update", db_session)
# Update only the fields that were passed in # Update only the fields that were passed in
for var, value in vars(chapter_object).items(): for var, value in vars(chapter_object).items():
if value is not None: if value is not None:
@ -148,7 +162,7 @@ async def update_chapter(
async def delete_chapter( async def delete_chapter(
request: Request, request: Request,
chapter_id: str, chapter_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Chapter).where(Chapter.id == chapter_id) statement = select(Chapter).where(Chapter.id == chapter_id)
@ -159,6 +173,9 @@ async def delete_chapter(
status_code=status.HTTP_409_CONFLICT, detail="Chapter does not exist" status_code=status.HTTP_409_CONFLICT, detail="Chapter does not exist"
) )
# RBAC check
await rbac_check(request, chapter.chapter_uuid, current_user, "delete", db_session)
db_session.delete(chapter) db_session.delete(chapter)
db_session.commit() db_session.commit()
@ -173,15 +190,12 @@ async def delete_chapter(
return {"detail": "chapter deleted"} return {"detail": "chapter deleted"}
####################################################
# Misc
####################################################
async def get_course_chapters( async def get_course_chapters(
request: Request, request: Request,
course_id: int, course_id: int,
db_session: Session, db_session: Session,
current_user: PublicUser | AnonymousUser,
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
) -> List[ChapterRead]: ) -> List[ChapterRead]:
@ -195,6 +209,9 @@ async def get_course_chapters(
chapters = [ChapterRead(**chapter.dict(), activities=[]) for chapter in chapters] chapters = [ChapterRead(**chapter.dict(), activities=[]) for chapter in chapters]
# RBAC check
await rbac_check(request, "chapter_x", current_user, "read", db_session)
# Get activities for each chapter # Get activities for each chapter
for chapter in chapters: for chapter in chapters:
statement = ( statement = (
@ -233,6 +250,9 @@ async def get_depreceated_course_chapters(
status_code=status.HTTP_409_CONFLICT, detail="Course does not exist" status_code=status.HTTP_409_CONFLICT, detail="Course does not exist"
) )
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
# Get chapters that are linked to his course and order them by order, using the order field in the CourseChapter table # Get chapters that are linked to his course and order them by order, using the order field in the CourseChapter table
statement = ( statement = (
select(Chapter) select(Chapter)
@ -310,6 +330,9 @@ async def reorder_chapters_and_activities(
status_code=status.HTTP_409_CONFLICT, detail="Course does not exist" status_code=status.HTTP_409_CONFLICT, detail="Course does not exist"
) )
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
########### ###########
# Chapters # Chapters
########### ###########
@ -469,3 +492,27 @@ async def reorder_chapters_and_activities(
db_session.commit() db_session.commit()
return {"detail": "Chapters reordered"} return {"detail": "Chapters reordered"}
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
course_id,
db_session,
)
## 🔒 RBAC Utils ##

View file

@ -1,7 +1,12 @@
from datetime import datetime from datetime import datetime
from typing import List from typing import List, Literal
from uuid import uuid4 from uuid import uuid4
from sqlmodel import Session, select from sqlmodel import Session, select
from src.db.users import AnonymousUser
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.collections import ( from src.db.collections import (
Collection, Collection,
CollectionCreate, CollectionCreate,
@ -37,6 +42,11 @@ async def get_collection(
status_code=status.HTTP_409_CONFLICT, detail="Collection does not exist" status_code=status.HTTP_409_CONFLICT, detail="Collection does not exist"
) )
# RBAC check
await rbac_check(
request, collection.collection_uuid, current_user, "read", db_session
)
# get courses in collection # get courses in collection
statement = ( statement = (
select(Course) select(Course)
@ -58,6 +68,9 @@ async def create_collection(
) -> CollectionRead: ) -> CollectionRead:
collection = Collection.from_orm(collection_object) collection = Collection.from_orm(collection_object)
# RBAC check
await rbac_check(request, "collection_x", current_user, "create", db_session)
# Complete the collection object # Complete the collection object
collection.collection_uuid = f"collection_{uuid4()}" collection.collection_uuid = f"collection_{uuid4()}"
collection.creation_date = str(datetime.now()) collection.creation_date = str(datetime.now())
@ -70,16 +83,17 @@ async def create_collection(
db_session.refresh(collection) db_session.refresh(collection)
# Link courses to collection # Link courses to collection
for course in collection_object.courses: if collection:
collection_course = CollectionCourse( for course_id in collection_object.courses:
collection_id=int(collection.id is not None), collection_course = CollectionCourse(
course_id=int(course), collection_id=int(collection.id), # type: ignore
org_id=int(collection_object.org_id), course_id=course_id,
creation_date=str(datetime.now()), org_id=int(collection_object.org_id),
update_date=str(datetime.now()), creation_date=str(datetime.now()),
) update_date=str(datetime.now()),
# Add collection_course to database )
db_session.add(collection_course) # Add collection_course to database
db_session.add(collection_course)
db_session.commit() db_session.commit()
db_session.refresh(collection) db_session.refresh(collection)
@ -113,6 +127,11 @@ async def update_collection(
status_code=status.HTTP_409_CONFLICT, detail="Collection does not exist" status_code=status.HTTP_409_CONFLICT, detail="Collection does not exist"
) )
# RBAC check
await rbac_check(
request, collection.collection_uuid, current_user, "update", db_session
)
courses = collection_object.courses courses = collection_object.courses
del collection_object.collection_id del collection_object.collection_id
@ -142,7 +161,7 @@ async def update_collection(
# Add new collection_courses # Add new collection_courses
for course in courses or []: for course in courses or []:
collection_course = CollectionCourse( collection_course = CollectionCourse(
collection_id=int(collection.id is not None), collection_id=int(collection.id), # type: ignore
course_id=int(course), course_id=int(course),
org_id=int(collection.org_id), org_id=int(collection.org_id),
creation_date=str(datetime.now()), creation_date=str(datetime.now()),
@ -180,6 +199,11 @@ async def delete_collection(
detail="Collection not found", detail="Collection not found",
) )
# RBAC check
await rbac_check(
request, collection.collection_uuid, current_user, "delete", db_session
)
# delete collection from database # delete collection from database
db_session.delete(collection) db_session.delete(collection)
db_session.commit() db_session.commit()
@ -195,11 +219,14 @@ async def delete_collection(
async def get_collections( async def get_collections(
request: Request, request: Request,
org_id: str, org_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
) -> List[CollectionRead]: ) -> List[CollectionRead]:
# RBAC check
await rbac_check(request, "collection_x", current_user, "read", db_session)
statement = ( statement = (
select(Collection).where(Collection.org_id == org_id).distinct(Collection.id) select(Collection).where(Collection.org_id == org_id).distinct(Collection.id)
) )
@ -223,3 +250,27 @@ async def get_collections(
collections_with_courses.append(collection) collections_with_courses.append(collection)
return collections_with_courses return collections_with_courses
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
course_id,
db_session,
)
## 🔒 RBAC Utils ##

View file

@ -1,12 +1,28 @@
from calendar import c
import json import json
from queue import Full
import resource import resource
from typing import Literal from typing import Literal
from uuid import uuid4 from uuid import uuid4
from sqlmodel import Session, select from sqlmodel import Session, select
from src.db import chapters
from src.db.activities import Activity, ActivityRead
from src.db.chapter_activities import ChapterActivity
from src.db.chapters import Chapter, ChapterRead
from src.db.organizations import Organization
from src.db.trails import TrailRead
from src.services.trail.trail import get_user_trail_with_orgid
from src import db from src import db
from src.db.resource_authors import ResourceAuthor, ResourceAuthorshipEnum from src.db.resource_authors import ResourceAuthor, ResourceAuthorshipEnum
from src.db.users import PublicUser, AnonymousUser from src.db.users import PublicUser, AnonymousUser
from src.db.courses import Course, CourseCreate, CourseRead, CourseUpdate from src.db.courses import (
Course,
CourseCreate,
CourseRead,
CourseUpdate,
FullCourseReadWithTrail,
)
from src.security.rbac.rbac import ( from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship, authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_element_is_public, authorization_verify_if_element_is_public,
@ -18,7 +34,10 @@ from datetime import datetime
async def get_course( async def get_course(
request: Request, course_id: str, current_user: PublicUser, db_session: Session request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
db_session: Session,
): ):
statement = select(Course).where(Course.id == course_id) statement = select(Course).where(Course.id == course_id)
course = db_session.exec(statement).first() course = db_session.exec(statement).first()
@ -29,12 +48,21 @@ async def get_course(
detail="Course not found", detail="Course not found",
) )
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
return course return course
async def get_course_meta( async def get_course_meta(
request: Request, course_id: str, current_user: PublicUser, db_session: Session request: Request,
): course_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
) -> FullCourseReadWithTrail:
# Avoid circular import
from src.services.courses.chapters import get_course_chapters
course_statement = select(Course).where(Course.id == course_id) course_statement = select(Course).where(Course.id == course_id)
course = db_session.exec(course_statement).first() course = db_session.exec(course_statement).first()
@ -44,22 +72,40 @@ async def get_course_meta(
detail="Course not found", detail="Course not found",
) )
# todo : get course chapters # RBAC check
# todo : get course activities await rbac_check(request, course.course_uuid, current_user, "read", db_session)
# todo : get trail
return course course = CourseRead.from_orm(course)
# Get course chapters
chapters = await get_course_chapters(request, course.id, db_session, current_user)
# Trail
trail = await get_user_trail_with_orgid(
request, current_user, course.org_id, db_session
)
trail = TrailRead.from_orm(trail)
return FullCourseReadWithTrail(
**course.dict(),
chapters=chapters,
trail=trail,
)
async def create_course( async def create_course(
request: Request, request: Request,
course_object: CourseCreate, course_object: CourseCreate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
thumbnail_file: UploadFile | None = None, thumbnail_file: UploadFile | None = None,
): ):
course = Course.from_orm(course_object) course = Course.from_orm(course_object)
# RBAC check
await rbac_check(request, "course_x", current_user, "create", db_session)
# Complete course object # Complete course object
course.org_id = course.org_id course.org_id = course.org_id
course.course_uuid = str(f"course_{uuid4()}") course.course_uuid = str(f"course_{uuid4()}")
@ -69,7 +115,9 @@ async def create_course(
# Upload thumbnail # Upload thumbnail
if thumbnail_file and thumbnail_file.filename: if thumbnail_file and thumbnail_file.filename:
name_in_disk = f"{course.course_uuid}_thumbnail_{uuid4()}.{thumbnail_file.filename.split('.')[-1]}" name_in_disk = f"{course.course_uuid}_thumbnail_{uuid4()}.{thumbnail_file.filename.split('.')[-1]}"
await upload_thumbnail(thumbnail_file, name_in_disk, course_object.org_id, course.course_uuid) await upload_thumbnail(
thumbnail_file, name_in_disk, course_object.org_id, course.course_uuid
)
course_object.thumbnail = name_in_disk course_object.thumbnail = name_in_disk
# Insert course # Insert course
@ -97,7 +145,7 @@ async def create_course(
async def update_course_thumbnail( async def update_course_thumbnail(
request: Request, request: Request,
course_id: str, course_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
thumbnail_file: UploadFile | None = None, thumbnail_file: UploadFile | None = None,
): ):
@ -112,6 +160,9 @@ async def update_course_thumbnail(
detail="Course not found", detail="Course not found",
) )
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
# Upload thumbnail # Upload thumbnail
if thumbnail_file and thumbnail_file.filename: if thumbnail_file and thumbnail_file.filename:
name_in_disk = ( name_in_disk = (
@ -143,7 +194,7 @@ async def update_course_thumbnail(
async def update_course( async def update_course(
request: Request, request: Request,
course_object: CourseUpdate, course_object: CourseUpdate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Course).where(Course.id == course_object.course_id) statement = select(Course).where(Course.id == course_object.course_id)
@ -155,6 +206,9 @@ async def update_course(
detail="Course not found", detail="Course not found",
) )
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
del course_object.course_id del course_object.course_id
# Update only the fields that were passed in # Update only the fields that were passed in
@ -173,7 +227,10 @@ async def update_course(
async def delete_course( async def delete_course(
request: Request, course_id: str, current_user: PublicUser, db_session: Session request: Request,
course_id: str,
current_user: PublicUser | AnonymousUser,
db_session: Session,
): ):
statement = select(Course).where(Course.id == course_id) statement = select(Course).where(Course.id == course_id)
course = db_session.exec(statement).first() course = db_session.exec(statement).first()
@ -184,92 +241,74 @@ async def delete_course(
detail="Course not found", detail="Course not found",
) )
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "delete", db_session)
db_session.delete(course) db_session.delete(course)
db_session.commit() db_session.commit()
return {"detail": "Course deleted"} return {"detail": "Course deleted"}
####################################################
# Misc
####################################################
async def get_courses_orgslug( async def get_courses_orgslug(
request: Request, request: Request,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
org_slug: str,
db_session: Session,
page: int = 1, page: int = 1,
limit: int = 10, limit: int = 10,
org_slug: str | None = None,
): ):
courses = request.app.db["courses"] statement_public = (
orgs = request.app.db["organizations"] select(Course)
.join(Organization)
.where(Organization.slug == org_slug, Course.public == True)
)
statement_all = (
select(Course).join(Organization).where(Organization.slug == org_slug)
)
# get org_id from slug if current_user.id == 0:
org = await orgs.find_one({"slug": org_slug}) statement = statement_public
if not org:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Organization does not exist"
)
# show only public courses if user is not logged in
if current_user.id == "anonymous":
all_courses = (
courses.find({"org_id": org["org_id"], "public": True})
.sort("name", 1)
.skip(10 * (page - 1))
.limit(limit)
)
else: else:
all_courses = ( # RBAC check
courses.find({"org_id": org["org_id"]}) await authorization_verify_if_user_is_anon(current_user.id)
.sort("name", 1)
.skip(10 * (page - 1))
.limit(limit)
)
return [ statement = statement_all
json.loads(json.dumps(course, default=str))
for course in await all_courses.to_list(length=100) courses = db_session.exec(statement)
]
return courses
#### Security #################################################### ## 🔒 RBAC Utils ##
async def verify_rights( async def rbac_check(
request: Request, request: Request,
course_id: str, course_uuid: str,
current_user: PublicUser | AnonymousUser, current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"], action: Literal["create", "read", "update", "delete"],
db_session: Session, db_session: Session,
): ):
if action == "read": if action == "read":
if current_user.id == "anonymous": if current_user.id == 0: # Anonymous user
await authorization_verify_if_element_is_public( await authorization_verify_if_element_is_public(
request, course_id, str(current_user.id), action, db_session request, course_uuid, action, db_session
) )
else: else:
await authorization_verify_based_on_roles_and_authorship( await authorization_verify_based_on_roles_and_authorship(
request, request, current_user.id, action, course_uuid, db_session
str(current_user.id),
action,
course_id,
db_session,
) )
else: else:
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_if_user_is_anon(str(current_user.id))
await authorization_verify_based_on_roles_and_authorship( await authorization_verify_based_on_roles_and_authorship(
request, request,
str(current_user.id), current_user.id,
action, action,
course_id, course_uuid,
db_session, db_session,
) )
#### Security #################################################### ## 🔒 RBAC Utils ##

View file

@ -1,7 +1,12 @@
from datetime import datetime from datetime import datetime
from typing import Literal
from uuid import uuid4 from uuid import uuid4
from sqlmodel import Session, select from sqlmodel import Session, select
from src.db.users import PublicUser from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.users import AnonymousUser, PublicUser
from src.db.user_organizations import UserOrganization from src.db.user_organizations import UserOrganization
from src.db.organizations import ( from src.db.organizations import (
Organization, Organization,
@ -13,7 +18,12 @@ from src.services.orgs.logos import upload_org_logo
from fastapi import HTTPException, UploadFile, status, Request from fastapi import HTTPException, UploadFile, status, Request
async def get_organization(request: Request, org_id: str, db_session: Session): async def get_organization(
request: Request,
org_id: str,
db_session: Session,
current_user: PublicUser | AnonymousUser,
):
statement = select(Organization).where(Organization.id == org_id) statement = select(Organization).where(Organization.id == org_id)
result = db_session.exec(statement) result = db_session.exec(statement)
@ -25,11 +35,17 @@ async def get_organization(request: Request, org_id: str, db_session: Session):
detail="Organization not found", detail="Organization not found",
) )
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
return org return org
async def get_organization_by_slug( async def get_organization_by_slug(
request: Request, org_slug: str, db_session: Session request: Request,
org_slug: str,
db_session: Session,
current_user: PublicUser | AnonymousUser,
): ):
statement = select(Organization).where(Organization.slug == org_slug) statement = select(Organization).where(Organization.slug == org_slug)
result = db_session.exec(statement) result = db_session.exec(statement)
@ -42,13 +58,16 @@ async def get_organization_by_slug(
detail="Organization not found", detail="Organization not found",
) )
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
return org return org
async def create_org( async def create_org(
request: Request, request: Request,
org_object: OrganizationCreate, org_object: OrganizationCreate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Organization).where(Organization.slug == org_object.slug) statement = select(Organization).where(Organization.slug == org_object.slug)
@ -64,6 +83,9 @@ async def create_org(
org = Organization.from_orm(org_object) org = Organization.from_orm(org_object)
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
# Complete the org object # Complete the org object
org.org_uuid = f"org_{uuid4()}" org.org_uuid = f"org_{uuid4()}"
org.creation_date = str(datetime.now()) org.creation_date = str(datetime.now())
@ -92,7 +114,7 @@ async def create_org(
async def update_org( async def update_org(
request: Request, request: Request,
org_object: OrganizationUpdate, org_object: OrganizationUpdate,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Organization).where(Organization.id == org_object.org_id) statement = select(Organization).where(Organization.id == org_object.org_id)
@ -106,6 +128,9 @@ async def update_org(
detail="Organization slug not found", detail="Organization slug not found",
) )
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
org = Organization.from_orm(org_object) org = Organization.from_orm(org_object)
# Verify if the new slug is already in use # Verify if the new slug is already in use
@ -142,7 +167,7 @@ async def update_org_logo(
request: Request, request: Request,
logo_file: UploadFile, logo_file: UploadFile,
org_id: str, org_id: str,
current_user: PublicUser, current_user: PublicUser | AnonymousUser,
db_session: Session, db_session: Session,
): ):
statement = select(Organization).where(Organization.id == org_id) statement = select(Organization).where(Organization.id == org_id)
@ -156,6 +181,9 @@ async def update_org_logo(
detail="Organization not found", detail="Organization not found",
) )
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
# Upload logo # Upload logo
name_in_disk = await upload_org_logo(logo_file, org_id) name_in_disk = await upload_org_logo(logo_file, org_id)
@ -173,7 +201,10 @@ async def update_org_logo(
async def delete_org( async def delete_org(
request: Request, org_id: str, current_user: PublicUser, db_session: Session request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
): ):
statement = select(Organization).where(Organization.id == org_id) statement = select(Organization).where(Organization.id == org_id)
result = db_session.exec(statement) result = db_session.exec(statement)
@ -186,6 +217,9 @@ async def delete_org(
detail="Organization not found", detail="Organization not found",
) )
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "delete", db_session)
db_session.delete(org) db_session.delete(org)
db_session.commit() db_session.commit()
@ -224,3 +258,28 @@ async def get_orgs_by_user(
orgs = result.all() orgs = result.all()
return orgs return orgs
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
org_id: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
# Organizations are readable by anyone
if action == "read":
return True
else:
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request, current_user.id, action, org_id, db_session
)
## 🔒 RBAC Utils ##

View file

@ -1,6 +1,12 @@
from typing import Literal
from uuid import uuid4 from uuid import uuid4
from sqlmodel import Session, select from sqlmodel import Session, select
from src.db.users import PublicUser from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
authorization_verify_if_user_is_author,
)
from src.db.users import AnonymousUser, PublicUser
from src.db.roles import Role, RoleCreate, RoleUpdate from src.db.roles import Role, RoleCreate, RoleUpdate
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from datetime import datetime from datetime import datetime
@ -14,6 +20,9 @@ async def create_role(
): ):
role = Role.from_orm(role_object) role = Role.from_orm(role_object)
# RBAC check
await rbac_check(request, current_user, "create", "role_xxx", db_session)
# Complete the role object # Complete the role object
role.role_uuid = f"role_{uuid4()}" role.role_uuid = f"role_{uuid4()}"
role.creation_date = str(datetime.now()) role.creation_date = str(datetime.now())
@ -40,6 +49,9 @@ async def read_role(
detail="Role not found", detail="Role not found",
) )
# RBAC check
await rbac_check(request, current_user, "read", role.role_uuid, db_session)
return role return role
@ -60,6 +72,9 @@ async def update_role(
detail="Role not found", detail="Role not found",
) )
# RBAC check
await rbac_check(request, current_user, "update", role.role_uuid, db_session)
# Complete the role object # Complete the role object
role.update_date = str(datetime.now()) role.update_date = str(datetime.now())
@ -81,6 +96,9 @@ async def update_role(
async def delete_role( async def delete_role(
request: Request, db_session: Session, role_id: str, current_user: PublicUser request: Request, db_session: Session, role_id: str, current_user: PublicUser
): ):
# RBAC check
await rbac_check(request, current_user, "delete", role_id, db_session)
statement = select(Role).where(Role.id == role_id) statement = select(Role).where(Role.id == role_id)
result = db_session.exec(statement) result = db_session.exec(statement)
@ -96,3 +114,23 @@ async def delete_role(
db_session.commit() db_session.commit()
return "Role deleted" return "Role deleted"
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
role_uuid: str,
db_session: Session,
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request, current_user.id, action, role_uuid, db_session
)
## 🔒 RBAC Utils ##

View file

@ -6,7 +6,7 @@ from src.db.courses import Course
from src.db.trail_runs import TrailRun, TrailRunRead from src.db.trail_runs import TrailRun, TrailRunRead
from src.db.trail_steps import TrailStep from src.db.trail_steps import TrailStep
from src.db.trails import Trail, TrailCreate, TrailRead from src.db.trails import Trail, TrailCreate, TrailRead
from src.db.users import PublicUser from src.db.users import AnonymousUser, PublicUser
async def create_user_trail( async def create_user_trail(
@ -80,7 +80,7 @@ async def get_user_trails(
async def get_user_trail_with_orgid( async def get_user_trail_with_orgid(
request: Request, user: PublicUser, org_id: int, db_session: Session request: Request, user: PublicUser | AnonymousUser, org_id: int, db_session: Session
) -> TrailRead: ) -> TrailRead:
statement = select(Trail).where(Trail.org_id == org_id, Trail.user_id == user.id) statement = select(Trail).where(Trail.org_id == org_id, Trail.user_id == user.id)
trail = db_session.exec(statement).first() trail = db_session.exec(statement).first()

View file

@ -1,9 +1,17 @@
from datetime import datetime from datetime import datetime
from typing import Literal
from uuid import uuid4 from uuid import uuid4
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from sqlmodel import Session, select from sqlmodel import Session, select
from src import db
from src.security.rbac.rbac import (
authorization_verify_based_on_roles,
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.organizations import Organization from src.db.organizations import Organization
from src.db.users import ( from src.db.users import (
AnonymousUser,
PublicUser, PublicUser,
User, User,
UserCreate, UserCreate,
@ -18,12 +26,15 @@ from src.security.security import security_hash_password, security_verify_passwo
async def create_user( async def create_user(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
user_object: UserCreate, user_object: UserCreate,
org_id: int, org_id: int,
): ):
user = User.from_orm(user_object) user = User.from_orm(user_object)
# RBAC check
await rbac_check(request, current_user, "create", "user_x", db_session)
# Complete the user object # Complete the user object
user.user_uuid = f"user_{uuid4()}" user.user_uuid = f"user_{uuid4()}"
user.password = await security_hash_password(user_object.password) user.password = await security_hash_password(user_object.password)
@ -94,11 +105,14 @@ async def create_user(
async def create_user_without_org( async def create_user_without_org(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
user_object: UserCreate, user_object: UserCreate,
): ):
user = User.from_orm(user_object) user = User.from_orm(user_object)
# RBAC check
await rbac_check(request, current_user, "create", "user_x", db_session)
# Complete the user object # Complete the user object
user.user_uuid = f"user_{uuid4()}" user.user_uuid = f"user_{uuid4()}"
user.password = await security_hash_password(user_object.password) user.password = await security_hash_password(user_object.password)
@ -146,7 +160,7 @@ async def create_user_without_org(
async def update_user( async def update_user(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
user_object: UserUpdate, user_object: UserUpdate,
): ):
# Get user # Get user
@ -159,6 +173,9 @@ async def update_user(
detail="User does not exist", detail="User does not exist",
) )
# RBAC check
await rbac_check(request, current_user, "update", user.user_uuid, db_session)
# Update user # Update user
user_data = user_object.dict(exclude_unset=True) user_data = user_object.dict(exclude_unset=True)
for key, value in user_data.items(): for key, value in user_data.items():
@ -179,7 +196,7 @@ async def update_user(
async def update_user_password( async def update_user_password(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
form: UserUpdatePassword, form: UserUpdatePassword,
): ):
# Get user # Get user
@ -192,6 +209,9 @@ async def update_user_password(
detail="User does not exist", detail="User does not exist",
) )
# RBAC check
await rbac_check(request, current_user, "update", user.user_uuid, db_session)
if not await security_verify_password(form.old_password, user.password): if not await security_verify_password(form.old_password, user.password):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password" status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password"
@ -214,7 +234,7 @@ async def update_user_password(
async def read_user_by_id( async def read_user_by_id(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
user_id: int, user_id: int,
): ):
# Get user # Get user
@ -227,6 +247,9 @@ async def read_user_by_id(
detail="User does not exist", detail="User does not exist",
) )
# RBAC check
await rbac_check(request, current_user, "read", user.user_uuid, db_session)
user = UserRead.from_orm(user) user = UserRead.from_orm(user)
return user return user
@ -235,11 +258,11 @@ async def read_user_by_id(
async def read_user_by_uuid( async def read_user_by_uuid(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
uuid: str, user_uuid: str,
): ):
# Get user # Get user
statement = select(User).where(User.user_uuid == uuid) statement = select(User).where(User.user_uuid == user_uuid)
user = db_session.exec(statement).first() user = db_session.exec(statement).first()
if not user: if not user:
@ -248,6 +271,9 @@ async def read_user_by_uuid(
detail="User does not exist", detail="User does not exist",
) )
# RBAC check
await rbac_check(request, current_user, "read", user.user_uuid, db_session)
user = UserRead.from_orm(user) user = UserRead.from_orm(user)
return user return user
@ -256,7 +282,7 @@ async def read_user_by_uuid(
async def delete_user_by_id( async def delete_user_by_id(
request: Request, request: Request,
db_session: Session, db_session: Session,
current_user: PublicUser | None, current_user: PublicUser | AnonymousUser,
user_id: int, user_id: int,
): ):
# Get user # Get user
@ -269,6 +295,9 @@ async def delete_user_by_id(
detail="User does not exist", detail="User does not exist",
) )
# RBAC check
await rbac_check(request, current_user, "delete", user.user_uuid, db_session)
# Delete user # Delete user
db_session.delete(user) db_session.delete(user)
db_session.commit() db_session.commit()
@ -293,3 +322,37 @@ async def security_get_user(request: Request, db_session: Session, email: str) -
user = User(**user.dict()) user = User(**user.dict())
return user return user
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
user_uuid: str,
db_session: Session,
):
if action == "create":
if current_user.id == 0: # if user is anonymous
return True
else:
res = await authorization_verify_based_on_roles_and_authorship(
request, current_user.id, "create", "user_x", db_session
)
else:
await authorization_verify_if_user_is_anon(current_user.id)
# if user is the same as the one being read
if current_user.user_uuid == user_uuid:
return True
await authorization_verify_based_on_roles_and_authorship(
request, current_user.id, "read", action, db_session
)
## 🔒 RBAC Utils ##