mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: init stripe utils
This commit is contained in:
parent
412651e817
commit
416c3a4afc
8 changed files with 315 additions and 41 deletions
19
apps/api/poetry.lock
generated
19
apps/api/poetry.lock
generated
|
|
@ -3513,6 +3513,21 @@ anyio = ">=3.4.0,<5"
|
|||
[package.extras]
|
||||
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "stripe"
|
||||
version = "11.1.1"
|
||||
description = "Python bindings for the Stripe API"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "stripe-11.1.1-py2.py3-none-any.whl", hash = "sha256:e79e02238d0ec7c89a64986af941dcae41e4857489b7cc83497acce9def356e5"},
|
||||
{file = "stripe-11.1.1.tar.gz", hash = "sha256:0bbdfe54a09728fc54db6bb099b2f440ffc111d07d9674b0f04bfd0d3c1cbdcf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = {version = ">=2.20", markers = "python_version >= \"3.0\""}
|
||||
typing-extensions = {version = ">=4.5.0", markers = "python_version >= \"3.7\""}
|
||||
|
||||
[[package]]
|
||||
name = "sympy"
|
||||
version = "1.13.3"
|
||||
|
|
@ -4281,4 +4296,8 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
<<<<<<< HEAD
|
||||
content-hash = "f833ec3787697499d05e2aafb89bcb275b0d7468a6a4a33eb20cd139a21880d8"
|
||||
=======
|
||||
content-hash = "5d2f7ddfb277f39999b7798b9659c5bd2c2751ad667dbcab76a9d83fd6bdfa33"
|
||||
>>>>>>> 59f348e (feat: init stripe utils)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ chromadb = "^0.5.13"
|
|||
alembic = "^1.13.2"
|
||||
alembic-postgresql-enum = "^1.2.0"
|
||||
sqlalchemy-utils = "^0.41.2"
|
||||
stripe = "^11.1.1"
|
||||
|
||||
[build-system]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from src.services.orgs.orgs import rbac_check
|
|||
from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
|
||||
|
||||
async def create_payments_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
|
|
@ -40,9 +42,14 @@ async def create_payments_product(
|
|||
new_product.creation_date = datetime.now()
|
||||
new_product.update_date = datetime.now()
|
||||
|
||||
# Create product in Stripe
|
||||
stripe_product = await create_stripe_product(request, org_id, new_product, current_user, db_session)
|
||||
new_product.provider_product_id = stripe_product.id
|
||||
|
||||
# Save to DB
|
||||
db_session.add(new_product)
|
||||
db_session.commit()
|
||||
db_session.refresh(new_product)
|
||||
db_session.refresh(new_product)
|
||||
|
||||
return PaymentsProductRead.model_validate(new_product)
|
||||
|
||||
|
|
@ -103,6 +110,9 @@ async def update_payments_product(
|
|||
db_session.commit()
|
||||
db_session.refresh(product)
|
||||
|
||||
# Update product in Stripe
|
||||
await update_stripe_product(request, org_id, product.provider_product_id, product, current_user, db_session)
|
||||
|
||||
return PaymentsProductRead.model_validate(product)
|
||||
|
||||
async def delete_payments_product(
|
||||
|
|
@ -126,6 +136,9 @@ async def delete_payments_product(
|
|||
product = db_session.exec(statement).first()
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Payments product not found")
|
||||
|
||||
# Archive product in Stripe
|
||||
await archive_stripe_product(request, org_id, product.provider_product_id, current_user, db_session)
|
||||
|
||||
# Delete product
|
||||
db_session.delete(product)
|
||||
|
|
@ -147,7 +160,7 @@ async def list_payments_products(
|
|||
await rbac_check(request, org.org_uuid, current_user, "read", db_session)
|
||||
|
||||
# Get payments products ordered by id
|
||||
statement = select(PaymentsProduct).where(PaymentsProduct.org_id == org_id).order_by(PaymentsProduct.id.desc())
|
||||
statement = select(PaymentsProduct).where(PaymentsProduct.org_id == org_id).order_by(PaymentsProduct.id.desc()) # type: ignore
|
||||
products = db_session.exec(statement).all()
|
||||
|
||||
return [PaymentsProductRead.model_validate(product) for product in products]
|
||||
|
|
|
|||
150
apps/api/src/services/payments/stripe.py
Normal file
150
apps/api/src/services/payments/stripe.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
from email.policy import default
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session
|
||||
import stripe
|
||||
from src.db.payments.payments_products import PaymentProductTypeEnum, PaymentsProduct, PaymentsProductCreate
|
||||
from src.db.users import AnonymousUser, PublicUser
|
||||
from src.services.payments.payments import get_payments_config
|
||||
|
||||
|
||||
async def get_stripe_credentials(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
configs = await get_payments_config(request, org_id, current_user, db_session)
|
||||
|
||||
if len(configs) == 0:
|
||||
raise HTTPException(status_code=404, detail="Payments config not found")
|
||||
if len(configs) > 1:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Organization has multiple payments configs"
|
||||
)
|
||||
config = configs[0]
|
||||
if config.provider != "stripe":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Payments config is not a Stripe config"
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
credentials = config.provider_config
|
||||
|
||||
return credentials
|
||||
|
||||
async def create_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_data: PaymentsProduct,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
## Create product
|
||||
|
||||
# Interval or one time
|
||||
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
|
||||
interval = "month"
|
||||
else:
|
||||
interval = None
|
||||
|
||||
# Prepare default_price_data
|
||||
default_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"unit_amount": int(product_data.amount * 100) # Convert to cents
|
||||
}
|
||||
|
||||
if interval:
|
||||
default_price_data["recurring"] = {"interval": interval}
|
||||
|
||||
product = stripe.Product.create(
|
||||
name=product_data.name,
|
||||
description=product_data.description or "",
|
||||
marketing_features=[{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()],
|
||||
default_price_data=default_price_data # type: ignore
|
||||
)
|
||||
|
||||
return product
|
||||
|
||||
async def archive_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
try:
|
||||
# Archive the product in Stripe
|
||||
archived_product = stripe.Product.modify(product_id, active=False)
|
||||
|
||||
return archived_product
|
||||
except stripe.StripeError as e:
|
||||
print(f"Error archiving Stripe product: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"Error archiving Stripe product: {str(e)}")
|
||||
|
||||
async def update_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: str,
|
||||
product_data: PaymentsProduct,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
try:
|
||||
|
||||
# Always create a new price
|
||||
new_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"unit_amount": int(product_data.amount * 100), # Convert to cents
|
||||
"product": product_id,
|
||||
}
|
||||
|
||||
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
|
||||
new_price_data["recurring"] = {"interval": "month"}
|
||||
|
||||
new_price = stripe.Price.create(**new_price_data)
|
||||
|
||||
# Prepare the update data
|
||||
update_data = {
|
||||
"name": product_data.name,
|
||||
"description": product_data.description or "",
|
||||
"metadata": {"benefits": product_data.benefits},
|
||||
"marketing_features": [{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()],
|
||||
"default_price": new_price.id
|
||||
}
|
||||
|
||||
# Update the product in Stripe
|
||||
updated_product = stripe.Product.modify(product_id, **update_data)
|
||||
|
||||
|
||||
# Archive all existing prices for the product
|
||||
existing_prices = stripe.Price.list(product=product_id, active=True)
|
||||
for price in existing_prices:
|
||||
if price.id != new_price.id:
|
||||
stripe.Price.modify(price.id, active=False)
|
||||
|
||||
# Set the new price as the default price for the product
|
||||
updated_product = stripe.Product.modify(product_id, default_price=new_price.id)
|
||||
|
||||
return updated_product
|
||||
except stripe.StripeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue