mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
507 lines
14 KiB
Python
507 lines
14 KiB
Python
from datetime import datetime
|
|
import logging
|
|
from typing import Literal
|
|
from uuid import uuid4
|
|
from fastapi import HTTPException, Request
|
|
from sqlmodel import Session, select
|
|
from src.security.features_utils.usage import (
|
|
check_limits_with_usage,
|
|
increase_feature_usage,
|
|
)
|
|
from src.security.rbac.rbac import (
|
|
authorization_verify_based_on_roles_and_authorship,
|
|
authorization_verify_if_user_is_anon,
|
|
)
|
|
from src.db.usergroup_resources import UserGroupResource
|
|
from src.db.usergroup_user import UserGroupUser
|
|
from src.db.organizations import Organization
|
|
from src.db.usergroups import UserGroup, UserGroupCreate, UserGroupRead, UserGroupUpdate
|
|
from src.db.users import AnonymousUser, InternalUser, PublicUser, User, UserRead
|
|
|
|
|
|
async def create_usergroup(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_create: UserGroupCreate,
|
|
) -> UserGroupRead:
|
|
|
|
usergroup = UserGroup.model_validate(usergroup_create)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid="usergroup_X",
|
|
current_user=current_user,
|
|
action="create",
|
|
db_session=db_session,
|
|
)
|
|
|
|
# Check if Organization exists
|
|
statement = select(Organization).where(Organization.id == usergroup_create.org_id)
|
|
org = db_session.exec(statement).first()
|
|
|
|
if not org or org.id is None:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Organization does not exist",
|
|
)
|
|
|
|
# Usage check
|
|
check_limits_with_usage("courses", org.id, db_session)
|
|
|
|
# Complete the object
|
|
usergroup.usergroup_uuid = f"usergroup_{uuid4()}"
|
|
usergroup.creation_date = str(datetime.now())
|
|
usergroup.update_date = str(datetime.now())
|
|
|
|
# Save the object
|
|
db_session.add(usergroup)
|
|
db_session.commit()
|
|
db_session.refresh(usergroup)
|
|
|
|
# Feature usage
|
|
increase_feature_usage("usergroups", org.id, db_session)
|
|
|
|
usergroup = UserGroupRead.model_validate(usergroup)
|
|
|
|
return usergroup
|
|
|
|
|
|
async def read_usergroup_by_id(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
) -> UserGroupRead:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="read",
|
|
db_session=db_session,
|
|
)
|
|
|
|
usergroup = UserGroupRead.model_validate(usergroup)
|
|
|
|
return usergroup
|
|
|
|
|
|
async def get_users_linked_to_usergroup(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
) -> list[UserRead]:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="read",
|
|
db_session=db_session,
|
|
)
|
|
|
|
statement = select(UserGroupUser).where(UserGroupUser.usergroup_id == usergroup_id)
|
|
usergroup_users = db_session.exec(statement).all()
|
|
|
|
user_ids = [usergroup_user.user_id for usergroup_user in usergroup_users]
|
|
|
|
# get users
|
|
users = []
|
|
for user_id in user_ids:
|
|
statement = select(User).where(User.id == user_id)
|
|
user = db_session.exec(statement).first()
|
|
users.append(user)
|
|
|
|
users = [UserRead.model_validate(user) for user in users]
|
|
|
|
return users
|
|
|
|
|
|
async def read_usergroups_by_org_id(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
org_id: int,
|
|
) -> list[UserGroupRead]:
|
|
|
|
statement = select(UserGroup).where(UserGroup.org_id == org_id)
|
|
usergroups = db_session.exec(statement).all()
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid="usergroup_X",
|
|
current_user=current_user,
|
|
action="read",
|
|
db_session=db_session,
|
|
)
|
|
|
|
usergroups = [UserGroupRead.model_validate(usergroup) for usergroup in usergroups]
|
|
|
|
return usergroups
|
|
|
|
|
|
async def get_usergroups_by_resource(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
resource_uuid: str,
|
|
) -> list[UserGroupRead]:
|
|
|
|
statement = select(UserGroupResource).where(
|
|
UserGroupResource.resource_uuid == resource_uuid
|
|
)
|
|
usergroup_resources = db_session.exec(statement).all()
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid="usergroup_X",
|
|
current_user=current_user,
|
|
action="read",
|
|
db_session=db_session,
|
|
)
|
|
|
|
usergroup_ids = [usergroup.usergroup_id for usergroup in usergroup_resources]
|
|
|
|
# get usergroups
|
|
usergroups = []
|
|
for usergroup_id in usergroup_ids:
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
usergroups.append(usergroup)
|
|
|
|
usergroups = [UserGroupRead.model_validate(usergroup) for usergroup in usergroups]
|
|
|
|
return usergroups
|
|
|
|
|
|
async def update_usergroup_by_id(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
usergroup_update: UserGroupUpdate,
|
|
) -> UserGroupRead:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="update",
|
|
db_session=db_session,
|
|
)
|
|
|
|
usergroup.name = usergroup_update.name
|
|
usergroup.description = usergroup_update.description
|
|
usergroup.update_date = str(datetime.now())
|
|
|
|
db_session.add(usergroup)
|
|
db_session.commit()
|
|
db_session.refresh(usergroup)
|
|
|
|
usergroup = UserGroupRead.model_validate(usergroup)
|
|
|
|
return usergroup
|
|
|
|
|
|
async def delete_usergroup_by_id(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
) -> str:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="delete",
|
|
db_session=db_session,
|
|
)
|
|
|
|
# Feature usage
|
|
increase_feature_usage("usergroups", usergroup.org_id, db_session)
|
|
|
|
db_session.delete(usergroup)
|
|
db_session.commit()
|
|
|
|
return "UserGroup deleted successfully"
|
|
|
|
|
|
async def add_users_to_usergroup(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser | InternalUser,
|
|
usergroup_id: int,
|
|
user_ids: str,
|
|
) -> str:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="create",
|
|
db_session=db_session,
|
|
)
|
|
|
|
user_ids_array = user_ids.split(",")
|
|
|
|
for user_id in user_ids_array:
|
|
statement = select(User).where(User.id == user_id)
|
|
user = db_session.exec(statement).first()
|
|
|
|
# Check if User is already Linked to UserGroup
|
|
statement = select(UserGroupUser).where(
|
|
UserGroupUser.usergroup_id == usergroup_id,
|
|
UserGroupUser.user_id == user_id,
|
|
)
|
|
usergroup_user = db_session.exec(statement).first()
|
|
|
|
if usergroup_user:
|
|
logging.error(f"User with id {user_id} already exists in UserGroup")
|
|
continue
|
|
|
|
if user:
|
|
# Add user to UserGroup
|
|
if user.id is not None:
|
|
usergroup_obj = UserGroupUser(
|
|
usergroup_id=usergroup_id,
|
|
user_id=user.id,
|
|
org_id=usergroup.org_id,
|
|
creation_date=str(datetime.now()),
|
|
update_date=str(datetime.now()),
|
|
)
|
|
|
|
db_session.add(usergroup_obj)
|
|
db_session.commit()
|
|
db_session.refresh(usergroup_obj)
|
|
else:
|
|
logging.error(f"User with id {user_id} not found")
|
|
|
|
return "Users added to UserGroup successfully"
|
|
|
|
|
|
async def remove_users_from_usergroup(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
user_ids: str,
|
|
) -> str:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="delete",
|
|
db_session=db_session,
|
|
)
|
|
|
|
user_ids_array = user_ids.split(",")
|
|
|
|
for user_id in user_ids_array:
|
|
statement = select(UserGroupUser).where(
|
|
UserGroupUser.user_id == user_id, UserGroupUser.usergroup_id == usergroup_id
|
|
)
|
|
usergroup_user = db_session.exec(statement).first()
|
|
|
|
if usergroup_user:
|
|
db_session.delete(usergroup_user)
|
|
db_session.commit()
|
|
else:
|
|
logging.error(f"User with id {user_id} not found in UserGroup")
|
|
|
|
return "Users removed from UserGroup successfully"
|
|
|
|
|
|
async def add_resources_to_usergroup(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
resources_uuids: str,
|
|
) -> str:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="create",
|
|
db_session=db_session,
|
|
)
|
|
|
|
resources_uuids_array = resources_uuids.split(",")
|
|
|
|
for resource_uuid in resources_uuids_array:
|
|
# Check if a link between UserGroup and Resource already exists
|
|
statement = select(UserGroupResource).where(
|
|
UserGroupResource.usergroup_id == usergroup_id,
|
|
UserGroupResource.resource_uuid == resource_uuid,
|
|
)
|
|
usergroup_resource = db_session.exec(statement).first()
|
|
|
|
if usergroup_resource:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Resource {resource_uuid} already exists in UserGroup",
|
|
)
|
|
continue
|
|
|
|
# TODO : Find a way to check if resource really exists
|
|
usergroup_obj = UserGroupResource(
|
|
usergroup_id=usergroup_id,
|
|
resource_uuid=resource_uuid,
|
|
org_id=usergroup.org_id,
|
|
creation_date=str(datetime.now()),
|
|
update_date=str(datetime.now()),
|
|
)
|
|
|
|
db_session.add(usergroup_obj)
|
|
db_session.commit()
|
|
db_session.refresh(usergroup_obj)
|
|
|
|
return "Resources added to UserGroup successfully"
|
|
|
|
|
|
async def remove_resources_from_usergroup(
|
|
request: Request,
|
|
db_session: Session,
|
|
current_user: PublicUser | AnonymousUser,
|
|
usergroup_id: int,
|
|
resources_uuids: str,
|
|
) -> str:
|
|
|
|
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
|
|
usergroup = db_session.exec(statement).first()
|
|
|
|
if not usergroup:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="UserGroup not found",
|
|
)
|
|
|
|
# RBAC check
|
|
await rbac_check(
|
|
request,
|
|
usergroup_uuid=usergroup.usergroup_uuid,
|
|
current_user=current_user,
|
|
action="delete",
|
|
db_session=db_session,
|
|
)
|
|
|
|
resources_uuids_array = resources_uuids.split(",")
|
|
|
|
for resource_uuid in resources_uuids_array:
|
|
statement = select(UserGroupResource).where(
|
|
UserGroupResource.resource_uuid == resource_uuid
|
|
)
|
|
usergroup_resource = db_session.exec(statement).first()
|
|
|
|
if usergroup_resource:
|
|
db_session.delete(usergroup_resource)
|
|
db_session.commit()
|
|
else:
|
|
logging.error(f"resource with uuid {resource_uuid} not found in UserGroup")
|
|
|
|
return "Resources removed from UserGroup successfully"
|
|
|
|
|
|
## 🔒 RBAC Utils ##
|
|
|
|
|
|
async def rbac_check(
|
|
request: Request,
|
|
usergroup_uuid: str,
|
|
current_user: PublicUser | AnonymousUser | InternalUser,
|
|
action: Literal["create", "read", "update", "delete"],
|
|
db_session: Session,
|
|
):
|
|
if isinstance(current_user, InternalUser):
|
|
return True
|
|
|
|
await authorization_verify_if_user_is_anon(current_user.id)
|
|
|
|
await authorization_verify_based_on_roles_and_authorship(
|
|
request,
|
|
current_user.id,
|
|
action,
|
|
usergroup_uuid,
|
|
db_session,
|
|
)
|
|
|
|
|
|
## 🔒 RBAC Utils ##
|