fix: webhooks issues

This commit is contained in:
swve 2024-11-10 15:42:55 +01:00
parent a8ba053447
commit 93c0838fab
12 changed files with 323 additions and 182 deletions

View file

@ -12,6 +12,7 @@ class PaymentsConfigBase(SQLModel):
enabled: bool = True
active: bool = False
provider: PaymentProviderEnum = PaymentProviderEnum.STRIPE
provider_specific_id: str | None = None
provider_config: dict = Field(default={}, sa_column=Column(JSON))
@ -31,6 +32,7 @@ class PaymentsConfigCreate(PaymentsConfigBase):
class PaymentsConfigUpdate(PaymentsConfigBase):
enabled: Optional[bool] = True
provider_config: Optional[dict] = None
provider_specific_id: Optional[str] = None
class PaymentsConfigRead(PaymentsConfigBase):

View file

@ -18,11 +18,11 @@ from src.services.payments.payments_courses import (
get_courses_by_product,
)
from src.services.payments.payments_users import get_owned_courses
from src.services.payments.webhooks.payments_connected_webhook import handle_stripe_webhook
from src.services.payments.payments_stripe import create_checkout_session, update_stripe_account_id
from src.services.payments.payments_access import check_course_paid_access
from src.services.payments.payments_customers import get_customers
from src.services.payments.payments_stripe import generate_stripe_connect_link
from src.services.payments.webhooks.payments_webhooks import handle_stripe_webhook
router = APIRouter()
@ -160,13 +160,12 @@ async def api_get_products_by_course(
# Payments webhooks
@router.post("/{org_id}/stripe/webhook")
async def api_handle_stripe_webhook(
@router.post("/stripe/webhook")
async def api_handle_connected_accounts_stripe_webhook(
request: Request,
org_id: int,
db_session: Session = Depends(get_db_session),
):
return await handle_stripe_webhook(request, org_id, db_session)
return await handle_stripe_webhook(request, db_session)
# Payments checkout

View file

@ -46,8 +46,8 @@ async def init_payments_config(
provider=PaymentProviderEnum.STRIPE,
provider_config={
"onboarding_completed": False,
"stripe_account_id": ""
}
},
provider_specific_id=None
)
# Save to database

View file

@ -33,14 +33,10 @@ async def get_stripe_connected_account_id(
# Get payments config
payments_config = await get_payments_config(request, org_id, current_user, db_session)
return payments_config[0].provider_config.get("stripe_account_id")
return payments_config[0].provider_specific_id
async def get_stripe_credentials(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser | InternalUser,
db_session: Session,
async def get_stripe_internal_credentials(
):
# Get payments config from config file
learnhouse_config = get_learnhouse_config()
@ -56,6 +52,7 @@ async def get_stripe_credentials(
return {
"stripe_secret_key": learnhouse_config.payments_config.stripe.stripe_secret_key,
"stripe_publishable_key": learnhouse_config.payments_config.stripe.stripe_publishable_key,
"stripe_webhook_secret": learnhouse_config.payments_config.stripe.stripe_webhook_secret,
}
@ -66,7 +63,7 @@ async def create_stripe_product(
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
creds = await get_stripe_internal_credentials()
# Set the Stripe API key using the credentials
stripe.api_key = creds.get("stripe_secret_key")
@ -113,14 +110,16 @@ async def archive_stripe_product(
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
creds = await get_stripe_internal_credentials()
# Set the Stripe API key using the credentials
stripe.api_key = creds.get("stripe_secret_key")
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
try:
# Archive the product in Stripe
archived_product = stripe.Product.modify(product_id, active=False)
archived_product = stripe.Product.modify(product_id, active=False, stripe_account=stripe_acc_id)
return archived_product
except stripe.StripeError as e:
@ -138,11 +137,13 @@ async def update_stripe_product(
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
creds = await get_stripe_internal_credentials()
# Set the Stripe API key using the credentials
stripe.api_key = creds.get("stripe_secret_key")
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
try:
# Create new price based on price_type
if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE:
@ -180,13 +181,13 @@ async def update_stripe_product(
}
# Update the product in Stripe
updated_product = stripe.Product.modify(product_id, **update_data)
updated_product = stripe.Product.modify(product_id, **update_data, stripe_account=stripe_acc_id)
# Archive all existing prices for the product
existing_prices = stripe.Price.list(product=product_id, active=True)
for price in existing_prices:
if price.id != new_price.id:
stripe.Price.modify(price.id, active=False)
stripe.Price.modify(price.id, active=False, stripe_account=stripe_acc_id)
return updated_product
except stripe.StripeError as e:
@ -204,9 +205,12 @@ async def create_checkout_session(
db_session: Session,
):
# Get Stripe credentials
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
creds = await get_stripe_internal_credentials()
stripe.api_key = creds.get("stripe_secret_key")
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
# Get product details
statement = select(PaymentsProduct).where(
PaymentsProduct.id == product_id, PaymentsProduct.org_id == org_id
@ -220,10 +224,9 @@ async def create_checkout_session(
cancel_url = redirect_uri
# Get the default price for the product
stripe_product = stripe.Product.retrieve(product.provider_product_id)
stripe_product = stripe.Product.retrieve(product.provider_product_id, stripe_account=stripe_acc_id)
line_items = [{"price": stripe_product.default_price, "quantity": 1}]
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
# Create or retrieve Stripe customer
try:
@ -282,8 +285,7 @@ async def create_checkout_session(
"metadata": {
"product_id": str(product.id),
"payment_user_id": str(payment_user.id),
},
"stripe_account": stripe_acc_id,
}
}
# Add payment_intent_data only for one-time payments
@ -303,7 +305,7 @@ async def create_checkout_session(
}
}
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
checkout_session = stripe.checkout.Session.create(**checkout_session_params, stripe_account=stripe_acc_id)
return {"checkout_url": checkout_session.url, "session_id": checkout_session.id}
@ -328,20 +330,26 @@ async def generate_stripe_connect_link(
Generate a Stripe OAuth link for connecting a Stripe account
"""
# Get credentials
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
creds = await get_stripe_internal_credentials()
stripe.api_key = creds.get("stripe_secret_key")
# Get config
learnhouse_config = get_learnhouse_config()
# Get client id
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
if not stripe_acc_id:
raise HTTPException(status_code=400, detail="No Stripe account ID found for this organization")
try:
# Try to get existing account ID
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
except HTTPException:
# If no account exists, create one
stripe_account = await create_stripe_account(
request,
org_id,
"standard",
current_user,
db_session
)
stripe_acc_id = stripe_account
# Generate OAuth link
connect_link = stripe.AccountLink.create(
account=stripe_acc_id,
account=str(stripe_acc_id),
type="account_onboarding",
return_url=redirect_uri,
refresh_url=redirect_uri,
@ -357,18 +365,16 @@ async def create_stripe_account(
db_session: Session,
):
# Get credentials
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
creds = await get_stripe_internal_credentials()
stripe.api_key = creds.get("stripe_secret_key")
# Get existing payments config
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
existing_config = db_session.exec(statement).first()
if existing_config and existing_config.provider_config.get("stripe_account_id"):
raise HTTPException(
status_code=400,
detail="A Stripe Express account is already linked to this organization"
)
if existing_config and existing_config.provider_specific_id:
logging.error(f"A Stripe Account is already linked to this organization: {existing_config.provider_specific_id}")
return existing_config.provider_specific_id
# Create Stripe account
stripe_account = stripe.Account.create(
@ -379,14 +385,18 @@ async def create_stripe_account(
},
)
config_data = existing_config.model_dump() if existing_config else {}
config_data.update({
"enabled": True,
"provider_specific_id": stripe_account.id, # Use the ID directly
"provider_config": {"onboarding_completed": False}
})
# Update payments config for the org
await update_payments_config(
request,
org_id,
PaymentsConfigUpdate(
enabled=True,
provider_config={"stripe_account_id": stripe_account.id}
),
PaymentsConfigUpdate(**config_data),
current_user,
db_session,
)
@ -414,14 +424,15 @@ async def update_stripe_account_id(
detail="No payments configuration found for this organization"
)
# Update payments config with new stripe account id
# Create config update with existing values but new stripe account id
config_data = existing_config.model_dump()
config_data["provider_specific_id"] = stripe_account_id
# Update payments config
await update_payments_config(
request,
org_id,
PaymentsConfigUpdate(
enabled=True,
provider_config={"stripe_account_id": stripe_account_id}
),
PaymentsConfigUpdate(**config_data),
current_user,
db_session,
)

View file

@ -0,0 +1,59 @@
from fastapi import HTTPException
from sqlmodel import Session, select
import stripe
import logging
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import User
from src.db.payments.payments import PaymentsConfig
logger = logging.getLogger(__name__)
async def get_user_from_customer(customer_id: str, db_session: Session) -> User:
"""Helper function to get user from Stripe customer ID"""
try:
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if not user:
raise HTTPException(
status_code=404, detail=f"User not found for customer {customer_id}"
)
return user
except stripe.StripeError as e:
logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}")
raise HTTPException(
status_code=400, detail="Error retrieving customer information"
)
async def get_product_from_stripe_id(
product_id: str, db_session: Session
) -> PaymentsProduct:
"""Helper function to get product from Stripe product ID"""
statement = select(PaymentsProduct).where(
PaymentsProduct.provider_product_id == product_id
)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail=f"Product not found: {product_id}")
return product
async def get_org_id_from_stripe_account(
stripe_account_id: str,
db_session: Session,
) -> int:
"""Get organization ID from Stripe account ID"""
statement = select(PaymentsConfig).where(
PaymentsConfig.provider_specific_id == stripe_account_id
)
config = db_session.exec(statement).first()
if not config:
raise HTTPException(
status_code=404,
detail=f"No organization found for Stripe account {stripe_account_id}",
)
return config.org_id

View file

@ -1,121 +0,0 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
import stripe
import logging
from src.db.payments.payments_users import PaymentStatusEnum
from src.db.payments.payments_products import PaymentsProduct
from src.db.users import InternalUser, User
from src.services.payments.payments_users import update_payment_user_status
from src.services.payments.payments_stripe import get_stripe_credentials
logger = logging.getLogger(__name__)
async def get_user_from_customer(customer_id: str, db_session: Session) -> User:
"""Helper function to get user from Stripe customer ID"""
try:
customer = stripe.Customer.retrieve(customer_id)
statement = select(User).where(User.email == customer.email)
user = db_session.exec(statement).first()
if not user:
raise HTTPException(status_code=404, detail=f"User not found for customer {customer_id}")
return user
except stripe.StripeError as e:
logger.error(f"Stripe error retrieving customer {customer_id}: {str(e)}")
raise HTTPException(status_code=400, detail="Error retrieving customer information")
async def get_product_from_stripe_id(product_id: str, db_session: Session) -> PaymentsProduct:
"""Helper function to get product from Stripe product ID"""
statement = select(PaymentsProduct).where(PaymentsProduct.provider_product_id == product_id)
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail=f"Product not found: {product_id}")
return product
async def handle_stripe_webhook(
request: Request,
org_id: int,
db_session: Session,
) -> dict:
# Get Stripe credentials for the organization
creds = await get_stripe_credentials(request, org_id, InternalUser(), db_session)
# Get the webhook secret and API key from credentials
webhook_secret = creds.get('stripe_webhook_secret')
stripe.api_key = creds.get('stripe_secret_key') # Set API key globally
if not webhook_secret:
raise HTTPException(status_code=400, detail="Stripe webhook secret not configured")
# Get the raw request body
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, webhook_secret
)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
raise HTTPException(status_code=400, detail="Invalid signature")
# Handle the event
if event.type == 'checkout.session.completed':
session = event.data.object
payment_user_id = int(session.get('metadata', {}).get('payment_user_id'))
if session.get('mode') == 'subscription':
# Handle subscription payment
if session.get('subscription'):
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.ACTIVE,
current_user=InternalUser(),
db_session=db_session
)
else:
# Handle one-time payment
if session.get('payment_status') == 'paid':
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.COMPLETED,
current_user=InternalUser(),
db_session=db_session
)
elif event.type == 'customer.subscription.deleted':
subscription = event.data.object
payment_user_id = int(subscription.get('metadata', {}).get('payment_user_id'))
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.CANCELLED,
current_user=InternalUser(),
db_session=db_session
)
elif event.type == 'payment_intent.payment_failed':
payment_intent = event.data.object
payment_user_id = int(payment_intent.get('metadata', {}).get('payment_user_id'))
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session
)
return {"status": "success"}

View file

@ -0,0 +1,177 @@
from fastapi import HTTPException, Request
from sqlmodel import Session, select
import stripe
import logging
from src.db.payments.payments_users import PaymentStatusEnum
from src.db.users import InternalUser
from src.services.payments.payments_users import update_payment_user_status
from src.services.payments.payments_stripe import get_stripe_internal_credentials
from src.db.payments.payments import PaymentsConfig, PaymentsConfigUpdate
from src.services.payments.payments_config import update_payments_config
from src.services.payments.utils.stripe_utils import get_org_id_from_stripe_account
logger = logging.getLogger(__name__)
async def handle_stripe_webhook(
request: Request,
db_session: Session,
) -> dict:
# Get Stripe credentials
creds = await get_stripe_internal_credentials()
webhook_secret = creds.get('stripe_webhook_secret')
stripe.api_key = creds.get("stripe_secret_key")
if not webhook_secret:
logger.error("Stripe webhook secret not configured")
raise HTTPException(status_code=400, detail="Stripe webhook secret not configured")
# Get request data
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
# Verify webhook signature
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except ValueError:
logger.error(ValueError)
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.SignatureVerificationError:
logger.error(stripe.SignatureVerificationError)
raise HTTPException(status_code=400, detail="Invalid signature")
try:
event_type = event.type
event_data = event.data.object
# Get organization ID based on the event type
stripe_account_id = event.account
if not stripe_account_id:
logger.error("Stripe account ID not found")
raise HTTPException(status_code=400, detail="Stripe account ID not found")
org_id = await get_org_id_from_stripe_account(stripe_account_id, db_session)
# Handle internal account events
if event_type == 'account.application.authorized':
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
config = db_session.exec(statement).first()
if not config:
logger.error("No payments configuration found for this organization")
raise HTTPException(
status_code=404,
detail="No payments configuration found for this organization"
)
config_data = config.model_dump()
config_data.update({
"enabled": True,
"active": True,
"provider_config": {
**config.provider_config,
"onboarding_completed": True
}
})
await update_payments_config(
request,
org_id,
PaymentsConfigUpdate(**config_data),
InternalUser(),
db_session,
)
logger.info(f"Account authorized for organization {org_id}")
return {"status": "success", "message": "Account authorized successfully"}
elif event_type == 'account.application.deauthorized':
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
config = db_session.exec(statement).first()
if not config:
raise HTTPException(
status_code=404,
detail="No payments configuration found for this organization"
)
config_data = config.model_dump()
config_data.update({
"enabled": True,
"active": False,
"provider_config": {
**config.provider_config,
"onboarding_completed": False
}
})
await update_payments_config(
request,
org_id,
PaymentsConfigUpdate(**config_data),
InternalUser(),
db_session,
)
logger.info(f"Account deauthorized for organization {org_id}")
return {"status": "success", "message": "Account deauthorized successfully"}
# Handle payment-related events
elif event_type == "checkout.session.completed":
session = event_data
payment_user_id = int(session.get("metadata", {}).get("payment_user_id"))
if session.get("mode") == "subscription":
if session.get("subscription"):
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.ACTIVE,
current_user=InternalUser(),
db_session=db_session,
)
else:
if session.get("payment_status") == "paid":
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.COMPLETED,
current_user=InternalUser(),
db_session=db_session,
)
elif event_type == "customer.subscription.deleted":
subscription = event_data
payment_user_id = int(subscription.get("metadata", {}).get("payment_user_id"))
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.CANCELLED,
current_user=InternalUser(),
db_session=db_session,
)
elif event_type == "payment_intent.payment_failed":
payment_intent = event_data
payment_user_id = int(payment_intent.get("metadata", {}).get("payment_user_id"))
await update_payment_user_status(
request=request,
org_id=org_id,
payment_user_id=payment_user_id,
status=PaymentStatusEnum.FAILED,
current_user=InternalUser(),
db_session=db_session,
)
else:
logger.warning(f"Unhandled event type: {event_type}")
return {"status": "ignored", "message": f"Unhandled event type: {event_type}"}
return {"status": "success"}
except Exception as e:
logger.error(f"Error processing webhook: {str(e)}")
raise HTTPException(status_code=400, detail=f"Error processing webhook: {str(e)}")