mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: add course checkout UI and stripe integration and webhook wip
This commit is contained in:
parent
d8913d1a60
commit
1bff401e73
18 changed files with 1086 additions and 131 deletions
|
|
@ -48,6 +48,7 @@ def install(
|
|||
slug="default",
|
||||
email="",
|
||||
logo_image="",
|
||||
thumbnail_image="",
|
||||
)
|
||||
install_create_organization(org, db_session)
|
||||
print("Default organization created ✅")
|
||||
|
|
@ -89,6 +90,7 @@ def install(
|
|||
slug=slug.lower(),
|
||||
email="",
|
||||
logo_image="",
|
||||
thumbnail_image="",
|
||||
)
|
||||
install_create_organization(org, db_session)
|
||||
print(orgname + " Organization created ✅")
|
||||
|
|
|
|||
|
|
@ -1,26 +1,55 @@
|
|||
import logging
|
||||
import os
|
||||
import importlib
|
||||
from typing import Optional
|
||||
from config.config import get_learnhouse_config
|
||||
from fastapi import FastAPI
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
from sqlmodel import Field, SQLModel, Session, create_engine
|
||||
|
||||
def import_all_models():
|
||||
base_dir = 'src/db'
|
||||
base_module_path = 'src.db'
|
||||
|
||||
# Recursively walk through the base directory
|
||||
for root, dirs, files in os.walk(base_dir):
|
||||
# Filter out __init__.py and non-Python files
|
||||
module_files = [f for f in files if f.endswith('.py') and f != '__init__.py']
|
||||
|
||||
# Calculate the module's base path from its directory structure
|
||||
path_diff = os.path.relpath(root, base_dir)
|
||||
if path_diff == '.':
|
||||
current_module_base = base_module_path
|
||||
else:
|
||||
current_module_base = f"{base_module_path}.{path_diff.replace(os.sep, '.')}"
|
||||
|
||||
# Dynamically import each module
|
||||
for file_name in module_files:
|
||||
module_name = file_name[:-3] # Remove the '.py' extension
|
||||
full_module_path = f"{current_module_base}.{module_name}"
|
||||
importlib.import_module(full_module_path)
|
||||
|
||||
# Import all models before creating engine
|
||||
import_all_models()
|
||||
|
||||
learnhouse_config = get_learnhouse_config()
|
||||
engine = create_engine(
|
||||
learnhouse_config.database_config.sql_connection_string, echo=False, pool_pre_ping=True # type: ignore
|
||||
learnhouse_config.database_config.sql_connection_string, # type: ignore
|
||||
echo=False,
|
||||
pool_pre_ping=True # type: ignore
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
# Create all tables after importing all models
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
async def connect_to_db(app: FastAPI):
|
||||
app.db_engine = engine # type: ignore
|
||||
logging.info("LearnHouse database has been started.")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
def get_db_session():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def close_database(app: FastAPI):
|
||||
logging.info("LearnHouse has been shut down.")
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -2,12 +2,14 @@ from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey
|
|||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class PaymentCourseBase(SQLModel):
|
||||
class PaymentsCourseBase(SQLModel):
|
||||
course_id: int = Field(sa_column=Column(BigInteger, ForeignKey("course.id", ondelete="CASCADE")))
|
||||
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
|
||||
org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")))
|
||||
|
||||
class PaymentCourse(PaymentCourseBase, table=True):
|
||||
class PaymentsCourse(PaymentsCourseBase, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
|
||||
org_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
|
||||
)
|
||||
creation_date: datetime = Field(default=datetime.now())
|
||||
update_date: datetime = Field(default=datetime.now())
|
||||
|
|
@ -1,19 +1,37 @@
|
|||
from enum import Enum
|
||||
from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey
|
||||
from openai import BaseModel
|
||||
from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey, JSON
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
class PaymentUserStatusEnum(str, Enum):
|
||||
class PaymentStatusEnum(str, Enum):
|
||||
PENDING = "pending"
|
||||
COMPLETED = "completed"
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
CANCELLED = "cancelled"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
|
||||
|
||||
class ProviderSpecificData(BaseModel):
|
||||
stripe_customer: dict | None = None
|
||||
custom_customer: dict | None = None
|
||||
|
||||
class PaymentsUserBase(SQLModel):
|
||||
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE")))
|
||||
status: PaymentUserStatusEnum = PaymentUserStatusEnum.ACTIVE
|
||||
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
|
||||
org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")))
|
||||
status: PaymentStatusEnum = PaymentStatusEnum.PENDING
|
||||
provider_specific_data: dict = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
class PaymentsUser(PaymentsUserBase, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE"))
|
||||
)
|
||||
org_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
|
||||
)
|
||||
payment_product_id: int = Field(
|
||||
sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE"))
|
||||
)
|
||||
creation_date: datetime = Field(default=datetime.now())
|
||||
update_date: datetime = Field(default=datetime.now())
|
||||
|
||||
|
|
|
|||
|
|
@ -59,6 +59,11 @@ class AnonymousUser(SQLModel):
|
|||
user_uuid: str = "user_anonymous"
|
||||
username: str = "anonymous"
|
||||
|
||||
class InternalUser(SQLModel):
|
||||
id: int = 0
|
||||
user_uuid: str = "user_internal"
|
||||
username: str = "internal"
|
||||
|
||||
|
||||
class User(UserBase, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
|
|
|
|||
|
|
@ -11,12 +11,14 @@ from src.services.payments.payments_config import (
|
|||
delete_payments_config,
|
||||
)
|
||||
from src.db.payments.payments_products import PaymentsProductCreate, PaymentsProductRead, PaymentsProductUpdate
|
||||
from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, list_payments_products, update_payments_product
|
||||
from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, get_products_by_course, list_payments_products, update_payments_product
|
||||
from src.services.payments.payments_courses import (
|
||||
link_course_to_product,
|
||||
unlink_course_from_product,
|
||||
get_courses_by_product
|
||||
get_courses_by_product,
|
||||
)
|
||||
from src.services.payments.payments_webhook import handle_stripe_webhook
|
||||
from src.services.payments.stripe import create_checkout_session
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -148,3 +150,38 @@ async def api_get_courses_by_product(
|
|||
return await get_courses_by_product(
|
||||
request, org_id, product_id, current_user, db_session
|
||||
)
|
||||
|
||||
@router.get("/{org_id}/courses/{course_id}/products")
|
||||
async def api_get_products_by_course(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
course_id: int,
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
return await get_products_by_course(
|
||||
request, org_id, course_id, current_user, db_session
|
||||
)
|
||||
|
||||
# Payments webhooks
|
||||
|
||||
@router.post("/{org_id}/stripe/webhook")
|
||||
async def api_handle_stripe_webhook(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
return await handle_stripe_webhook(request, org_id, db_session)
|
||||
|
||||
# Payments checkout
|
||||
|
||||
@router.post("/{org_id}/stripe/checkout/product/{product_id}")
|
||||
async def api_create_checkout_session(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: int,
|
||||
redirect_uri: str,
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
return await create_checkout_session(request, org_id, product_id, redirect_uri, current_user, db_session)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from src.security.rbac.rbac import (
|
|||
authorization_verify_based_on_org_admin_status,
|
||||
authorization_verify_if_user_is_anon,
|
||||
)
|
||||
from src.db.users import AnonymousUser, PublicUser
|
||||
from src.db.users import AnonymousUser, InternalUser, PublicUser
|
||||
from src.db.user_organizations import UserOrganization
|
||||
from src.db.organizations import (
|
||||
Organization,
|
||||
|
|
@ -682,7 +682,7 @@ async def get_org_join_mechanism(
|
|||
async def rbac_check(
|
||||
request: Request,
|
||||
org_uuid: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
action: Literal["create", "read", "update", "delete"],
|
||||
db_session: Session,
|
||||
):
|
||||
|
|
@ -690,6 +690,10 @@ async def rbac_check(
|
|||
if action == "read":
|
||||
return True
|
||||
|
||||
# Internal users can do anything
|
||||
if isinstance(current_user, InternalUser):
|
||||
return True
|
||||
|
||||
else:
|
||||
isUserAnon = await authorization_verify_if_user_is_anon(current_user.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from src.db.payments.payments import (
|
|||
PaymentsConfigUpdate,
|
||||
PaymentsConfigRead,
|
||||
)
|
||||
from src.db.users import PublicUser, AnonymousUser
|
||||
from src.db.users import PublicUser, AnonymousUser, InternalUser
|
||||
from src.db.organizations import Organization
|
||||
from src.services.orgs.orgs import rbac_check
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ async def create_payments_config(
|
|||
async def get_payments_config(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> list[PaymentsConfigRead]:
|
||||
# Check if organization exists
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from src.db.payments.payments_courses import PaymentCourse
|
||||
from src.db.payments.payments_courses import PaymentsCourse
|
||||
from src.db.payments.payments_products import PaymentsProduct
|
||||
from src.db.courses.courses import Course
|
||||
from src.db.users import PublicUser, AnonymousUser
|
||||
|
|
@ -36,7 +36,7 @@ async def link_course_to_product(
|
|||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
|
||||
# Check if course is already linked to another product
|
||||
statement = select(PaymentCourse).where(PaymentCourse.course_id == course.id)
|
||||
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
|
||||
existing_link = db_session.exec(statement).first()
|
||||
|
||||
if existing_link:
|
||||
|
|
@ -46,7 +46,7 @@ async def link_course_to_product(
|
|||
)
|
||||
|
||||
# Create new payment course link
|
||||
payment_course = PaymentCourse(
|
||||
payment_course = PaymentsCourse(
|
||||
course_id=course.id, # type: ignore
|
||||
payment_product_id=product_id,
|
||||
org_id=org_id,
|
||||
|
|
@ -75,9 +75,9 @@ async def unlink_course_from_product(
|
|||
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
|
||||
|
||||
# Find and delete the payment course link
|
||||
statement = select(PaymentCourse).where(
|
||||
PaymentCourse.course_id == course.id,
|
||||
PaymentCourse.org_id == org_id
|
||||
statement = select(PaymentsCourse).where(
|
||||
PaymentsCourse.course_id == course.id,
|
||||
PaymentsCourse.org_id == org_id
|
||||
)
|
||||
payment_course = db_session.exec(statement).first()
|
||||
|
||||
|
|
@ -113,12 +113,14 @@ async def get_courses_by_product(
|
|||
statement = (
|
||||
select(Course)
|
||||
.select_from(Course)
|
||||
.join(PaymentCourse, Course.id == PaymentCourse.course_id) # type: ignore
|
||||
.join(PaymentsCourse, Course.id == PaymentsCourse.course_id) # type: ignore
|
||||
.where(
|
||||
PaymentCourse.payment_product_id == product_id,
|
||||
PaymentCourse.org_id == org_id
|
||||
PaymentsCourse.payment_product_id == product_id,
|
||||
PaymentsCourse.org_id == org_id
|
||||
)
|
||||
)
|
||||
courses = db_session.exec(statement).all()
|
||||
|
||||
return courses
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from src.db.courses.courses import Course
|
||||
from src.db.payments.payments import PaymentsConfig
|
||||
from src.db.payments.payments_courses import PaymentsCourse
|
||||
from src.db.payments.payments_products import (
|
||||
PaymentsProduct,
|
||||
PaymentsProductCreate,
|
||||
|
|
@ -12,7 +14,7 @@ from src.db.organizations import Organization
|
|||
from src.services.orgs.orgs import rbac_check
|
||||
from datetime import datetime
|
||||
|
||||
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
|
||||
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product
|
||||
|
||||
async def create_payments_product(
|
||||
request: Request,
|
||||
|
|
@ -163,3 +165,36 @@ async def list_payments_products(
|
|||
products = db_session.exec(statement).all()
|
||||
|
||||
return [PaymentsProductRead.model_validate(product) for product in products]
|
||||
|
||||
async def get_products_by_course(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
course_id: int,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
) -> list[PaymentsProductRead]:
|
||||
# Check if course exists and user has permission
|
||||
statement = select(Course).where(Course.id == course_id)
|
||||
course = db_session.exec(statement).first()
|
||||
|
||||
if not course:
|
||||
raise HTTPException(status_code=404, detail="Course not found")
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
|
||||
|
||||
# Get all products linked to this course with explicit join
|
||||
statement = (
|
||||
select(PaymentsProduct)
|
||||
.select_from(PaymentsProduct)
|
||||
.join(PaymentsCourse, PaymentsProduct.id == PaymentsCourse.payment_product_id) # type: ignore
|
||||
.where(
|
||||
PaymentsCourse.course_id == course_id,
|
||||
PaymentsCourse.org_id == org_id
|
||||
)
|
||||
)
|
||||
products = db_session.exec(statement).all()
|
||||
|
||||
return [PaymentsProductRead.model_validate(product) for product in products]
|
||||
|
||||
|
||||
|
|
|
|||
181
apps/api/src/services/payments/payments_users.py
Normal file
181
apps/api/src/services/payments/payments_users.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from typing import Any
|
||||
from src.db.payments.payments_users import PaymentsUser, PaymentStatusEnum, ProviderSpecificData
|
||||
from src.db.payments.payments_products import PaymentsProduct
|
||||
from src.db.users import InternalUser, PublicUser, AnonymousUser
|
||||
from src.db.organizations import Organization
|
||||
from src.services.orgs.orgs import rbac_check
|
||||
from datetime import datetime
|
||||
|
||||
async def create_payment_user(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
user_id: int,
|
||||
product_id: int,
|
||||
status: PaymentStatusEnum,
|
||||
provider_data: Any,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> PaymentsUser:
|
||||
# Check if organization exists
|
||||
statement = select(Organization).where(Organization.id == org_id)
|
||||
org = db_session.exec(statement).first()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
|
||||
|
||||
# Check if product exists
|
||||
statement = select(PaymentsProduct).where(
|
||||
PaymentsProduct.id == product_id,
|
||||
PaymentsProduct.org_id == org_id
|
||||
)
|
||||
product = db_session.exec(statement).first()
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
|
||||
provider_specific_data = ProviderSpecificData(
|
||||
stripe_customer=provider_data if provider_data else None,
|
||||
)
|
||||
|
||||
# Check if user already has a payment user
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user_id,
|
||||
PaymentsUser.org_id == org_id
|
||||
)
|
||||
existing_payment_user = db_session.exec(statement).first()
|
||||
|
||||
if existing_payment_user:
|
||||
raise HTTPException(status_code=400, detail="User already has purchase for this product")
|
||||
|
||||
# Create new payment user
|
||||
payment_user = PaymentsUser(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
payment_product_id=product_id,
|
||||
provider_specific_data=provider_specific_data.model_dump(),
|
||||
status=status
|
||||
)
|
||||
|
||||
db_session.add(payment_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(payment_user)
|
||||
|
||||
return payment_user
|
||||
|
||||
async def get_payment_user(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
payment_user_id: int,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> PaymentsUser:
|
||||
# Check if organization exists
|
||||
statement = select(Organization).where(Organization.id == org_id)
|
||||
org = db_session.exec(statement).first()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
|
||||
|
||||
# Get payment user
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.id == payment_user_id,
|
||||
PaymentsUser.org_id == org_id
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
if not payment_user:
|
||||
raise HTTPException(status_code=404, detail="Payment user not found")
|
||||
|
||||
return payment_user
|
||||
|
||||
async def update_payment_user_status(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
payment_user_id: int,
|
||||
status: PaymentStatusEnum,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> PaymentsUser:
|
||||
# Check if organization exists
|
||||
statement = select(Organization).where(Organization.id == org_id)
|
||||
org = db_session.exec(statement).first()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
|
||||
|
||||
# Get existing payment user
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.id == payment_user_id,
|
||||
PaymentsUser.org_id == org_id
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
if not payment_user:
|
||||
raise HTTPException(status_code=404, detail="Payment user not found")
|
||||
|
||||
# Update status
|
||||
payment_user.status = status
|
||||
payment_user.update_date = datetime.now()
|
||||
|
||||
db_session.add(payment_user)
|
||||
db_session.commit()
|
||||
db_session.refresh(payment_user)
|
||||
|
||||
return payment_user
|
||||
|
||||
async def list_payment_users(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> list[PaymentsUser]:
|
||||
# Check if organization exists
|
||||
statement = select(Organization).where(Organization.id == org_id)
|
||||
org = db_session.exec(statement).first()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
|
||||
|
||||
# Get all payment users for org ordered by id
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.org_id == org_id
|
||||
).order_by(PaymentsUser.id.desc()) # type: ignore
|
||||
payment_users = list(db_session.exec(statement).all()) # Convert to list
|
||||
|
||||
return payment_users
|
||||
|
||||
async def delete_payment_user(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
payment_user_id: int,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Check if organization exists
|
||||
statement = select(Organization).where(Organization.id == org_id)
|
||||
org = db_session.exec(statement).first()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# RBAC check
|
||||
await rbac_check(request, org.org_uuid, current_user, "delete", db_session)
|
||||
|
||||
# Get existing payment user
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.id == payment_user_id,
|
||||
PaymentsUser.org_id == org_id
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
if not payment_user:
|
||||
raise HTTPException(status_code=404, detail="Payment user not found")
|
||||
|
||||
# Delete payment user
|
||||
db_session.delete(payment_user)
|
||||
db_session.commit()
|
||||
260
apps/api/src/services/payments/payments_webhook.py
Normal file
260
apps/api/src/services/payments/payments_webhook.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
import stripe
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict
|
||||
import logging
|
||||
|
||||
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
|
||||
from src.db.payments.payments_products import PaymentsProduct
|
||||
from src.db.users import InternalUser, User
|
||||
from src.services.payments.payments_users import create_payment_user, update_payment_user_status
|
||||
from src.services.payments.stripe import get_stripe_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_user_from_customer(customer_id: str, db_session: Session) -> User:
|
||||
"""Helper function to get user from Stripe customer ID"""
|
||||
try:
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
statement = select(User).where(User.email == customer.email)
|
||||
user = db_session.exec(statement).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail=f"User not found for customer {customer_id}")
|
||||
return user
|
||||
except stripe.StripeError as e:
|
||||
logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail="Error retrieving customer information")
|
||||
|
||||
async def get_product_from_stripe_id(product_id: str, db_session: Session) -> PaymentsProduct:
|
||||
"""Helper function to get product from Stripe product ID"""
|
||||
statement = select(PaymentsProduct).where(PaymentsProduct.provider_product_id == product_id)
|
||||
product = db_session.exec(statement).first()
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail=f"Product not found: {product_id}")
|
||||
return product
|
||||
|
||||
async def handle_stripe_webhook(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
db_session: Session,
|
||||
) -> dict:
|
||||
# Get Stripe credentials for the organization
|
||||
creds = await get_stripe_credentials(request, org_id, InternalUser(), db_session)
|
||||
|
||||
# Get the webhook secret and API key from credentials
|
||||
webhook_secret = creds.get('stripe_webhook_secret')
|
||||
stripe.api_key = creds.get('stripe_secret_key') # Set API key globally
|
||||
|
||||
if not webhook_secret:
|
||||
raise HTTPException(status_code=400, detail="Stripe webhook secret not configured")
|
||||
|
||||
# Get the raw request body
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
# Verify webhook signature and construct event
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, webhook_secret
|
||||
)
|
||||
|
||||
# Get the appropriate handler
|
||||
handler = STRIPE_EVENT_HANDLERS.get(event.type)
|
||||
if handler:
|
||||
await handler(request, event.data.object, org_id, db_session)
|
||||
return {"status": "success", "event": event.type}
|
||||
else:
|
||||
logger.info(f"Unhandled event type: {event.type}")
|
||||
return {"status": "ignored", "event": event.type}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}")
|
||||
|
||||
async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session):
|
||||
# Get the customer and product details from the session
|
||||
customer_email = session.customer_details.email
|
||||
product_id = session.line_items.data[0].price.product
|
||||
|
||||
# Use helper functions
|
||||
user = await get_user_from_customer(session.customer, db_session)
|
||||
product = await get_product_from_stripe_id(product_id, db_session)
|
||||
|
||||
# Find payment user record
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user.id,
|
||||
PaymentsUser.payment_product_id == product.id
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
|
||||
# Update status to completed
|
||||
await update_payment_user_status(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
payment_user_id=payment_user.id, # type: ignore
|
||||
status=PaymentStatusEnum.COMPLETED,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session):
|
||||
customer_id = subscription.customer
|
||||
|
||||
# Get product_id from metadata
|
||||
product_id = subscription.metadata.get('product_id')
|
||||
if not product_id:
|
||||
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
|
||||
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
|
||||
|
||||
# Get customer email from Stripe
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
|
||||
# Find user and create/update payment record
|
||||
statement = select(User).where(User.email == customer.email)
|
||||
user = db_session.exec(statement).first()
|
||||
|
||||
if user:
|
||||
payment_user = await create_payment_user(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
user_id=user.id, # type: ignore
|
||||
product_id=int(product_id), # Convert string from metadata to int
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
await update_payment_user_status(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
payment_user_id=payment_user.id, # type: ignore
|
||||
status=PaymentStatusEnum.ACTIVE,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session):
|
||||
customer_id = subscription.customer
|
||||
|
||||
# Get product_id from metadata
|
||||
product_id = subscription.metadata.get('product_id')
|
||||
if not product_id:
|
||||
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
|
||||
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
|
||||
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
|
||||
statement = select(User).where(User.email == customer.email)
|
||||
user = db_session.exec(statement).first()
|
||||
|
||||
if user:
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user.id,
|
||||
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
|
||||
if payment_user:
|
||||
status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING
|
||||
await update_payment_user_status(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
payment_user_id=payment_user.id, # type: ignore
|
||||
status=status,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session):
|
||||
customer_id = subscription.customer
|
||||
|
||||
# Get product_id from metadata
|
||||
product_id = subscription.metadata.get('product_id')
|
||||
if not product_id:
|
||||
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
|
||||
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
|
||||
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
|
||||
statement = select(User).where(User.email == customer.email)
|
||||
user = db_session.exec(statement).first()
|
||||
|
||||
if user:
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user.id,
|
||||
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
|
||||
if payment_user:
|
||||
await update_payment_user_status(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
payment_user_id=payment_user.id, # type: ignore
|
||||
status=PaymentStatusEnum.FAILED,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session):
|
||||
customer_id = payment_intent.customer
|
||||
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
|
||||
statement = select(User).where(User.email == customer.email)
|
||||
user = db_session.exec(statement).first()
|
||||
|
||||
# Get product_id directly from metadata
|
||||
product_id = payment_intent.metadata.get('product_id')
|
||||
if not product_id:
|
||||
logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}")
|
||||
raise HTTPException(status_code=400, detail="No product_id found in payment metadata")
|
||||
|
||||
if user:
|
||||
await create_payment_user(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
user_id=user.id, # type: ignore
|
||||
product_id=int(product_id), # Convert string from metadata to int
|
||||
status=PaymentStatusEnum.COMPLETED,
|
||||
provider_data=customer,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session):
|
||||
# Update payment status to failed
|
||||
customer_id = payment_intent.customer
|
||||
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
|
||||
statement = select(User).where(User.email == customer.email)
|
||||
user = db_session.exec(statement).first()
|
||||
|
||||
if user:
|
||||
statement = select(PaymentsUser).where(
|
||||
PaymentsUser.user_id == user.id,
|
||||
PaymentsUser.org_id == org_id,
|
||||
PaymentsUser.status == PaymentStatusEnum.PENDING
|
||||
)
|
||||
payment_user = db_session.exec(statement).first()
|
||||
|
||||
if payment_user:
|
||||
await update_payment_user_status(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
payment_user_id=payment_user.id, # type: ignore
|
||||
status=PaymentStatusEnum.FAILED,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
# Create event handler mapping
|
||||
STRIPE_EVENT_HANDLERS = {
|
||||
'checkout.session.completed': handle_checkout_session_completed,
|
||||
'customer.subscription.created': handle_subscription_created,
|
||||
'customer.subscription.updated': handle_subscription_updated,
|
||||
'customer.subscription.deleted': handle_subscription_deleted,
|
||||
'payment_intent.succeeded': handle_payment_succeeded,
|
||||
'payment_intent.payment_failed': handle_payment_failed,
|
||||
}
|
||||
|
|
@ -1,15 +1,16 @@
|
|||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session
|
||||
import stripe
|
||||
from config.config import get_learnhouse_config
|
||||
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
|
||||
from src.db.users import AnonymousUser, PublicUser
|
||||
from src.db.users import AnonymousUser, InternalUser, PublicUser
|
||||
from src.services.payments.payments_config import get_payments_config
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
async def get_stripe_credentials(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
configs = await get_payments_config(request, org_id, current_user, db_session)
|
||||
|
|
@ -149,6 +150,94 @@ async def update_stripe_product(
|
|||
except stripe.StripeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")
|
||||
|
||||
async def create_checkout_session(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: int,
|
||||
redirect_uri: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
# Get Stripe credentials
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
# Get product details
|
||||
statement = select(PaymentsProduct).where(
|
||||
PaymentsProduct.id == product_id,
|
||||
PaymentsProduct.org_id == org_id
|
||||
)
|
||||
product = db_session.exec(statement).first()
|
||||
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
|
||||
|
||||
success_url = redirect_uri
|
||||
cancel_url = redirect_uri
|
||||
|
||||
# Get the default price for the product
|
||||
stripe_product = stripe.Product.retrieve(product.provider_product_id)
|
||||
line_items = [{
|
||||
"price": stripe_product.default_price,
|
||||
"quantity": 1
|
||||
}]
|
||||
|
||||
# Create or retrieve Stripe customer
|
||||
try:
|
||||
customers = stripe.Customer.list(email=current_user.email)
|
||||
if customers.data:
|
||||
customer = customers.data[0]
|
||||
else:
|
||||
customer = stripe.Customer.create(
|
||||
email=current_user.email,
|
||||
metadata={
|
||||
"user_id": str(current_user.id),
|
||||
"org_id": str(org_id)
|
||||
}
|
||||
)
|
||||
except stripe.StripeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
|
||||
|
||||
# Create checkout session with customer
|
||||
try:
|
||||
checkout_session_params = {
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription',
|
||||
"line_items": line_items,
|
||||
"customer": customer.id,
|
||||
"metadata": {
|
||||
"product_id": str(product.id)
|
||||
}
|
||||
}
|
||||
|
||||
# Add payment_intent_data only for one-time payments
|
||||
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
|
||||
checkout_session_params["payment_intent_data"] = {
|
||||
"metadata": {
|
||||
"product_id": str(product.id)
|
||||
}
|
||||
}
|
||||
# Add subscription_data for subscription payments
|
||||
else:
|
||||
checkout_session_params["subscription_data"] = {
|
||||
"metadata": {
|
||||
"product_id": str(product.id)
|
||||
}
|
||||
}
|
||||
|
||||
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
|
||||
|
||||
return {
|
||||
"checkout_url": checkout_session.url,
|
||||
"session_id": checkout_session.id
|
||||
}
|
||||
|
||||
except stripe.StripeError as e:
|
||||
print(f"Error creating checkout session: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
'use client'
|
||||
import { removeCourse, startCourse } from '@services/courses/activity'
|
||||
import Link from 'next/link'
|
||||
import React, { useEffect, useState } from 'react'
|
||||
import { getUriWithOrg } from '@services/config/config'
|
||||
|
|
@ -15,15 +14,13 @@ import {
|
|||
import { ArrowRight, Backpack, Check, File, Sparkles, Video } from 'lucide-react'
|
||||
import { useOrg } from '@components/Contexts/OrgContext'
|
||||
import UserAvatar from '@components/Objects/UserAvatar'
|
||||
import CourseUpdates from '@components/Objects/CourseUpdates/CourseUpdates'
|
||||
import CourseUpdates from '@components/Objects/Courses/CourseUpdates/CourseUpdates'
|
||||
import { CourseProvider } from '@components/Contexts/CourseContext'
|
||||
import { useLHSession } from '@components/Contexts/LHSessionContext'
|
||||
import { useMediaQuery } from 'usehooks-ts'
|
||||
import CoursesActions from '@components/Objects/Courses/CourseActions/CoursesActions'
|
||||
|
||||
const CourseClient = (props: any) => {
|
||||
const [user, setUser] = useState<any>({})
|
||||
const [learnings, setLearnings] = useState<any>([])
|
||||
const session = useLHSession() as any;
|
||||
const courseuuid = props.courseuuid
|
||||
const orgslug = props.orgslug
|
||||
const course = props.course
|
||||
|
|
@ -37,33 +34,6 @@ const CourseClient = (props: any) => {
|
|||
setLearnings(learnings)
|
||||
}
|
||||
|
||||
async function startCourseUI() {
|
||||
// Create activity
|
||||
await startCourse('course_' + courseuuid, orgslug, session.data?.tokens?.access_token)
|
||||
await revalidateTags(['courses'], orgslug)
|
||||
router.refresh()
|
||||
|
||||
// refresh page (FIX for Next.js BUG)
|
||||
// window.location.reload();
|
||||
}
|
||||
|
||||
function isCourseStarted() {
|
||||
const runs = course.trail?.runs
|
||||
if (!runs) return false
|
||||
return runs.some(
|
||||
(run: any) =>
|
||||
run.status === 'STATUS_IN_PROGRESS' && run.course_id === course.id
|
||||
)
|
||||
}
|
||||
|
||||
async function quitCourse() {
|
||||
// Close activity
|
||||
let activity = await removeCourse('course_' + courseuuid, orgslug, session.data?.tokens?.access_token)
|
||||
// Mutate course
|
||||
await revalidateTags(['courses'], orgslug)
|
||||
router.refresh()
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
getLearningTags()
|
||||
}, [org, course])
|
||||
|
|
@ -113,11 +83,11 @@ const CourseClient = (props: any) => {
|
|||
course={course}
|
||||
/>
|
||||
|
||||
<div className="flex flex-col md:flex-row pt-10">
|
||||
<div className="course_metadata_left grow space-y-2">
|
||||
<h2 className="py-3 text-2xl font-bold">Description</h2>
|
||||
<div className="flex flex-col md:flex-row md:space-x-10 space-y-6 md:space-y-0 pt-10">
|
||||
<div className="course_metadata_left w-full md:basis-3/4 space-y-2">
|
||||
<h2 className="py-3 text-2xl font-bold">About</h2>
|
||||
<div className="bg-white shadow-md shadow-gray-300/25 outline outline-1 outline-neutral-200/40 rounded-lg overflow-hidden">
|
||||
<p className="py-5 px-5">{course.description}</p>
|
||||
<p className="py-5 px-5 whitespace-pre-wrap">{course.about}</p>
|
||||
</div>
|
||||
|
||||
{learnings.length > 0 && learnings[0] !== 'null' && (
|
||||
|
|
@ -305,60 +275,8 @@ const CourseClient = (props: any) => {
|
|||
})}
|
||||
</div>
|
||||
</div>
|
||||
<div className="course_metadata_right space-y-3 w-full md:w-72 antialiased flex flex-col md:ml-10 h-fit p-3 py-5 bg-white shadow-md shadow-gray-300/25 outline outline-1 outline-neutral-200/40 rounded-lg overflow-hidden mt-6 md:mt-0">
|
||||
{user && (
|
||||
<div className="flex flex-row md:flex-col mx-auto space-y-0 md:space-y-3 space-x-4 md:space-x-0 px-2 py-2 items-center">
|
||||
<UserAvatar
|
||||
border="border-8"
|
||||
avatar_url={course.authors[0].avatar_image ? getUserAvatarMediaDirectory(course.authors[0].user_uuid, course.authors[0].avatar_image) : ''}
|
||||
predefined_avatar={course.authors[0].avatar_image ? undefined : 'empty'}
|
||||
width={isMobile ? 60 : 100}
|
||||
/>
|
||||
<div className="md:-space-y-2">
|
||||
<div className="text-[12px] text-neutral-400 font-semibold">
|
||||
Author
|
||||
</div>
|
||||
<div className="text-lg md:text-xl font-bold text-neutral-800">
|
||||
{course.authors[0].first_name &&
|
||||
course.authors[0].last_name && (
|
||||
<div className="flex space-x-2 items-center">
|
||||
<p>
|
||||
{course.authors[0].first_name +
|
||||
' ' +
|
||||
course.authors[0].last_name}
|
||||
</p>
|
||||
<span className="text-xs bg-neutral-100 p-1 px-3 rounded-full text-neutral-400 font-semibold">
|
||||
{' '}
|
||||
@{course.authors[0].username}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
{!course.authors[0].first_name &&
|
||||
!course.authors[0].last_name && (
|
||||
<div className="flex space-x-2 items-center">
|
||||
<p>@{course.authors[0].username}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isCourseStarted() ? (
|
||||
<button
|
||||
className="py-2 px-5 mx-auto rounded-xl text-white font-bold h-12 w-full md:w-[200px] drop-shadow-md bg-red-600 hover:bg-red-700 hover:cursor-pointer"
|
||||
onClick={quitCourse}
|
||||
>
|
||||
Quit Course
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="py-2 px-5 mx-auto rounded-xl text-white font-bold h-12 w-full md:w-[200px] drop-shadow-md bg-black hover:bg-gray-900 hover:cursor-pointer"
|
||||
onClick={startCourseUI}
|
||||
>
|
||||
Start Course
|
||||
</button>
|
||||
)}
|
||||
<div className='course_metadata_right basis-1/4'>
|
||||
<CoursesActions courseuuid={courseuuid} orgslug={orgslug} course={course} />
|
||||
</div>
|
||||
</div>
|
||||
</GeneralWrapperStyled>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,161 @@
|
|||
import React, { useState } from 'react'
|
||||
import { useOrg } from '@components/Contexts/OrgContext'
|
||||
import { useLHSession } from '@components/Contexts/LHSessionContext'
|
||||
import useSWR from 'swr'
|
||||
import { getProductsByCourse, getStripeProductCheckoutSession } from '@services/payments/products'
|
||||
import { RefreshCcw, SquareCheck, ChevronDown, ChevronUp } from 'lucide-react'
|
||||
import { Badge } from '@components/ui/badge'
|
||||
import { Button } from '@components/ui/button'
|
||||
import toast from 'react-hot-toast'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { getUriWithOrg } from '@services/config/config'
|
||||
|
||||
interface CoursePaidOptionsProps {
|
||||
course: {
|
||||
id: string;
|
||||
org_id: number;
|
||||
}
|
||||
}
|
||||
|
||||
function CoursePaidOptions({ course }: CoursePaidOptionsProps) {
|
||||
const org = useOrg() as any
|
||||
const session = useLHSession() as any
|
||||
const [expandedProducts, setExpandedProducts] = useState<{ [key: string]: boolean }>({})
|
||||
const [isProcessing, setIsProcessing] = useState<{ [key: string]: boolean }>({})
|
||||
const router = useRouter()
|
||||
|
||||
const { data: linkedProducts, error } = useSWR(
|
||||
() => org && session ? [`/payments/${course.org_id}/courses/${course.id}/products`, session.data?.tokens?.access_token] : null,
|
||||
([url, token]) => getProductsByCourse(course.org_id, course.id, token)
|
||||
)
|
||||
|
||||
const handleCheckout = async (productId: number) => {
|
||||
if (!session.data?.user) {
|
||||
// Redirect to login if user is not authenticated
|
||||
router.push(`/signup?orgslug=${org.slug}`)
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
setIsProcessing(prev => ({ ...prev, [productId]: true }))
|
||||
const redirect_uri = getUriWithOrg(org.slug, '/courses')
|
||||
const response = await getStripeProductCheckoutSession(
|
||||
course.org_id,
|
||||
productId,
|
||||
redirect_uri,
|
||||
session.data?.tokens?.access_token
|
||||
)
|
||||
|
||||
if (response.success) {
|
||||
router.push(response.data.checkout_url)
|
||||
} else {
|
||||
toast.error('Failed to initiate checkout process')
|
||||
}
|
||||
} catch (error) {
|
||||
toast.error('An error occurred while processing your request')
|
||||
} finally {
|
||||
setIsProcessing(prev => ({ ...prev, [productId]: false }))
|
||||
}
|
||||
}
|
||||
|
||||
const toggleProductExpansion = (productId: string) => {
|
||||
setExpandedProducts(prev => ({
|
||||
...prev,
|
||||
[productId]: !prev[productId]
|
||||
}))
|
||||
}
|
||||
|
||||
if (error) return <div>Failed to load product options</div>
|
||||
if (!linkedProducts) return <div>Loading...</div>
|
||||
|
||||
return (
|
||||
<div className="space-y-4 p-1">
|
||||
{linkedProducts.data.map((product: any) => (
|
||||
<div key={product.id} className="bg-slate-50/30 p-4 rounded-lg nice-shadow flex flex-col">
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<div className="flex flex-col space-y-1 items-start">
|
||||
<Badge className='w-fit flex items-center space-x-2' variant="outline">
|
||||
{product.product_type === 'subscription' ? <RefreshCcw size={12} /> : <SquareCheck size={12} />}
|
||||
<span className='text-sm'>
|
||||
{product.product_type === 'subscription' ? 'Subscription' : 'One-time payment'}
|
||||
{product.product_type === 'subscription' && ' (per month)'}
|
||||
</span>
|
||||
</Badge>
|
||||
<h3 className="font-bold text-lg">{product.name}</h3>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-grow overflow-hidden">
|
||||
<div className={`transition-all duration-300 ease-in-out ${expandedProducts[product.id] ? 'max-h-[1000px]' : 'max-h-24'
|
||||
} overflow-hidden`}>
|
||||
<p className="text-gray-600">
|
||||
{product.description}
|
||||
</p>
|
||||
{product.benefits && (
|
||||
<div className="mt-2">
|
||||
<h4 className="font-semibold text-sm">Benefits:</h4>
|
||||
<p className="text-sm text-gray-600">
|
||||
{product.benefits}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
<button
|
||||
onClick={() => toggleProductExpansion(product.id)}
|
||||
className="text-slate-500 hover:text-slate-700 text-sm flex items-center"
|
||||
>
|
||||
{expandedProducts[product.id] ? (
|
||||
<>
|
||||
<ChevronUp size={16} />
|
||||
<span>Show less</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronDown size={16} />
|
||||
<span>Show more</span>
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="mt-2 flex items-center justify-between bg-gray-100 rounded-md p-2">
|
||||
<span className="text-sm text-gray-600">
|
||||
{product.price_type === 'customer_choice' ? 'Minimum Price:' : 'Price:'}
|
||||
</span>
|
||||
<div className="flex flex-col items-end">
|
||||
<span className="font-semibold text-lg">
|
||||
{new Intl.NumberFormat('en-US', {
|
||||
style: 'currency',
|
||||
currency: product.currency
|
||||
}).format(product.amount)}
|
||||
{product.product_type === 'subscription' && <span className="text-sm text-gray-500 ml-1">/month</span>}
|
||||
</span>
|
||||
{product.price_type === 'customer_choice' && (
|
||||
<span className="text-sm text-gray-500">Choose your price</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
className="mt-4 w-full"
|
||||
variant="default"
|
||||
onClick={() => handleCheckout(product.id)}
|
||||
disabled={isProcessing[product.id]}
|
||||
>
|
||||
{isProcessing[product.id]
|
||||
? 'Processing...'
|
||||
: product.product_type === 'subscription'
|
||||
? 'Subscribe Now'
|
||||
: 'Purchase Now'
|
||||
}
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default CoursePaidOptions
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
import React, { useState, useEffect } from 'react'
|
||||
import UserAvatar from '../../UserAvatar'
|
||||
import { getUserAvatarMediaDirectory } from '@services/media/media'
|
||||
import { removeCourse, startCourse } from '@services/courses/activity'
|
||||
import { revalidateTags } from '@services/utils/ts/requests'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useLHSession } from '@components/Contexts/LHSessionContext'
|
||||
import { useMediaQuery } from 'usehooks-ts'
|
||||
import { getUriWithOrg } from '@services/config/config'
|
||||
import { getProductsByCourse } from '@services/payments/products'
|
||||
import { LogIn, LogOut, ShoppingCart, AlertCircle } from 'lucide-react'
|
||||
import Modal from '@components/StyledElements/Modal/Modal'
|
||||
import CourseCTA from './CoursePaidOptions'
|
||||
import CoursePaidOptions from './CoursePaidOptions'
|
||||
|
||||
interface Author {
|
||||
user_uuid: string
|
||||
avatar_image: string
|
||||
first_name: string
|
||||
last_name: string
|
||||
username: string
|
||||
}
|
||||
|
||||
interface CourseRun {
|
||||
status: string
|
||||
course_id: string
|
||||
}
|
||||
|
||||
interface Course {
|
||||
id: string
|
||||
authors: Author[]
|
||||
trail?: {
|
||||
runs: CourseRun[]
|
||||
}
|
||||
}
|
||||
|
||||
interface CourseActionsProps {
|
||||
courseuuid: string
|
||||
orgslug: string
|
||||
course: Course & {
|
||||
org_id: number
|
||||
}
|
||||
}
|
||||
|
||||
// Separate component for author display
|
||||
const AuthorInfo = ({ author, isMobile }: { author: Author, isMobile: boolean }) => (
|
||||
<div className="flex flex-row md:flex-col mx-auto space-y-0 md:space-y-3 space-x-4 md:space-x-0 px-2 py-2 items-center">
|
||||
<UserAvatar
|
||||
border="border-8"
|
||||
avatar_url={author.avatar_image ? getUserAvatarMediaDirectory(author.user_uuid, author.avatar_image) : ''}
|
||||
predefined_avatar={author.avatar_image ? undefined : 'empty'}
|
||||
width={isMobile ? 60 : 100}
|
||||
/>
|
||||
<div className="md:-space-y-2">
|
||||
<div className="text-[12px] text-neutral-400 font-semibold">Author</div>
|
||||
<div className="text-lg md:text-xl font-bold text-neutral-800">
|
||||
{(author.first_name && author.last_name) ? (
|
||||
<div className="flex space-x-2 items-center">
|
||||
<p>{`${author.first_name} ${author.last_name}`}</p>
|
||||
<span className="text-xs bg-neutral-100 p-1 px-3 rounded-full text-neutral-400 font-semibold">
|
||||
@{author.username}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex space-x-2 items-center">
|
||||
<p>@{author.username}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
const Actions = ({ courseuuid, orgslug, course }: CourseActionsProps) => {
|
||||
const router = useRouter()
|
||||
const session = useLHSession() as any
|
||||
const [linkedProducts, setLinkedProducts] = useState<any[]>([])
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const [isModalOpen, setIsModalOpen] = useState(false)
|
||||
|
||||
const isStarted = course.trail?.runs?.some(
|
||||
(run) => run.status === 'STATUS_IN_PROGRESS' && run.course_id === course.id
|
||||
) ?? false
|
||||
|
||||
useEffect(() => {
|
||||
const fetchLinkedProducts = async () => {
|
||||
try {
|
||||
const response = await getProductsByCourse(
|
||||
course.org_id,
|
||||
course.id,
|
||||
session.data?.tokens?.access_token
|
||||
)
|
||||
setLinkedProducts(response.data || [])
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch linked products')
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}
|
||||
|
||||
fetchLinkedProducts()
|
||||
}, [course.id, course.org_id, session.data?.tokens?.access_token])
|
||||
|
||||
const handleCourseAction = async () => {
|
||||
if (!session.data?.user) {
|
||||
router.push(getUriWithOrg(orgslug, '/signup?orgslug=' + orgslug))
|
||||
return
|
||||
}
|
||||
const action = isStarted ? removeCourse : startCourse
|
||||
await action('course_' + courseuuid, orgslug, session.data?.tokens?.access_token)
|
||||
await revalidateTags(['courses'], orgslug)
|
||||
router.refresh()
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return <div className="animate-pulse h-20 bg-gray-100 rounded-lg nice-shadow" />
|
||||
}
|
||||
|
||||
if (linkedProducts.length > 0) {
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="p-4 bg-amber-50 border border-amber-200 rounded-lg nice-shadow">
|
||||
<div className="flex items-center gap-3">
|
||||
<AlertCircle className="w-5 h-5 text-amber-800" />
|
||||
<h3 className="text-amber-800 font-semibold">Paid Course</h3>
|
||||
</div>
|
||||
<p className="text-amber-700 text-sm">
|
||||
This course requires purchase to access its content.
|
||||
</p>
|
||||
</div>
|
||||
<Modal
|
||||
isDialogOpen={isModalOpen}
|
||||
onOpenChange={setIsModalOpen}
|
||||
dialogContent={<CoursePaidOptions course={course} />}
|
||||
dialogTitle="Purchase Course"
|
||||
dialogDescription="Select a payment option to access this course"
|
||||
minWidth="sm"
|
||||
/>
|
||||
<button
|
||||
className="w-full bg-neutral-900 text-white py-3 rounded-lg nice-shadow font-semibold hover:bg-neutral-800 transition-colors flex items-center justify-center gap-2"
|
||||
onClick={() => setIsModalOpen(true)}
|
||||
>
|
||||
<ShoppingCart className="w-5 h-5" />
|
||||
Purchase Course
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={handleCourseAction}
|
||||
className={`w-full py-3 rounded-lg nice-shadow font-semibold transition-colors flex items-center justify-center gap-2 ${
|
||||
isStarted
|
||||
? 'bg-red-500 text-white hover:bg-red-600'
|
||||
: 'bg-neutral-900 text-white hover:bg-neutral-800'
|
||||
}`}
|
||||
>
|
||||
{!session.data?.user ? (
|
||||
<>
|
||||
<LogIn className="w-5 h-5" />
|
||||
Authenticate to start course
|
||||
</>
|
||||
) : isStarted ? (
|
||||
<>
|
||||
<LogOut className="w-5 h-5" />
|
||||
Leave Course
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<LogIn className="w-5 h-5" />
|
||||
Start Course
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
function CoursesActions({ courseuuid, orgslug, course }: CourseActionsProps) {
|
||||
const router = useRouter()
|
||||
const session = useLHSession() as any
|
||||
const isMobile = useMediaQuery('(max-width: 768px)')
|
||||
|
||||
|
||||
return (
|
||||
<div className=" space-y-3 antialiased flex flex-col p-3 py-5 bg-white shadow-md shadow-gray-300/25 outline outline-1 outline-neutral-200/40 rounded-lg overflow-hidden">
|
||||
<AuthorInfo author={course.authors[0]} isMobile={isMobile} />
|
||||
<div className='px-3 py-2'>
|
||||
<Actions courseuuid={courseuuid} orgslug={orgslug} course={course} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default CoursesActions
|
||||
|
|
@ -73,4 +73,21 @@ export async function getCoursesLinkedToProduct(orgId: number, productId: string
|
|||
return res;
|
||||
}
|
||||
|
||||
export async function getProductsByCourse(orgId: number, courseId: string, access_token: string) {
|
||||
const result = await fetch(
|
||||
`${getAPIUrl()}payments/${orgId}/courses/${courseId}/products`,
|
||||
RequestBodyWithAuthHeader('GET', null, null, access_token)
|
||||
);
|
||||
const res = await getResponseMetadata(result);
|
||||
return res;
|
||||
}
|
||||
|
||||
export async function getStripeProductCheckoutSession(orgId: number, productId: number, redirect_uri: string, access_token: string) {
|
||||
const result = await fetch(
|
||||
`${getAPIUrl()}payments/${orgId}/stripe/checkout/product/${productId}?redirect_uri=${redirect_uri}`,
|
||||
RequestBodyWithAuthHeader('POST', null, null, access_token)
|
||||
);
|
||||
const res = await getResponseMetadata(result);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue