mirror of
https://github.com/rzmk/learnhouse.git
synced 2025-12-19 04:19:25 +00:00
feat: use stripe connect for payments
This commit is contained in:
parent
cdd893ca6f
commit
a8ba053447
17 changed files with 835 additions and 364 deletions
|
|
@ -1,22 +1,16 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON
|
||||
from sqlmodel import Field, SQLModel, Column, BigInteger, ForeignKey
|
||||
|
||||
# Stripe provider config
|
||||
class StripeProviderConfig(BaseModel):
|
||||
stripe_key: str = ""
|
||||
stripe_secret_key: str = ""
|
||||
stripe_webhook_secret: str = ""
|
||||
|
||||
# PaymentsConfig
|
||||
class PaymentProviderEnum(str, Enum):
|
||||
STRIPE = "stripe"
|
||||
|
||||
class PaymentsConfigBase(SQLModel):
|
||||
enabled: bool = True
|
||||
active: bool = False
|
||||
provider: PaymentProviderEnum = PaymentProviderEnum.STRIPE
|
||||
provider_config: dict = Field(default={}, sa_column=Column(JSON))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Literal
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlmodel import Session
|
||||
from src.core.events.database import get_db_session
|
||||
from src.db.payments.payments import PaymentsConfig, PaymentsConfigCreate, PaymentsConfigRead, PaymentsConfigUpdate
|
||||
from src.db.payments.payments import PaymentsConfig, PaymentsConfigRead
|
||||
from src.db.users import PublicUser
|
||||
from src.security.auth import get_current_user
|
||||
from src.services.payments.payments_config import (
|
||||
create_payments_config,
|
||||
init_payments_config,
|
||||
get_payments_config,
|
||||
update_payments_config,
|
||||
delete_payments_config,
|
||||
)
|
||||
from src.db.payments.payments_products import PaymentsProductCreate, PaymentsProductRead, PaymentsProductUpdate
|
||||
|
|
@ -18,10 +18,11 @@ from src.services.payments.payments_courses import (
|
|||
get_courses_by_product,
|
||||
)
|
||||
from src.services.payments.payments_users import get_owned_courses
|
||||
from src.services.payments.payments_webhook import handle_stripe_webhook
|
||||
from src.services.payments.stripe import create_checkout_session
|
||||
from src.services.payments.webhooks.payments_connected_webhook import handle_stripe_webhook
|
||||
from src.services.payments.payments_stripe import create_checkout_session, update_stripe_account_id
|
||||
from src.services.payments.payments_access import check_course_paid_access
|
||||
from src.services.payments.payments_customers import get_customers
|
||||
from src.services.payments.payments_stripe import generate_stripe_connect_link
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -30,11 +31,12 @@ router = APIRouter()
|
|||
async def api_create_payments_config(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
payments_config: PaymentsConfigCreate,
|
||||
provider: Literal["stripe"],
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
) -> PaymentsConfig:
|
||||
return await create_payments_config(request, org_id, payments_config, current_user, db_session)
|
||||
return await init_payments_config(request, org_id, provider, current_user, db_session)
|
||||
|
||||
|
||||
@router.get("/{org_id}/config")
|
||||
async def api_get_payments_config(
|
||||
|
|
@ -45,16 +47,6 @@ async def api_get_payments_config(
|
|||
) -> list[PaymentsConfigRead]:
|
||||
return await get_payments_config(request, org_id, current_user, db_session)
|
||||
|
||||
@router.put("/{org_id}/config")
|
||||
async def api_update_payments_config(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
payments_config: PaymentsConfigUpdate,
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
) -> PaymentsConfig:
|
||||
return await update_payments_config(request, org_id, payments_config, current_user, db_session)
|
||||
|
||||
@router.delete("/{org_id}/config")
|
||||
async def api_delete_payments_config(
|
||||
request: Request,
|
||||
|
|
@ -227,4 +219,31 @@ async def api_get_owned_courses(
|
|||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
return await get_owned_courses(request, current_user, db_session)
|
||||
return await get_owned_courses(request, current_user, db_session)
|
||||
|
||||
@router.put("/{org_id}/stripe/account")
|
||||
async def api_update_stripe_account_id(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
stripe_account_id: str,
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
return await update_stripe_account_id(
|
||||
request, org_id, stripe_account_id, current_user, db_session
|
||||
)
|
||||
|
||||
@router.post("/{org_id}/stripe/connect/link")
|
||||
async def api_generate_stripe_connect_link(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
redirect_uri: str,
|
||||
current_user: PublicUser = Depends(get_current_user),
|
||||
db_session: Session = Depends(get_db_session),
|
||||
):
|
||||
"""
|
||||
Generate a Stripe OAuth link for connecting a Stripe account
|
||||
"""
|
||||
return await generate_stripe_connect_link(
|
||||
request, org_id, redirect_uri, current_user, db_session
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Literal
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from src.db.payments.payments import (
|
||||
PaymentProviderEnum,
|
||||
PaymentsConfig,
|
||||
PaymentsConfigCreate,
|
||||
PaymentsConfigUpdate,
|
||||
PaymentsConfigRead,
|
||||
)
|
||||
|
|
@ -11,33 +12,45 @@ from src.db.organizations import Organization
|
|||
from src.services.orgs.orgs import rbac_check
|
||||
|
||||
|
||||
async def create_payments_config(
|
||||
async def init_payments_config(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
payments_config: PaymentsConfigCreate,
|
||||
provider: Literal["stripe"],
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
) -> PaymentsConfig:
|
||||
# Check if organization exists
|
||||
statement = select(Organization).where(Organization.id == org_id)
|
||||
org = db_session.exec(statement).first()
|
||||
# Validate organization exists
|
||||
org = db_session.exec(
|
||||
select(Organization).where(Organization.id == org_id)
|
||||
).first()
|
||||
if not org:
|
||||
raise HTTPException(status_code=404, detail="Organization not found")
|
||||
|
||||
# RBAC check
|
||||
# Verify permissions
|
||||
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
|
||||
|
||||
# Check if payments config already exists for this organization
|
||||
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
|
||||
existing_config = db_session.exec(statement).first()
|
||||
# Check for existing config
|
||||
existing_config = db_session.exec(
|
||||
select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Payments config already exists for this organization",
|
||||
detail="Payments config already exists for this organization"
|
||||
)
|
||||
|
||||
# Create new payments config
|
||||
new_config = PaymentsConfig(**payments_config.model_dump(), org_id=org_id)
|
||||
# Initialize new config
|
||||
new_config = PaymentsConfig(
|
||||
org_id=org_id,
|
||||
provider=PaymentProviderEnum.STRIPE,
|
||||
provider_config={
|
||||
"onboarding_completed": False,
|
||||
"stripe_account_id": ""
|
||||
}
|
||||
)
|
||||
|
||||
# Save to database
|
||||
db_session.add(new_config)
|
||||
db_session.commit()
|
||||
db_session.refresh(new_config)
|
||||
|
|
@ -71,7 +84,7 @@ async def update_payments_config(
|
|||
request: Request,
|
||||
org_id: int,
|
||||
payments_config: PaymentsConfigUpdate,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
) -> PaymentsConfig:
|
||||
# Check if organization exists
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from src.db.organizations import Organization
|
|||
from src.services.orgs.orgs import rbac_check
|
||||
from datetime import datetime
|
||||
|
||||
from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product
|
||||
from src.services.payments.payments_stripe import archive_stripe_product, create_stripe_product, update_stripe_product
|
||||
|
||||
async def create_payments_product(
|
||||
request: Request,
|
||||
|
|
@ -33,12 +33,15 @@ async def create_payments_product(
|
|||
# RBAC check
|
||||
await rbac_check(request, org.org_uuid, current_user, "create", db_session)
|
||||
|
||||
# Check if payments config exists and has a valid id
|
||||
# Check if payments config exists, has a valid id, and is active
|
||||
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
|
||||
config = db_session.exec(statement).first()
|
||||
if not config or config.id is None:
|
||||
raise HTTPException(status_code=404, detail="Valid payments config not found")
|
||||
|
||||
if not config.active:
|
||||
raise HTTPException(status_code=400, detail="Payments config is not active")
|
||||
|
||||
# Create new payments product
|
||||
new_product = PaymentsProduct(**payments_product.model_dump(), org_id=org_id, payments_config_id=config.id)
|
||||
new_product.creation_date = datetime.now()
|
||||
|
|
|
|||
429
apps/api/src/services/payments/payments_stripe.py
Normal file
429
apps/api/src/services/payments/payments_stripe.py
Normal file
|
|
@ -0,0 +1,429 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session
|
||||
import stripe
|
||||
from config.config import get_learnhouse_config
|
||||
from src.db.payments.payments import PaymentsConfigUpdate, PaymentsConfig
|
||||
from src.db.payments.payments_products import (
|
||||
PaymentPriceTypeEnum,
|
||||
PaymentProductTypeEnum,
|
||||
PaymentsProduct,
|
||||
)
|
||||
from src.db.payments.payments_users import PaymentStatusEnum
|
||||
from src.db.users import AnonymousUser, InternalUser, PublicUser
|
||||
from src.services.payments.payments_config import (
|
||||
get_payments_config,
|
||||
update_payments_config,
|
||||
)
|
||||
from sqlmodel import select
|
||||
|
||||
from src.services.payments.payments_users import (
|
||||
create_payment_user,
|
||||
delete_payment_user,
|
||||
)
|
||||
|
||||
|
||||
async def get_stripe_connected_account_id(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
# Get payments config
|
||||
payments_config = await get_payments_config(request, org_id, current_user, db_session)
|
||||
|
||||
return payments_config[0].provider_config.get("stripe_account_id")
|
||||
|
||||
|
||||
async def get_stripe_credentials(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
# Get payments config from config file
|
||||
learnhouse_config = get_learnhouse_config()
|
||||
|
||||
if not learnhouse_config.payments_config.stripe.stripe_secret_key:
|
||||
raise HTTPException(status_code=400, detail="Stripe secret key not configured")
|
||||
|
||||
if not learnhouse_config.payments_config.stripe.stripe_publishable_key:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Stripe publishable key not configured"
|
||||
)
|
||||
|
||||
return {
|
||||
"stripe_secret_key": learnhouse_config.payments_config.stripe.stripe_secret_key,
|
||||
"stripe_publishable_key": learnhouse_config.payments_config.stripe.stripe_publishable_key,
|
||||
}
|
||||
|
||||
|
||||
async def create_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_data: PaymentsProduct,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get("stripe_secret_key")
|
||||
|
||||
# Prepare default_price_data based on price_type
|
||||
if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE:
|
||||
default_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"custom_unit_amount": {
|
||||
"enabled": True,
|
||||
"minimum": int(product_data.amount * 100), # Convert to cents
|
||||
},
|
||||
}
|
||||
else:
|
||||
default_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"unit_amount": int(product_data.amount * 100), # Convert to cents
|
||||
}
|
||||
|
||||
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
|
||||
default_price_data["recurring"] = {"interval": "month"}
|
||||
|
||||
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
|
||||
|
||||
product = stripe.Product.create(
|
||||
name=product_data.name,
|
||||
description=product_data.description or "",
|
||||
marketing_features=[
|
||||
{"name": benefit.strip()}
|
||||
for benefit in product_data.benefits.split(",")
|
||||
if benefit.strip()
|
||||
],
|
||||
default_price_data=default_price_data, # type: ignore
|
||||
stripe_account=stripe_acc_id,
|
||||
)
|
||||
|
||||
return product
|
||||
|
||||
|
||||
async def archive_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get("stripe_secret_key")
|
||||
|
||||
try:
|
||||
# Archive the product in Stripe
|
||||
archived_product = stripe.Product.modify(product_id, active=False)
|
||||
|
||||
return archived_product
|
||||
except stripe.StripeError as e:
|
||||
print(f"Error archiving Stripe product: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Error archiving Stripe product: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def update_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: str,
|
||||
product_data: PaymentsProduct,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get("stripe_secret_key")
|
||||
|
||||
try:
|
||||
# Create new price based on price_type
|
||||
if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE:
|
||||
new_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"product": product_id,
|
||||
"custom_unit_amount": {
|
||||
"enabled": True,
|
||||
"minimum": int(product_data.amount * 100), # Convert to cents
|
||||
},
|
||||
}
|
||||
else:
|
||||
new_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"unit_amount": int(product_data.amount * 100), # Convert to cents
|
||||
"product": product_id,
|
||||
}
|
||||
|
||||
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
|
||||
new_price_data["recurring"] = {"interval": "month"}
|
||||
|
||||
new_price = stripe.Price.create(**new_price_data)
|
||||
|
||||
# Prepare the update data
|
||||
update_data = {
|
||||
"name": product_data.name,
|
||||
"description": product_data.description or "",
|
||||
"metadata": {"benefits": product_data.benefits},
|
||||
"marketing_features": [
|
||||
{"name": benefit.strip()}
|
||||
for benefit in product_data.benefits.split(",")
|
||||
if benefit.strip()
|
||||
],
|
||||
"default_price": new_price.id,
|
||||
}
|
||||
|
||||
# Update the product in Stripe
|
||||
updated_product = stripe.Product.modify(product_id, **update_data)
|
||||
|
||||
# Archive all existing prices for the product
|
||||
existing_prices = stripe.Price.list(product=product_id, active=True)
|
||||
for price in existing_prices:
|
||||
if price.id != new_price.id:
|
||||
stripe.Price.modify(price.id, active=False)
|
||||
|
||||
return updated_product
|
||||
except stripe.StripeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Error updating Stripe product: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: int,
|
||||
redirect_uri: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
# Get Stripe credentials
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
stripe.api_key = creds.get("stripe_secret_key")
|
||||
|
||||
# Get product details
|
||||
statement = select(PaymentsProduct).where(
|
||||
PaymentsProduct.id == product_id, PaymentsProduct.org_id == org_id
|
||||
)
|
||||
product = db_session.exec(statement).first()
|
||||
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
|
||||
success_url = redirect_uri
|
||||
cancel_url = redirect_uri
|
||||
|
||||
# Get the default price for the product
|
||||
stripe_product = stripe.Product.retrieve(product.provider_product_id)
|
||||
line_items = [{"price": stripe_product.default_price, "quantity": 1}]
|
||||
|
||||
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
|
||||
|
||||
# Create or retrieve Stripe customer
|
||||
try:
|
||||
customers = stripe.Customer.list(
|
||||
email=current_user.email, stripe_account=stripe_acc_id
|
||||
)
|
||||
if customers.data:
|
||||
customer = customers.data[0]
|
||||
else:
|
||||
customer = stripe.Customer.create(
|
||||
email=current_user.email,
|
||||
metadata={
|
||||
"user_id": str(current_user.id),
|
||||
"org_id": str(org_id),
|
||||
},
|
||||
stripe_account=stripe_acc_id,
|
||||
)
|
||||
|
||||
# Create initial payment user with pending status
|
||||
payment_user = await create_payment_user(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
user_id=current_user.id,
|
||||
product_id=product_id,
|
||||
status=PaymentStatusEnum.PENDING,
|
||||
provider_data=customer,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if not payment_user:
|
||||
raise HTTPException(status_code=400, detail="Error creating payment user")
|
||||
|
||||
except stripe.StripeError as e:
|
||||
# Clean up payment user if customer creation fails
|
||||
if payment_user and payment_user.id:
|
||||
await delete_payment_user(
|
||||
request, org_id, payment_user.id, InternalUser(), db_session
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Error creating/retrieving customer: {str(e)}"
|
||||
)
|
||||
|
||||
# Create checkout session with customer
|
||||
try:
|
||||
checkout_session_params = {
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"mode": (
|
||||
"payment"
|
||||
if product.product_type == PaymentProductTypeEnum.ONE_TIME
|
||||
else "subscription"
|
||||
),
|
||||
"line_items": line_items,
|
||||
"customer": customer.id,
|
||||
"metadata": {
|
||||
"product_id": str(product.id),
|
||||
"payment_user_id": str(payment_user.id),
|
||||
},
|
||||
"stripe_account": stripe_acc_id,
|
||||
}
|
||||
|
||||
# Add payment_intent_data only for one-time payments
|
||||
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
|
||||
checkout_session_params["payment_intent_data"] = {
|
||||
"metadata": {
|
||||
"product_id": str(product.id),
|
||||
"payment_user_id": str(payment_user.id),
|
||||
}
|
||||
}
|
||||
# Add subscription_data for subscription payments
|
||||
else:
|
||||
checkout_session_params["subscription_data"] = {
|
||||
"metadata": {
|
||||
"product_id": str(product.id),
|
||||
"payment_user_id": str(payment_user.id),
|
||||
}
|
||||
}
|
||||
|
||||
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
|
||||
|
||||
return {"checkout_url": checkout_session.url, "session_id": checkout_session.id}
|
||||
|
||||
except stripe.StripeError as e:
|
||||
# Clean up payment user if checkout session creation fails
|
||||
if payment_user and payment_user.id:
|
||||
await delete_payment_user(
|
||||
request, org_id, payment_user.id, InternalUser(), db_session
|
||||
)
|
||||
logging.error(f"Error creating checkout session: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
async def generate_stripe_connect_link(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
redirect_uri: str,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
"""
|
||||
Generate a Stripe OAuth link for connecting a Stripe account
|
||||
"""
|
||||
# Get credentials
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
stripe.api_key = creds.get("stripe_secret_key")
|
||||
|
||||
# Get config
|
||||
learnhouse_config = get_learnhouse_config()
|
||||
|
||||
# Get client id
|
||||
stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session)
|
||||
if not stripe_acc_id:
|
||||
raise HTTPException(status_code=400, detail="No Stripe account ID found for this organization")
|
||||
|
||||
# Generate OAuth link
|
||||
connect_link = stripe.AccountLink.create(
|
||||
account=stripe_acc_id,
|
||||
type="account_onboarding",
|
||||
return_url=redirect_uri,
|
||||
refresh_url=redirect_uri,
|
||||
)
|
||||
|
||||
return {"connect_url": connect_link.url}
|
||||
|
||||
async def create_stripe_account(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
type: Literal["standard"], # Only standard is supported for now, we'll see if we need express later
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
# Get credentials
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
stripe.api_key = creds.get("stripe_secret_key")
|
||||
|
||||
# Get existing payments config
|
||||
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
|
||||
existing_config = db_session.exec(statement).first()
|
||||
|
||||
if existing_config and existing_config.provider_config.get("stripe_account_id"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="A Stripe Express account is already linked to this organization"
|
||||
)
|
||||
|
||||
# Create Stripe account
|
||||
stripe_account = stripe.Account.create(
|
||||
type="standard",
|
||||
capabilities={
|
||||
"card_payments": {"requested": True},
|
||||
"transfers": {"requested": True},
|
||||
},
|
||||
)
|
||||
|
||||
# Update payments config for the org
|
||||
await update_payments_config(
|
||||
request,
|
||||
org_id,
|
||||
PaymentsConfigUpdate(
|
||||
enabled=True,
|
||||
provider_config={"stripe_account_id": stripe_account.id}
|
||||
),
|
||||
current_user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return stripe_account
|
||||
|
||||
|
||||
async def update_stripe_account_id(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
stripe_account_id: str,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
"""
|
||||
Update the Stripe account ID for an organization
|
||||
"""
|
||||
# Get existing payments config
|
||||
statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id)
|
||||
existing_config = db_session.exec(statement).first()
|
||||
|
||||
if not existing_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No payments configuration found for this organization"
|
||||
)
|
||||
|
||||
# Update payments config with new stripe account id
|
||||
await update_payments_config(
|
||||
request,
|
||||
org_id,
|
||||
PaymentsConfigUpdate(
|
||||
enabled=True,
|
||||
provider_config={"stripe_account_id": stripe_account_id}
|
||||
),
|
||||
current_user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return {"message": "Stripe account ID updated successfully"}
|
||||
|
|
@ -1,272 +0,0 @@
|
|||
import logging
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlmodel import Session
|
||||
import stripe
|
||||
from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct
|
||||
from src.db.payments.payments_users import PaymentStatusEnum
|
||||
from src.db.users import AnonymousUser, InternalUser, PublicUser
|
||||
from src.services.payments.payments_config import get_payments_config
|
||||
from sqlmodel import select
|
||||
|
||||
from src.services.payments.payments_users import create_payment_user, delete_payment_user
|
||||
|
||||
async def get_stripe_credentials(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
current_user: PublicUser | AnonymousUser | InternalUser,
|
||||
db_session: Session,
|
||||
):
|
||||
configs = await get_payments_config(request, org_id, current_user, db_session)
|
||||
|
||||
if len(configs) == 0:
|
||||
raise HTTPException(status_code=404, detail="Payments config not found")
|
||||
if len(configs) > 1:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Organization has multiple payments configs"
|
||||
)
|
||||
config = configs[0]
|
||||
if config.provider != "stripe":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Payments config is not a Stripe config"
|
||||
)
|
||||
|
||||
# Get provider config
|
||||
credentials = config.provider_config
|
||||
|
||||
return credentials
|
||||
|
||||
async def create_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_data: PaymentsProduct,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
# Prepare default_price_data based on price_type
|
||||
if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE:
|
||||
default_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"custom_unit_amount": {
|
||||
"enabled": True,
|
||||
"minimum": int(product_data.amount * 100), # Convert to cents
|
||||
}
|
||||
}
|
||||
else:
|
||||
default_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"unit_amount": int(product_data.amount * 100) # Convert to cents
|
||||
}
|
||||
|
||||
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
|
||||
default_price_data["recurring"] = {"interval": "month"}
|
||||
|
||||
product = stripe.Product.create(
|
||||
name=product_data.name,
|
||||
description=product_data.description or "",
|
||||
marketing_features=[{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()],
|
||||
default_price_data=default_price_data # type: ignore
|
||||
)
|
||||
|
||||
return product
|
||||
|
||||
async def archive_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
try:
|
||||
# Archive the product in Stripe
|
||||
archived_product = stripe.Product.modify(product_id, active=False)
|
||||
|
||||
return archived_product
|
||||
except stripe.StripeError as e:
|
||||
print(f"Error archiving Stripe product: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"Error archiving Stripe product: {str(e)}")
|
||||
|
||||
async def update_stripe_product(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: str,
|
||||
product_data: PaymentsProduct,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
|
||||
# Set the Stripe API key using the credentials
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
try:
|
||||
# Create new price based on price_type
|
||||
if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE:
|
||||
new_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"product": product_id,
|
||||
"custom_unit_amount": {
|
||||
"enabled": True,
|
||||
"minimum": int(product_data.amount * 100), # Convert to cents
|
||||
}
|
||||
}
|
||||
else:
|
||||
new_price_data = {
|
||||
"currency": product_data.currency,
|
||||
"unit_amount": int(product_data.amount * 100), # Convert to cents
|
||||
"product": product_id,
|
||||
}
|
||||
|
||||
if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION:
|
||||
new_price_data["recurring"] = {"interval": "month"}
|
||||
|
||||
new_price = stripe.Price.create(**new_price_data)
|
||||
|
||||
# Prepare the update data
|
||||
update_data = {
|
||||
"name": product_data.name,
|
||||
"description": product_data.description or "",
|
||||
"metadata": {"benefits": product_data.benefits},
|
||||
"marketing_features": [{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()],
|
||||
"default_price": new_price.id
|
||||
}
|
||||
|
||||
# Update the product in Stripe
|
||||
updated_product = stripe.Product.modify(product_id, **update_data)
|
||||
|
||||
# Archive all existing prices for the product
|
||||
existing_prices = stripe.Price.list(product=product_id, active=True)
|
||||
for price in existing_prices:
|
||||
if price.id != new_price.id:
|
||||
stripe.Price.modify(price.id, active=False)
|
||||
|
||||
return updated_product
|
||||
except stripe.StripeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}")
|
||||
|
||||
async def create_checkout_session(
|
||||
request: Request,
|
||||
org_id: int,
|
||||
product_id: int,
|
||||
redirect_uri: str,
|
||||
current_user: PublicUser | AnonymousUser,
|
||||
db_session: Session,
|
||||
):
|
||||
# Get Stripe credentials
|
||||
creds = await get_stripe_credentials(request, org_id, current_user, db_session)
|
||||
stripe.api_key = creds.get('stripe_secret_key')
|
||||
|
||||
# Get product details
|
||||
statement = select(PaymentsProduct).where(
|
||||
PaymentsProduct.id == product_id,
|
||||
PaymentsProduct.org_id == org_id
|
||||
)
|
||||
product = db_session.exec(statement).first()
|
||||
|
||||
if not product:
|
||||
raise HTTPException(status_code=404, detail="Product not found")
|
||||
|
||||
|
||||
success_url = redirect_uri
|
||||
cancel_url = redirect_uri
|
||||
|
||||
# Get the default price for the product
|
||||
stripe_product = stripe.Product.retrieve(product.provider_product_id)
|
||||
line_items = [{
|
||||
"price": stripe_product.default_price,
|
||||
"quantity": 1
|
||||
}]
|
||||
|
||||
# Create or retrieve Stripe customer
|
||||
try:
|
||||
customers = stripe.Customer.list(email=current_user.email)
|
||||
if customers.data:
|
||||
customer = customers.data[0]
|
||||
else:
|
||||
customer = stripe.Customer.create(
|
||||
email=current_user.email,
|
||||
metadata={
|
||||
"user_id": str(current_user.id),
|
||||
"org_id": str(org_id),
|
||||
}
|
||||
)
|
||||
|
||||
# Create initial payment user with pending status
|
||||
payment_user = await create_payment_user(
|
||||
request=request,
|
||||
org_id=org_id,
|
||||
user_id=current_user.id,
|
||||
product_id=product_id,
|
||||
status=PaymentStatusEnum.PENDING,
|
||||
provider_data=customer,
|
||||
current_user=InternalUser(),
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
if not payment_user:
|
||||
raise HTTPException(status_code=400, detail="Error creating payment user")
|
||||
|
||||
except stripe.StripeError as e:
|
||||
# Clean up payment user if customer creation fails
|
||||
if payment_user and payment_user.id:
|
||||
await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session)
|
||||
raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}")
|
||||
|
||||
# Create checkout session with customer
|
||||
try:
|
||||
checkout_session_params = {
|
||||
"success_url": success_url,
|
||||
"cancel_url": cancel_url,
|
||||
"mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription',
|
||||
"line_items": line_items,
|
||||
"customer": customer.id,
|
||||
"metadata": {
|
||||
"product_id": str(product.id),
|
||||
"payment_user_id": str(payment_user.id)
|
||||
}
|
||||
}
|
||||
|
||||
# Add payment_intent_data only for one-time payments
|
||||
if product.product_type == PaymentProductTypeEnum.ONE_TIME:
|
||||
checkout_session_params["payment_intent_data"] = {
|
||||
"metadata": {
|
||||
"product_id": str(product.id),
|
||||
"payment_user_id": str(payment_user.id)
|
||||
}
|
||||
}
|
||||
# Add subscription_data for subscription payments
|
||||
else:
|
||||
checkout_session_params["subscription_data"] = {
|
||||
"metadata": {
|
||||
"product_id": str(product.id),
|
||||
"payment_user_id": str(payment_user.id)
|
||||
}
|
||||
}
|
||||
|
||||
checkout_session = stripe.checkout.Session.create(**checkout_session_params)
|
||||
|
||||
return {
|
||||
"checkout_url": checkout_session.url,
|
||||
"session_id": checkout_session.id
|
||||
}
|
||||
|
||||
except stripe.StripeError as e:
|
||||
# Clean up payment user if checkout session creation fails
|
||||
if payment_user and payment_user.id:
|
||||
await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session)
|
||||
logging.error(f"Error creating checkout session: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from src.db.payments.payments_users import PaymentStatusEnum
|
|||
from src.db.payments.payments_products import PaymentsProduct
|
||||
from src.db.users import InternalUser, User
|
||||
from src.services.payments.payments_users import update_payment_user_status
|
||||
from src.services.payments.stripe import get_stripe_credentials
|
||||
from src.services.payments.payments_stripe import get_stripe_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue