From 1bff401e739906324c827a9d9a927433e92b32ab Mon Sep 17 00:00:00 2001 From: swve Date: Fri, 1 Nov 2024 20:51:52 +0100 Subject: [PATCH] feat: add course checkout UI and stripe integration and webhook wip --- apps/api/cli.py | 2 + apps/api/src/core/events/database.py | 39 ++- apps/api/src/db/payments/payments_courses.py | 12 +- apps/api/src/db/payments/payments_users.py | 36 ++- apps/api/src/db/users.py | 5 + apps/api/src/routers/ee/payments.py | 41 ++- apps/api/src/services/orgs/orgs.py | 8 +- .../src/services/payments/payments_config.py | 4 +- .../src/services/payments/payments_courses.py | 20 +- .../services/payments/payments_products.py | 37 ++- .../src/services/payments/payments_users.py | 181 ++++++++++++ .../src/services/payments/payments_webhook.py | 260 ++++++++++++++++++ apps/api/src/services/payments/stripe.py | 95 ++++++- .../(withmenu)/course/[courseuuid]/course.tsx | 104 +------ .../CourseActions/CoursePaidOptions.tsx | 161 +++++++++++ .../Courses/CourseActions/CoursesActions.tsx | 195 +++++++++++++ .../CourseUpdates/CourseUpdates.tsx | 0 apps/web/services/payments/products.ts | 17 ++ 18 files changed, 1086 insertions(+), 131 deletions(-) create mode 100644 apps/api/src/services/payments/payments_users.py create mode 100644 apps/api/src/services/payments/payments_webhook.py create mode 100644 apps/web/components/Objects/Courses/CourseActions/CoursePaidOptions.tsx create mode 100644 apps/web/components/Objects/Courses/CourseActions/CoursesActions.tsx rename apps/web/components/Objects/{ => Courses}/CourseUpdates/CourseUpdates.tsx (100%) diff --git a/apps/api/cli.py b/apps/api/cli.py index 603dbf0b..b4238ac4 100644 --- a/apps/api/cli.py +++ b/apps/api/cli.py @@ -48,6 +48,7 @@ def install( slug="default", email="", logo_image="", + thumbnail_image="", ) install_create_organization(org, db_session) print("Default organization created ✅") @@ -89,6 +90,7 @@ def install( slug=slug.lower(), email="", logo_image="", + thumbnail_image="", ) install_create_organization(org, db_session) print(orgname + " Organization created ✅") diff --git a/apps/api/src/core/events/database.py b/apps/api/src/core/events/database.py index 01318298..53120f7a 100644 --- a/apps/api/src/core/events/database.py +++ b/apps/api/src/core/events/database.py @@ -1,26 +1,55 @@ import logging +import os +import importlib +from typing import Optional from config.config import get_learnhouse_config from fastapi import FastAPI -from sqlmodel import SQLModel, Session, create_engine +from sqlmodel import Field, SQLModel, Session, create_engine + +def import_all_models(): + base_dir = 'src/db' + base_module_path = 'src.db' + + # Recursively walk through the base directory + for root, dirs, files in os.walk(base_dir): + # Filter out __init__.py and non-Python files + module_files = [f for f in files if f.endswith('.py') and f != '__init__.py'] + + # Calculate the module's base path from its directory structure + path_diff = os.path.relpath(root, base_dir) + if path_diff == '.': + current_module_base = base_module_path + else: + current_module_base = f"{base_module_path}.{path_diff.replace(os.sep, '.')}" + + # Dynamically import each module + for file_name in module_files: + module_name = file_name[:-3] # Remove the '.py' extension + full_module_path = f"{current_module_base}.{module_name}" + importlib.import_module(full_module_path) + +# Import all models before creating engine +import_all_models() learnhouse_config = get_learnhouse_config() engine = create_engine( - learnhouse_config.database_config.sql_connection_string, echo=False, pool_pre_ping=True # type: ignore + learnhouse_config.database_config.sql_connection_string, # type: ignore + echo=False, + pool_pre_ping=True # type: ignore ) -SQLModel.metadata.create_all(engine) +# Create all tables after importing all models +SQLModel.metadata.create_all(engine) async def connect_to_db(app: FastAPI): app.db_engine = engine # type: ignore logging.info("LearnHouse database has been started.") SQLModel.metadata.create_all(engine) - def get_db_session(): with Session(engine) as session: yield session - async def close_database(app: FastAPI): logging.info("LearnHouse has been shut down.") return app diff --git a/apps/api/src/db/payments/payments_courses.py b/apps/api/src/db/payments/payments_courses.py index f3f05687..bdcfbead 100644 --- a/apps/api/src/db/payments/payments_courses.py +++ b/apps/api/src/db/payments/payments_courses.py @@ -2,12 +2,14 @@ from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey from typing import Optional from datetime import datetime -class PaymentCourseBase(SQLModel): +class PaymentsCourseBase(SQLModel): course_id: int = Field(sa_column=Column(BigInteger, ForeignKey("course.id", ondelete="CASCADE"))) - payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE"))) - org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))) - -class PaymentCourse(PaymentCourseBase, table=True): + +class PaymentsCourse(PaymentsCourseBase, table=True): id: Optional[int] = Field(default=None, primary_key=True) + payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE"))) + org_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")) + ) creation_date: datetime = Field(default=datetime.now()) update_date: datetime = Field(default=datetime.now()) \ No newline at end of file diff --git a/apps/api/src/db/payments/payments_users.py b/apps/api/src/db/payments/payments_users.py index 4b90157e..ce3b3f68 100644 --- a/apps/api/src/db/payments/payments_users.py +++ b/apps/api/src/db/payments/payments_users.py @@ -1,19 +1,37 @@ -from enum import Enum -from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey +from openai import BaseModel +from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey, JSON from typing import Optional from datetime import datetime +from enum import Enum -class PaymentUserStatusEnum(str, Enum): +class PaymentStatusEnum(str, Enum): + PENDING = "pending" + COMPLETED = "completed" ACTIVE = "active" - INACTIVE = "inactive" - + CANCELLED = "cancelled" + FAILED = "failed" + REFUNDED = "refunded" + + +class ProviderSpecificData(BaseModel): + stripe_customer: dict | None = None + custom_customer: dict | None = None + class PaymentsUserBase(SQLModel): - user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE"))) - status: PaymentUserStatusEnum = PaymentUserStatusEnum.ACTIVE - payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE"))) - org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))) + status: PaymentStatusEnum = PaymentStatusEnum.PENDING + provider_specific_data: dict = Field(default={}, sa_column=Column(JSON)) class PaymentsUser(PaymentsUserBase, table=True): id: Optional[int] = Field(default=None, primary_key=True) + user_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE")) + ) + org_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")) + ) + payment_product_id: int = Field( + sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")) + ) creation_date: datetime = Field(default=datetime.now()) update_date: datetime = Field(default=datetime.now()) + diff --git a/apps/api/src/db/users.py b/apps/api/src/db/users.py index 8d98c3e7..d7412e30 100644 --- a/apps/api/src/db/users.py +++ b/apps/api/src/db/users.py @@ -59,6 +59,11 @@ class AnonymousUser(SQLModel): user_uuid: str = "user_anonymous" username: str = "anonymous" +class InternalUser(SQLModel): + id: int = 0 + user_uuid: str = "user_internal" + username: str = "internal" + class User(UserBase, table=True): id: Optional[int] = Field(default=None, primary_key=True) diff --git a/apps/api/src/routers/ee/payments.py b/apps/api/src/routers/ee/payments.py index 526e4311..2ae0e48d 100644 --- a/apps/api/src/routers/ee/payments.py +++ b/apps/api/src/routers/ee/payments.py @@ -11,12 +11,14 @@ from src.services.payments.payments_config import ( delete_payments_config, ) from src.db.payments.payments_products import PaymentsProductCreate, PaymentsProductRead, PaymentsProductUpdate -from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, list_payments_products, update_payments_product +from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, get_products_by_course, list_payments_products, update_payments_product from src.services.payments.payments_courses import ( link_course_to_product, unlink_course_from_product, - get_courses_by_product + get_courses_by_product, ) +from src.services.payments.payments_webhook import handle_stripe_webhook +from src.services.payments.stripe import create_checkout_session router = APIRouter() @@ -148,3 +150,38 @@ async def api_get_courses_by_product( return await get_courses_by_product( request, org_id, product_id, current_user, db_session ) + +@router.get("/{org_id}/courses/{course_id}/products") +async def api_get_products_by_course( + request: Request, + org_id: int, + course_id: int, + current_user: PublicUser = Depends(get_current_user), + db_session: Session = Depends(get_db_session), +): + return await get_products_by_course( + request, org_id, course_id, current_user, db_session + ) + +# Payments webhooks + +@router.post("/{org_id}/stripe/webhook") +async def api_handle_stripe_webhook( + request: Request, + org_id: int, + db_session: Session = Depends(get_db_session), +): + return await handle_stripe_webhook(request, org_id, db_session) + +# Payments checkout + +@router.post("/{org_id}/stripe/checkout/product/{product_id}") +async def api_create_checkout_session( + request: Request, + org_id: int, + product_id: int, + redirect_uri: str, + current_user: PublicUser = Depends(get_current_user), + db_session: Session = Depends(get_db_session), +): + return await create_checkout_session(request, org_id, product_id, redirect_uri, current_user, db_session) diff --git a/apps/api/src/services/orgs/orgs.py b/apps/api/src/services/orgs/orgs.py index c7e92000..13c084ed 100644 --- a/apps/api/src/services/orgs/orgs.py +++ b/apps/api/src/services/orgs/orgs.py @@ -26,7 +26,7 @@ from src.security.rbac.rbac import ( authorization_verify_based_on_org_admin_status, authorization_verify_if_user_is_anon, ) -from src.db.users import AnonymousUser, PublicUser +from src.db.users import AnonymousUser, InternalUser, PublicUser from src.db.user_organizations import UserOrganization from src.db.organizations import ( Organization, @@ -682,13 +682,17 @@ async def get_org_join_mechanism( async def rbac_check( request: Request, org_uuid: str, - current_user: PublicUser | AnonymousUser, + current_user: PublicUser | AnonymousUser | InternalUser, action: Literal["create", "read", "update", "delete"], db_session: Session, ): # Organizations are readable by anyone if action == "read": return True + + # Internal users can do anything + if isinstance(current_user, InternalUser): + return True else: isUserAnon = await authorization_verify_if_user_is_anon(current_user.id) diff --git a/apps/api/src/services/payments/payments_config.py b/apps/api/src/services/payments/payments_config.py index 848cf15a..074a68d8 100644 --- a/apps/api/src/services/payments/payments_config.py +++ b/apps/api/src/services/payments/payments_config.py @@ -6,7 +6,7 @@ from src.db.payments.payments import ( PaymentsConfigUpdate, PaymentsConfigRead, ) -from src.db.users import PublicUser, AnonymousUser +from src.db.users import PublicUser, AnonymousUser, InternalUser from src.db.organizations import Organization from src.services.orgs.orgs import rbac_check @@ -48,7 +48,7 @@ async def create_payments_config( async def get_payments_config( request: Request, org_id: int, - current_user: PublicUser | AnonymousUser, + current_user: PublicUser | AnonymousUser | InternalUser, db_session: Session, ) -> list[PaymentsConfigRead]: # Check if organization exists diff --git a/apps/api/src/services/payments/payments_courses.py b/apps/api/src/services/payments/payments_courses.py index e83aa10a..2e54c8c7 100644 --- a/apps/api/src/services/payments/payments_courses.py +++ b/apps/api/src/services/payments/payments_courses.py @@ -1,7 +1,7 @@ from datetime import datetime from fastapi import HTTPException, Request from sqlmodel import Session, select -from src.db.payments.payments_courses import PaymentCourse +from src.db.payments.payments_courses import PaymentsCourse from src.db.payments.payments_products import PaymentsProduct from src.db.courses.courses import Course from src.db.users import PublicUser, AnonymousUser @@ -36,7 +36,7 @@ async def link_course_to_product( raise HTTPException(status_code=404, detail="Product not found") # Check if course is already linked to another product - statement = select(PaymentCourse).where(PaymentCourse.course_id == course.id) + statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id) existing_link = db_session.exec(statement).first() if existing_link: @@ -46,7 +46,7 @@ async def link_course_to_product( ) # Create new payment course link - payment_course = PaymentCourse( + payment_course = PaymentsCourse( course_id=course.id, # type: ignore payment_product_id=product_id, org_id=org_id, @@ -75,9 +75,9 @@ async def unlink_course_from_product( await rbac_check(request, course.course_uuid, current_user, "update", db_session) # Find and delete the payment course link - statement = select(PaymentCourse).where( - PaymentCourse.course_id == course.id, - PaymentCourse.org_id == org_id + statement = select(PaymentsCourse).where( + PaymentsCourse.course_id == course.id, + PaymentsCourse.org_id == org_id ) payment_course = db_session.exec(statement).first() @@ -113,12 +113,14 @@ async def get_courses_by_product( statement = ( select(Course) .select_from(Course) - .join(PaymentCourse, Course.id == PaymentCourse.course_id) # type: ignore + .join(PaymentsCourse, Course.id == PaymentsCourse.course_id) # type: ignore .where( - PaymentCourse.payment_product_id == product_id, - PaymentCourse.org_id == org_id + PaymentsCourse.payment_product_id == product_id, + PaymentsCourse.org_id == org_id ) ) courses = db_session.exec(statement).all() return courses + + diff --git a/apps/api/src/services/payments/payments_products.py b/apps/api/src/services/payments/payments_products.py index 3db0136b..cd3f5b8b 100644 --- a/apps/api/src/services/payments/payments_products.py +++ b/apps/api/src/services/payments/payments_products.py @@ -1,6 +1,8 @@ from fastapi import HTTPException, Request from sqlmodel import Session, select +from src.db.courses.courses import Course from src.db.payments.payments import PaymentsConfig +from src.db.payments.payments_courses import PaymentsCourse from src.db.payments.payments_products import ( PaymentsProduct, PaymentsProductCreate, @@ -12,7 +14,7 @@ from src.db.organizations import Organization from src.services.orgs.orgs import rbac_check from datetime import datetime -from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product +from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product async def create_payments_product( request: Request, @@ -163,3 +165,36 @@ async def list_payments_products( products = db_session.exec(statement).all() return [PaymentsProductRead.model_validate(product) for product in products] + +async def get_products_by_course( + request: Request, + org_id: int, + course_id: int, + current_user: PublicUser | AnonymousUser, + db_session: Session, +) -> list[PaymentsProductRead]: + # Check if course exists and user has permission + statement = select(Course).where(Course.id == course_id) + course = db_session.exec(statement).first() + + if not course: + raise HTTPException(status_code=404, detail="Course not found") + + # RBAC check + await rbac_check(request, course.course_uuid, current_user, "read", db_session) + + # Get all products linked to this course with explicit join + statement = ( + select(PaymentsProduct) + .select_from(PaymentsProduct) + .join(PaymentsCourse, PaymentsProduct.id == PaymentsCourse.payment_product_id) # type: ignore + .where( + PaymentsCourse.course_id == course_id, + PaymentsCourse.org_id == org_id + ) + ) + products = db_session.exec(statement).all() + + return [PaymentsProductRead.model_validate(product) for product in products] + + diff --git a/apps/api/src/services/payments/payments_users.py b/apps/api/src/services/payments/payments_users.py new file mode 100644 index 00000000..253364bd --- /dev/null +++ b/apps/api/src/services/payments/payments_users.py @@ -0,0 +1,181 @@ +from fastapi import HTTPException, Request +from sqlmodel import Session, select +from typing import Any +from src.db.payments.payments_users import PaymentsUser, PaymentStatusEnum, ProviderSpecificData +from src.db.payments.payments_products import PaymentsProduct +from src.db.users import InternalUser, PublicUser, AnonymousUser +from src.db.organizations import Organization +from src.services.orgs.orgs import rbac_check +from datetime import datetime + +async def create_payment_user( + request: Request, + org_id: int, + user_id: int, + product_id: int, + status: PaymentStatusEnum, + provider_data: Any, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +) -> PaymentsUser: + # Check if organization exists + statement = select(Organization).where(Organization.id == org_id) + org = db_session.exec(statement).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # RBAC check + await rbac_check(request, org.org_uuid, current_user, "create", db_session) + + # Check if product exists + statement = select(PaymentsProduct).where( + PaymentsProduct.id == product_id, + PaymentsProduct.org_id == org_id + ) + product = db_session.exec(statement).first() + if not product: + raise HTTPException(status_code=404, detail="Product not found") + + provider_specific_data = ProviderSpecificData( + stripe_customer=provider_data if provider_data else None, + ) + + # Check if user already has a payment user + statement = select(PaymentsUser).where( + PaymentsUser.user_id == user_id, + PaymentsUser.org_id == org_id + ) + existing_payment_user = db_session.exec(statement).first() + + if existing_payment_user: + raise HTTPException(status_code=400, detail="User already has purchase for this product") + + # Create new payment user + payment_user = PaymentsUser( + user_id=user_id, + org_id=org_id, + payment_product_id=product_id, + provider_specific_data=provider_specific_data.model_dump(), + status=status + ) + + db_session.add(payment_user) + db_session.commit() + db_session.refresh(payment_user) + + return payment_user + +async def get_payment_user( + request: Request, + org_id: int, + payment_user_id: int, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +) -> PaymentsUser: + # Check if organization exists + statement = select(Organization).where(Organization.id == org_id) + org = db_session.exec(statement).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # RBAC check + await rbac_check(request, org.org_uuid, current_user, "read", db_session) + + # Get payment user + statement = select(PaymentsUser).where( + PaymentsUser.id == payment_user_id, + PaymentsUser.org_id == org_id + ) + payment_user = db_session.exec(statement).first() + if not payment_user: + raise HTTPException(status_code=404, detail="Payment user not found") + + return payment_user + +async def update_payment_user_status( + request: Request, + org_id: int, + payment_user_id: int, + status: PaymentStatusEnum, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +) -> PaymentsUser: + # Check if organization exists + statement = select(Organization).where(Organization.id == org_id) + org = db_session.exec(statement).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # RBAC check + await rbac_check(request, org.org_uuid, current_user, "update", db_session) + + # Get existing payment user + statement = select(PaymentsUser).where( + PaymentsUser.id == payment_user_id, + PaymentsUser.org_id == org_id + ) + payment_user = db_session.exec(statement).first() + if not payment_user: + raise HTTPException(status_code=404, detail="Payment user not found") + + # Update status + payment_user.status = status + payment_user.update_date = datetime.now() + + db_session.add(payment_user) + db_session.commit() + db_session.refresh(payment_user) + + return payment_user + +async def list_payment_users( + request: Request, + org_id: int, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +) -> list[PaymentsUser]: + # Check if organization exists + statement = select(Organization).where(Organization.id == org_id) + org = db_session.exec(statement).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # RBAC check + await rbac_check(request, org.org_uuid, current_user, "read", db_session) + + # Get all payment users for org ordered by id + statement = select(PaymentsUser).where( + PaymentsUser.org_id == org_id + ).order_by(PaymentsUser.id.desc()) # type: ignore + payment_users = list(db_session.exec(statement).all()) # Convert to list + + return payment_users + +async def delete_payment_user( + request: Request, + org_id: int, + payment_user_id: int, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +) -> None: + # Check if organization exists + statement = select(Organization).where(Organization.id == org_id) + org = db_session.exec(statement).first() + if not org: + raise HTTPException(status_code=404, detail="Organization not found") + + # RBAC check + await rbac_check(request, org.org_uuid, current_user, "delete", db_session) + + # Get existing payment user + statement = select(PaymentsUser).where( + PaymentsUser.id == payment_user_id, + PaymentsUser.org_id == org_id + ) + payment_user = db_session.exec(statement).first() + if not payment_user: + raise HTTPException(status_code=404, detail="Payment user not found") + + # Delete payment user + db_session.delete(payment_user) + db_session.commit() diff --git a/apps/api/src/services/payments/payments_webhook.py b/apps/api/src/services/payments/payments_webhook.py new file mode 100644 index 00000000..9ff3bb4a --- /dev/null +++ b/apps/api/src/services/payments/payments_webhook.py @@ -0,0 +1,260 @@ +from fastapi import HTTPException, Request +from sqlmodel import Session, select +import stripe +from datetime import datetime +from typing import Callable, Dict +import logging + +from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser +from src.db.payments.payments_products import PaymentsProduct +from src.db.users import InternalUser, User +from src.services.payments.payments_users import create_payment_user, update_payment_user_status +from src.services.payments.stripe import get_stripe_credentials + +logger = logging.getLogger(__name__) + +async def get_user_from_customer(customer_id: str, db_session: Session) -> User: + """Helper function to get user from Stripe customer ID""" + try: + customer = stripe.Customer.retrieve(customer_id) + statement = select(User).where(User.email == customer.email) + user = db_session.exec(statement).first() + if not user: + raise HTTPException(status_code=404, detail=f"User not found for customer {customer_id}") + return user + except stripe.StripeError as e: + logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}") + raise HTTPException(status_code=400, detail="Error retrieving customer information") + +async def get_product_from_stripe_id(product_id: str, db_session: Session) -> PaymentsProduct: + """Helper function to get product from Stripe product ID""" + statement = select(PaymentsProduct).where(PaymentsProduct.provider_product_id == product_id) + product = db_session.exec(statement).first() + if not product: + raise HTTPException(status_code=404, detail=f"Product not found: {product_id}") + return product + +async def handle_stripe_webhook( + request: Request, + org_id: int, + db_session: Session, +) -> dict: + # Get Stripe credentials for the organization + creds = await get_stripe_credentials(request, org_id, InternalUser(), db_session) + + # Get the webhook secret and API key from credentials + webhook_secret = creds.get('stripe_webhook_secret') + stripe.api_key = creds.get('stripe_secret_key') # Set API key globally + + if not webhook_secret: + raise HTTPException(status_code=400, detail="Stripe webhook secret not configured") + + # Get the raw request body + payload = await request.body() + sig_header = request.headers.get('stripe-signature') + + try: + # Verify webhook signature and construct event + event = stripe.Webhook.construct_event( + payload, sig_header, webhook_secret + ) + + # Get the appropriate handler + handler = STRIPE_EVENT_HANDLERS.get(event.type) + if handler: + await handler(request, event.data.object, org_id, db_session) + return {"status": "success", "event": event.type} + else: + logger.info(f"Unhandled event type: {event.type}") + return {"status": "ignored", "event": event.type} + + except Exception as e: + logger.error(f"Error processing webhook: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}") + +async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session): + # Get the customer and product details from the session + customer_email = session.customer_details.email + product_id = session.line_items.data[0].price.product + + # Use helper functions + user = await get_user_from_customer(session.customer, db_session) + product = await get_product_from_stripe_id(product_id, db_session) + + # Find payment user record + statement = select(PaymentsUser).where( + PaymentsUser.user_id == user.id, + PaymentsUser.payment_product_id == product.id + ) + payment_user = db_session.exec(statement).first() + + # Update status to completed + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user.id, # type: ignore + status=PaymentStatusEnum.COMPLETED, + current_user=InternalUser(), + db_session=db_session + ) + +async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session): + customer_id = subscription.customer + + # Get product_id from metadata + product_id = subscription.metadata.get('product_id') + if not product_id: + logger.error(f"No product_id found in subscription metadata: {subscription.id}") + raise HTTPException(status_code=400, detail="No product_id found in subscription metadata") + + # Get customer email from Stripe + customer = stripe.Customer.retrieve(customer_id) + + # Find user and create/update payment record + statement = select(User).where(User.email == customer.email) + user = db_session.exec(statement).first() + + if user: + payment_user = await create_payment_user( + request=request, + org_id=org_id, + user_id=user.id, # type: ignore + product_id=int(product_id), # Convert string from metadata to int + current_user=InternalUser(), + db_session=db_session + ) + + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user.id, # type: ignore + status=PaymentStatusEnum.ACTIVE, + current_user=InternalUser(), + db_session=db_session + ) + +async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session): + customer_id = subscription.customer + + # Get product_id from metadata + product_id = subscription.metadata.get('product_id') + if not product_id: + logger.error(f"No product_id found in subscription metadata: {subscription.id}") + raise HTTPException(status_code=400, detail="No product_id found in subscription metadata") + + customer = stripe.Customer.retrieve(customer_id) + + statement = select(User).where(User.email == customer.email) + user = db_session.exec(statement).first() + + if user: + statement = select(PaymentsUser).where( + PaymentsUser.user_id == user.id, + PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int + ) + payment_user = db_session.exec(statement).first() + + if payment_user: + status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user.id, # type: ignore + status=status, + current_user=InternalUser(), + db_session=db_session + ) + +async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session): + customer_id = subscription.customer + + # Get product_id from metadata + product_id = subscription.metadata.get('product_id') + if not product_id: + logger.error(f"No product_id found in subscription metadata: {subscription.id}") + raise HTTPException(status_code=400, detail="No product_id found in subscription metadata") + + customer = stripe.Customer.retrieve(customer_id) + + statement = select(User).where(User.email == customer.email) + user = db_session.exec(statement).first() + + if user: + statement = select(PaymentsUser).where( + PaymentsUser.user_id == user.id, + PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int + ) + payment_user = db_session.exec(statement).first() + + if payment_user: + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user.id, # type: ignore + status=PaymentStatusEnum.FAILED, + current_user=InternalUser(), + db_session=db_session + ) + +async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session): + customer_id = payment_intent.customer + + customer = stripe.Customer.retrieve(customer_id) + + statement = select(User).where(User.email == customer.email) + user = db_session.exec(statement).first() + + # Get product_id directly from metadata + product_id = payment_intent.metadata.get('product_id') + if not product_id: + logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}") + raise HTTPException(status_code=400, detail="No product_id found in payment metadata") + + if user: + await create_payment_user( + request=request, + org_id=org_id, + user_id=user.id, # type: ignore + product_id=int(product_id), # Convert string from metadata to int + status=PaymentStatusEnum.COMPLETED, + provider_data=customer, + current_user=InternalUser(), + db_session=db_session + ) + +async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session): + # Update payment status to failed + customer_id = payment_intent.customer + + customer = stripe.Customer.retrieve(customer_id) + + statement = select(User).where(User.email == customer.email) + user = db_session.exec(statement).first() + + if user: + statement = select(PaymentsUser).where( + PaymentsUser.user_id == user.id, + PaymentsUser.org_id == org_id, + PaymentsUser.status == PaymentStatusEnum.PENDING + ) + payment_user = db_session.exec(statement).first() + + if payment_user: + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user.id, # type: ignore + status=PaymentStatusEnum.FAILED, + current_user=InternalUser(), + db_session=db_session + ) + +# Create event handler mapping +STRIPE_EVENT_HANDLERS = { + 'checkout.session.completed': handle_checkout_session_completed, + 'customer.subscription.created': handle_subscription_created, + 'customer.subscription.updated': handle_subscription_updated, + 'customer.subscription.deleted': handle_subscription_deleted, + 'payment_intent.succeeded': handle_payment_succeeded, + 'payment_intent.payment_failed': handle_payment_failed, +} diff --git a/apps/api/src/services/payments/stripe.py b/apps/api/src/services/payments/stripe.py index f351d0bc..e987ddff 100644 --- a/apps/api/src/services/payments/stripe.py +++ b/apps/api/src/services/payments/stripe.py @@ -1,15 +1,16 @@ from fastapi import HTTPException, Request from sqlmodel import Session import stripe +from config.config import get_learnhouse_config from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct -from src.db.users import AnonymousUser, PublicUser +from src.db.users import AnonymousUser, InternalUser, PublicUser from src.services.payments.payments_config import get_payments_config - +from sqlmodel import select async def get_stripe_credentials( request: Request, org_id: int, - current_user: PublicUser | AnonymousUser, + current_user: PublicUser | AnonymousUser | InternalUser, db_session: Session, ): configs = await get_payments_config(request, org_id, current_user, db_session) @@ -149,6 +150,94 @@ async def update_stripe_product( except stripe.StripeError as e: raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}") +async def create_checkout_session( + request: Request, + org_id: int, + product_id: int, + redirect_uri: str, + current_user: PublicUser | AnonymousUser, + db_session: Session, +): + # Get Stripe credentials + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + stripe.api_key = creds.get('stripe_secret_key') + + # Get product details + statement = select(PaymentsProduct).where( + PaymentsProduct.id == product_id, + PaymentsProduct.org_id == org_id + ) + product = db_session.exec(statement).first() + + if not product: + raise HTTPException(status_code=404, detail="Product not found") + + + success_url = redirect_uri + cancel_url = redirect_uri + + # Get the default price for the product + stripe_product = stripe.Product.retrieve(product.provider_product_id) + line_items = [{ + "price": stripe_product.default_price, + "quantity": 1 + }] + + # Create or retrieve Stripe customer + try: + customers = stripe.Customer.list(email=current_user.email) + if customers.data: + customer = customers.data[0] + else: + customer = stripe.Customer.create( + email=current_user.email, + metadata={ + "user_id": str(current_user.id), + "org_id": str(org_id) + } + ) + except stripe.StripeError as e: + raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}") + + # Create checkout session with customer + try: + checkout_session_params = { + "success_url": success_url, + "cancel_url": cancel_url, + "mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription', + "line_items": line_items, + "customer": customer.id, + "metadata": { + "product_id": str(product.id) + } + } + + # Add payment_intent_data only for one-time payments + if product.product_type == PaymentProductTypeEnum.ONE_TIME: + checkout_session_params["payment_intent_data"] = { + "metadata": { + "product_id": str(product.id) + } + } + # Add subscription_data for subscription payments + else: + checkout_session_params["subscription_data"] = { + "metadata": { + "product_id": str(product.id) + } + } + + checkout_session = stripe.checkout.Session.create(**checkout_session_params) + + return { + "checkout_url": checkout_session.url, + "session_id": checkout_session.id + } + + except stripe.StripeError as e: + print(f"Error creating checkout session: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + diff --git a/apps/web/app/orgs/[orgslug]/(withmenu)/course/[courseuuid]/course.tsx b/apps/web/app/orgs/[orgslug]/(withmenu)/course/[courseuuid]/course.tsx index b3311d8c..01b8e753 100644 --- a/apps/web/app/orgs/[orgslug]/(withmenu)/course/[courseuuid]/course.tsx +++ b/apps/web/app/orgs/[orgslug]/(withmenu)/course/[courseuuid]/course.tsx @@ -1,5 +1,4 @@ 'use client' -import { removeCourse, startCourse } from '@services/courses/activity' import Link from 'next/link' import React, { useEffect, useState } from 'react' import { getUriWithOrg } from '@services/config/config' @@ -15,15 +14,13 @@ import { import { ArrowRight, Backpack, Check, File, Sparkles, Video } from 'lucide-react' import { useOrg } from '@components/Contexts/OrgContext' import UserAvatar from '@components/Objects/UserAvatar' -import CourseUpdates from '@components/Objects/CourseUpdates/CourseUpdates' +import CourseUpdates from '@components/Objects/Courses/CourseUpdates/CourseUpdates' import { CourseProvider } from '@components/Contexts/CourseContext' -import { useLHSession } from '@components/Contexts/LHSessionContext' import { useMediaQuery } from 'usehooks-ts' +import CoursesActions from '@components/Objects/Courses/CourseActions/CoursesActions' const CourseClient = (props: any) => { - const [user, setUser] = useState({}) const [learnings, setLearnings] = useState([]) - const session = useLHSession() as any; const courseuuid = props.courseuuid const orgslug = props.orgslug const course = props.course @@ -37,33 +34,6 @@ const CourseClient = (props: any) => { setLearnings(learnings) } - async function startCourseUI() { - // Create activity - await startCourse('course_' + courseuuid, orgslug, session.data?.tokens?.access_token) - await revalidateTags(['courses'], orgslug) - router.refresh() - - // refresh page (FIX for Next.js BUG) - // window.location.reload(); - } - - function isCourseStarted() { - const runs = course.trail?.runs - if (!runs) return false - return runs.some( - (run: any) => - run.status === 'STATUS_IN_PROGRESS' && run.course_id === course.id - ) - } - - async function quitCourse() { - // Close activity - let activity = await removeCourse('course_' + courseuuid, orgslug, session.data?.tokens?.access_token) - // Mutate course - await revalidateTags(['courses'], orgslug) - router.refresh() - } - useEffect(() => { getLearningTags() }, [org, course]) @@ -80,7 +50,7 @@ const CourseClient = (props: any) => {

{course.name}

- {!isMobile && + {!isMobile && }
@@ -113,11 +83,11 @@ const CourseClient = (props: any) => { course={course} /> -
-
-

Description

+
+
+

About

-

{course.description}

+

{course.about}

{learnings.length > 0 && learnings[0] !== 'null' && ( @@ -187,7 +157,7 @@ const CourseClient = (props: any) => { />
)} - {activity.activity_type === + {activity.activity_type === 'TYPE_ASSIGNMENT' && (
{ )} - {activity.activity_type === + {activity.activity_type === 'TYPE_ASSIGNMENT' && ( <> { })}
-
- {user && ( -
- -
-
- Author -
-
- {course.authors[0].first_name && - course.authors[0].last_name && ( -
-

- {course.authors[0].first_name + - ' ' + - course.authors[0].last_name} -

- - {' '} - @{course.authors[0].username} - -
- )} - {!course.authors[0].first_name && - !course.authors[0].last_name && ( -
-

@{course.authors[0].username}

-
- )} -
-
-
- )} - - {isCourseStarted() ? ( - - ) : ( - - )} +
+
diff --git a/apps/web/components/Objects/Courses/CourseActions/CoursePaidOptions.tsx b/apps/web/components/Objects/Courses/CourseActions/CoursePaidOptions.tsx new file mode 100644 index 00000000..0dd77dfe --- /dev/null +++ b/apps/web/components/Objects/Courses/CourseActions/CoursePaidOptions.tsx @@ -0,0 +1,161 @@ +import React, { useState } from 'react' +import { useOrg } from '@components/Contexts/OrgContext' +import { useLHSession } from '@components/Contexts/LHSessionContext' +import useSWR from 'swr' +import { getProductsByCourse, getStripeProductCheckoutSession } from '@services/payments/products' +import { RefreshCcw, SquareCheck, ChevronDown, ChevronUp } from 'lucide-react' +import { Badge } from '@components/ui/badge' +import { Button } from '@components/ui/button' +import toast from 'react-hot-toast' +import { useRouter } from 'next/navigation' +import { getUriWithOrg } from '@services/config/config' + +interface CoursePaidOptionsProps { + course: { + id: string; + org_id: number; + } +} + +function CoursePaidOptions({ course }: CoursePaidOptionsProps) { + const org = useOrg() as any + const session = useLHSession() as any + const [expandedProducts, setExpandedProducts] = useState<{ [key: string]: boolean }>({}) + const [isProcessing, setIsProcessing] = useState<{ [key: string]: boolean }>({}) + const router = useRouter() + + const { data: linkedProducts, error } = useSWR( + () => org && session ? [`/payments/${course.org_id}/courses/${course.id}/products`, session.data?.tokens?.access_token] : null, + ([url, token]) => getProductsByCourse(course.org_id, course.id, token) + ) + + const handleCheckout = async (productId: number) => { + if (!session.data?.user) { + // Redirect to login if user is not authenticated + router.push(`/signup?orgslug=${org.slug}`) + return + } + + try { + setIsProcessing(prev => ({ ...prev, [productId]: true })) + const redirect_uri = getUriWithOrg(org.slug, '/courses') + const response = await getStripeProductCheckoutSession( + course.org_id, + productId, + redirect_uri, + session.data?.tokens?.access_token + ) + + if (response.success) { + router.push(response.data.checkout_url) + } else { + toast.error('Failed to initiate checkout process') + } + } catch (error) { + toast.error('An error occurred while processing your request') + } finally { + setIsProcessing(prev => ({ ...prev, [productId]: false })) + } + } + + const toggleProductExpansion = (productId: string) => { + setExpandedProducts(prev => ({ + ...prev, + [productId]: !prev[productId] + })) + } + + if (error) return
Failed to load product options
+ if (!linkedProducts) return
Loading...
+ + return ( +
+ {linkedProducts.data.map((product: any) => ( +
+
+
+ + {product.product_type === 'subscription' ? : } + + {product.product_type === 'subscription' ? 'Subscription' : 'One-time payment'} + {product.product_type === 'subscription' && ' (per month)'} + + +

{product.name}

+
+
+ +
+
+

+ {product.description} +

+ {product.benefits && ( +
+

Benefits:

+

+ {product.benefits} +

+
+ )} +
+
+ +
+ +
+ +
+ + {product.price_type === 'customer_choice' ? 'Minimum Price:' : 'Price:'} + +
+ + {new Intl.NumberFormat('en-US', { + style: 'currency', + currency: product.currency + }).format(product.amount)} + {product.product_type === 'subscription' && /month} + + {product.price_type === 'customer_choice' && ( + Choose your price + )} +
+
+ + +
+ ))} +
+ ) +} + +export default CoursePaidOptions diff --git a/apps/web/components/Objects/Courses/CourseActions/CoursesActions.tsx b/apps/web/components/Objects/Courses/CourseActions/CoursesActions.tsx new file mode 100644 index 00000000..861b7731 --- /dev/null +++ b/apps/web/components/Objects/Courses/CourseActions/CoursesActions.tsx @@ -0,0 +1,195 @@ +import React, { useState, useEffect } from 'react' +import UserAvatar from '../../UserAvatar' +import { getUserAvatarMediaDirectory } from '@services/media/media' +import { removeCourse, startCourse } from '@services/courses/activity' +import { revalidateTags } from '@services/utils/ts/requests' +import { useRouter } from 'next/navigation' +import { useLHSession } from '@components/Contexts/LHSessionContext' +import { useMediaQuery } from 'usehooks-ts' +import { getUriWithOrg } from '@services/config/config' +import { getProductsByCourse } from '@services/payments/products' +import { LogIn, LogOut, ShoppingCart, AlertCircle } from 'lucide-react' +import Modal from '@components/StyledElements/Modal/Modal' +import CourseCTA from './CoursePaidOptions' +import CoursePaidOptions from './CoursePaidOptions' + +interface Author { + user_uuid: string + avatar_image: string + first_name: string + last_name: string + username: string +} + +interface CourseRun { + status: string + course_id: string +} + +interface Course { + id: string + authors: Author[] + trail?: { + runs: CourseRun[] + } +} + +interface CourseActionsProps { + courseuuid: string + orgslug: string + course: Course & { + org_id: number + } +} + +// Separate component for author display +const AuthorInfo = ({ author, isMobile }: { author: Author, isMobile: boolean }) => ( +
+ +
+
Author
+
+ {(author.first_name && author.last_name) ? ( +
+

{`${author.first_name} ${author.last_name}`}

+ + @{author.username} + +
+ ) : ( +
+

@{author.username}

+
+ )} +
+
+
+) + +const Actions = ({ courseuuid, orgslug, course }: CourseActionsProps) => { + const router = useRouter() + const session = useLHSession() as any + const [linkedProducts, setLinkedProducts] = useState([]) + const [isLoading, setIsLoading] = useState(true) + const [isModalOpen, setIsModalOpen] = useState(false) + + const isStarted = course.trail?.runs?.some( + (run) => run.status === 'STATUS_IN_PROGRESS' && run.course_id === course.id + ) ?? false + + useEffect(() => { + const fetchLinkedProducts = async () => { + try { + const response = await getProductsByCourse( + course.org_id, + course.id, + session.data?.tokens?.access_token + ) + setLinkedProducts(response.data || []) + } catch (error) { + console.error('Failed to fetch linked products') + } finally { + setIsLoading(false) + } + } + + fetchLinkedProducts() + }, [course.id, course.org_id, session.data?.tokens?.access_token]) + + const handleCourseAction = async () => { + if (!session.data?.user) { + router.push(getUriWithOrg(orgslug, '/signup?orgslug=' + orgslug)) + return + } + const action = isStarted ? removeCourse : startCourse + await action('course_' + courseuuid, orgslug, session.data?.tokens?.access_token) + await revalidateTags(['courses'], orgslug) + router.refresh() + } + + if (isLoading) { + return
+ } + + if (linkedProducts.length > 0) { + return ( +
+
+
+ +

Paid Course

+
+

+ This course requires purchase to access its content. +

+
+ } + dialogTitle="Purchase Course" + dialogDescription="Select a payment option to access this course" + minWidth="sm" + /> + +
+ ) + } + + return ( + + ) +} + +function CoursesActions({ courseuuid, orgslug, course }: CourseActionsProps) { + const router = useRouter() + const session = useLHSession() as any + const isMobile = useMediaQuery('(max-width: 768px)') + + + return ( +
+ +
+ +
+
+ ) +} + +export default CoursesActions \ No newline at end of file diff --git a/apps/web/components/Objects/CourseUpdates/CourseUpdates.tsx b/apps/web/components/Objects/Courses/CourseUpdates/CourseUpdates.tsx similarity index 100% rename from apps/web/components/Objects/CourseUpdates/CourseUpdates.tsx rename to apps/web/components/Objects/Courses/CourseUpdates/CourseUpdates.tsx diff --git a/apps/web/services/payments/products.ts b/apps/web/services/payments/products.ts index 63f5b422..c39eb669 100644 --- a/apps/web/services/payments/products.ts +++ b/apps/web/services/payments/products.ts @@ -73,4 +73,21 @@ export async function getCoursesLinkedToProduct(orgId: number, productId: string return res; } +export async function getProductsByCourse(orgId: number, courseId: string, access_token: string) { + const result = await fetch( + `${getAPIUrl()}payments/${orgId}/courses/${courseId}/products`, + RequestBodyWithAuthHeader('GET', null, null, access_token) + ); + const res = await getResponseMetadata(result); + return res; +} + +export async function getStripeProductCheckoutSession(orgId: number, productId: number, redirect_uri: string, access_token: string) { + const result = await fetch( + `${getAPIUrl()}payments/${orgId}/stripe/checkout/product/${productId}?redirect_uri=${redirect_uri}`, + RequestBodyWithAuthHeader('POST', null, null, access_token) + ); + const res = await getResponseMetadata(result); + return res; +}