feat: add course checkout UI and stripe integration and webhook wip

This commit is contained in:
swve 2024-11-01 20:51:52 +01:00
parent d8913d1a60
commit 1bff401e73
18 changed files with 1086 additions and 131 deletions

View file

@ -26,7 +26,7 @@ from src.security.rbac.rbac import (
authorization_verify_based_on_org_admin_status,
authorization_verify_if_user_is_anon,
)
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.db.user_organizations import UserOrganization
from src.db.organizations import (
Organization,
@ -682,13 +682,17 @@ async def get_org_join_mechanism(
async def rbac_check(
request: Request,
org_uuid: str,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
action: Literal["create", "read", "update", "delete"],
db_session: Session,
):
# Organizations are readable by anyone
if action == "read":
return True
# Internal users can do anything
if isinstance(current_user, InternalUser):
return True
else:
isUserAnon = await authorization_verify_if_user_is_anon(current_user.id)

View file

@ -6,7 +6,7 @@ from src.db.payments.payments import (
PaymentsConfigUpdate,
PaymentsConfigRead,
)
from src.db.users import PublicUser, AnonymousUser
from src.db.users import PublicUser, AnonymousUser, InternalUser
from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
@ -48,7 +48,7 @@ async def create_payments_config(
async def get_payments_config(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> list[PaymentsConfigRead]:
# Check if organization exists

View file

@ -1,7 +1,7 @@
from datetime import datetime
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.payments.payments_courses import PaymentCourse
from src.db.payments.payments_courses import PaymentsCourse
from src.db.payments.payments_products import PaymentsProduct
from src.db.courses.courses import Course
from src.db.users import PublicUser, AnonymousUser
@ -36,7 +36,7 @@ async def link_course_to_product(
raise HTTPException(status_code=404, detail="Product not found")
# Check if course is already linked to another product
statement = select(PaymentCourse).where(PaymentCourse.course_id == course.id)
statement = select(PaymentsCourse).where(PaymentsCourse.course_id == course.id)
existing_link = db_session.exec(statement).first()
if existing_link:
@ -46,7 +46,7 @@ async def link_course_to_product(
)
# Create new payment course link
payment_course = PaymentCourse(
payment_course = PaymentsCourse(
course_id=course.id, # type: ignore
payment_product_id=product_id,
org_id=org_id,
@ -75,9 +75,9 @@ async def unlink_course_from_product(
await rbac_check(request, course.course_uuid, current_user, "update", db_session)
# Find and delete the payment course link
statement = select(PaymentCourse).where(
PaymentCourse.course_id == course.id,
PaymentCourse.org_id == org_id
statement = select(PaymentsCourse).where(
PaymentsCourse.course_id == course.id,
PaymentsCourse.org_id == org_id
)
payment_course = db_session.exec(statement).first()
@ -113,12 +113,14 @@ async def get_courses_by_product(
statement = (
select(Course)
.select_from(Course)
.join(PaymentCourse, Course.id == PaymentCourse.course_id) # type: ignore
.join(PaymentsCourse, Course.id == PaymentsCourse.course_id) # type: ignore
.where(
PaymentCourse.payment_product_id == product_id,
PaymentCourse.org_id == org_id
PaymentsCourse.payment_product_id == product_id,
PaymentsCourse.org_id == org_id
)
)
courses = db_session.exec(statement).all()
return courses

View file

@ -1,6 +1,8 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from src.db.courses.courses import Course
from src.db.payments.payments import PaymentsConfig
from src.db.payments.payments_courses import PaymentsCourse
from src.db.payments.payments_products import (
PaymentsProduct,
PaymentsProductCreate,
@ -12,7 +14,7 @@ from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
from datetime import datetime
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product
async def create_payments_product(
request: Request,
@ -163,3 +165,36 @@ async def list_payments_products(
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]
async def get_products_by_course(
request: Request,
org_id: int,
course_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
) -> list[PaymentsProductRead]:
# Check if course exists and user has permission
statement = select(Course).where(Course.id == course_id)
course = db_session.exec(statement).first()
if not course:
raise HTTPException(status_code=404, detail="Course not found")
# RBAC check
await rbac_check(request, course.course_uuid, current_user, "read", db_session)
# Get all products linked to this course with explicit join
statement = (
select(PaymentsProduct)
.select_from(PaymentsProduct)
.join(PaymentsCourse, PaymentsProduct.id == PaymentsCourse.payment_product_id) # type: ignore
.where(
PaymentsCourse.course_id == course_id,
PaymentsCourse.org_id == org_id
)
)
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]

View file

@ -0,0 +1,181 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
from typing import Any
from src.db.payments.payments_users import PaymentsUser, PaymentStatusEnum, ProviderSpecificData
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, PublicUser, AnonymousUser
from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check
from datetime import datetime
async def create_payment_user(
request: Request,
org_id: int,
user_id: int,
product_id: int,
status: PaymentStatusEnum,
provider_data: Any,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
# Check if product exists
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id,
PaymentsProduct.org_id == org_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
provider_specific_data = ProviderSpecificData(
stripe_customer=provider_data if provider_data else None,
)
# Check if user already has a payment user
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user_id,
PaymentsUser.org_id == org_id
)
existing_payment_user = db_session.exec(statement).first()
if existing_payment_user:
raise HTTPException(status_code=400, detail="User already has purchase for this product")
# Create new payment user
payment_user = PaymentsUser(
user_id=user_id,
org_id=org_id,
payment_product_id=product_id,
provider_specific_data=provider_specific_data.model_dump(),
status=status
)
db_session.add(payment_user)
db_session.commit()
db_session.refresh(payment_user)
return payment_user
async def get_payment_user(
request: Request,
org_id: int,
payment_user_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
return payment_user
async def update_payment_user_status(
request: Request,
org_id: int,
payment_user_id: int,
status: PaymentStatusEnum,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> PaymentsUser:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "update", db_session)
# Get existing payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
# Update status
payment_user.status = status
payment_user.update_date = datetime.now()
db_session.add(payment_user)
db_session.commit()
db_session.refresh(payment_user)
return payment_user
async def list_payment_users(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> list[PaymentsUser]:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get all payment users for org ordered by id
statement = select(PaymentsUser).where(
PaymentsUser.org_id == org_id
).order_by(PaymentsUser.id.desc()) # type: ignore
payment_users = list(db_session.exec(statement).all()) # Convert to list
return payment_users
async def delete_payment_user(
request: Request,
org_id: int,
payment_user_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
) -> None:
# Check if organization exists
statement = select(Organization).where(Organization.id == org_id)
org = db_session.exec(statement).first()
if not org:
raise HTTPException(status_code=404, detail="Organization not found")
# RBAC check
await rbac_check(request, org.org_uuid, current_user, "delete", db_session)
# Get existing payment user
statement = select(PaymentsUser).where(
PaymentsUser.id == payment_user_id,
PaymentsUser.org_id == org_id
)
payment_user = db_session.exec(statement).first()
if not payment_user:
raise HTTPException(status_code=404, detail="Payment user not found")
# Delete payment user
db_session.delete(payment_user)
db_session.commit()

View file

@ -0,0 +1,260 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
import stripe
from datetime import datetime
from typing import Callable, Dict
import logging
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, User
from src.services.payments.payments_users import create_payment_user, update_payment_user_status
from src.services.payments.stripe import get_stripe_credentials
logger = logging.getLogger(__name__)
async def get_user_from_customer(customer_id: str, db_session: Session) -> User:
"""Helper function to get user from Stripe customer ID"""
try:
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if not user:
raise HTTPException(status_code=404, detail=f"User not found for customer {customer_id}")
return user
except stripe.StripeError as e:
logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}")
raise HTTPException(status_code=400, detail="Error retrieving customer information")
async def get_product_from_stripe_id(product_id: str, db_session: Session) -> PaymentsProduct:
"""Helper function to get product from Stripe product ID"""
statement = select(PaymentsProduct).where(PaymentsProduct.provider_product_id == product_id)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail=f"Product not found: {product_id}")
return product
async def handle_stripe_webhook(
request: Request,
org_id: int,
db_session: Session,
) -> dict:
# Get Stripe credentials for the organization
creds = await get_stripe_credentials(request, org_id, InternalUser(), db_session)
# Get the webhook secret and API key from credentials
webhook_secret = creds.get('stripe_webhook_secret')
stripe.api_key = creds.get('stripe_secret_key') # Set API key globally
if not webhook_secret:
raise HTTPException(status_code=400, detail="Stripe webhook secret not configured")
# Get the raw request body
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
# Verify webhook signature and construct event
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
# Get the appropriate handler
handler = STRIPE_EVENT_HANDLERS.get(event.type)
if handler:
await handler(request, event.data.object, org_id, db_session)
return {"status": "success", "event": event.type}
else:
logger.info(f"Unhandled event type: {event.type}")
return {"status": "ignored", "event": event.type}
except Exception as e:
logger.error(f"Error processing webhook: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}")
async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session):
# Get the customer and product details from the session
customer_email = session.customer_details.email
product_id = session.line_items.data[0].price.product
# Use helper functions
user = await get_user_from_customer(session.customer, db_session)
product = await get_product_from_stripe_id(product_id, db_session)
# Find payment user record
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == product.id
)
payment_user = db_session.exec(statement).first()
# Update status to completed
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.COMPLETED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
# Get customer email from Stripe
customer = stripe.Customer.retrieve(customer_id)
# Find user and create/update payment record
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
payment_user = await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
current_user=InternalUser(),
db_session=db_session
)
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.ACTIVE,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=status,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session):
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
# Get product_id directly from metadata
product_id = payment_intent.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}")
raise HTTPException(status_code=400, detail="No product_id found in payment metadata")
if user:
await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
status=PaymentStatusEnum.COMPLETED,
provider_data=customer,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session):
# Update payment status to failed
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.org_id == org_id,
PaymentsUser.status == PaymentStatusEnum.PENDING
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
# Create event handler mapping
STRIPE_EVENT_HANDLERS = {
'checkout.session.completed': handle_checkout_session_completed,
'customer.subscription.created': handle_subscription_created,
'customer.subscription.updated': handle_subscription_updated,
'customer.subscription.deleted': handle_subscription_deleted,
'payment_intent.succeeded': handle_payment_succeeded,
'payment_intent.payment_failed': handle_payment_failed,
}

View file

@ -1,15 +1,16 @@
from fastapi import HTTPException, Request
from sqlmodel import Session
import stripe
from config.config import get_learnhouse_config
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
from src.db.users import AnonymousUser, PublicUser
from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.services.payments.payments_config import get_payments_config
from sqlmodel import select
async def get_stripe_credentials(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
):
configs = await get_payments_config(request, org_id, current_user, db_session)
@ -149,6 +150,94 @@ async def update_stripe_product(
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")
async def create_checkout_session(
request: Request,
org_id: int,
product_id: int,
redirect_uri: str,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
# Get Stripe credentials
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
stripe.api_key = creds.get('stripe_secret_key')
# Get product details
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id,
PaymentsProduct.org_id == org_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
success_url = redirect_uri
cancel_url = redirect_uri
# Get the default price for the product
stripe_product = stripe.Product.retrieve(product.provider_product_id)
line_items = [{
"price": stripe_product.default_price,
"quantity": 1
}]
# Create or retrieve Stripe customer
try:
customers = stripe.Customer.list(email=current_user.email)
if customers.data:
customer = customers.data[0]
else:
customer = stripe.Customer.create(
email=current_user.email,
metadata={
"user_id": str(current_user.id),
"org_id": str(org_id)
}
)
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
# Create checkout session with customer
try:
checkout_session_params = {
"success_url": success_url,
"cancel_url": cancel_url,
"mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription',
"line_items": line_items,
"customer": customer.id,
"metadata": {
"product_id": str(product.id)
}
}
# Add payment_intent_data only for one-time payments
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
checkout_session_params["payment_intent_data"] = {
"metadata": {
"product_id": str(product.id)
}
}
# Add subscription_data for subscription payments
else:
checkout_session_params["subscription_data"] = {
"metadata": {
"product_id": str(product.id)
}
}
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
return {
"checkout_url": checkout_session.url,
"session_id": checkout_session.id
}
except stripe.StripeError as e:
print(f"Error creating checkout session: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))