From b7f09885df8c05a7bd6bd046dd47bf0f9f2be393 Mon Sep 17 00:00:00 2001 From: swve Date: Sat, 2 Nov 2024 22:17:01 +0100 Subject: [PATCH] feat: support subscriptions and onetime payments w/ webhooks --- apps/api/src/core/events/database.py | 3 +- .../src/services/payments/payments_courses.py | 1 - .../services/payments/payments_products.py | 2 +- .../src/services/payments/payments_users.py | 10 +- .../src/services/payments/payments_webhook.py | 233 ++++-------------- apps/api/src/services/payments/stripe.py | 46 +++- 6 files changed, 94 insertions(+), 201 deletions(-) diff --git a/apps/api/src/core/events/database.py b/apps/api/src/core/events/database.py index 53120f7a..e910f628 100644 --- a/apps/api/src/core/events/database.py +++ b/apps/api/src/core/events/database.py @@ -1,10 +1,9 @@ import logging import os import importlib -from typing import Optional from config.config import get_learnhouse_config from fastapi import FastAPI -from sqlmodel import Field, SQLModel, Session, create_engine +from sqlmodel import SQLModel, Session, create_engine def import_all_models(): base_dir = 'src/db' diff --git a/apps/api/src/services/payments/payments_courses.py b/apps/api/src/services/payments/payments_courses.py index 2e54c8c7..d35d9cfe 100644 --- a/apps/api/src/services/payments/payments_courses.py +++ b/apps/api/src/services/payments/payments_courses.py @@ -1,4 +1,3 @@ -from datetime import datetime from fastapi import HTTPException, Request from sqlmodel import Session, select from src.db.payments.payments_courses import PaymentsCourse diff --git a/apps/api/src/services/payments/payments_products.py b/apps/api/src/services/payments/payments_products.py index cd3f5b8b..d0963fe0 100644 --- a/apps/api/src/services/payments/payments_products.py +++ b/apps/api/src/services/payments/payments_products.py @@ -14,7 +14,7 @@ from src.db.organizations import Organization from src.services.orgs.orgs import rbac_check from datetime import datetime -from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product +from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product async def create_payments_product( request: Request, diff --git a/apps/api/src/services/payments/payments_users.py b/apps/api/src/services/payments/payments_users.py index 253364bd..5b34e8c6 100644 --- a/apps/api/src/services/payments/payments_users.py +++ b/apps/api/src/services/payments/payments_users.py @@ -43,12 +43,18 @@ async def create_payment_user( # Check if user already has a payment user statement = select(PaymentsUser).where( PaymentsUser.user_id == user_id, - PaymentsUser.org_id == org_id + PaymentsUser.org_id == org_id, + PaymentsUser.payment_product_id == product_id ) existing_payment_user = db_session.exec(statement).first() if existing_payment_user: - raise HTTPException(status_code=400, detail="User already has purchase for this product") + if existing_payment_user.status == PaymentStatusEnum.PENDING: + # Delete existing pending payment + db_session.delete(existing_payment_user) + db_session.commit() + else: + raise HTTPException(status_code=400, detail="User already has purchase for this product") # Create new payment user payment_user = PaymentsUser( diff --git a/apps/api/src/services/payments/payments_webhook.py b/apps/api/src/services/payments/payments_webhook.py index 9ff3bb4a..0f79f265 100644 --- a/apps/api/src/services/payments/payments_webhook.py +++ b/apps/api/src/services/payments/payments_webhook.py @@ -1,14 +1,12 @@ from fastapi import HTTPException, Request from sqlmodel import Session, select import stripe -from datetime import datetime -from typing import Callable, Dict import logging -from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser +from src.db.payments.payments_users import PaymentStatusEnum from src.db.payments.payments_products import PaymentsProduct from src.db.users import InternalUser, User -from src.services.payments.payments_users import create_payment_user, update_payment_user_status +from src.services.payments.payments_users import update_payment_user_status from src.services.payments.stripe import get_stripe_credentials logger = logging.getLogger(__name__) @@ -54,207 +52,70 @@ async def handle_stripe_webhook( sig_header = request.headers.get('stripe-signature') try: - # Verify webhook signature and construct event event = stripe.Webhook.construct_event( payload, sig_header, webhook_secret ) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid payload") + except stripe.SignatureVerificationError: + raise HTTPException(status_code=400, detail="Invalid signature") + + # Handle the event + if event.type == 'checkout.session.completed': + session = event.data.object + payment_user_id = int(session.get('metadata', {}).get('payment_user_id')) - # Get the appropriate handler - handler = STRIPE_EVENT_HANDLERS.get(event.type) - if handler: - await handler(request, event.data.object, org_id, db_session) - return {"status": "success", "event": event.type} + if session.get('mode') == 'subscription': + # Handle subscription payment + if session.get('subscription'): + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user_id, + status=PaymentStatusEnum.ACTIVE, + current_user=InternalUser(), + db_session=db_session + ) else: - logger.info(f"Unhandled event type: {event.type}") - return {"status": "ignored", "event": event.type} + # Handle one-time payment + if session.get('payment_status') == 'paid': + await update_payment_user_status( + request=request, + org_id=org_id, + payment_user_id=payment_user_id, + status=PaymentStatusEnum.COMPLETED, + current_user=InternalUser(), + db_session=db_session + ) - except Exception as e: - logger.error(f"Error processing webhook: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}") - -async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session): - # Get the customer and product details from the session - customer_email = session.customer_details.email - product_id = session.line_items.data[0].price.product - - # Use helper functions - user = await get_user_from_customer(session.customer, db_session) - product = await get_product_from_stripe_id(product_id, db_session) - - # Find payment user record - statement = select(PaymentsUser).where( - PaymentsUser.user_id == user.id, - PaymentsUser.payment_product_id == product.id - ) - payment_user = db_session.exec(statement).first() - - # Update status to completed - await update_payment_user_status( - request=request, - org_id=org_id, - payment_user_id=payment_user.id, # type: ignore - status=PaymentStatusEnum.COMPLETED, - current_user=InternalUser(), - db_session=db_session - ) - -async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session): - customer_id = subscription.customer - - # Get product_id from metadata - product_id = subscription.metadata.get('product_id') - if not product_id: - logger.error(f"No product_id found in subscription metadata: {subscription.id}") - raise HTTPException(status_code=400, detail="No product_id found in subscription metadata") - - # Get customer email from Stripe - customer = stripe.Customer.retrieve(customer_id) - - # Find user and create/update payment record - statement = select(User).where(User.email == customer.email) - user = db_session.exec(statement).first() - - if user: - payment_user = await create_payment_user( - request=request, - org_id=org_id, - user_id=user.id, # type: ignore - product_id=int(product_id), # Convert string from metadata to int - current_user=InternalUser(), - db_session=db_session - ) + elif event.type == 'customer.subscription.deleted': + subscription = event.data.object + payment_user_id = int(subscription.get('metadata', {}).get('payment_user_id')) await update_payment_user_status( request=request, org_id=org_id, - payment_user_id=payment_user.id, # type: ignore - status=PaymentStatusEnum.ACTIVE, + payment_user_id=payment_user_id, + status=PaymentStatusEnum.CANCELLED, current_user=InternalUser(), db_session=db_session ) -async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session): - customer_id = subscription.customer - - # Get product_id from metadata - product_id = subscription.metadata.get('product_id') - if not product_id: - logger.error(f"No product_id found in subscription metadata: {subscription.id}") - raise HTTPException(status_code=400, detail="No product_id found in subscription metadata") - - customer = stripe.Customer.retrieve(customer_id) - - statement = select(User).where(User.email == customer.email) - user = db_session.exec(statement).first() - - if user: - statement = select(PaymentsUser).where( - PaymentsUser.user_id == user.id, - PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int - ) - payment_user = db_session.exec(statement).first() + elif event.type == 'payment_intent.payment_failed': + payment_intent = event.data.object + payment_user_id = int(payment_intent.get('metadata', {}).get('payment_user_id')) - if payment_user: - status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING - await update_payment_user_status( - request=request, - org_id=org_id, - payment_user_id=payment_user.id, # type: ignore - status=status, - current_user=InternalUser(), - db_session=db_session - ) - -async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session): - customer_id = subscription.customer - - # Get product_id from metadata - product_id = subscription.metadata.get('product_id') - if not product_id: - logger.error(f"No product_id found in subscription metadata: {subscription.id}") - raise HTTPException(status_code=400, detail="No product_id found in subscription metadata") - - customer = stripe.Customer.retrieve(customer_id) - - statement = select(User).where(User.email == customer.email) - user = db_session.exec(statement).first() - - if user: - statement = select(PaymentsUser).where( - PaymentsUser.user_id == user.id, - PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int - ) - payment_user = db_session.exec(statement).first() - - if payment_user: - await update_payment_user_status( - request=request, - org_id=org_id, - payment_user_id=payment_user.id, # type: ignore - status=PaymentStatusEnum.FAILED, - current_user=InternalUser(), - db_session=db_session - ) - -async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session): - customer_id = payment_intent.customer - - customer = stripe.Customer.retrieve(customer_id) - - statement = select(User).where(User.email == customer.email) - user = db_session.exec(statement).first() - - # Get product_id directly from metadata - product_id = payment_intent.metadata.get('product_id') - if not product_id: - logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}") - raise HTTPException(status_code=400, detail="No product_id found in payment metadata") - - if user: - await create_payment_user( + await update_payment_user_status( request=request, org_id=org_id, - user_id=user.id, # type: ignore - product_id=int(product_id), # Convert string from metadata to int - status=PaymentStatusEnum.COMPLETED, - provider_data=customer, + payment_user_id=payment_user_id, + status=PaymentStatusEnum.FAILED, current_user=InternalUser(), db_session=db_session ) -async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session): - # Update payment status to failed - customer_id = payment_intent.customer - - customer = stripe.Customer.retrieve(customer_id) - - statement = select(User).where(User.email == customer.email) - user = db_session.exec(statement).first() - - if user: - statement = select(PaymentsUser).where( - PaymentsUser.user_id == user.id, - PaymentsUser.org_id == org_id, - PaymentsUser.status == PaymentStatusEnum.PENDING - ) - payment_user = db_session.exec(statement).first() - - if payment_user: - await update_payment_user_status( - request=request, - org_id=org_id, - payment_user_id=payment_user.id, # type: ignore - status=PaymentStatusEnum.FAILED, - current_user=InternalUser(), - db_session=db_session - ) + return {"status": "success"} -# Create event handler mapping -STRIPE_EVENT_HANDLERS = { - 'checkout.session.completed': handle_checkout_session_completed, - 'customer.subscription.created': handle_subscription_created, - 'customer.subscription.updated': handle_subscription_updated, - 'customer.subscription.deleted': handle_subscription_deleted, - 'payment_intent.succeeded': handle_payment_succeeded, - 'payment_intent.payment_failed': handle_payment_failed, -} + + + \ No newline at end of file diff --git a/apps/api/src/services/payments/stripe.py b/apps/api/src/services/payments/stripe.py index e987ddff..6154b9a0 100644 --- a/apps/api/src/services/payments/stripe.py +++ b/apps/api/src/services/payments/stripe.py @@ -1,12 +1,15 @@ +import logging from fastapi import HTTPException, Request from sqlmodel import Session import stripe -from config.config import get_learnhouse_config from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct +from src.db.payments.payments_users import PaymentStatusEnum from src.db.users import AnonymousUser, InternalUser, PublicUser from src.services.payments.payments_config import get_payments_config from sqlmodel import select +from src.services.payments.payments_users import create_payment_user, delete_payment_user + async def get_stripe_credentials( request: Request, org_id: int, @@ -179,9 +182,9 @@ async def create_checkout_session( # Get the default price for the product stripe_product = stripe.Product.retrieve(product.provider_product_id) line_items = [{ - "price": stripe_product.default_price, - "quantity": 1 - }] + "price": stripe_product.default_price, + "quantity": 1 + }] # Create or retrieve Stripe customer try: @@ -193,10 +196,29 @@ async def create_checkout_session( email=current_user.email, metadata={ "user_id": str(current_user.id), - "org_id": str(org_id) + "org_id": str(org_id), } ) + + # Create initial payment user with pending status + payment_user = await create_payment_user( + request=request, + org_id=org_id, + user_id=current_user.id, + product_id=product_id, + status=PaymentStatusEnum.PENDING, + provider_data=customer, + current_user=current_user, + db_session=db_session + ) + + if not payment_user: + raise HTTPException(status_code=400, detail="Error creating payment user") + except stripe.StripeError as e: + # Clean up payment user if customer creation fails + if payment_user and payment_user.id: + await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session) raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}") # Create checkout session with customer @@ -208,7 +230,8 @@ async def create_checkout_session( "line_items": line_items, "customer": customer.id, "metadata": { - "product_id": str(product.id) + "product_id": str(product.id), + "payment_user_id": str(payment_user.id) } } @@ -216,14 +239,16 @@ async def create_checkout_session( if product.product_type == PaymentProductTypeEnum.ONE_TIME: checkout_session_params["payment_intent_data"] = { "metadata": { - "product_id": str(product.id) + "product_id": str(product.id), + "payment_user_id": str(payment_user.id) } } # Add subscription_data for subscription payments else: checkout_session_params["subscription_data"] = { "metadata": { - "product_id": str(product.id) + "product_id": str(product.id), + "payment_user_id": str(payment_user.id) } } @@ -235,7 +260,10 @@ async def create_checkout_session( } except stripe.StripeError as e: - print(f"Error creating checkout session: {str(e)}") + # Clean up payment user if checkout session creation fails + if payment_user and payment_user.id: + await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session) + logging.error(f"Error creating checkout session: {str(e)}") raise HTTPException(status_code=400, detail=str(e))