feat: init stripe utils

This commit is contained in:
swve 2024-10-19 01:10:26 +02:00
parent 412651e817
commit 416c3a4afc
8 changed files with 315 additions and 41 deletions

View file

@ -13,6 +13,8 @@ from src.services.orgs.orgs import rbac_check
from datetime import datetime
from uuid import uuid4
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
async def create_payments_product(
request: Request,
org_id: int,
@ -40,9 +42,14 @@ async def create_payments_product(
new_product.creation_date = datetime.now()
new_product.update_date = datetime.now()
# Create product in Stripe
stripe_product = await create_stripe_product(request, org_id, new_product, current_user, db_session)
new_product.provider_product_id = stripe_product.id
# Save to DB
db_session.add(new_product)
db_session.commit()
db_session.refresh(new_product)
db_session.refresh(new_product)
return PaymentsProductRead.model_validate(new_product)
@ -103,6 +110,9 @@ async def update_payments_product(
db_session.commit()
db_session.refresh(product)
# Update product in Stripe
await update_stripe_product(request, org_id, product.provider_product_id, product, current_user, db_session)
return PaymentsProductRead.model_validate(product)
async def delete_payments_product(
@ -126,6 +136,9 @@ async def delete_payments_product(
product = db_session.exec(statement).first()
if not product:
raise HTTPException(status_code=404, detail="Payments product not found")
# Archive product in Stripe
await archive_stripe_product(request, org_id, product.provider_product_id, current_user, db_session)
# Delete product
db_session.delete(product)
@ -147,7 +160,7 @@ async def list_payments_products(
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
# Get payments products ordered by id
statement = select(PaymentsProduct).where(PaymentsProduct.org_id == org_id).order_by(PaymentsProduct.id.desc())
statement = select(PaymentsProduct).where(PaymentsProduct.org_id == org_id).order_by(PaymentsProduct.id.desc()) # type: ignore
products = db_session.exec(statement).all()
return [PaymentsProductRead.model_validate(product) for product in products]

View file

@ -0,0 +1,150 @@
from email.policy import default
from fastapi import HTTPException, Request
from sqlmodel import Session
import stripe
from src.db.payments.payments_products import PaymentProductTypeEnum, PaymentsProduct, PaymentsProductCreate
from src.db.users import AnonymousUser, PublicUser
from src.services.payments.payments import get_payments_config
async def get_stripe_credentials(
request: Request,
org_id: int,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
configs = await get_payments_config(request, org_id, current_user, db_session)
if len(configs) == 0:
raise HTTPException(status_code=404, detail="Payments config not found")
if len(configs) > 1:
raise HTTPException(
status_code=400, detail="Organization has multiple payments configs"
)
config = configs[0]
if config.provider != "stripe":
raise HTTPException(
status_code=400, detail="Payments config is not a Stripe config"
)
# Get provider config
credentials = config.provider_config
return credentials
async def create_stripe_product(
request: Request,
org_id: int,
product_data: PaymentsProduct,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
# Set the Stripe API key using the credentials
stripe.api_key = creds.get('stripe_secret_key')
## Create product
# Interval or one time
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
interval = "month"
else:
interval = None
# Prepare default_price_data
default_price_data = {
"currency": product_data.currency,
"unit_amount": int(product_data.amount * 100) # Convert to cents
}
if interval:
default_price_data["recurring"] = {"interval": interval}
product = stripe.Product.create(
name=product_data.name,
description=product_data.description or "",
marketing_features=[{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()],
default_price_data=default_price_data # type: ignore
)
return product
async def archive_stripe_product(
request: Request,
org_id: int,
product_id: str,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
# Set the Stripe API key using the credentials
stripe.api_key = creds.get('stripe_secret_key')
try:
# Archive the product in Stripe
archived_product = stripe.Product.modify(product_id, active=False)
return archived_product
except stripe.StripeError as e:
print(f"Error archiving Stripe product: {str(e)}")
raise HTTPException(status_code=400, detail=f"Error archiving Stripe product: {str(e)}")
async def update_stripe_product(
request: Request,
org_id: int,
product_id: str,
product_data: PaymentsProduct,
current_user: PublicUser | AnonymousUser,
db_session: Session,
):
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
# Set the Stripe API key using the credentials
stripe.api_key = creds.get('stripe_secret_key')
try:
# Always create a new price
new_price_data = {
"currency": product_data.currency,
"unit_amount": int(product_data.amount * 100), # Convert to cents
"product": product_id,
}
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
new_price_data["recurring"] = {"interval": "month"}
new_price = stripe.Price.create(**new_price_data)
# Prepare the update data
update_data = {
"name": product_data.name,
"description": product_data.description or "",
"metadata": {"benefits": product_data.benefits},
"marketing_features": [{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()],
"default_price": new_price.id
}
# Update the product in Stripe
updated_product = stripe.Product.modify(product_id, **update_data)
# Archive all existing prices for the product
existing_prices = stripe.Price.list(product=product_id, active=True)
for price in existing_prices:
if price.id != new_price.id:
stripe.Price.modify(price.id, active=False)
# Set the new price as the default price for the product
updated_product = stripe.Product.modify(product_id, default_price=new_price.id)
return updated_product
except stripe.StripeError as e:
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")