learnhouse/apps/api/src/services/users/usergroups.py
2024-11-28 22:11:28 +01:00

507 lines
14 KiB
Python

from datetime import datetime
import logging
from typing import Literal
from uuid import uuid4
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.security.features_utils.usage import (
check_limits_with_usage,
increase_feature_usage,
)
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_if_user_is_anon,
)
from src.db.usergroup_resources import UserGroupResource
from src.db.usergroup_user import UserGroupUser
from src.db.organizations import Organization
from src.db.usergroups import UserGroup, UserGroupCreate, UserGroupRead, UserGroupUpdate
from src.db.users import AnonymousUser, InternalUser, PublicUser, User, UserRead
async def create_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_create: UserGroupCreate,
) -> UserGroupRead:
usergroup = UserGroup.model_validate(usergroup_create)
# RBAC check
await rbac_check(
request,
usergroup_uuid="usergroup_X",
current_user=current_user,
action="create",
db_session=db_session,
)
# Check if Organization exists
statement = select(Organization).where(Organization.id == usergroup_create.org_id)
org = db_session.exec(statement).first()
if not org or org.id is None:
raise HTTPException(
status_code=400,
detail="Organization does not exist",
)
# Usage check
check_limits_with_usage("courses", org.id, db_session)
# Complete the object
usergroup.usergroup_uuid = f"usergroup_{uuid4()}"
usergroup.creation_date = str(datetime.now())
usergroup.update_date = str(datetime.now())
# Save the object
db_session.add(usergroup)
db_session.commit()
db_session.refresh(usergroup)
# Feature usage
increase_feature_usage("usergroups", org.id, db_session)
usergroup = UserGroupRead.model_validate(usergroup)
return usergroup
async def read_usergroup_by_id(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
) -> UserGroupRead:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="read",
db_session=db_session,
)
usergroup = UserGroupRead.model_validate(usergroup)
return usergroup
async def get_users_linked_to_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
) -> list[UserRead]:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="read",
db_session=db_session,
)
statement = select(UserGroupUser).where(UserGroupUser.usergroup_id == usergroup_id)
usergroup_users = db_session.exec(statement).all()
user_ids = [usergroup_user.user_id for usergroup_user in usergroup_users]
# get users
users = []
for user_id in user_ids:
statement = select(User).where(User.id == user_id)
user = db_session.exec(statement).first()
users.append(user)
users = [UserRead.model_validate(user) for user in users]
return users
async def read_usergroups_by_org_id(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
org_id: int,
) -> list[UserGroupRead]:
statement = select(UserGroup).where(UserGroup.org_id == org_id)
usergroups = db_session.exec(statement).all()
# RBAC check
await rbac_check(
request,
usergroup_uuid="usergroup_X",
current_user=current_user,
action="read",
db_session=db_session,
)
usergroups = [UserGroupRead.model_validate(usergroup) for usergroup in usergroups]
return usergroups
async def get_usergroups_by_resource(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
resource_uuid: str,
) -> list[UserGroupRead]:
statement = select(UserGroupResource).where(
UserGroupResource.resource_uuid == resource_uuid
)
usergroup_resources = db_session.exec(statement).all()
# RBAC check
await rbac_check(
request,
usergroup_uuid="usergroup_X",
current_user=current_user,
action="read",
db_session=db_session,
)
usergroup_ids = [usergroup.usergroup_id for usergroup in usergroup_resources]
# get usergroups
usergroups = []
for usergroup_id in usergroup_ids:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
usergroups.append(usergroup)
usergroups = [UserGroupRead.model_validate(usergroup) for usergroup in usergroups]
return usergroups
async def update_usergroup_by_id(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
usergroup_update: UserGroupUpdate,
) -> UserGroupRead:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="update",
db_session=db_session,
)
usergroup.name = usergroup_update.name
usergroup.description = usergroup_update.description
usergroup.update_date = str(datetime.now())
db_session.add(usergroup)
db_session.commit()
db_session.refresh(usergroup)
usergroup = UserGroupRead.model_validate(usergroup)
return usergroup
async def delete_usergroup_by_id(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="delete",
db_session=db_session,
)
# Feature usage
increase_feature_usage("usergroups", usergroup.org_id, db_session)
db_session.delete(usergroup)
db_session.commit()
return "UserGroup deleted successfully"
async def add_users_to_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser | InternalUser,
usergroup_id: int,
user_ids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="create",
db_session=db_session,
)
user_ids_array = user_ids.split(",")
for user_id in user_ids_array:
statement = select(User).where(User.id == user_id)
user = db_session.exec(statement).first()
# Check if User is already Linked to UserGroup
statement = select(UserGroupUser).where(
UserGroupUser.usergroup_id == usergroup_id,
UserGroupUser.user_id == user_id,
)
usergroup_user = db_session.exec(statement).first()
if usergroup_user:
logging.error(f"User with id {user_id} already exists in UserGroup")
continue
if user:
# Add user to UserGroup
if user.id is not None:
usergroup_obj = UserGroupUser(
usergroup_id=usergroup_id,
user_id=user.id,
org_id=usergroup.org_id,
creation_date=str(datetime.now()),
update_date=str(datetime.now()),
)
db_session.add(usergroup_obj)
db_session.commit()
db_session.refresh(usergroup_obj)
else:
logging.error(f"User with id {user_id} not found")
return "Users added to UserGroup successfully"
async def remove_users_from_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
user_ids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="delete",
db_session=db_session,
)
user_ids_array = user_ids.split(",")
for user_id in user_ids_array:
statement = select(UserGroupUser).where(
UserGroupUser.user_id == user_id, UserGroupUser.usergroup_id == usergroup_id
)
usergroup_user = db_session.exec(statement).first()
if usergroup_user:
db_session.delete(usergroup_user)
db_session.commit()
else:
logging.error(f"User with id {user_id} not found in UserGroup")
return "Users removed from UserGroup successfully"
async def add_resources_to_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
resources_uuids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="create",
db_session=db_session,
)
resources_uuids_array = resources_uuids.split(",")
for resource_uuid in resources_uuids_array:
# Check if a link between UserGroup and Resource already exists
statement = select(UserGroupResource).where(
UserGroupResource.usergroup_id == usergroup_id,
UserGroupResource.resource_uuid == resource_uuid,
)
usergroup_resource = db_session.exec(statement).first()
if usergroup_resource:
raise HTTPException(
status_code=400,
detail=f"Resource {resource_uuid} already exists in UserGroup",
)
continue
# TODO : Find a way to check if resource really exists
usergroup_obj = UserGroupResource(
usergroup_id=usergroup_id,
resource_uuid=resource_uuid,
org_id=usergroup.org_id,
creation_date=str(datetime.now()),
update_date=str(datetime.now()),
)
db_session.add(usergroup_obj)
db_session.commit()
db_session.refresh(usergroup_obj)
return "Resources added to UserGroup successfully"
async def remove_resources_from_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
resources_uuids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
# RBAC check
await rbac_check(
request,
usergroup_uuid=usergroup.usergroup_uuid,
current_user=current_user,
action="delete",
db_session=db_session,
)
resources_uuids_array = resources_uuids.split(",")
for resource_uuid in resources_uuids_array:
statement = select(UserGroupResource).where(
UserGroupResource.resource_uuid == resource_uuid
)
usergroup_resource = db_session.exec(statement).first()
if usergroup_resource:
db_session.delete(usergroup_resource)
db_session.commit()
else:
logging.error(f"resource with uuid {resource_uuid} not found in UserGroup")
return "Resources removed from UserGroup successfully"
## 🔒 RBAC Utils ##
async def rbac_check(
request: Request,
usergroup_uuid: str,
current_user: PublicUser | AnonymousUser | InternalUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
if isinstance(current_user, InternalUser):
return True
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
usergroup_uuid,
db_session,
)
## 🔒 RBAC Utils ##