feat: add course checkout UI and stripe integration and webhook wip

This commit is contained in:
swve 2024-11-01 20:51:52 +01:00
parent d8913d1a60
commit 1bff401e73
18 changed files with 1086 additions and 131 deletions

View file

@ -48,6 +48,7 @@ def install(
slug="default",
email="",
logo_image="",
thumbnail_image="",
)
install_create_organization(org, db_session)
print("Default organization created ✅")
@ -89,6 +90,7 @@ def install(
slug=slug.lower(),
email="",
logo_image="",
thumbnail_image="",
)
install_create_organization(org, db_session)
print(orgname + " Organization created ✅")

View file

@ -1,26 +1,55 @@
import logging
import os
import importlib
from typing import Optional
from config.config import get_learnhouse_config
from fastapi import FastAPI
from sqlmodel import SQLModel, Session, create_engine
from sqlmodel import Field, SQLModel, Session, create_engine
def import_all_models():
base_dir = 'src/db'
base_module_path = 'src.db'
# Recursively walk through the base directory
for root, dirs, files in os.walk(base_dir):
# Filter out __init__.py and non-Python files
module_files = [f for f in files if f.endswith('.py') and f != '__init__.py']
# Calculate the module's base path from its directory structure
path_diff = os.path.relpath(root, base_dir)
if path_diff == '.':
current_module_base = base_module_path
else:
current_module_base = f"{base_module_path}.{path_diff.replace(os.sep, '.')}"
# Dynamically import each module
for file_name in module_files:
module_name = file_name[:-3] # Remove the '.py' extension
full_module_path = f"{current_module_base}.{module_name}"
importlib.import_module(full_module_path)
# Import all models before creating engine
import_all_models()
learnhouse_config = get_learnhouse_config()
engine = create_engine(
learnhouse_config.database_config.sql_connection_string, echo=False, pool_pre_ping=True # type: ignore
learnhouse_config.database_config.sql_connection_string, # type: ignore
echo=False,
pool_pre_ping=True # type: ignore
)
SQLModel.metadata.create_all(engine)
# Create all tables after importing all models
SQLModel.metadata.create_all(engine)
async def connect_to_db(app: FastAPI):
app.db_engine = engine # type: ignore
logging.info("LearnHouse database has been started.")
SQLModel.metadata.create_all(engine)
def get_db_session():
with Session(engine) as session:
yield session
async def close_database(app: FastAPI):
logging.info("LearnHouse has been shut down.")
return app

View file

@ -2,12 +2,14 @@ from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey
from typing import Optional
from datetime import datetime
class PaymentCourseBase(SQLModel):
class PaymentsCourseBase(SQLModel):
course_id: int = Field(sa_column=Column(BigInteger, ForeignKey("course.id", ondelete="CASCADE")))
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")))
class PaymentCourse(PaymentCourseBase, table=True):
class PaymentsCourse(PaymentsCourseBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
org_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
)
creation_date: datetime = Field(default=datetime.now())
update_date: datetime = Field(default=datetime.now())

View file

@ -1,19 +1,37 @@
from enum import Enum
from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey
from openai import BaseModel
from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey, JSON
from typing import Optional
from datetime import datetime
from enum import Enum
class PaymentUserStatusEnum(str, Enum):
class PaymentStatusEnum(str, Enum):
PENDING = "pending"
COMPLETED = "completed"
ACTIVE = "active"
INACTIVE = "inactive"
CANCELLED = "cancelled"
FAILED = "failed"
REFUNDED = "refunded"
class ProviderSpecificData(BaseModel):
stripe_customer: dict | None = None
custom_customer: dict | None = None
class PaymentsUserBase(SQLModel):
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE")))
status: PaymentUserStatusEnum = PaymentUserStatusEnum.ACTIVE
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")))
status: PaymentStatusEnum = PaymentStatusEnum.PENDING
provider_specific_data: dict = Field(default={}, sa_column=Column(JSON))
class PaymentsUser(PaymentsUserBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE"))
)
org_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
)
payment_product_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE"))
)
creation_date: datetime = Field(default=datetime.now())
update_date: datetime = Field(default=datetime.now())

View file

@ -59,6 +59,11 @@ class AnonymousUser(SQLModel):
user_uuid: str = "user_anonymous"
username: str = "anonymous"
class InternalUser(SQLModel):
id: int = 0
user_uuid: str = "user_internal"
username: str = "internal"
class User(UserBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

View file

@ -11,12 +11,14 @@ from src.services.payments.payments_config import (
delete_payments_config,
)
from src.db.payments.payments_products import PaymentsProductCreate, PaymentsProductRead, PaymentsProductUpdate
from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, list_payments_products, update_payments_product
from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, get_products_by_course, list_payments_products, update_payments_product
from src.services.payments.payments_courses import (
link_course_to_product,
unlink_course_from_product,
get_courses_by_product
get_courses_by_product,
)
from src.services.payments.payments_webhook import handle_stripe_webhook
from src.services.payments.stripe import create_checkout_session
router = APIRouter()
@ -148,3 +150,38 @@ async def api_get_courses_by_product(
return await get_courses_by_product(
request, org_id, product_id, current_user, db_session
)
@router.get("/{org_id}/courses/{course_id}/products")
async def api_get_products_by_course(
request: Request,
org_id: int,
course_id: int,
current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session),
):
return await get_products_by_course(
request, org_id, course_id, current_user, db_session
)
# Payments webhooks
@router.post("/{org_id}/stripe/webhook")
async def api_handle_stripe_webhook(
request: Request,
org_id: int,
db_session: Session = Depends(get_db_session),
):
return await handle_stripe_webhook(request, org_id, db_session)
# Payments checkout
@router.post("/{org_id}/stripe/checkout/product/{product_id}")
async def api_create_checkout_session(
request: Request,
org_id: int,
product_id: int,
redirect_uri: str,
current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session),
):
return await create_checkout_session(request, org_id, product_id, redirect_uri, current_user, db_session)

View file

@ -26,7 +26,7 @@ from src.security.rbac.rbac import (
authorization_verify_based_on_org_admin_status,
authorization_verify_if_user_is_anon,
)
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.db.user_organizations import UserOrganization
from src.db.organizations import (
Organization,
@ -682,7 +682,7 @@ async def get_org_join_mechanism(
async def rbac_check(
request: Request,
org_uuid: str,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
@ -690,6 +690,10 @@ async def rbac_check(
if action == "read":
return True
# Internal users can do anything
if isinstance(current_user, InternalUser):
return True
else:
isUserAnon = await authorization_verify_if_user_is_anon(current_user.id)

View file

@ -6,7 +6,7 @@ from src.db.payments.payments import (
PaymentsConfigUpdate,
PaymentsConfigRead,
)
from src.db.users import PublicUser, AnonymousUser
from src.db.users import PublicUser, AnonymousUser, InternalUser
from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
@ -48,7 +48,7 @@ async def create_payments_config(
async def get_payments_config(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> list[PaymentsConfigRead]:
# Check if organization exists

View file

@ -1,7 +1,7 @@
from datetime import datetime
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.payments.payments_courses import PaymentCourse
from src.db.payments.payments_courses import PaymentsCourse
from src.db.payments.payments_products import PaymentsProduct
from src.db.courses.courses import Course
from src.db.users import PublicUser, AnonymousUser
@ -36,7 +36,7 @@ async def link_course_to_product(
raise HTTPException(status_code=404, detail="Product not found")
# Check if course is already linked to another product
statement = select(PaymentCourse).where(PaymentCourse.course_id == course.id)
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
existing_link = db_session.exec(statement).first()
if existing_link:
@ -46,7 +46,7 @@ async def link_course_to_product(
)
# Create new payment course link
payment_course = PaymentCourse(
payment_course = PaymentsCourse(
course_id=course.id, # type: ignore
payment_product_id=product_id,
org_id=org_id,
@ -75,9 +75,9 @@ async def unlink_course_from_product(
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
# Find and delete the payment course link
statement = select(PaymentCourse).where(
PaymentCourse.course_id == course.id,
PaymentCourse.org_id == org_id
statement = select(PaymentsCourse).where(
PaymentsCourse.course_id == course.id,
PaymentsCourse.org_id == org_id
)
payment_course = db_session.exec(statement).first()
@ -113,12 +113,14 @@ async def get_courses_by_product(
statement = (
select(Course)
.select_from(Course)
.join(PaymentCourse, Course.id == PaymentCourse.course_id) # type: ignore
.join(PaymentsCourse, Course.id == PaymentsCourse.course_id) # type: ignore
.where(
PaymentCourse.payment_product_id == product_id,
PaymentCourse.org_id == org_id
PaymentsCourse.payment_product_id == product_id,
PaymentsCourse.org_id == org_id
)
)
courses = db_session.exec(statement).all()
return courses

View file

@ -1,6 +1,8 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.courses.courses import Course
from src.db.payments.payments import PaymentsConfig
from src.db.payments.payments_courses import PaymentsCourse
from src.db.payments.payments_products import (
PaymentsProduct,
PaymentsProductCreate,
@ -12,7 +14,7 @@ from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
from datetime import datetime
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product
async def create_payments_product(
request: Request,
@ -163,3 +165,36 @@ async def list_payments_products(
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]
async def get_products_by_course(
request: Request,
org_id: int,
course_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
) -> list[PaymentsProductRead]:
# Check if course exists and user has permission
statement = select(Course).where(Course.id == course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(status_code=404, detail="Course not found")
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
# Get all products linked to this course with explicit join
statement = (
select(PaymentsProduct)
.select_from(PaymentsProduct)
.join(PaymentsCourse, PaymentsProduct.id == PaymentsCourse.payment_product_id) # type: ignore
.where(
PaymentsCourse.course_id == course_id,
PaymentsCourse.org_id == org_id
)
)
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]

View file

@ -0,0 +1,181 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from typing import Any
from src.db.payments.payments_users import PaymentsUser, PaymentStatusEnum, ProviderSpecificData
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, PublicUser, AnonymousUser
from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
from datetime import datetime
async def create_payment_user(
request: Request,
org_id: int,
user_id: int,
product_id: int,
status: PaymentStatusEnum,
provider_data: Any,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
# Check if product exists
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id,
PaymentsProduct.org_id == org_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
provider_specific_data = ProviderSpecificData(
stripe_customer=provider_data if provider_data else None,
)
# Check if user already has a payment user
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user_id,
PaymentsUser.org_id == org_id
)
existing_payment_user = db_session.exec(statement).first()
if existing_payment_user:
raise HTTPException(status_code=400, detail="User already has purchase for this product")
# Create new payment user
payment_user = PaymentsUser(
user_id=user_id,
org_id=org_id,
payment_product_id=product_id,
provider_specific_data=provider_specific_data.model_dump(),
status=status
)
db_session.add(payment_user)
db_session.commit()
db_session.refresh(payment_user)
return payment_user
async def get_payment_user(
request: Request,
org_id: int,
payment_user_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
return payment_user
async def update_payment_user_status(
request: Request,
org_id: int,
payment_user_id: int,
status: PaymentStatusEnum,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
# Get existing payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
# Update status
payment_user.status = status
payment_user.update_date = datetime.now()
db_session.add(payment_user)
db_session.commit()
db_session.refresh(payment_user)
return payment_user
async def list_payment_users(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> list[PaymentsUser]:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get all payment users for org ordered by id
statement = select(PaymentsUser).where(
PaymentsUser.org_id == org_id
).order_by(PaymentsUser.id.desc()) # type: ignore
payment_users = list(db_session.exec(statement).all()) # Convert to list
return payment_users
async def delete_payment_user(
request: Request,
org_id: int,
payment_user_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> None:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "delete", db_session)
# Get existing payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
# Delete payment user
db_session.delete(payment_user)
db_session.commit()

View file

@ -0,0 +1,260 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
import stripe
from datetime import datetime
from typing import Callable, Dict
import logging
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, User
from src.services.payments.payments_users import create_payment_user, update_payment_user_status
from src.services.payments.stripe import get_stripe_credentials
logger = logging.getLogger(__name__)
async def get_user_from_customer(customer_id: str, db_session: Session) -> User:
"""Helper function to get user from Stripe customer ID"""
try:
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if not user:
raise HTTPException(status_code=404, detail=f"User not found for customer {customer_id}")
return user
except stripe.StripeError as e:
logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}")
raise HTTPException(status_code=400, detail="Error retrieving customer information")
async def get_product_from_stripe_id(product_id: str, db_session: Session) -> PaymentsProduct:
"""Helper function to get product from Stripe product ID"""
statement = select(PaymentsProduct).where(PaymentsProduct.provider_product_id == product_id)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail=f"Product not found: {product_id}")
return product
async def handle_stripe_webhook(
request: Request,
org_id: int,
db_session: Session,
) -> dict:
# Get Stripe credentials for the organization
creds = await get_stripe_credentials(request, org_id, InternalUser(), db_session)
# Get the webhook secret and API key from credentials
webhook_secret = creds.get('stripe_webhook_secret')
stripe.api_key = creds.get('stripe_secret_key') # Set API key globally
if not webhook_secret:
raise HTTPException(status_code=400, detail="Stripe webhook secret not configured")
# Get the raw request body
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
# Verify webhook signature and construct event
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
# Get the appropriate handler
handler = STRIPE_EVENT_HANDLERS.get(event.type)
if handler:
await handler(request, event.data.object, org_id, db_session)
return {"status": "success", "event": event.type}
else:
logger.info(f"Unhandled event type: {event.type}")
return {"status": "ignored", "event": event.type}
except Exception as e:
logger.error(f"Error processing webhook: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}")
async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session):
# Get the customer and product details from the session
customer_email = session.customer_details.email
product_id = session.line_items.data[0].price.product
# Use helper functions
user = await get_user_from_customer(session.customer, db_session)
product = await get_product_from_stripe_id(product_id, db_session)
# Find payment user record
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == product.id
)
payment_user = db_session.exec(statement).first()
# Update status to completed
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.COMPLETED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
# Get customer email from Stripe
customer = stripe.Customer.retrieve(customer_id)
# Find user and create/update payment record
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
payment_user = await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
current_user=InternalUser(),
db_session=db_session
)
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.ACTIVE,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=status,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session):
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
# Get product_id directly from metadata
product_id = payment_intent.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}")
raise HTTPException(status_code=400, detail="No product_id found in payment metadata")
if user:
await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
status=PaymentStatusEnum.COMPLETED,
provider_data=customer,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session):
# Update payment status to failed
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.org_id == org_id,
PaymentsUser.status == PaymentStatusEnum.PENDING
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
# Create event handler mapping
STRIPE_EVENT_HANDLERS = {
'checkout.session.completed': handle_checkout_session_completed,
'customer.subscription.created': handle_subscription_created,
'customer.subscription.updated': handle_subscription_updated,
'customer.subscription.deleted': handle_subscription_deleted,
'payment_intent.succeeded': handle_payment_succeeded,
'payment_intent.payment_failed': handle_payment_failed,
}

View file

@ -1,15 +1,16 @@
from fastapi import HTTPException, Request
from sqlmodel import Session
import stripe
from config.config import get_learnhouse_config
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.services.payments.payments_config import get_payments_config
from sqlmodel import select
async def get_stripe_credentials(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
):
configs = await get_payments_config(request, org_id, current_user, db_session)
@ -149,6 +150,94 @@ async def update_stripe_product(
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")
async def create_checkout_session(
request: Request,
org_id: int,
product_id: int,
redirect_uri: str,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
# Get Stripe credentials
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
stripe.api_key = creds.get('stripe_secret_key')
# Get product details
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id,
PaymentsProduct.org_id == org_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
success_url = redirect_uri
cancel_url = redirect_uri
# Get the default price for the product
stripe_product = stripe.Product.retrieve(product.provider_product_id)
line_items = [{
"price": stripe_product.default_price,
"quantity": 1
}]
# Create or retrieve Stripe customer
try:
customers = stripe.Customer.list(email=current_user.email)
if customers.data:
customer = customers.data[0]
else:
customer = stripe.Customer.create(
email=current_user.email,
metadata={
"user_id": str(current_user.id),
"org_id": str(org_id)
}
)
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
# Create checkout session with customer
try:
checkout_session_params = {
"success_url": success_url,
"cancel_url": cancel_url,
"mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription',
"line_items": line_items,
"customer": customer.id,
"metadata": {
"product_id": str(product.id)
}
}
# Add payment_intent_data only for one-time payments
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
checkout_session_params["payment_intent_data"] = {
"metadata": {
"product_id": str(product.id)
}
}
# Add subscription_data for subscription payments
else:
checkout_session_params["subscription_data"] = {
"metadata": {
"product_id": str(product.id)
}
}
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
return {
"checkout_url": checkout_session.url,
"session_id": checkout_session.id
}
except stripe.StripeError as e:
print(f"Error creating checkout session: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))

View file

@ -1,5 +1,4 @@
'use client'
import { removeCourse, startCourse } from '@services/courses/activity'
import Link from 'next/link'
import React, { useEffect, useState } from 'react'
import { getUriWithOrg } from '@services/config/config'
@ -15,15 +14,13 @@ import {
import { ArrowRight, Backpack, Check, File, Sparkles, Video } from 'lucide-react'
import { useOrg } from '@components/Contexts/OrgContext'
import UserAvatar from '@components/Objects/UserAvatar'
import CourseUpdates from '@components/Objects/CourseUpdates/CourseUpdates'
import CourseUpdates from '@components/Objects/Courses/CourseUpdates/CourseUpdates'
import { CourseProvider } from '@components/Contexts/CourseContext'
import { useLHSession } from '@components/Contexts/LHSessionContext'
import { useMediaQuery } from 'usehooks-ts'
import CoursesActions from '@components/Objects/Courses/CourseActions/CoursesActions'
const CourseClient = (props: any) => {
const [user, setUser] = useState<any>({})
const [learnings, setLearnings] = useState<any>([])
const session = useLHSession() as any;
const courseuuid = props.courseuuid
const orgslug = props.orgslug
const course = props.course
@ -37,33 +34,6 @@ const CourseClient = (props: any) => {
setLearnings(learnings)
}
async function startCourseUI() {
// Create activity
await startCourse('course_' + courseuuid, orgslug, session.data?.tokens?.access_token)
await revalidateTags(['courses'], orgslug)
router.refresh()
// refresh page (FIX for Next.js BUG)
// window.location.reload();
}
function isCourseStarted() {
const runs = course.trail?.runs
if (!runs) return false
return runs.some(
(run: any) =>
run.status === 'STATUS_IN_PROGRESS' && run.course_id === course.id
)
}
async function quitCourse() {
// Close activity
let activity = await removeCourse('course_' + courseuuid, orgslug, session.data?.tokens?.access_token)
// Mutate course
await revalidateTags(['courses'], orgslug)
router.refresh()
}
useEffect(() => {
getLearningTags()
}, [org, course])
@ -80,7 +50,7 @@ const CourseClient = (props: any) => {
<h1 className="text-3xl md:text-3xl -mt-3 font-bold">{course.name}</h1>
</div>
<div className="mt-4 md:mt-0">
{!isMobile && <CourseProvider courseuuid={course.course_uuid}>
{!isMobile && <CourseProvider courseuuid={course.course_uuid}>
<CourseUpdates />
</CourseProvider>}
</div>
@ -113,11 +83,11 @@ const CourseClient = (props: any) => {
course={course}
/>
<div className="flex flex-col md:flex-row pt-10">
<div className="course_metadata_left grow space-y-2">
<h2 className="py-3 text-2xl font-bold">Description</h2>
<div className="flex flex-col md:flex-row md:space-x-10 space-y-6 md:space-y-0 pt-10">
<div className="course_metadata_left w-full md:basis-3/4 space-y-2">
<h2 className="py-3 text-2xl font-bold">About</h2>
<div className="bg-white shadow-md shadow-gray-300/25 outline outline-1 outline-neutral-200/40 rounded-lg overflow-hidden">
<p className="py-5 px-5">{course.description}</p>
<p className="py-5 px-5 whitespace-pre-wrap">{course.about}</p>
</div>
{learnings.length > 0 && learnings[0] !== 'null' && (
@ -187,7 +157,7 @@ const CourseClient = (props: any) => {
/>
</div>
)}
{activity.activity_type ===
{activity.activity_type ===
'TYPE_ASSIGNMENT' && (
<div className="bg-gray-100 px-2 py-2 rounded-full">
<Backpack
@ -273,7 +243,7 @@ const CourseClient = (props: any) => {
</Link>
</>
)}
{activity.activity_type ===
{activity.activity_type ===
'TYPE_ASSIGNMENT' && (
<>
<Link
@ -305,60 +275,8 @@ const CourseClient = (props: any) => {
})}
</div>
</div>
<div className="course_metadata_right space-y-3 w-full md:w-72 antialiased flex flex-col md:ml-10 h-fit p-3 py-5 bg-white shadow-md shadow-gray-300/25 outline outline-1 outline-neutral-200/40 rounded-lg overflow-hidden mt-6 md:mt-0">
{user && (
<div className="flex flex-row md:flex-col mx-auto space-y-0 md:space-y-3 space-x-4 md:space-x-0 px-2 py-2 items-center">
<UserAvatar
border="border-8"
avatar_url={course.authors[0].avatar_image ? getUserAvatarMediaDirectory(course.authors[0].user_uuid, course.authors[0].avatar_image) : ''}
predefined_avatar={course.authors[0].avatar_image ? undefined : 'empty'}
width={isMobile ? 60 : 100}
/>
<div className="md:-space-y-2">
<div className="text-[12px] text-neutral-400 font-semibold">
Author
</div>
<div className="text-lg md:text-xl font-bold text-neutral-800">
{course.authors[0].first_name &&
course.authors[0].last_name && (
<div className="flex space-x-2 items-center">
<p>
{course.authors[0].first_name +
' ' +
course.authors[0].last_name}
</p>
<span className="text-xs bg-neutral-100 p-1 px-3 rounded-full text-neutral-400 font-semibold">
{' '}
@{course.authors[0].username}
</span>
</div>
)}
{!course.authors[0].first_name &&
!course.authors[0].last_name && (
<div className="flex space-x-2 items-center">
<p>@{course.authors[0].username}</p>
</div>
)}
</div>
</div>
</div>
)}
{isCourseStarted() ? (
<button
className="py-2 px-5 mx-auto rounded-xl text-white font-bold h-12 w-full md:w-[200px] drop-shadow-md bg-red-600 hover:bg-red-700 hover:cursor-pointer"
onClick={quitCourse}
>
Quit Course
</button>
) : (
<button
className="py-2 px-5 mx-auto rounded-xl text-white font-bold h-12 w-full md:w-[200px] drop-shadow-md bg-black hover:bg-gray-900 hover:cursor-pointer"
onClick={startCourseUI}
>
Start Course
</button>
)}
<div className='course_metadata_right basis-1/4'>
<CoursesActions courseuuid={courseuuid} orgslug={orgslug} course={course} />
</div>
</div>
</GeneralWrapperStyled>

View file

@ -0,0 +1,161 @@
import React, { useState } from 'react'
import { useOrg } from '@components/Contexts/OrgContext'
import { useLHSession } from '@components/Contexts/LHSessionContext'
import useSWR from 'swr'
import { getProductsByCourse, getStripeProductCheckoutSession } from '@services/payments/products'
import { RefreshCcw, SquareCheck, ChevronDown, ChevronUp } from 'lucide-react'
import { Badge } from '@components/ui/badge'
import { Button } from '@components/ui/button'
import toast from 'react-hot-toast'
import { useRouter } from 'next/navigation'
import { getUriWithOrg } from '@services/config/config'
interface CoursePaidOptionsProps {
course: {
id: string;
org_id: number;
}
}
function CoursePaidOptions({ course }: CoursePaidOptionsProps) {
const org = useOrg() as any
const session = useLHSession() as any
const [expandedProducts, setExpandedProducts] = useState<{ [key: string]: boolean }>({})
const [isProcessing, setIsProcessing] = useState<{ [key: string]: boolean }>({})
const router = useRouter()
const { data: linkedProducts, error } = useSWR(
() => org && session ? [`/payments/${course.org_id}/courses/${course.id}/products`, session.data?.tokens?.access_token] : null,
([url, token]) => getProductsByCourse(course.org_id, course.id, token)
)
const handleCheckout = async (productId: number) => {
if (!session.data?.user) {
// Redirect to login if user is not authenticated
router.push(`/signup?orgslug=${org.slug}`)
return
}
try {
setIsProcessing(prev => ({ ...prev, [productId]: true }))
const redirect_uri = getUriWithOrg(org.slug, '/courses')
const response = await getStripeProductCheckoutSession(
course.org_id,
productId,
redirect_uri,
session.data?.tokens?.access_token
)
if (response.success) {
router.push(response.data.checkout_url)
} else {
toast.error('Failed to initiate checkout process')
}
} catch (error) {
toast.error('An error occurred while processing your request')
} finally {
setIsProcessing(prev => ({ ...prev, [productId]: false }))
}
}
const toggleProductExpansion = (productId: string) => {
setExpandedProducts(prev => ({
...prev,
[productId]: !prev[productId]
}))
}
if (error) return <div>Failed to load product options</div>
if (!linkedProducts) return <div>Loading...</div>
return (
<div className="space-y-4 p-1">
{linkedProducts.data.map((product: any) => (
<div key={product.id} className="bg-slate-50/30 p-4 rounded-lg nice-shadow flex flex-col">
<div className="flex justify-between items-start mb-2">
<div className="flex flex-col space-y-1 items-start">
<Badge className='w-fit flex items-center space-x-2' variant="outline">
{product.product_type === 'subscription' ? <RefreshCcw size={12} /> : <SquareCheck size={12} />}
<span className='text-sm'>
{product.product_type === 'subscription' ? 'Subscription' : 'One-time payment'}
{product.product_type === 'subscription' && ' (per month)'}
</span>
</Badge>
<h3 className="font-bold text-lg">{product.name}</h3>
</div>
</div>
<div className="flex-grow overflow-hidden">
<div className={`transition-all duration-300 ease-in-out ${expandedProducts[product.id] ? 'max-h-[1000px]' : 'max-h-24'
} overflow-hidden`}>
<p className="text-gray-600">
{product.description}
</p>
{product.benefits && (
<div className="mt-2">
<h4 className="font-semibold text-sm">Benefits:</h4>
<p className="text-sm text-gray-600">
{product.benefits}
</p>
</div>
)}
</div>
</div>
<div className="mt-2">
<button
onClick={() => toggleProductExpansion(product.id)}
className="text-slate-500 hover:text-slate-700 text-sm flex items-center"
>
{expandedProducts[product.id] ? (
<>
<ChevronUp size={16} />
<span>Show less</span>
</>
) : (
<>
<ChevronDown size={16} />
<span>Show more</span>
</>
)}
</button>
</div>
<div className="mt-2 flex items-center justify-between bg-gray-100 rounded-md p-2">
<span className="text-sm text-gray-600">
{product.price_type === 'customer_choice' ? 'Minimum Price:' : 'Price:'}
</span>
<div className="flex flex-col items-end">
<span className="font-semibold text-lg">
{new Intl.NumberFormat('en-US', {
style: 'currency',
currency: product.currency
}).format(product.amount)}
{product.product_type === 'subscription' && <span className="text-sm text-gray-500 ml-1">/month</span>}
</span>
{product.price_type === 'customer_choice' && (
<span className="text-sm text-gray-500">Choose your price</span>
)}
</div>
</div>
<Button
className="mt-4 w-full"
variant="default"
onClick={() => handleCheckout(product.id)}
disabled={isProcessing[product.id]}
>
{isProcessing[product.id]
? 'Processing...'
: product.product_type === 'subscription'
? 'Subscribe Now'
: 'Purchase Now'
}
</Button>
</div>
))}
</div>
)
}
export default CoursePaidOptions

View file

@ -0,0 +1,195 @@
import React, { useState, useEffect } from 'react'
import UserAvatar from '../../UserAvatar'
import { getUserAvatarMediaDirectory } from '@services/media/media'
import { removeCourse, startCourse } from '@services/courses/activity'
import { revalidateTags } from '@services/utils/ts/requests'
import { useRouter } from 'next/navigation'
import { useLHSession } from '@components/Contexts/LHSessionContext'
import { useMediaQuery } from 'usehooks-ts'
import { getUriWithOrg } from '@services/config/config'
import { getProductsByCourse } from '@services/payments/products'
import { LogIn, LogOut, ShoppingCart, AlertCircle } from 'lucide-react'
import Modal from '@components/StyledElements/Modal/Modal'
import CourseCTA from './CoursePaidOptions'
import CoursePaidOptions from './CoursePaidOptions'
interface Author {
user_uuid: string
avatar_image: string
first_name: string
last_name: string
username: string
}
interface CourseRun {
status: string
course_id: string
}
interface Course {
id: string
authors: Author[]
trail?: {
runs: CourseRun[]
}
}
interface CourseActionsProps {
courseuuid: string
orgslug: string
course: Course & {
org_id: number
}
}
// Separate component for author display
const AuthorInfo = ({ author, isMobile }: { author: Author, isMobile: boolean }) => (
<div className="flex flex-row md:flex-col mx-auto space-y-0 md:space-y-3 space-x-4 md:space-x-0 px-2 py-2 items-center">
<UserAvatar
border="border-8"
avatar_url={author.avatar_image ? getUserAvatarMediaDirectory(author.user_uuid, author.avatar_image) : ''}
predefined_avatar={author.avatar_image ? undefined : 'empty'}
width={isMobile ? 60 : 100}
/>
<div className="md:-space-y-2">
<div className="text-[12px] text-neutral-400 font-semibold">Author</div>
<div className="text-lg md:text-xl font-bold text-neutral-800">
{(author.first_name && author.last_name) ? (
<div className="flex space-x-2 items-center">
<p>{`${author.first_name} ${author.last_name}`}</p>
<span className="text-xs bg-neutral-100 p-1 px-3 rounded-full text-neutral-400 font-semibold">
@{author.username}
</span>
</div>
) : (
<div className="flex space-x-2 items-center">
<p>@{author.username}</p>
</div>
)}
</div>
</div>
</div>
)
const Actions = ({ courseuuid, orgslug, course }: CourseActionsProps) => {
const router = useRouter()
const session = useLHSession() as any
const [linkedProducts, setLinkedProducts] = useState<any[]>([])
const [isLoading, setIsLoading] = useState(true)
const [isModalOpen, setIsModalOpen] = useState(false)
const isStarted = course.trail?.runs?.some(
(run) => run.status === 'STATUS_IN_PROGRESS' && run.course_id === course.id
) ?? false
useEffect(() => {
const fetchLinkedProducts = async () => {
try {
const response = await getProductsByCourse(
course.org_id,
course.id,
session.data?.tokens?.access_token
)
setLinkedProducts(response.data || [])
} catch (error) {
console.error('Failed to fetch linked products')
} finally {
setIsLoading(false)
}
}
fetchLinkedProducts()
}, [course.id, course.org_id, session.data?.tokens?.access_token])
const handleCourseAction = async () => {
if (!session.data?.user) {
router.push(getUriWithOrg(orgslug, '/signup?orgslug=' + orgslug))
return
}
const action = isStarted ? removeCourse : startCourse
await action('course_' + courseuuid, orgslug, session.data?.tokens?.access_token)
await revalidateTags(['courses'], orgslug)
router.refresh()
}
if (isLoading) {
return <div className="animate-pulse h-20 bg-gray-100 rounded-lg nice-shadow" />
}
if (linkedProducts.length > 0) {
return (
<div className="space-y-4">
<div className="p-4 bg-amber-50 border border-amber-200 rounded-lg nice-shadow">
<div className="flex items-center gap-3">
<AlertCircle className="w-5 h-5 text-amber-800" />
<h3 className="text-amber-800 font-semibold">Paid Course</h3>
</div>
<p className="text-amber-700 text-sm">
This course requires purchase to access its content.
</p>
</div>
<Modal
isDialogOpen={isModalOpen}
onOpenChange={setIsModalOpen}
dialogContent={<CoursePaidOptions course={course} />}
dialogTitle="Purchase Course"
dialogDescription="Select a payment option to access this course"
minWidth="sm"
/>
<button
className="w-full bg-neutral-900 text-white py-3 rounded-lg nice-shadow font-semibold hover:bg-neutral-800 transition-colors flex items-center justify-center gap-2"
onClick={() => setIsModalOpen(true)}
>
<ShoppingCart className="w-5 h-5" />
Purchase Course
</button>
</div>
)
}
return (
<button
onClick={handleCourseAction}
className={`w-full py-3 rounded-lg nice-shadow font-semibold transition-colors flex items-center justify-center gap-2 ${
isStarted
? 'bg-red-500 text-white hover:bg-red-600'
: 'bg-neutral-900 text-white hover:bg-neutral-800'
}`}
>
{!session.data?.user ? (
<>
<LogIn className="w-5 h-5" />
Authenticate to start course
</>
) : isStarted ? (
<>
<LogOut className="w-5 h-5" />
Leave Course
</>
) : (
<>
<LogIn className="w-5 h-5" />
Start Course
</>
)}
</button>
)
}
function CoursesActions({ courseuuid, orgslug, course }: CourseActionsProps) {
const router = useRouter()
const session = useLHSession() as any
const isMobile = useMediaQuery('(max-width: 768px)')
return (
<div className=" space-y-3 antialiased flex flex-col p-3 py-5 bg-white shadow-md shadow-gray-300/25 outline outline-1 outline-neutral-200/40 rounded-lg overflow-hidden">
<AuthorInfo author={course.authors[0]} isMobile={isMobile} />
<div className='px-3 py-2'>
<Actions courseuuid={courseuuid} orgslug={orgslug} course={course} />
</div>
</div>
)
}
export default CoursesActions

View file

@ -73,4 +73,21 @@ export async function getCoursesLinkedToProduct(orgId: number, productId: string
return res;
}
export async function getProductsByCourse(orgId: number, courseId: string, access_token: string) {
const result = await fetch(
`${getAPIUrl()}payments/${orgId}/courses/${courseId}/products`,
RequestBodyWithAuthHeader('GET', null, null, access_token)
);
const res = await getResponseMetadata(result);
return res;
}
export async function getStripeProductCheckoutSession(orgId: number, productId: number, redirect_uri: string, access_token: string) {
const result = await fetch(
`${getAPIUrl()}payments/${orgId}/stripe/checkout/product/${productId}?redirect_uri=${redirect_uri}`,
RequestBodyWithAuthHeader('POST', null, null, access_token)
);
const res = await getResponseMetadata(result);
return res;
}