mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
Merge pull request #135 from learnhouse/fix/post-migration-bugs
Post Migration Bug Fixes
This commit is contained in:
commit
a7e2bda41e
44 changed files with 580 additions and 336 deletions
|
|
@ -11,6 +11,8 @@ botocore
|
|||
python-jose
|
||||
passlib
|
||||
fastapi-jwt-auth
|
||||
pytest
|
||||
httpx
|
||||
faker
|
||||
requests
|
||||
pyyaml
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class CourseUpdate(CourseBase):
|
|||
class CourseRead(CourseBase):
|
||||
id: int
|
||||
org_id: int = Field(default=None, foreign_key="organization.id")
|
||||
authors: List[UserRead]
|
||||
authors: List[UserRead]
|
||||
course_uuid: str
|
||||
creation_date: str
|
||||
update_date: str
|
||||
|
|
@ -49,22 +49,22 @@ class CourseRead(CourseBase):
|
|||
|
||||
class FullCourseRead(CourseBase):
|
||||
id: int
|
||||
course_uuid: str
|
||||
creation_date: str
|
||||
update_date: str
|
||||
course_uuid: Optional[str]
|
||||
creation_date: Optional[str]
|
||||
update_date: Optional[str]
|
||||
# Chapters, Activities
|
||||
chapters: List[ChapterRead]
|
||||
authors: List[UserRead]
|
||||
authors: List[UserRead]
|
||||
pass
|
||||
|
||||
|
||||
class FullCourseReadWithTrail(CourseBase):
|
||||
id: int
|
||||
course_uuid: str
|
||||
creation_date: str
|
||||
update_date: str
|
||||
course_uuid: Optional[str]
|
||||
creation_date: Optional[str]
|
||||
update_date: Optional[str]
|
||||
org_id: int = Field(default=None, foreign_key="organization.id")
|
||||
authors: List[UserRead]
|
||||
authors: List[UserRead]
|
||||
# Chapters, Activities
|
||||
chapters: List[ChapterRead]
|
||||
# Trail
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ class TrailRunRead(BaseModel):
|
|||
org_id: int = Field(default=None, foreign_key="organization.id")
|
||||
user_id: int = Field(default=None, foreign_key="user.id")
|
||||
# course object
|
||||
course: dict
|
||||
course: Optional[dict]
|
||||
# timestamps
|
||||
creation_date: str
|
||||
update_date: str
|
||||
creation_date: Optional[str]
|
||||
update_date: Optional[str]
|
||||
# number of activities in course
|
||||
course_total_steps: int
|
||||
steps: list[TrailStep]
|
||||
|
|
|
|||
|
|
@ -23,11 +23,11 @@ class TrailCreate(TrailBase):
|
|||
# trick because Lists are not supported in SQLModel (runs: list[TrailRun] )
|
||||
class TrailRead(BaseModel):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
trail_uuid: str
|
||||
trail_uuid: Optional[str]
|
||||
org_id: int = Field(default=None, foreign_key="organization.id")
|
||||
user_id: int = Field(default=None, foreign_key="user.id")
|
||||
creation_date: str
|
||||
update_date: str
|
||||
creation_date: Optional[str]
|
||||
update_date: Optional[str]
|
||||
runs: list[TrailRunRead]
|
||||
|
||||
class Config:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from src.db.roles import RoleRead
|
||||
from src.db.organizations import OrganizationRead
|
||||
|
||||
|
||||
class UserBase(SQLModel):
|
||||
username: str
|
||||
|
|
@ -33,14 +37,27 @@ class UserRead(UserBase):
|
|||
id: int
|
||||
user_uuid: str
|
||||
|
||||
|
||||
class PublicUser(UserRead):
|
||||
pass
|
||||
|
||||
|
||||
class UserRoleWithOrg(BaseModel):
|
||||
role: RoleRead
|
||||
org: OrganizationRead
|
||||
|
||||
|
||||
class UserSession(BaseModel):
|
||||
user: UserRead
|
||||
roles: list[UserRoleWithOrg]
|
||||
|
||||
|
||||
class AnonymousUser(SQLModel):
|
||||
id: int = 0
|
||||
user_uuid: str = "user_anonymous"
|
||||
username: str = "anonymous"
|
||||
|
||||
|
||||
class User(UserBase, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
password: str = ""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from src.db.users import (
|
|||
User,
|
||||
UserCreate,
|
||||
UserRead,
|
||||
UserSession,
|
||||
UserUpdate,
|
||||
UserUpdatePassword,
|
||||
)
|
||||
|
|
@ -17,6 +18,7 @@ from src.services.users.users import (
|
|||
create_user,
|
||||
create_user_without_org,
|
||||
delete_user_by_id,
|
||||
get_user_session,
|
||||
read_user_by_id,
|
||||
read_user_by_uuid,
|
||||
update_user,
|
||||
|
|
@ -35,6 +37,18 @@ async def api_get_current_user(current_user: User = Depends(get_current_user)):
|
|||
return current_user.dict()
|
||||
|
||||
|
||||
@router.get("/session")
|
||||
async def api_get_current_user_session(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_db_session),
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
) -> UserSession:
|
||||
"""
|
||||
Get current user
|
||||
"""
|
||||
return await get_user_session(request, db_session, current_user)
|
||||
|
||||
|
||||
@router.get("/authorize/ressource/{ressource_uuid}/action/{action}")
|
||||
async def api_get_authorization_status(
|
||||
request: Request,
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ async def authorization_verify_if_element_is_public(
|
|||
# Verifies if the element is public
|
||||
if element_nature == ("courses" or "collections") and action == "read":
|
||||
if element_nature == "courses":
|
||||
print("looking for course")
|
||||
statement = select(Course).where(
|
||||
Course.public is True, Course.course_uuid == element_uuid
|
||||
)
|
||||
|
|
@ -28,10 +29,7 @@ async def authorization_verify_if_element_is_public(
|
|||
if course:
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User rights (public content) : You don't have the right to perform this action",
|
||||
)
|
||||
return False
|
||||
|
||||
if element_nature == "collections":
|
||||
statement = select(Collection).where(
|
||||
|
|
@ -42,15 +40,9 @@ async def authorization_verify_if_element_is_public(
|
|||
if collection:
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User rights (public content) : You don't have the right to perform this action",
|
||||
)
|
||||
return False
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User rights (public content) : You don't have the right to perform this action",
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# Tested and working
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from sqlmodel import Session, select
|
|||
from src.db.chapters import Chapter
|
||||
from src.security.rbac.rbac import (
|
||||
authorization_verify_based_on_roles_and_authorship,
|
||||
authorization_verify_if_element_is_public,
|
||||
authorization_verify_if_user_is_anon,
|
||||
)
|
||||
from src.db.activities import ActivityCreate, Activity, ActivityRead, ActivityUpdate
|
||||
|
|
@ -212,20 +213,33 @@ async def get_activities(
|
|||
|
||||
async def rbac_check(
|
||||
request: Request,
|
||||
course_id: str,
|
||||
course_uuid: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
action: Literal["create", "read", "update", "delete"],
|
||||
db_session: Session,
|
||||
):
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
if action == "read":
|
||||
if current_user.id == 0: # Anonymous user
|
||||
res = await authorization_verify_if_element_is_public(
|
||||
request, course_uuid, action, db_session
|
||||
)
|
||||
print('res',res)
|
||||
return res
|
||||
else:
|
||||
res = await authorization_verify_based_on_roles_and_authorship(
|
||||
request, current_user.id, action, course_uuid, db_session
|
||||
)
|
||||
return res
|
||||
else:
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
course_id,
|
||||
db_session,
|
||||
)
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
course_uuid,
|
||||
db_session,
|
||||
)
|
||||
|
||||
|
||||
## 🔒 RBAC Utils ##
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlmodel import Session, select
|
|||
from src.db.users import AnonymousUser
|
||||
from src.security.rbac.rbac import (
|
||||
authorization_verify_based_on_roles_and_authorship,
|
||||
authorization_verify_if_element_is_public,
|
||||
authorization_verify_if_user_is_anon,
|
||||
)
|
||||
from src.db.course_chapters import CourseChapter
|
||||
|
|
@ -207,6 +208,10 @@ async def get_course_chapters(
|
|||
page: int = 1,
|
||||
limit: int = 10,
|
||||
) -> List[ChapterRead]:
|
||||
|
||||
statement = select(Course).where(Course.id == course_id)
|
||||
course = db_session.exec(statement).first()
|
||||
|
||||
statement = (
|
||||
select(Chapter)
|
||||
.join(CourseChapter, Chapter.id == CourseChapter.chapter_id)
|
||||
|
|
@ -220,7 +225,7 @@ async def get_course_chapters(
|
|||
chapters = [ChapterRead(**chapter.dict(), activities=[]) for chapter in chapters]
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, "chapter_x", current_user, "read", db_session)
|
||||
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
|
||||
|
||||
# Get activities for each chapter
|
||||
for chapter in chapters:
|
||||
|
|
@ -532,20 +537,33 @@ async def reorder_chapters_and_activities(
|
|||
|
||||
async def rbac_check(
|
||||
request: Request,
|
||||
course_id: str,
|
||||
course_uuid: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
action: Literal["create", "read", "update", "delete"],
|
||||
db_session: Session,
|
||||
):
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
if action == "read":
|
||||
if current_user.id == 0: # Anonymous user
|
||||
res = await authorization_verify_if_element_is_public(
|
||||
request, course_uuid, action, db_session
|
||||
)
|
||||
print('res',res)
|
||||
return res
|
||||
else:
|
||||
res = await authorization_verify_based_on_roles_and_authorship(
|
||||
request, current_user.id, action, course_uuid, db_session
|
||||
)
|
||||
return res
|
||||
else:
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
course_id,
|
||||
db_session,
|
||||
)
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
course_uuid,
|
||||
db_session,
|
||||
)
|
||||
|
||||
|
||||
## 🔒 RBAC Utils ##
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlmodel import Session, select
|
|||
from src.db.users import AnonymousUser
|
||||
from src.security.rbac.rbac import (
|
||||
authorization_verify_based_on_roles_and_authorship,
|
||||
authorization_verify_if_element_is_public,
|
||||
authorization_verify_if_user_is_anon,
|
||||
)
|
||||
from src.db.collections import (
|
||||
|
|
@ -245,20 +246,34 @@ async def get_collections(
|
|||
|
||||
async def rbac_check(
|
||||
request: Request,
|
||||
course_id: str,
|
||||
collection_uuid: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
action: Literal["create", "read", "update", "delete"],
|
||||
db_session: Session,
|
||||
):
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
if action == "read":
|
||||
if current_user.id == 0: # Anonymous user
|
||||
res = await authorization_verify_if_element_is_public(
|
||||
request, collection_uuid, action, db_session
|
||||
)
|
||||
print('res',res)
|
||||
return res
|
||||
else:
|
||||
res = await authorization_verify_based_on_roles_and_authorship(
|
||||
request, current_user.id, action, collection_uuid, db_session
|
||||
)
|
||||
return res
|
||||
else:
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
course_id,
|
||||
db_session,
|
||||
)
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
collection_uuid,
|
||||
db_session,
|
||||
)
|
||||
|
||||
|
||||
## 🔒 RBAC Utils ##
|
||||
|
||||
|
|
|
|||
|
|
@ -96,11 +96,16 @@ async def get_course_meta(
|
|||
chapters = await get_course_chapters(request, course.id, db_session, current_user)
|
||||
|
||||
# Trail
|
||||
trail = await get_user_trail_with_orgid(
|
||||
request, current_user, course.org_id, db_session
|
||||
)
|
||||
trail = None
|
||||
|
||||
if isinstance(current_user, AnonymousUser):
|
||||
trail = None
|
||||
else:
|
||||
trail = await get_user_trail_with_orgid(
|
||||
request, current_user, course.org_id, db_session
|
||||
)
|
||||
trail = TrailRead.from_orm(trail)
|
||||
|
||||
trail = TrailRead.from_orm(trail)
|
||||
|
||||
return FullCourseReadWithTrail(
|
||||
**course.dict(),
|
||||
|
|
@ -359,7 +364,6 @@ async def get_courses_orgslug(
|
|||
|
||||
## 🔒 RBAC Utils ##
|
||||
|
||||
|
||||
async def rbac_check(
|
||||
request: Request,
|
||||
course_uuid: str,
|
||||
|
|
@ -369,13 +373,16 @@ async def rbac_check(
|
|||
):
|
||||
if action == "read":
|
||||
if current_user.id == 0: # Anonymous user
|
||||
await authorization_verify_if_element_is_public(
|
||||
res = await authorization_verify_if_element_is_public(
|
||||
request, course_uuid, action, db_session
|
||||
)
|
||||
print('res',res)
|
||||
return res
|
||||
else:
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
res = await authorization_verify_based_on_roles_and_authorship(
|
||||
request, current_user.id, action, course_uuid, db_session
|
||||
)
|
||||
return res
|
||||
else:
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ async def update_install_instance(
|
|||
|
||||
|
||||
# Install Default roles
|
||||
async def install_default_elements(request: Request, data: dict, db_session: Session):
|
||||
async def install_default_elements( data: dict, db_session: Session):
|
||||
# remove all default roles
|
||||
statement = select(Role).where(Role.role_type == RoleTypeEnum.TYPE_GLOBAL)
|
||||
roles = db_session.exec(statement).all()
|
||||
|
|
@ -279,7 +279,7 @@ async def install_default_elements(request: Request, data: dict, db_session: Ses
|
|||
|
||||
# Organization creation
|
||||
async def install_create_organization(
|
||||
request: Request, org_object: OrganizationCreate, db_session: Session
|
||||
org_object: OrganizationCreate, db_session: Session
|
||||
):
|
||||
org = Organization.from_orm(org_object)
|
||||
|
||||
|
|
@ -296,7 +296,7 @@ async def install_create_organization(
|
|||
|
||||
|
||||
async def install_create_organization_user(
|
||||
request: Request, user_object: UserCreate, org_slug: str, db_session: Session
|
||||
user_object: UserCreate, org_slug: str, db_session: Session
|
||||
):
|
||||
user = User.from_orm(user_object)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from src.db.courses import Course
|
|||
from src.db.trail_runs import TrailRun, TrailRunRead
|
||||
from src.db.trail_steps import TrailStep
|
||||
from src.db.trails import Trail, TrailCreate, TrailRead
|
||||
from src.db.users import PublicUser
|
||||
from src.db.users import AnonymousUser, PublicUser
|
||||
|
||||
|
||||
async def create_user_trail(
|
||||
|
|
@ -17,7 +17,7 @@ async def create_user_trail(
|
|||
trail_object: TrailCreate,
|
||||
db_session: Session,
|
||||
) -> Trail:
|
||||
statement = select(Trail).where(Trail.org_id == trail_object.org_id)
|
||||
statement = select(Trail).where(Trail.org_id == trail_object.org_id, Trail.user_id == user.id)
|
||||
trail = db_session.exec(statement).first()
|
||||
|
||||
if trail:
|
||||
|
|
@ -103,7 +103,7 @@ async def check_trail_presence(
|
|||
user: PublicUser,
|
||||
db_session: Session,
|
||||
):
|
||||
statement = select(Trail).where(Trail.org_id == org_id, Trail.user_id == user.id)
|
||||
statement = select(Trail).where(Trail.org_id == org_id, Trail.user_id == user_id)
|
||||
trail = db_session.exec(statement).first()
|
||||
|
||||
if not trail:
|
||||
|
|
@ -122,9 +122,15 @@ async def check_trail_presence(
|
|||
|
||||
|
||||
async def get_user_trail_with_orgid(
|
||||
request: Request, user: PublicUser, org_id: int, db_session: Session
|
||||
request: Request, user: PublicUser | AnonymousUser, org_id: int, db_session: Session
|
||||
) -> TrailRead:
|
||||
|
||||
if isinstance(user, AnonymousUser):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Anonymous users cannot access this endpoint",
|
||||
)
|
||||
|
||||
trail = await check_trail_presence(
|
||||
org_id=org_id,
|
||||
user_id=user.id,
|
||||
|
|
|
|||
|
|
@ -3,17 +3,20 @@ from typing import Literal
|
|||
from uuid import uuid4
|
||||
from fastapi import HTTPException, Request, status
|
||||
from sqlmodel import Session, select
|
||||
from src.db.roles import Role, RoleRead
|
||||
from src.security.rbac.rbac import (
|
||||
authorization_verify_based_on_roles_and_authorship,
|
||||
authorization_verify_if_user_is_anon,
|
||||
)
|
||||
from src.db.organizations import Organization
|
||||
from src.db.organizations import Organization, OrganizationRead
|
||||
from src.db.users import (
|
||||
AnonymousUser,
|
||||
PublicUser,
|
||||
User,
|
||||
UserCreate,
|
||||
UserRead,
|
||||
UserRoleWithOrg,
|
||||
UserSession,
|
||||
UserUpdate,
|
||||
UserUpdatePassword,
|
||||
)
|
||||
|
|
@ -279,6 +282,57 @@ async def read_user_by_uuid(
|
|||
return user
|
||||
|
||||
|
||||
async def get_user_session(
|
||||
request: Request,
|
||||
db_session: Session,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
) -> UserSession:
|
||||
# Get user
|
||||
statement = select(User).where(User.user_uuid == current_user.user_uuid)
|
||||
user = db_session.exec(statement).first()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="User does not exist",
|
||||
)
|
||||
|
||||
user = UserRead.from_orm(user)
|
||||
|
||||
# Get roles and orgs
|
||||
statement = (
|
||||
select(UserOrganization)
|
||||
.where(UserOrganization.user_id == user.id)
|
||||
.join(Organization)
|
||||
)
|
||||
user_organizations = db_session.exec(statement).all()
|
||||
|
||||
roles = []
|
||||
|
||||
for user_organization in user_organizations:
|
||||
role_statement = select(Role).where(Role.id == user_organization.role_id)
|
||||
role = db_session.exec(role_statement).first()
|
||||
|
||||
org_statement = select(Organization).where(
|
||||
Organization.id == user_organization.org_id
|
||||
)
|
||||
org = db_session.exec(org_statement).first()
|
||||
|
||||
roles.append(
|
||||
UserRoleWithOrg(
|
||||
role=RoleRead.from_orm(role),
|
||||
org=OrganizationRead.from_orm(org),
|
||||
)
|
||||
)
|
||||
|
||||
user_session = UserSession(
|
||||
user=user,
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
return user_session
|
||||
|
||||
|
||||
async def authorize_user_action(
|
||||
request: Request,
|
||||
db_session: Session,
|
||||
|
|
|
|||
50
apps/api/src/tests/test_main.py
Normal file
50
apps/api/src/tests/test_main.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import SQLModel, Session
|
||||
from src.tests.utils.init_data_for_tests import create_initial_data_for_tests
|
||||
from src.core.events.database import get_db_session
|
||||
import pytest
|
||||
import asyncio
|
||||
from app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# TODO : fix this later https://stackoverflow.com/questions/10253826/path-issue-with-pytest-importerror-no-module-named
|
||||
|
||||
|
||||
@pytest.fixture(name="session", scope="session")
|
||||
def session_fixture():
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
def client_fixture(session: Session):
|
||||
def get_session_override():
|
||||
return session
|
||||
|
||||
app.dependency_overrides[get_db_session] = get_session_override
|
||||
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def execute_before_all_tests(session: Session):
|
||||
# This function will run once before all tests.
|
||||
asyncio.run(create_initial_data_for_tests(session))
|
||||
|
||||
|
||||
def test_create_default_elements(client: TestClient, session: Session):
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/orgs/slug/wayne",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
0
apps/api/src/tests/test_rbac.py
Normal file
0
apps/api/src/tests/test_rbac.py
Normal file
0
apps/api/src/tests/utils/__init__.py
Normal file
0
apps/api/src/tests/utils/__init__.py
Normal file
57
apps/api/src/tests/utils/init_data_for_tests.py
Normal file
57
apps/api/src/tests/utils/init_data_for_tests.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from sqlmodel import Session, select
|
||||
from src.db.user_organizations import UserOrganization
|
||||
from src.db.organizations import OrganizationCreate
|
||||
from src.db.users import User, UserCreate
|
||||
from src.services.install.install import (
|
||||
install_create_organization,
|
||||
install_create_organization_user,
|
||||
install_default_elements,
|
||||
)
|
||||
|
||||
|
||||
async def create_initial_data_for_tests(db_session: Session):
|
||||
# Install default elements
|
||||
await install_default_elements({}, db_session)
|
||||
|
||||
# Initiate test Organization
|
||||
test_org = OrganizationCreate(
|
||||
name="Wayne Enterprises",
|
||||
description=None,
|
||||
slug="wayne",
|
||||
email="hello@wayne.dev",
|
||||
logo_image=None,
|
||||
)
|
||||
|
||||
# Create test organization
|
||||
await install_create_organization(test_org, db_session)
|
||||
|
||||
users = [
|
||||
UserCreate(
|
||||
username="batman",
|
||||
first_name="Bruce",
|
||||
last_name="Wayne",
|
||||
email="bruce@wayne.com",
|
||||
password="imbatman",
|
||||
),
|
||||
UserCreate(
|
||||
username="robin",
|
||||
first_name="Richard John",
|
||||
last_name="Grayson",
|
||||
email="robin@wayne.com",
|
||||
password="secret",
|
||||
),
|
||||
]
|
||||
|
||||
# Create 2 users in that Organization
|
||||
for user in users:
|
||||
await install_create_organization_user(user, "wayne", db_session)
|
||||
|
||||
# Make robin a normal user
|
||||
statement = select(UserOrganization).join(User).where(User.username == "robin")
|
||||
user_org = db_session.exec(statement).first()
|
||||
|
||||
user_org.role_id = 3 # type: ignore
|
||||
db_session.add(user_org)
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
Loading…
Add table
Add a link
Reference in a new issue