mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: support subscriptions and onetime payments w/ webhooks
This commit is contained in:
parent
1bff401e73
commit
b7f09885df
6 changed files with 94 additions and 201 deletions
|
|
@ -1,10 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Optional
|
|
||||||
from config.config import get_learnhouse_config
|
from config.config import get_learnhouse_config
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from sqlmodel import Field, SQLModel, Session, create_engine
|
from sqlmodel import SQLModel, Session, create_engine
|
||||||
|
|
||||||
def import_all_models():
|
def import_all_models():
|
||||||
base_dir = 'src/db'
|
base_dir = 'src/db'
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from datetime import datetime
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from src.db.payments.payments_courses import PaymentsCourse
|
from src.db.payments.payments_courses import PaymentsCourse
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from src.db.organizations import Organization
|
||||||
from src.services.orgs.orgs import rbac_check
|
from src.services.orgs.orgs import rbac_check
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, get_stripe_credentials, update_stripe_product
|
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
|
||||||
|
|
||||||
async def create_payments_product(
|
async def create_payments_product(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,18 @@ async def create_payment_user(
|
||||||
# Check if user already has a payment user
|
# Check if user already has a payment user
|
||||||
statement = select(PaymentsUser).where(
|
statement = select(PaymentsUser).where(
|
||||||
PaymentsUser.user_id == user_id,
|
PaymentsUser.user_id == user_id,
|
||||||
PaymentsUser.org_id == org_id
|
PaymentsUser.org_id == org_id,
|
||||||
|
PaymentsUser.payment_product_id == product_id
|
||||||
)
|
)
|
||||||
existing_payment_user = db_session.exec(statement).first()
|
existing_payment_user = db_session.exec(statement).first()
|
||||||
|
|
||||||
if existing_payment_user:
|
if existing_payment_user:
|
||||||
raise HTTPException(status_code=400, detail="User already has purchase for this product")
|
if existing_payment_user.status == PaymentStatusEnum.PENDING:
|
||||||
|
# Delete existing pending payment
|
||||||
|
db_session.delete(existing_payment_user)
|
||||||
|
db_session.commit()
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail="User already has purchase for this product")
|
||||||
|
|
||||||
# Create new payment user
|
# Create new payment user
|
||||||
payment_user = PaymentsUser(
|
payment_user = PaymentsUser(
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,12 @@
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
import stripe
|
import stripe
|
||||||
from datetime import datetime
|
|
||||||
from typing import Callable, Dict
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from src.db.payments.payments_users import PaymentStatusEnum, PaymentsUser
|
from src.db.payments.payments_users import PaymentStatusEnum
|
||||||
from src.db.payments.payments_products import PaymentsProduct
|
from src.db.payments.payments_products import PaymentsProduct
|
||||||
from src.db.users import InternalUser, User
|
from src.db.users import InternalUser, User
|
||||||
from src.services.payments.payments_users import create_payment_user, update_payment_user_status
|
from src.services.payments.payments_users import update_payment_user_status
|
||||||
from src.services.payments.stripe import get_stripe_credentials
|
from src.services.payments.stripe import get_stripe_credentials
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -54,207 +52,70 @@ async def handle_stripe_webhook(
|
||||||
sig_header = request.headers.get('stripe-signature')
|
sig_header = request.headers.get('stripe-signature')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify webhook signature and construct event
|
|
||||||
event = stripe.Webhook.construct_event(
|
event = stripe.Webhook.construct_event(
|
||||||
payload, sig_header, webhook_secret
|
payload, sig_header, webhook_secret
|
||||||
)
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||||
|
except stripe.SignatureVerificationError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||||
|
|
||||||
|
# Handle the event
|
||||||
|
if event.type == 'checkout.session.completed':
|
||||||
|
session = event.data.object
|
||||||
|
payment_user_id = int(session.get('metadata', {}).get('payment_user_id'))
|
||||||
|
|
||||||
# Get the appropriate handler
|
if session.get('mode') == 'subscription':
|
||||||
handler = STRIPE_EVENT_HANDLERS.get(event.type)
|
# Handle subscription payment
|
||||||
if handler:
|
if session.get('subscription'):
|
||||||
await handler(request, event.data.object, org_id, db_session)
|
await update_payment_user_status(
|
||||||
return {"status": "success", "event": event.type}
|
request=request,
|
||||||
|
org_id=org_id,
|
||||||
|
payment_user_id=payment_user_id,
|
||||||
|
status=PaymentStatusEnum.ACTIVE,
|
||||||
|
current_user=InternalUser(),
|
||||||
|
db_session=db_session
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Unhandled event type: {event.type}")
|
# Handle one-time payment
|
||||||
return {"status": "ignored", "event": event.type}
|
if session.get('payment_status') == 'paid':
|
||||||
|
await update_payment_user_status(
|
||||||
|
request=request,
|
||||||
|
org_id=org_id,
|
||||||
|
payment_user_id=payment_user_id,
|
||||||
|
status=PaymentStatusEnum.COMPLETED,
|
||||||
|
current_user=InternalUser(),
|
||||||
|
db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
elif event.type == 'customer.subscription.deleted':
|
||||||
logger.error(f"Error processing webhook: {str(e)}", exc_info=True)
|
subscription = event.data.object
|
||||||
raise HTTPException(status_code=500, detail=f"Error processing webhook: {str(e)}")
|
payment_user_id = int(subscription.get('metadata', {}).get('payment_user_id'))
|
||||||
|
|
||||||
async def handle_checkout_session_completed(request: Request, session, org_id: int, db_session: Session):
|
|
||||||
# Get the customer and product details from the session
|
|
||||||
customer_email = session.customer_details.email
|
|
||||||
product_id = session.line_items.data[0].price.product
|
|
||||||
|
|
||||||
# Use helper functions
|
|
||||||
user = await get_user_from_customer(session.customer, db_session)
|
|
||||||
product = await get_product_from_stripe_id(product_id, db_session)
|
|
||||||
|
|
||||||
# Find payment user record
|
|
||||||
statement = select(PaymentsUser).where(
|
|
||||||
PaymentsUser.user_id == user.id,
|
|
||||||
PaymentsUser.payment_product_id == product.id
|
|
||||||
)
|
|
||||||
payment_user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
# Update status to completed
|
|
||||||
await update_payment_user_status(
|
|
||||||
request=request,
|
|
||||||
org_id=org_id,
|
|
||||||
payment_user_id=payment_user.id, # type: ignore
|
|
||||||
status=PaymentStatusEnum.COMPLETED,
|
|
||||||
current_user=InternalUser(),
|
|
||||||
db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_subscription_created(request: Request, subscription, org_id: int, db_session: Session):
|
|
||||||
customer_id = subscription.customer
|
|
||||||
|
|
||||||
# Get product_id from metadata
|
|
||||||
product_id = subscription.metadata.get('product_id')
|
|
||||||
if not product_id:
|
|
||||||
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
|
|
||||||
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
|
|
||||||
|
|
||||||
# Get customer email from Stripe
|
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
|
||||||
|
|
||||||
# Find user and create/update payment record
|
|
||||||
statement = select(User).where(User.email == customer.email)
|
|
||||||
user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if user:
|
|
||||||
payment_user = await create_payment_user(
|
|
||||||
request=request,
|
|
||||||
org_id=org_id,
|
|
||||||
user_id=user.id, # type: ignore
|
|
||||||
product_id=int(product_id), # Convert string from metadata to int
|
|
||||||
current_user=InternalUser(),
|
|
||||||
db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
await update_payment_user_status(
|
await update_payment_user_status(
|
||||||
request=request,
|
request=request,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
payment_user_id=payment_user.id, # type: ignore
|
payment_user_id=payment_user_id,
|
||||||
status=PaymentStatusEnum.ACTIVE,
|
status=PaymentStatusEnum.CANCELLED,
|
||||||
current_user=InternalUser(),
|
current_user=InternalUser(),
|
||||||
db_session=db_session
|
db_session=db_session
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_subscription_updated(request: Request, subscription, org_id: int, db_session: Session):
|
elif event.type == 'payment_intent.payment_failed':
|
||||||
customer_id = subscription.customer
|
payment_intent = event.data.object
|
||||||
|
payment_user_id = int(payment_intent.get('metadata', {}).get('payment_user_id'))
|
||||||
# Get product_id from metadata
|
|
||||||
product_id = subscription.metadata.get('product_id')
|
|
||||||
if not product_id:
|
|
||||||
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
|
|
||||||
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
|
|
||||||
|
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
|
||||||
|
|
||||||
statement = select(User).where(User.email == customer.email)
|
|
||||||
user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if user:
|
|
||||||
statement = select(PaymentsUser).where(
|
|
||||||
PaymentsUser.user_id == user.id,
|
|
||||||
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
|
|
||||||
)
|
|
||||||
payment_user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if payment_user:
|
await update_payment_user_status(
|
||||||
status = PaymentStatusEnum.ACTIVE if subscription.status == 'active' else PaymentStatusEnum.PENDING
|
|
||||||
await update_payment_user_status(
|
|
||||||
request=request,
|
|
||||||
org_id=org_id,
|
|
||||||
payment_user_id=payment_user.id, # type: ignore
|
|
||||||
status=status,
|
|
||||||
current_user=InternalUser(),
|
|
||||||
db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_subscription_deleted(request: Request, subscription, org_id: int, db_session: Session):
|
|
||||||
customer_id = subscription.customer
|
|
||||||
|
|
||||||
# Get product_id from metadata
|
|
||||||
product_id = subscription.metadata.get('product_id')
|
|
||||||
if not product_id:
|
|
||||||
logger.error(f"No product_id found in subscription metadata: {subscription.id}")
|
|
||||||
raise HTTPException(status_code=400, detail="No product_id found in subscription metadata")
|
|
||||||
|
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
|
||||||
|
|
||||||
statement = select(User).where(User.email == customer.email)
|
|
||||||
user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if user:
|
|
||||||
statement = select(PaymentsUser).where(
|
|
||||||
PaymentsUser.user_id == user.id,
|
|
||||||
PaymentsUser.payment_product_id == int(product_id) # Convert string from metadata to int
|
|
||||||
)
|
|
||||||
payment_user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if payment_user:
|
|
||||||
await update_payment_user_status(
|
|
||||||
request=request,
|
|
||||||
org_id=org_id,
|
|
||||||
payment_user_id=payment_user.id, # type: ignore
|
|
||||||
status=PaymentStatusEnum.FAILED,
|
|
||||||
current_user=InternalUser(),
|
|
||||||
db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
async def handle_payment_succeeded(request: Request, payment_intent, org_id: int, db_session: Session):
|
|
||||||
customer_id = payment_intent.customer
|
|
||||||
|
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
|
||||||
|
|
||||||
statement = select(User).where(User.email == customer.email)
|
|
||||||
user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
# Get product_id directly from metadata
|
|
||||||
product_id = payment_intent.metadata.get('product_id')
|
|
||||||
if not product_id:
|
|
||||||
logger.error(f"No product_id found in payment_intent metadata: {payment_intent.id}")
|
|
||||||
raise HTTPException(status_code=400, detail="No product_id found in payment metadata")
|
|
||||||
|
|
||||||
if user:
|
|
||||||
await create_payment_user(
|
|
||||||
request=request,
|
request=request,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
user_id=user.id, # type: ignore
|
payment_user_id=payment_user_id,
|
||||||
product_id=int(product_id), # Convert string from metadata to int
|
status=PaymentStatusEnum.FAILED,
|
||||||
status=PaymentStatusEnum.COMPLETED,
|
|
||||||
provider_data=customer,
|
|
||||||
current_user=InternalUser(),
|
current_user=InternalUser(),
|
||||||
db_session=db_session
|
db_session=db_session
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_payment_failed(request: Request, payment_intent, org_id: int, db_session: Session):
|
return {"status": "success"}
|
||||||
# Update payment status to failed
|
|
||||||
customer_id = payment_intent.customer
|
|
||||||
|
|
||||||
customer = stripe.Customer.retrieve(customer_id)
|
|
||||||
|
|
||||||
statement = select(User).where(User.email == customer.email)
|
|
||||||
user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if user:
|
|
||||||
statement = select(PaymentsUser).where(
|
|
||||||
PaymentsUser.user_id == user.id,
|
|
||||||
PaymentsUser.org_id == org_id,
|
|
||||||
PaymentsUser.status == PaymentStatusEnum.PENDING
|
|
||||||
)
|
|
||||||
payment_user = db_session.exec(statement).first()
|
|
||||||
|
|
||||||
if payment_user:
|
|
||||||
await update_payment_user_status(
|
|
||||||
request=request,
|
|
||||||
org_id=org_id,
|
|
||||||
payment_user_id=payment_user.id, # type: ignore
|
|
||||||
status=PaymentStatusEnum.FAILED,
|
|
||||||
current_user=InternalUser(),
|
|
||||||
db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create event handler mapping
|
|
||||||
STRIPE_EVENT_HANDLERS = {
|
|
||||||
'checkout.session.completed': handle_checkout_session_completed,
|
|
||||||
'customer.subscription.created': handle_subscription_created,
|
|
||||||
'customer.subscription.updated': handle_subscription_updated,
|
|
||||||
'customer.subscription.deleted': handle_subscription_deleted,
|
|
||||||
'payment_intent.succeeded': handle_payment_succeeded,
|
|
||||||
'payment_intent.payment_failed': handle_payment_failed,
|
|
||||||
}
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
|
import logging
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
import stripe
|
import stripe
|
||||||
from config.config import get_learnhouse_config
|
|
||||||
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
|
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
|
||||||
|
from src.db.payments.payments_users import PaymentStatusEnum
|
||||||
from src.db.users import AnonymousUser, InternalUser, PublicUser
|
from src.db.users import AnonymousUser, InternalUser, PublicUser
|
||||||
from src.services.payments.payments_config import get_payments_config
|
from src.services.payments.payments_config import get_payments_config
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from src.services.payments.payments_users import create_payment_user, delete_payment_user
|
||||||
|
|
||||||
async def get_stripe_credentials(
|
async def get_stripe_credentials(
|
||||||
request: Request,
|
request: Request,
|
||||||
org_id: int,
|
org_id: int,
|
||||||
|
|
@ -179,9 +182,9 @@ async def create_checkout_session(
|
||||||
# Get the default price for the product
|
# Get the default price for the product
|
||||||
stripe_product = stripe.Product.retrieve(product.provider_product_id)
|
stripe_product = stripe.Product.retrieve(product.provider_product_id)
|
||||||
line_items = [{
|
line_items = [{
|
||||||
"price": stripe_product.default_price,
|
"price": stripe_product.default_price,
|
||||||
"quantity": 1
|
"quantity": 1
|
||||||
}]
|
}]
|
||||||
|
|
||||||
# Create or retrieve Stripe customer
|
# Create or retrieve Stripe customer
|
||||||
try:
|
try:
|
||||||
|
|
@ -193,10 +196,29 @@ async def create_checkout_session(
|
||||||
email=current_user.email,
|
email=current_user.email,
|
||||||
metadata={
|
metadata={
|
||||||
"user_id": str(current_user.id),
|
"user_id": str(current_user.id),
|
||||||
"org_id": str(org_id)
|
"org_id": str(org_id),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create initial payment user with pending status
|
||||||
|
payment_user = await create_payment_user(
|
||||||
|
request=request,
|
||||||
|
org_id=org_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
product_id=product_id,
|
||||||
|
status=PaymentStatusEnum.PENDING,
|
||||||
|
provider_data=customer,
|
||||||
|
current_user=current_user,
|
||||||
|
db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
if not payment_user:
|
||||||
|
raise HTTPException(status_code=400, detail="Error creating payment user")
|
||||||
|
|
||||||
except stripe.StripeError as e:
|
except stripe.StripeError as e:
|
||||||
|
# Clean up payment user if customer creation fails
|
||||||
|
if payment_user and payment_user.id:
|
||||||
|
await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session)
|
||||||
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
|
||||||
|
|
||||||
# Create checkout session with customer
|
# Create checkout session with customer
|
||||||
|
|
@ -208,7 +230,8 @@ async def create_checkout_session(
|
||||||
"line_items": line_items,
|
"line_items": line_items,
|
||||||
"customer": customer.id,
|
"customer": customer.id,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"product_id": str(product.id)
|
"product_id": str(product.id),
|
||||||
|
"payment_user_id": str(payment_user.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -216,14 +239,16 @@ async def create_checkout_session(
|
||||||
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
|
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
|
||||||
checkout_session_params["payment_intent_data"] = {
|
checkout_session_params["payment_intent_data"] = {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"product_id": str(product.id)
|
"product_id": str(product.id),
|
||||||
|
"payment_user_id": str(payment_user.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# Add subscription_data for subscription payments
|
# Add subscription_data for subscription payments
|
||||||
else:
|
else:
|
||||||
checkout_session_params["subscription_data"] = {
|
checkout_session_params["subscription_data"] = {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"product_id": str(product.id)
|
"product_id": str(product.id),
|
||||||
|
"payment_user_id": str(payment_user.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -235,7 +260,10 @@ async def create_checkout_session(
|
||||||
}
|
}
|
||||||
|
|
||||||
except stripe.StripeError as e:
|
except stripe.StripeError as e:
|
||||||
print(f"Error creating checkout session: {str(e)}")
|
# Clean up payment user if checkout session creation fails
|
||||||
|
if payment_user and payment_user.id:
|
||||||
|
await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session)
|
||||||
|
logging.error(f"Error creating checkout session: {str(e)}")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue