feat: protect paid courses

This commit is contained in:
swve 2024-11-02 23:16:33 +01:00
parent b7f09885df
commit 3988ee1d4b
9 changed files with 266 additions and 33 deletions

View file

@ -19,6 +19,7 @@ from src.services.payments.payments_courses import (
)
from src.services.payments.payments_webhook import handle_stripe_webhook
from src.services.payments.stripe import create_checkout_session
from src.services.payments.payments_access import check_course_paid_access
router = APIRouter()
@ -185,3 +186,22 @@ async def api_create_checkout_session(
db_session: Session = Depends(get_db_session),
):
return await create_checkout_session(request, org_id, product_id, redirect_uri, current_user, db_session)
@router.get("/{org_id}/courses/{course_id}/access")
async def api_check_course_paid_access(
request: Request,
org_id: int,
course_id: int,
current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session),
):
"""
Check if current user has paid access to a specific course
"""
return {
"has_access": await check_course_paid_access(
course_id=course_id,
user=current_user,
db_session=db_session
)
}

View file

@ -14,6 +14,8 @@ from fastapi import HTTPException, Request
from uuid import uuid4
from datetime import datetime
from src.services.payments.payments_access import check_activity_paid_access
####################################################
# CRUD
@ -112,7 +114,16 @@ async def get_activity(
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
activity = ActivityRead.model_validate(activity)
# Paid access check
has_paid_access = await check_activity_paid_access(
activity_id=activity.id if activity.id else 0,
user=current_user,
db_session=db_session
)
activity_read = ActivityRead.model_validate(activity)
activity_read.content = activity_read.content if has_paid_access else { "paid_access": False }
activity = activity_read
return activity
@ -258,30 +269,32 @@ async def get_activities(
async def rbac_check(
request: Request,
course_uuid: str,
element_uuid: str,
current_user: PublicUser | AnonymousUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
if action == "read":
if current_user.id == 0: # Anonymous user
res = await authorization_verify_if_element_is_public(
request, course_uuid, action, db_session
request, element_uuid, action, db_session
)
return res
else:
res = await authorization_verify_based_on_roles_and_authorship(
request, current_user.id, action, course_uuid, db_session
request, current_user.id, action, element_uuid, db_session
)
return res
else:
# For non-read actions, proceed with regular RBAC checks
await authorization_verify_if_user_is_anon(current_user.id)
await authorization_verify_based_on_roles_and_authorship(
request,
current_user.id,
action,
course_uuid,
element_uuid,
db_session,
)

View file

@ -0,0 +1,98 @@
from sqlmodel import Session, select
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
from src.db.users import PublicUser, AnonymousUser
from src.db.payments.payments_courses import PaymentsCourse
from src.db.courses.activities import Activity
from src.db.courses.courses import Course
from fastapi import HTTPException
async def check_activity_paid_access(
activity_id: int,
user: PublicUser | AnonymousUser,
db_session: Session,
) -> bool:
"""
Check if a user has access to a specific activity
Returns True if:
- User is an author of the course
- Activity is in a free course
- User has a valid subscription for the course
"""
# Get activity and associated course
statement = select(Activity).where(Activity.id == activity_id)
activity = db_session.exec(statement).first()
if not activity:
raise HTTPException(status_code=404, detail="Activity not found")
# Check if course exists
statement = select(Course).where(Course.id == activity.course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(status_code=404, detail="Course not found")
# Check if course is linked to a product
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
course_payment = db_session.exec(statement).first()
# If course is not linked to any product, it's free
if not course_payment:
return True
# Anonymous users have no access to paid activities
if isinstance(user, AnonymousUser):
return False
# Check if user has a valid subscription or payment
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == course_payment.payment_product_id,
PaymentsUser.status.in_( # type: ignore
[PaymentStatusEnum.ACTIVE, PaymentStatusEnum.COMPLETED]
),
)
access = db_session.exec(statement).first()
return bool(access)
async def check_course_paid_access(
course_id: int,
user: PublicUser | AnonymousUser,
db_session: Session,
) -> bool:
"""
Check if a user has paid access to a specific course
Returns True if:
- User is an author of the course
- Course is free (not linked to any product)
- User has a valid subscription for the course
"""
# Check if course exists
statement = select(Course).where(Course.id == course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(status_code=404, detail="Course not found")
# Check if course is linked to a product
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
course_payment = db_session.exec(statement).first()
# If course is not linked to any product, it's free
if not course_payment:
return True
# Check if user has a valid subscription
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == course_payment.payment_product_id,
PaymentsUser.status.in_( # type: ignore
[PaymentStatusEnum.ACTIVE, PaymentStatusEnum.COMPLETED]
),
)
subscription = db_session.exec(statement).first()
return bool(subscription)

View file

@ -121,5 +121,3 @@ async def get_courses_by_product(
courses = db_session.exec(statement).all()
return courses