feat: add course checkout UI and stripe integration and webhook wip

This commit is contained in:
swve 2024-11-01 20:51:52 +01:00
parent d8913d1a60
commit 1bff401e73
18 changed files with 1086 additions and 131 deletions

View file

@ -48,6 +48,7 @@ def install(
slug="default",
email="",
logo_image="",
thumbnail_image="",
)
install_create_organization(org, db_session)
print("Default organization created ✅")
@ -89,6 +90,7 @@ def install(
slug=slug.lower(),
email="",
logo_image="",
thumbnail_image="",
)
install_create_organization(org, db_session)
print(orgname + " Organization created ✅")

View file

@ -1,26 +1,55 @@
import logging
import os
import importlib
from typing import Optional
from config.config import get_learnhouse_config
from fastapi import FastAPI
from sqlmodel import SQLModel, Session, create_engine
from sqlmodel import Field, SQLModel, Session, create_engine
def import_all_models():
base_dir = 'src/db'
base_module_path = 'src.db'
# Recursively walk through the base directory
for root, dirs, files in os.walk(base_dir):
# Filter out __init__.py and non-Python files
module_files = [f for f in files if f.endswith('.py') and f != '__init__.py']
# Calculate the module's base path from its directory structure
path_diff = os.path.relpath(root, base_dir)
if path_diff == '.':
current_module_base = base_module_path
else:
current_module_base = f"{base_module_path}.{path_diff.replace(os.sep, '.')}"
# Dynamically import each module
for file_name in module_files:
module_name = file_name[:-3] # Remove the '.py' extension
full_module_path = f"{current_module_base}.{module_name}"
importlib.import_module(full_module_path)
# Import all models before creating engine
import_all_models()
learnhouse_config = get_learnhouse_config()
engine = create_engine(
learnhouse_config.database_config.sql_connection_string, echo=False, pool_pre_ping=True # type: ignore
learnhouse_config.database_config.sql_connection_string, # type: ignore
echo=False,
pool_pre_ping=True # type: ignore
)
SQLModel.metadata.create_all(engine)
# Create all tables after importing all models
SQLModel.metadata.create_all(engine)
async def connect_to_db(app: FastAPI):
app.db_engine = engine # type: ignore
logging.info("LearnHouse database has been started.")
SQLModel.metadata.create_all(engine)
def get_db_session():
with Session(engine) as session:
yield session
async def close_database(app: FastAPI):
logging.info("LearnHouse has been shut down.")
return app

View file

@ -2,12 +2,14 @@ from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey
from typing import Optional
from datetime import datetime
class PaymentCourseBase(SQLModel):
class PaymentsCourseBase(SQLModel):
course_id: int = Field(sa_column=Column(BigInteger, ForeignKey("course.id", ondelete="CASCADE")))
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")))
class PaymentCourse(PaymentCourseBase, table=True):
class PaymentsCourse(PaymentsCourseBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
org_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
)
creation_date: datetime = Field(default=datetime.now())
update_date: datetime = Field(default=datetime.now())

View file

@ -1,19 +1,37 @@
from enum import Enum
from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey
from openai import BaseModel
from sqlmodel import SQLModel, Field, Column, BigInteger, ForeignKey, JSON
from typing import Optional
from datetime import datetime
from enum import Enum
class PaymentUserStatusEnum(str, Enum):
class PaymentStatusEnum(str, Enum):
PENDING = "pending"
COMPLETED = "completed"
ACTIVE = "active"
INACTIVE = "inactive"
CANCELLED = "cancelled"
FAILED = "failed"
REFUNDED = "refunded"
class ProviderSpecificData(BaseModel):
stripe_customer: dict | None = None
custom_customer: dict | None = None
class PaymentsUserBase(SQLModel):
user_id: int = Field(sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE")))
status: PaymentUserStatusEnum = PaymentUserStatusEnum.ACTIVE
payment_product_id: int = Field(sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE")))
org_id: int = Field(sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE")))
status: PaymentStatusEnum = PaymentStatusEnum.PENDING
provider_specific_data: dict = Field(default={}, sa_column=Column(JSON))
class PaymentsUser(PaymentsUserBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
user_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("user.id", ondelete="CASCADE"))
)
org_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("organization.id", ondelete="CASCADE"))
)
payment_product_id: int = Field(
sa_column=Column(BigInteger, ForeignKey("paymentsproduct.id", ondelete="CASCADE"))
)
creation_date: datetime = Field(default=datetime.now())
update_date: datetime = Field(default=datetime.now())

View file

@ -59,6 +59,11 @@ class AnonymousUser(SQLModel):
user_uuid: str = "user_anonymous"
username: str = "anonymous"
class InternalUser(SQLModel):
id: int = 0
user_uuid: str = "user_internal"
username: str = "internal"
class User(UserBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)

View file

@ -11,12 +11,14 @@ from src.services.payments.payments_config import (
delete_payments_config,
)
from src.db.payments.payments_products import PaymentsProductCreate, PaymentsProductRead, PaymentsProductUpdate
from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, list_payments_products, update_payments_product
from src.services.payments.payments_products import create_payments_product, delete_payments_product, get_payments_product, get_products_by_course, list_payments_products, update_payments_product
from src.services.payments.payments_courses import (
link_course_to_product,
unlink_course_from_product,
get_courses_by_product
get_courses_by_product,
)
from src.services.payments.payments_webhook import handle_stripe_webhook
from src.services.payments.stripe import create_checkout_session
router = APIRouter()
@ -148,3 +150,38 @@ async def api_get_courses_by_product(
return await get_courses_by_product(
request, org_id, product_id, current_user, db_session
)
@router.get("/{org_id}/courses/{course_id}/products")
async def api_get_products_by_course(
request: Request,
org_id: int,
course_id: int,
current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session),
):
return await get_products_by_course(
request, org_id, course_id, current_user, db_session
)
# Payments webhooks
@router.post("/{org_id}/stripe/webhook")
async def api_handle_stripe_webhook(
request: Request,
org_id: int,
db_session: Session = Depends(get_db_session),
):
return await handle_stripe_webhook(request, org_id, db_session)
# Payments checkout
@router.post("/{org_id}/stripe/checkout/product/{product_id}")
async def api_create_checkout_session(
request: Request,
org_id: int,
product_id: int,
redirect_uri: str,
current_user: PublicUser = Depends(get_current_user),
db_session: Session = Depends(get_db_session),
):
return await create_checkout_session(request, org_id, product_id, redirect_uri, current_user, db_session)

View file

@ -26,7 +26,7 @@ from src.security.rbac.rbac import (
authorization_verify_based_on_org_admin_status,
authorization_verify_if_user_is_anon,
)
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.db.user_organizations import UserOrganization
from src.db.organizations import (
Organization,
@ -682,13 +682,17 @@ async def get_org_join_mechanism(
async def rbac_check(
request: Request,
org_uuid: str,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
# Organizations are readable by anyone
if action == "read":
return True
# Internal users can do anything
if isinstance(current_user, InternalUser):
return True
else:
isUserAnon = await authorization_verify_if_user_is_anon(current_user.id)

View file

@ -6,7 +6,7 @@ from src.db.payments.payments import (
PaymentsConfigUpdate,
PaymentsConfigRead,
)
from src.db.users import PublicUser, AnonymousUser
from src.db.users import PublicUser, AnonymousUser, InternalUser
from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
@ -48,7 +48,7 @@ async def create_payments_config(
async def get_payments_config(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> list[PaymentsConfigRead]:
# Check if organization exists

View file

@ -1,7 +1,7 @@
from datetime import datetime
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.payments.payments_courses import PaymentCourse
from src.db.payments.payments_courses import PaymentsCourse
from src.db.payments.payments_products import PaymentsProduct
from src.db.courses.courses import Course
from src.db.users import PublicUser, AnonymousUser
@ -36,7 +36,7 @@ async def link_course_to_product(
raise HTTPException(status_code=404, detail="Product not found")
# Check if course is already linked to another product
statement = select(PaymentCourse).where(PaymentCourse.course_id == course.id)
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
existing_link = db_session.exec(statement).first()
if existing_link:
@ -46,7 +46,7 @@ async def link_course_to_product(
)
# Create new payment course link
payment_course = PaymentCourse(
payment_course = PaymentsCourse(
course_id=course.id, # type: ignore
payment_product_id=product_id,
org_id=org_id,
@ -75,9 +75,9 @@ async def unlink_course_from_product(
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
# Find and delete the payment course link
statement = select(PaymentCourse).where(
PaymentCourse.course_id == course.id,
PaymentCourse.org_id == org_id
statement = select(PaymentsCourse).where(
PaymentsCourse.course_id == course.id,
PaymentsCourse.org_id == org_id
)
payment_course = db_session.exec(statement).first()
@ -113,12 +113,14 @@ async def get_courses_by_product(
statement = (
select(Course)
.select_from(Course)
.join(PaymentCourse, Course.id == PaymentCourse.course_id) # type: ignore
.join(PaymentsCourse, Course.id == PaymentsCourse.course_id) # type: ignore
.where(
PaymentCourse.payment_product_id == product_id,
PaymentCourse.org_id == org_id
PaymentsCourse.payment_product_id == product_id,
PaymentsCourse.org_id == org_id
)
)
courses = db_session.exec(statement).all()
return courses

View file

@ -1,6 +1,8 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.courses.courses import Course
from src.db.payments.payments import PaymentsConfig
from src.db.payments.payments_courses import PaymentsCourse
from src.db.payments.payments_products import (
PaymentsProduct,
PaymentsProductCreate,
@ -12,7 +14,7 @@ from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
from datetime import datetime
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product
async def create_payments_product(
request: Request,
@ -163,3 +165,36 @@ async def list_payments_products(
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]
async def get_products_by_course(
request: Request,
org_id: int,
course_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
) -> list[PaymentsProductRead]:
# Check if course exists and user has permission
statement = select(Course).where(Course.id == course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(status_code=404, detail="Course not found")
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
# Get all products linked to this course with explicit join
statement = (
select(PaymentsProduct)
.select_from(PaymentsProduct)
.join(PaymentsCourse, PaymentsProduct.id == PaymentsCourse.payment_product_id) # type: ignore
.where(
PaymentsCourse.course_id == course_id,
PaymentsCourse.org_id == org_id
)
)
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]

View file

@ -0,0 +1,181 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from typing import Any
from src.db.payments.payments_users import PaymentsUser, PaymentStatusEnum, ProviderSpecificData
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, PublicUser, AnonymousUser
from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
from datetime import datetime
async def create_payment_user(
request: Request,
org_id: int,
user_id: int,
product_id: int,
status: PaymentStatusEnum,
provider_data: Any,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
# Check if product exists
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id,
PaymentsProduct.org_id == org_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
provider_specific_data = ProviderSpecificData(
stripe_customer=provider_data if provider_data else None,
)
# Check if user already has a payment user
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user_id,
PaymentsUser.org_id == org_id
)
existing_payment_user = db_session.exec(statement).first()
if existing_payment_user:
raise HTTPException(status_code=400, detail="User already has purchase for this product")
# Create new payment user
payment_user = PaymentsUser(
user_id=user_id,
org_id=org_id,
payment_product_id=product_id,
provider_specific_data=provider_specific_data.model_dump(),
status=status
)
db_session.add(payment_user)
db_session.commit()
db_session.refresh(payment_user)
return payment_user
async def get_payment_user(
request: Request,
org_id: int,
payment_user_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
return payment_user
async def update_payment_user_status(
request: Request,
org_id: int,
payment_user_id: int,
status: PaymentStatusEnum,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
# Get existing payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
# Update status
payment_user.status = status
payment_user.update_date = datetime.now()
db_session.add(payment_user)
db_session.commit()
db_session.refresh(payment_user)
return payment_user
async def list_payment_users(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> list[PaymentsUser]:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get all payment users for org ordered by id
statement = select(PaymentsUser).where(
PaymentsUser.org_id == org_id
).order_by(PaymentsUser.id.desc()) # type: ignore
payment_users = list(db_session.exec(statement).all()) # Convert to list
return payment_users
async def delete_payment_user(
request: Request,
org_id: int,
payment_user_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> None:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "delete", db_session)
# Get existing payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
# Delete payment user
db_session.delete(payment_user)
db_session.commit()

View file

@ -0,0 +1,260 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
import stripe
from datetime import datetime
from typing import Callable, Dict
import logging
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, User
from src.services.payments.payments_users import create_payment_user, update_payment_user_status
from src.services.payments.stripe import get_stripe_credentials
logger = logging.getLogger(__name__)
async def get_user_from_customer(customer_id: str, db_session: Session) -> User:
"""Helper function to get user from Stripe customer ID"""
try:
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if not user:
raise HTTPException(status_code=404, detail=f"User not found for customer {customer_id}")
return user
except stripe.StripeError as e:
logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}")
raise HTTPException(status_code=400, detail="Error retrieving customer information")
async def get_product_from_stripe_id(product_id: str, db_session: Session) -> PaymentsProduct:
"""Helper function to get product from Stripe product ID"""
statement = select(PaymentsProduct).where(PaymentsProduct.provider_product_id == product_id)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail=f"Product not found: {product_id}")
return product
async def handle_stripe_webhook(
request: Request,
org_id: int,
db_session: Session,
) -> dict:
# Get Stripe credentials for the organization
creds = await get_stripe_credentials(request, org_id, InternalUser(), db_session)
# Get the webhook secret and API key from credentials
webhook_secret = creds.get('stripe_webhook_secret')
stripe.api_key = creds.get('stripe_secret_key') # Set API key globally
if not webhook_secret:
raise HTTPException(status_code=400, detail="Stripe webhook secret not configured")
# Get the raw request body
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
# Verify webhook signature and construct event
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
# Get the appropriate handler
handler = STRIPE_EVENT_HANDLERS.get(event.type)
if handler:
await handler(request, event.data.object, org_id, db_session)
return {"status": "success", "event": event.type}
else:
logger.info(f"Unhandled event type: {event.type}")
return {"status": "ignored", "event": event.type}
except Exception as e:
logger.error(f"Error processing webhook: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}")
async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session):
# Get the customer and product details from the session
customer_email = session.customer_details.email
product_id = session.line_items.data[0].price.product
# Use helper functions
user = await get_user_from_customer(session.customer, db_session)
product = await get_product_from_stripe_id(product_id, db_session)
# Find payment user record
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == product.id
)
payment_user = db_session.exec(statement).first()
# Update status to completed
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.COMPLETED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
# Get customer email from Stripe
customer = stripe.Customer.retrieve(customer_id)
# Find user and create/update payment record
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
payment_user = await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
current_user=InternalUser(),
db_session=db_session
)
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.ACTIVE,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=status,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session):
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
# Get product_id directly from metadata
product_id = payment_intent.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}")
raise HTTPException(status_code=400, detail="No product_id found in payment metadata")
if user:
await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
status=PaymentStatusEnum.COMPLETED,
provider_data=customer,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session):
# Update payment status to failed
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.org_id == org_id,
PaymentsUser.status == PaymentStatusEnum.PENDING
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
# Create event handler mapping
STRIPE_EVENT_HANDLERS = {
'checkout.session.completed': handle_checkout_session_completed,
'customer.subscription.created': handle_subscription_created,
'customer.subscription.updated': handle_subscription_updated,
'customer.subscription.deleted': handle_subscription_deleted,
'payment_intent.succeeded': handle_payment_succeeded,
'payment_intent.payment_failed': handle_payment_failed,
}

View file

@ -1,15 +1,16 @@
from fastapi import HTTPException, Request
from sqlmodel import Session
import stripe
from config.config import get_learnhouse_config
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.services.payments.payments_config import get_payments_config
from sqlmodel import select
async def get_stripe_credentials(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
):
configs = await get_payments_config(request, org_id, current_user, db_session)
@ -149,6 +150,94 @@ async def update_stripe_product(
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")
async def create_checkout_session(
request: Request,
org_id: int,
product_id: int,
redirect_uri: str,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
# Get Stripe credentials
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
stripe.api_key = creds.get('stripe_secret_key')
# Get product details
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id,
PaymentsProduct.org_id == org_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
success_url = redirect_uri
cancel_url = redirect_uri
# Get the default price for the product
stripe_product = stripe.Product.retrieve(product.provider_product_id)
line_items = [{
"price": stripe_product.default_price,
"quantity": 1
}]
# Create or retrieve Stripe customer
try:
customers = stripe.Customer.list(email=current_user.email)
if customers.data:
customer = customers.data[0]
else:
customer = stripe.Customer.create(
email=current_user.email,
metadata={
"user_id": str(current_user.id),
"org_id": str(org_id)
}
)
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
# Create checkout session with customer
try:
checkout_session_params = {
"success_url": success_url,
"cancel_url": cancel_url,
"mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription',
"line_items": line_items,
"customer": customer.id,
"metadata": {
"product_id": str(product.id)
}
}
# Add payment_intent_data only for one-time payments
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
checkout_session_params["payment_intent_data"] = {
"metadata": {
"product_id": str(product.id)
}
}
# Add subscription_data for subscription payments
else:
checkout_session_params["subscription_data"] = {
"metadata": {
"product_id": str(product.id)
}
}
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
return {
"checkout_url": checkout_session.url,
"session_id": checkout_session.id
}
except stripe.StripeError as e:
print(f"Error creating checkout session: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))