mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: protect paid courses
This commit is contained in:
parent
b7f09885df
commit
3988ee1d4b
9 changed files with 266 additions and 33 deletions
|
|
@ -19,6 +19,7 @@ from src.services.payments.payments_courses import (
|
|||
)
|
||||
from src.services.payments.payments_webhook import handle_stripe_webhook
|
||||
from src.services.payments.stripe import create_checkout_session
|
||||
from src.services.payments.payments_access import check_course_paid_access
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -185,3 +186,22 @@ async def api_create_checkout_session(
|
|||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
return await create_checkout_session(request, org_id, product_id, redirect_uri, current_user, db_session)
|
||||
|
||||
@router.get("/{org_id}/courses/{course_id}/access")
|
||||
async def api_check_course_paid_access(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
course_id: int,
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
"""
|
||||
Check if current user has paid access to a specific course
|
||||
"""
|
||||
return {
|
||||
"has_access": await check_course_paid_access(
|
||||
course_id=course_id,
|
||||
user=current_user,
|
||||
db_session=db_session
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ from fastapi import HTTPException, Request
|
|||
from uuid import uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from src.services.payments.payments_access import check_activity_paid_access
|
||||
|
||||
|
||||
####################################################
|
||||
# CRUD
|
||||
|
|
@ -112,7 +114,16 @@ async def get_activity(
|
|||
# RBAC check
|
||||
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
|
||||
|
||||
activity = ActivityRead.model_validate(activity)
|
||||
# Paid access check
|
||||
has_paid_access = await check_activity_paid_access(
|
||||
activity_id=activity.id if activity.id else 0,
|
||||
user=current_user,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
activity_read = ActivityRead.model_validate(activity)
|
||||
activity_read.content = activity_read.content if has_paid_access else { "paid_access": False }
|
||||
activity = activity_read
|
||||
|
||||
return activity
|
||||
|
||||
|
|
@ -258,30 +269,32 @@ async def get_activities(
|
|||
|
||||
async def rbac_check(
|
||||
request: Request,
|
||||
course_uuid: str,
|
||||
element_uuid: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
action: Literal["create", "read", "update", "delete"],
|
||||
db_session: Session,
|
||||
):
|
||||
|
||||
|
||||
if action == "read":
|
||||
if current_user.id == 0: # Anonymous user
|
||||
res = await authorization_verify_if_element_is_public(
|
||||
request, course_uuid, action, db_session
|
||||
request, element_uuid, action, db_session
|
||||
)
|
||||
return res
|
||||
else:
|
||||
res = await authorization_verify_based_on_roles_and_authorship(
|
||||
request, current_user.id, action, course_uuid, db_session
|
||||
request, current_user.id, action, element_uuid, db_session
|
||||
)
|
||||
return res
|
||||
else:
|
||||
# For non-read actions, proceed with regular RBAC checks
|
||||
await authorization_verify_if_user_is_anon(current_user.id)
|
||||
|
||||
await authorization_verify_based_on_roles_and_authorship(
|
||||
request,
|
||||
current_user.id,
|
||||
action,
|
||||
course_uuid,
|
||||
element_uuid,
|
||||
db_session,
|
||||
)
|
||||
|
||||
|
|
|
|||
98
apps/api/src/services/payments/payments_access.py
Normal file
98
apps/api/src/services/payments/payments_access.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
from sqlmodel import Session, select
|
||||
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
|
||||
from src.db.users import PublicUser, AnonymousUser
|
||||
from src.db.payments.payments_courses import PaymentsCourse
|
||||
from src.db.courses.activities import Activity
|
||||
from src.db.courses.courses import Course
|
||||
from fastapi import HTTPException
|
||||
|
||||
async def check_activity_paid_access(
|
||||
activity_id: int,
|
||||
user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has access to a specific activity
|
||||
Returns True if:
|
||||
- User is an author of the course
|
||||
- Activity is in a free course
|
||||
- User has a valid subscription for the course
|
||||
"""
|
||||
|
||||
|
||||
# Get activity and associated course
|
||||
statement = select(Activity).where(Activity.id == activity_id)
|
||||
activity = db_session.exec(statement).first()
|
||||
|
||||
if not activity:
|
||||
raise HTTPException(status_code=404, detail="Activity not found")
|
||||
|
||||
# Check if course exists
|
||||
statement = select(Course).where(Course.id == activity.course_id)
|
||||
course = db_session.exec(statement).first()
|
||||
|
||||
if not course:
|
||||
raise HTTPException(status_code=404, detail="Course not found")
|
||||
|
||||
# Check if course is linked to a product
|
||||
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
|
||||
course_payment = db_session.exec(statement).first()
|
||||
|
||||
# If course is not linked to any product, it's free
|
||||
if not course_payment:
|
||||
return True
|
||||
|
||||
# Anonymous users have no access to paid activities
|
||||
if isinstance(user, AnonymousUser):
|
||||
return False
|
||||
|
||||
# Check if user has a valid subscription or payment
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user.id,
|
||||
PaymentsUser.payment_product_id == course_payment.payment_product_id,
|
||||
PaymentsUser.status.in_( # type: ignore
|
||||
[PaymentStatusEnum.ACTIVE, PaymentStatusEnum.COMPLETED]
|
||||
),
|
||||
)
|
||||
access = db_session.exec(statement).first()
|
||||
|
||||
return bool(access)
|
||||
|
||||
async def check_course_paid_access(
|
||||
course_id: int,
|
||||
user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has paid access to a specific course
|
||||
Returns True if:
|
||||
- User is an author of the course
|
||||
- Course is free (not linked to any product)
|
||||
- User has a valid subscription for the course
|
||||
"""
|
||||
# Check if course exists
|
||||
statement = select(Course).where(Course.id == course_id)
|
||||
course = db_session.exec(statement).first()
|
||||
|
||||
if not course:
|
||||
raise HTTPException(status_code=404, detail="Course not found")
|
||||
|
||||
# Check if course is linked to a product
|
||||
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
|
||||
course_payment = db_session.exec(statement).first()
|
||||
|
||||
# If course is not linked to any product, it's free
|
||||
if not course_payment:
|
||||
return True
|
||||
|
||||
# Check if user has a valid subscription
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user.id,
|
||||
PaymentsUser.payment_product_id == course_payment.payment_product_id,
|
||||
PaymentsUser.status.in_( # type: ignore
|
||||
[PaymentStatusEnum.ACTIVE, PaymentStatusEnum.COMPLETED]
|
||||
),
|
||||
)
|
||||
subscription = db_session.exec(statement).first()
|
||||
|
||||
return bool(subscription)
|
||||
|
|
@ -121,5 +121,3 @@ async def get_courses_by_product(
|
|||
courses = db_session.exec(statement).all()
|
||||
|
||||
return courses
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue