feat: refactor RBAC authorization functions to include usergroups

This commit is contained in:
swve 2024-03-26 19:56:14 +00:00
parent e1b3b62e40
commit 0df250c729
14 changed files with 392 additions and 37 deletions

View file

@ -0,0 +1,18 @@
from typing import Optional
from sqlalchemy import Column, ForeignKey, Integer
from sqlmodel import Field, SQLModel
class UserGroupUser(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
usergroup_id: int = Field(
sa_column=Column(Integer, ForeignKey("usergroup.id", ondelete="CASCADE"))
)
user_id: int = Field(
sa_column=Column(Integer, ForeignKey("user.id", ondelete="CASCADE"))
)
org_id: int = Field(
sa_column=Column(Integer, ForeignKey("organization.id", ondelete="CASCADE"))
)
creation_date: str = ""
update_date: str = ""

View file

@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, Request, UploadFile
from sqlmodel import Session
from src.services.orgs.invites import (
create_invite_code,
create_invite_code_with_usergroup,
delete_invite_code,
get_invite_code,
get_invite_codes,
@ -162,6 +163,22 @@ async def api_create_invite_code(
return await create_invite_code(request, org_id, current_user, db_session)
@router.post("/{org_id}/invites_with_usergroups")
async def api_create_invite_code_with_ug(
request: Request,
org_id: int,
usergroup_id: int,
current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session),
):
"""
Create invite code
"""
return await create_invite_code_with_usergroup(
request, org_id, usergroup_id, current_user, db_session
)
@router.get("/{org_id}/invites")
async def api_get_invite_codes(
request: Request,

View file

@ -4,7 +4,16 @@ from sqlmodel import Session
from src.services.users.users import delete_user_by_id
from src.db.usergroups import UserGroupCreate, UserGroupRead, UserGroupUpdate
from src.db.users import PublicUser
from src.services.users.usergroups import create_usergroup, delete_usergroup_by_id, read_usergroup_by_id, update_usergroup_by_id
from src.services.users.usergroups import (
add_ressources_to_usergroup,
add_users_to_usergroup,
create_usergroup,
delete_usergroup_by_id,
read_usergroup_by_id,
remove_ressources_from_usergroup,
remove_users_from_usergroup,
update_usergroup_by_id,
)
from src.services.orgs.orgs import get_org_join_mechanism
from src.security.auth import get_current_user
from src.core.events.database import get_db_session
@ -13,8 +22,8 @@ from src.core.events.database import get_db_session
router = APIRouter()
@router.post("/", response_model=UserGroupCreate, tags=["usergroups"])
async def api_create_user_without_org(
@router.post("/", response_model=UserGroupRead, tags=["usergroups"])
async def api_create_usergroup(
*,
request: Request,
db_session: Session = Depends(get_db_session),
@ -40,6 +49,7 @@ async def api_get_usergroup(
"""
return await read_usergroup_by_id(request, db_session, current_user, usergroup_id)
@router.put("/{usergroup_id}", response_model=UserGroupRead, tags=["usergroups"])
async def api_update_usergroup(
*,
@ -52,7 +62,10 @@ async def api_update_usergroup(
"""
Update UserGroup
"""
return await update_usergroup_by_id(request, db_session, current_user, usergroup_id, usergroup_object)
return await update_usergroup_by_id(
request, db_session, current_user, usergroup_id, usergroup_object
)
@router.delete("/{usergroup_id}", tags=["usergroups"])
async def api_delete_usergroup(
@ -66,3 +79,71 @@ async def api_delete_usergroup(
Delete UserGroup
"""
return await delete_usergroup_by_id(request, db_session, current_user, usergroup_id)
@router.post("/{usergroup_id}/add_users", tags=["usergroups"])
async def api_add_users_to_usergroup(
*,
request: Request,
db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
usergroup_id: int,
user_ids: str,
) -> str:
"""
Add Users to UserGroup
"""
return await add_users_to_usergroup(
request, db_session, current_user, usergroup_id, user_ids
)
@router.delete("/{usergroup_id}/remove_users", tags=["usergroups"])
async def api_delete_users_from_usergroup(
*,
request: Request,
db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
usergroup_id: int,
user_ids: str,
) -> str:
"""
Delete Users from UserGroup
"""
return await remove_users_from_usergroup(
request, db_session, current_user, usergroup_id, user_ids
)
@router.post("/{usergroup_id}/add_ressources", tags=["usergroups"])
async def api_add_ressources_to_usergroup(
*,
request: Request,
db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
usergroup_id: int,
ressource_uuids: str,
) -> str:
"""
Add Ressources to UserGroup
"""
return await add_ressources_to_usergroup(
request, db_session, current_user, usergroup_id, ressource_uuids
)
@router.delete("/{usergroup_id}/remove_ressources", tags=["usergroups"])
async def api_delete_ressources_from_usergroup(
*,
request: Request,
db_session: Session = Depends(get_db_session),
current_user: PublicUser = Depends(get_current_user),
usergroup_id: int,
ressource_uuids: str,
) -> str:
"""
Delete Ressources from UserGroup
"""
return await remove_ressources_from_usergroup(
request, db_session, current_user, usergroup_id, ressource_uuids
)

View file

@ -143,7 +143,7 @@ async def authorization_verify_based_on_org_admin_status(
# Tested and working
async def authorization_verify_based_on_roles_and_authorship(
async def authorization_verify_based_on_roles_and_authorship_and_usergroups(
request: Request,
user_id: int,
action: Literal["read", "update", "delete", "create"],

View file

@ -3,7 +3,7 @@ from sqlmodel import Session, select
from src.db.courses import Course
from src.db.chapters import Chapter
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_element_is_public,
authorization_verify_if_user_is_anon,
)
@ -238,14 +238,14 @@ async def rbac_check(
)
return res
else:
res = await authorization_verify_based_on_roles_and_authorship(
res = await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, course_uuid, db_session
)
return res
else:
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request,
current_user.id,
action,

View file

@ -3,7 +3,7 @@ from src.db.courses import Course
from src.db.organizations import Organization
from sqlmodel import Session, select
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_user_is_anon,
)
from src.db.chapters import Chapter
@ -150,7 +150,7 @@ async def rbac_check(
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request,
current_user.id,
action,

View file

@ -5,7 +5,7 @@ from src.db.organizations import Organization
from pydantic import BaseModel
from sqlmodel import Session, select
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_user_is_anon,
)
from src.db.chapters import Chapter
@ -232,7 +232,7 @@ async def rbac_check(
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request,
current_user.id,
action,

View file

@ -4,7 +4,7 @@ from uuid import uuid4
from sqlmodel import Session, select
from src.db.users import AnonymousUser
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_element_is_public,
authorization_verify_if_user_is_anon,
)
@ -562,14 +562,14 @@ async def rbac_check(
print("res", res)
return res
else:
res = await authorization_verify_based_on_roles_and_authorship(
res = await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, course_uuid, db_session
)
return res
else:
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request,
current_user.id,
action,

View file

@ -4,7 +4,7 @@ from uuid import uuid4
from sqlmodel import Session, select
from src.db.users import AnonymousUser
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_element_is_public,
authorization_verify_if_user_is_anon,
)
@ -297,14 +297,14 @@ async def rbac_check(
detail="User rights : You are not allowed to read this collection",
)
else:
res = await authorization_verify_based_on_roles_and_authorship(
res = await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, collection_uuid, db_session
)
return res
else:
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request,
current_user.id,
action,

View file

@ -3,7 +3,6 @@ from uuid import uuid4
from sqlmodel import Session, select
from src.db.organizations import Organization
from src.db.trails import TrailRead
from src.services.trail.trail import get_user_trail_with_orgid
from src.db.resource_authors import ResourceAuthor, ResourceAuthorshipEnum
from src.db.users import PublicUser, AnonymousUser, User, UserRead
@ -15,7 +14,7 @@ from src.db.courses import (
FullCourseReadWithTrail,
)
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_element_is_public,
authorization_verify_if_user_is_anon,
)
@ -142,7 +141,7 @@ async def create_course(
if thumbnail_file and thumbnail_file.filename:
name_in_disk = f"{course.course_uuid}_thumbnail_{uuid4()}.{thumbnail_file.filename.split('.')[-1]}"
await upload_thumbnail(
thumbnail_file, name_in_disk, org.org_uuid, course.course_uuid
thumbnail_file, name_in_disk, org.org_uuid, course.course_uuid # type: ignore
)
course.thumbnail_image = name_in_disk
@ -213,7 +212,7 @@ async def update_course_thumbnail(
if thumbnail_file and thumbnail_file.filename:
name_in_disk = f"{course_uuid}_thumbnail_{uuid4()}.{thumbnail_file.filename.split('.')[-1]}"
await upload_thumbnail(
thumbnail_file, name_in_disk, org.org_uuid, course.course_uuid
thumbnail_file, name_in_disk, org.org_uuid, course.course_uuid # type: ignore
)
# Update course
@ -381,14 +380,14 @@ async def rbac_check(
)
return res
else:
res = await authorization_verify_based_on_roles_and_authorship(
res = await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, course_uuid, db_session
)
return res
else:
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request,
current_user.id,
action,

View file

@ -94,6 +94,85 @@ async def create_invite_code(
return inviteCodeObject
async def create_invite_code_with_usergroup(
request: Request,
org_id: int,
usergroup_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
# Redis init
LH_CONFIG = get_learnhouse_config()
redis_conn_string = LH_CONFIG.redis_config.redis_connection_string
if not redis_conn_string:
raise HTTPException(
status_code=500,
detail="Redis connection string not found",
)
statement = select(Organization).where(Organization.id == org_id)
result = db_session.exec(statement)
org = result.first()
if not org:
raise HTTPException(
status_code=404,
detail="Organization not found",
)
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
# Connect to Redis
r = redis.Redis.from_url(redis_conn_string)
if not r:
raise HTTPException(
status_code=500,
detail="Could not connect to Redis",
)
# Check if this org has more than 6 invite codes
invite_codes = r.keys(f"*:org:{org.org_uuid}:code:*")
if len(invite_codes) >= 6:
raise HTTPException(
status_code=400,
detail="Organization has reached the maximum number of invite codes",
)
# Generate invite code
def generate_code(length=5):
letters_and_digits = string.ascii_letters + string.digits
return "".join(random.choice(letters_and_digits) for _ in range(length))
generated_invite_code = generate_code()
invite_code_uuid = f"org_invite_code_{uuid.uuid4()}"
# time to live in days to seconds
ttl = int(timedelta(days=365).total_seconds())
inviteCodeObject = {
"invite_code": generated_invite_code,
"invite_code_uuid": invite_code_uuid,
"invite_code_expires": ttl,
"usergroup_id": usergroup_id,
"invite_code_type": "signup",
"created_at": datetime.now().isoformat(),
"created_by": current_user.user_uuid,
}
r.set(
f"{invite_code_uuid}:org:{org.org_uuid}:code:{generated_invite_code}",
json.dumps(inviteCodeObject),
ex=ttl,
)
return inviteCodeObject
async def get_invite_codes(
request: Request,
org_id: int,
@ -136,9 +215,15 @@ async def get_invite_codes(
# Get invite codes
invite_codes = r.keys(f"org_invite_code_*:org:{org.org_uuid}:code:*")
if not invite_codes:
raise HTTPException(
status_code=404,
detail="Invite codes not found",
)
invite_codes_list = []
for invite_code in invite_codes:
for invite_code in invite_codes: # type: ignore
invite_code = r.get(invite_code)
invite_code = json.loads(invite_code) # type: ignore
invite_codes_list.append(invite_code)

View file

@ -2,7 +2,7 @@ from typing import Literal
from uuid import uuid4
from sqlmodel import Session, select
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_user_is_anon,
)
from src.db.users import AnonymousUser, PublicUser
@ -133,7 +133,7 @@ async def rbac_check(
):
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, role_uuid, db_session
)

View file

@ -1,10 +1,13 @@
from datetime import datetime
import logging
from uuid import uuid4
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.usergroup_ressources import UserGroupRessource
from src.db.usergroup_user import UserGroupUser
from src.db.organizations import Organization
from src.db.usergroups import UserGroup, UserGroupCreate, UserGroupRead, UserGroupUpdate
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, PublicUser, User
async def create_usergroup(
@ -112,3 +115,151 @@ async def delete_usergroup_by_id(
db_session.commit()
return "UserGroup deleted successfully"
async def add_users_to_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
user_ids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
user_ids_array = user_ids.split(",")
for user_id in user_ids_array:
statement = select(User).where(User.id == user_id)
user = db_session.exec(statement).first()
if user:
# Add user to UserGroup
if user.id is not None:
usergroup_obj = UserGroupUser(
usergroup_id=usergroup_id,
user_id=user.id,
org_id=usergroup.org_id,
creation_date=str(datetime.now()),
update_date=str(datetime.now()),
)
db_session.add(usergroup_obj)
db_session.commit()
db_session.refresh(usergroup_obj)
else:
logging.error(f"User with id {user_id} not found")
return "Users added to UserGroup successfully"
async def remove_users_from_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
user_ids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
user_ids_array = user_ids.split(",")
for user_id in user_ids_array:
statement = select(UserGroupUser).where(UserGroupUser.user_id == user_id)
usergroup_user = db_session.exec(statement).first()
if usergroup_user:
db_session.delete(usergroup_user)
db_session.commit()
else:
logging.error(f"User with id {user_id} not found in UserGroup")
return "Users removed from UserGroup successfully"
async def add_ressources_to_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
ressources_uuids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
ressources_uuids_array = ressources_uuids.split(",")
for ressource_uuid in ressources_uuids_array:
# TODO : Find a way to check if ressource exists
usergroup_obj = UserGroupRessource(
usergroup_id=usergroup_id,
ressource_uuid=ressource_uuid,
org_id=usergroup.org_id,
creation_date=str(datetime.now()),
update_date=str(datetime.now()),
)
db_session.add(usergroup_obj)
db_session.commit()
db_session.refresh(usergroup_obj)
return "Ressources added to UserGroup successfully"
async def remove_ressources_from_usergroup(
request: Request,
db_session: Session,
current_user: PublicUser | AnonymousUser,
usergroup_id: int,
ressources_uuids: str,
) -> str:
statement = select(UserGroup).where(UserGroup.id == usergroup_id)
usergroup = db_session.exec(statement).first()
if not usergroup:
raise HTTPException(
status_code=404,
detail="UserGroup not found",
)
ressources_uuids_array = ressources_uuids.split(",")
for ressource_uuid in ressources_uuids_array:
statement = select(UserGroupRessource).where(
UserGroupRessource.ressource_uuid == ressource_uuid
)
usergroup_ressource = db_session.exec(statement).first()
if usergroup_ressource:
db_session.delete(usergroup_ressource)
db_session.commit()
else:
logging.error(
f"Ressource with uuid {ressource_uuid} not found in UserGroup"
)
return "Ressources removed from UserGroup successfully"

View file

@ -10,7 +10,7 @@ from src.services.orgs.invites import get_invite_code
from src.services.users.avatars import upload_avatar
from src.db.roles import Role, RoleRead
from src.security.rbac.rbac import (
authorization_verify_based_on_roles_and_authorship,
authorization_verify_based_on_roles_and_authorship_and_usergroups,
authorization_verify_if_user_is_anon,
)
from src.db.organizations import Organization, OrganizationRead
@ -124,11 +124,15 @@ async def create_user_with_invite(
):
# Check if invite code exists
isInviteCodeCorrect = await get_invite_code(
inviteCOde = await get_invite_code(
request, org_id, invite_code, current_user, db_session
)
if not isInviteCodeCorrect:
# Check if invite code contains UserGroup
#TODO
if not inviteCOde:
raise HTTPException(
status_code=400,
detail="Invite code is incorrect",
@ -463,7 +467,7 @@ async def authorize_user_action(
)
# RBAC check
authorized = await authorization_verify_based_on_roles_and_authorship(
authorized = await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, ressource_uuid, db_session
)
@ -535,7 +539,7 @@ async def rbac_check(
if current_user.id == 0: # if user is anonymous
return True
else:
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, "create", "user_x", db_session
)
@ -546,7 +550,7 @@ async def rbac_check(
if current_user.user_uuid == user_uuid:
return True
await authorization_verify_based_on_roles_and_authorship(
await authorization_verify_based_on_roles_and_authorship_and_usergroups(
request, current_user.id, action, user_uuid, db_session
)