feat: support subscriptions and onetime payments w/ webhooks

This commit is contained in:
swve 2024-11-02 22:17:01 +01:00
parent 1bff401e73
commit b7f09885df
6 changed files with 94 additions and 201 deletions

View file

@ -1,10 +1,9 @@
import logging import logging
import os import os
import importlib import importlib
from typing import Optional
from config.config import get_learnhouse_config from config.config import get_learnhouse_config
from fastapi import FastAPI from fastapi import FastAPI
from sqlmodel import Field, SQLModel, Session, create_engine from sqlmodel import SQLModel, Session, create_engine
def import_all_models(): def import_all_models():
base_dir = 'src/db' base_dir = 'src/db'

View file

@ -1,4 +1,3 @@
from datetime import datetime
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from sqlmodel import Session, select from sqlmodel import Session, select
from src.db.payments.payments_courses import PaymentsCourse from src.db.payments.payments_courses import PaymentsCourse

View file

@ -14,7 +14,7 @@ from src.db.organizations import Organization
from src.services.orgs.orgs import rbac_check from src.services.orgs.orgs import rbac_check
from datetime import datetime from datetime import datetime
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
async def create_payments_product( async def create_payments_product(
request: Request, request: Request,

View file

@ -43,11 +43,17 @@ async def create_payment_user(
# Check if user already has a payment user # Check if user already has a payment user
statement = select(PaymentsUser).where( statement = select(PaymentsUser).where(
PaymentsUser.user_id == user_id, PaymentsUser.user_id == user_id,
PaymentsUser.org_id == org_id PaymentsUser.org_id == org_id,
PaymentsUser.payment_product_id == product_id
) )
existing_payment_user = db_session.exec(statement).first() existing_payment_user = db_session.exec(statement).first()
if existing_payment_user: if existing_payment_user:
if existing_payment_user.status == PaymentStatusEnum.PENDING:
# Delete existing pending payment
db_session.delete(existing_payment_user)
db_session.commit()
else:
raise HTTPException(status_code=400, detail="User already has purchase for this product") raise HTTPException(status_code=400, detail="User already has purchase for this product")
# Create new payment user # Create new payment user

View file

@ -1,14 +1,12 @@
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from sqlmodel import Session, select from sqlmodel import Session, select
import stripe import stripe
from datetime import datetime
from typing import Callable, Dict
import logging import logging
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser from src.db.payments.payments_users import PaymentStatusEnum
from src.db.payments.payments_products import PaymentsProduct from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, User from src.db.users import InternalUser, User
from src.services.payments.payments_users import create_payment_user, update_payment_user_status from src.services.payments.payments_users import update_payment_user_status
from src.services.payments.stripe import get_stripe_credentials from src.services.payments.stripe import get_stripe_credentials
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,207 +52,70 @@ async def handle_stripe_webhook(
sig_header = request.headers.get('stripe-signature') sig_header = request.headers.get('stripe-signature')
try: try:
# Verify webhook signature and construct event
event = stripe.Webhook.construct_event( event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret payload, sig_header, webhook_secret
) )
except ValueError:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
raise HTTPException(status_code=400, detail="Invalid signature")
# Get the appropriate handler # Handle the event
handler = STRIPE_EVENT_HANDLERS.get(event.type) if event.type == 'checkout.session.completed':
if handler: session = event.data.object
await handler(request, event.data.object, org_id, db_session) payment_user_id = int(session.get('metadata', {}).get('payment_user_id'))
return {"status": "success", "event": event.type}
else:
logger.info(f"Unhandled event type: {event.type}")
return {"status": "ignored", "event": event.type}
except Exception as e: if session.get('mode') == 'subscription':
logger.error(f"Error processing webhook: {str(e)}", exc_info=True) # Handle subscription payment
raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}") if session.get('subscription'):
async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session):
# Get the customer and product details from the session
customer_email = session.customer_details.email
product_id = session.line_items.data[0].price.product
# Use helper functions
user = await get_user_from_customer(session.customer, db_session)
product = await get_product_from_stripe_id(product_id, db_session)
# Find payment user record
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == product.id
)
payment_user = db_session.exec(statement).first()
# Update status to completed
await update_payment_user_status( await update_payment_user_status(
request=request, request=request,
org_id=org_id, org_id=org_id,
payment_user_id=payment_user.id, # type: ignore payment_user_id=payment_user_id,
status=PaymentStatusEnum.COMPLETED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
# Get customer email from Stripe
customer = stripe.Customer.retrieve(customer_id)
# Find user and create/update payment record
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
payment_user = await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
current_user=InternalUser(),
db_session=db_session
)
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.ACTIVE, status=PaymentStatusEnum.ACTIVE,
current_user=InternalUser(), current_user=InternalUser(),
db_session=db_session db_session=db_session
) )
else:
async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session): # Handle one-time payment
customer_id = subscription.customer if session.get('payment_status') == 'paid':
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING
await update_payment_user_status( await update_payment_user_status(
request=request, request=request,
org_id=org_id, org_id=org_id,
payment_user_id=payment_user.id, # type: ignore payment_user_id=payment_user_id,
status=status,
current_user=InternalUser(),
db_session=db_session
)
async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session):
customer_id = subscription.customer
# Get product_id from metadata
product_id = subscription.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user.id, # type: ignore
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session):
customer_id = payment_intent.customer
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
# Get product_id directly from metadata
product_id = payment_intent.metadata.get('product_id')
if not product_id:
logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}")
raise HTTPException(status_code=400, detail="No product_id found in payment metadata")
if user:
await create_payment_user(
request=request,
org_id=org_id,
user_id=user.id, # type: ignore
product_id=int(product_id), # Convert string from metadata to int
status=PaymentStatusEnum.COMPLETED, status=PaymentStatusEnum.COMPLETED,
provider_data=customer,
current_user=InternalUser(), current_user=InternalUser(),
db_session=db_session db_session=db_session
) )
async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session): elif event.type == 'customer.subscription.deleted':
# Update payment status to failed subscription = event.data.object
customer_id = payment_intent.customer payment_user_id = int(subscription.get('metadata', {}).get('payment_user_id'))
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if user:
statement = select(PaymentsUser).where(
PaymentsUser.user_id == user.id,
PaymentsUser.org_id == org_id,
PaymentsUser.status == PaymentStatusEnum.PENDING
)
payment_user = db_session.exec(statement).first()
if payment_user:
await update_payment_user_status( await update_payment_user_status(
request=request, request=request,
org_id=org_id, org_id=org_id,
payment_user_id=payment_user.id, # type: ignore payment_user_id=payment_user_id,
status=PaymentStatusEnum.CANCELLED,
current_user=InternalUser(),
db_session=db_session
)
elif event.type == 'payment_intent.payment_failed':
payment_intent = event.data.object
payment_user_id = int(payment_intent.get('metadata', {}).get('payment_user_id'))
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.FAILED, status=PaymentStatusEnum.FAILED,
current_user=InternalUser(), current_user=InternalUser(),
db_session=db_session db_session=db_session
) )
# Create event handler mapping return {"status": "success"}
STRIPE_EVENT_HANDLERS = {
'checkout.session.completed': handle_checkout_session_completed,
'customer.subscription.created': handle_subscription_created,
'customer.subscription.updated': handle_subscription_updated,
'customer.subscription.deleted': handle_subscription_deleted,
'payment_intent.succeeded': handle_payment_succeeded,
'payment_intent.payment_failed': handle_payment_failed,
}

View file

@ -1,12 +1,15 @@
import logging
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from sqlmodel import Session from sqlmodel import Session
import stripe import stripe
from config.config import get_learnhouse_config
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
from src.db.payments.payments_users import PaymentStatusEnum
from src.db.users import AnonymousUser, InternalUser, PublicUser from src.db.users import AnonymousUser, InternalUser, PublicUser
from src.services.payments.payments_config import get_payments_config from src.services.payments.payments_config import get_payments_config
from sqlmodel import select from sqlmodel import select
from src.services.payments.payments_users import create_payment_user, delete_payment_user
async def get_stripe_credentials( async def get_stripe_credentials(
request: Request, request: Request,
org_id: int, org_id: int,
@ -193,10 +196,29 @@ async def create_checkout_session(
email=current_user.email, email=current_user.email,
metadata={ metadata={
"user_id": str(current_user.id), "user_id": str(current_user.id),
"org_id": str(org_id) "org_id": str(org_id),
} }
) )
# Create initial payment user with pending status
payment_user = await create_payment_user(
request=request,
org_id=org_id,
user_id=current_user.id,
product_id=product_id,
status=PaymentStatusEnum.PENDING,
provider_data=customer,
current_user=current_user,
db_session=db_session
)
if not payment_user:
raise HTTPException(status_code=400, detail="Error creating payment user")
except stripe.StripeError as e: except stripe.StripeError as e:
# Clean up payment user if customer creation fails
if payment_user and payment_user.id:
await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session)
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}") raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
# Create checkout session with customer # Create checkout session with customer
@ -208,7 +230,8 @@ async def create_checkout_session(
"line_items": line_items, "line_items": line_items,
"customer": customer.id, "customer": customer.id,
"metadata": { "metadata": {
"product_id": str(product.id) "product_id": str(product.id),
"payment_user_id": str(payment_user.id)
} }
} }
@ -216,14 +239,16 @@ async def create_checkout_session(
if product.product_type == PaymentProductTypeEnum.ONE_TIME: if product.product_type == PaymentProductTypeEnum.ONE_TIME:
checkout_session_params["payment_intent_data"] = { checkout_session_params["payment_intent_data"] = {
"metadata": { "metadata": {
"product_id": str(product.id) "product_id": str(product.id),
"payment_user_id": str(payment_user.id)
} }
} }
# Add subscription_data for subscription payments # Add subscription_data for subscription payments
else: else:
checkout_session_params["subscription_data"] = { checkout_session_params["subscription_data"] = {
"metadata": { "metadata": {
"product_id": str(product.id) "product_id": str(product.id),
"payment_user_id": str(payment_user.id)
} }
} }
@ -235,7 +260,10 @@ async def create_checkout_session(
} }
except stripe.StripeError as e: except stripe.StripeError as e:
print(f"Error creating checkout session: {str(e)}") # Clean up payment user if checkout session creation fails
if payment_user and payment_user.id:
await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session)
logging.error(f"Error creating checkout session: {str(e)}")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))