diff --git a/apps/api/config/config.py b/apps/api/config/config.py index 4ff84e96..7be01dd8 100644 --- a/apps/api/config/config.py +++ b/apps/api/config/config.py @@ -71,6 +71,15 @@ class RedisConfig(BaseModel): redis_connection_string: Optional[str] +class InternalStripeConfig(BaseModel): + stripe_secret_key: str | None + stripe_publishable_key: str | None + + +class InternalPaymentsConfig(BaseModel): + stripe: InternalStripeConfig + + class LearnHouseConfig(BaseModel): site_name: str site_description: str @@ -82,6 +91,7 @@ class LearnHouseConfig(BaseModel): security_config: SecurityConfig ai_config: AIConfig mailing_config: MailingConfig + payments_config: InternalPaymentsConfig def get_learnhouse_config() -> LearnHouseConfig: @@ -261,6 +271,18 @@ def get_learnhouse_config() -> LearnHouseConfig: else: sentry_config = None + # Payments config + env_stripe_secret_key = os.environ.get("LEARNHOUSE_STRIPE_SECRET_KEY") + env_stripe_publishable_key = os.environ.get("LEARNHOUSE_STRIPE_PUBLISHABLE_KEY") + + stripe_secret_key = env_stripe_secret_key or yaml_config.get("payments_config", {}).get( + "stripe", {} + ).get("stripe_secret_key") + + stripe_publishable_key = env_stripe_publishable_key or yaml_config.get("payments_config", {}).get( + "stripe", {} + ).get("stripe_publishable_key") + # Create HostingConfig and DatabaseConfig objects hosting_config = HostingConfig( domain=domain, @@ -303,6 +325,12 @@ def get_learnhouse_config() -> LearnHouseConfig: mailing_config=MailingConfig( resend_api_key=resend_api_key, system_email_address=system_email_address ), + payments_config=InternalPaymentsConfig( + stripe=InternalStripeConfig( + stripe_secret_key=stripe_secret_key, + stripe_publishable_key=stripe_publishable_key + ) + ) ) return config diff --git a/apps/api/config/config.yaml b/apps/api/config/config.yaml index b9a8c0b6..2a9694a1 100644 --- a/apps/api/config/config.yaml +++ b/apps/api/config/config.yaml @@ -37,6 +37,11 @@ database_config: redis_config: redis_connection_string: redis://localhost:6379/learnhouse +payments_config: + stripe: + stripe_secret_key: "" + stripe_publishable_key: "" + ai_config: chromadb_config: isSeparateDatabaseEnabled: True diff --git a/apps/api/src/db/payments/payments.py b/apps/api/src/db/payments/payments.py index bfecc7e6..7f8d5c85 100644 --- a/apps/api/src/db/payments/payments.py +++ b/apps/api/src/db/payments/payments.py @@ -1,22 +1,16 @@ from datetime import datetime from enum import Enum from typing import Optional -from pydantic import BaseModel from sqlalchemy import JSON from sqlmodel import Field, SQLModel, Column, BigInteger, ForeignKey -# Stripe provider config -class StripeProviderConfig(BaseModel): - stripe_key: str = "" - stripe_secret_key: str = "" - stripe_webhook_secret: str = "" - # PaymentsConfig class PaymentProviderEnum(str, Enum): STRIPE = "stripe" class PaymentsConfigBase(SQLModel): enabled: bool = True + active: bool = False provider: PaymentProviderEnum = PaymentProviderEnum.STRIPE provider_config: dict = Field(default={}, sa_column=Column(JSON)) diff --git a/apps/api/src/routers/ee/payments.py b/apps/api/src/routers/ee/payments.py index aaec689c..e4569c33 100644 --- a/apps/api/src/routers/ee/payments.py +++ b/apps/api/src/routers/ee/payments.py @@ -1,13 +1,13 @@ +from typing import Literal from fastapi import APIRouter, Depends, Request from sqlmodel import Session from src.core.events.database import get_db_session -from src.db.payments.payments import PaymentsConfig, PaymentsConfigCreate, PaymentsConfigRead, PaymentsConfigUpdate +from src.db.payments.payments import PaymentsConfig, PaymentsConfigRead from src.db.users import PublicUser from src.security.auth import get_current_user from src.services.payments.payments_config import ( - create_payments_config, + init_payments_config, get_payments_config, - update_payments_config, delete_payments_config, ) from src.db.payments.payments_products import PaymentsProductCreate, PaymentsProductRead, PaymentsProductUpdate @@ -18,10 +18,11 @@ from src.services.payments.payments_courses import ( get_courses_by_product, ) from src.services.payments.payments_users import get_owned_courses -from src.services.payments.payments_webhook import handle_stripe_webhook -from src.services.payments.stripe import create_checkout_session +from src.services.payments.webhooks.payments_connected_webhook import handle_stripe_webhook +from src.services.payments.payments_stripe import create_checkout_session, update_stripe_account_id from src.services.payments.payments_access import check_course_paid_access from src.services.payments.payments_customers import get_customers +from src.services.payments.payments_stripe import generate_stripe_connect_link router = APIRouter() @@ -30,11 +31,12 @@ router = APIRouter() async def api_create_payments_config( request: Request, org_id: int, - payments_config: PaymentsConfigCreate, + provider: Literal["stripe"], current_user: PublicUser = Depends(get_current_user), db_session: Session = Depends(get_db_session), ) -> PaymentsConfig: - return await create_payments_config(request, org_id, payments_config, current_user, db_session) + return await init_payments_config(request, org_id, provider, current_user, db_session) + @router.get("/{org_id}/config") async def api_get_payments_config( @@ -45,16 +47,6 @@ async def api_get_payments_config( ) -> list[PaymentsConfigRead]: return await get_payments_config(request, org_id, current_user, db_session) -@router.put("/{org_id}/config") -async def api_update_payments_config( - request: Request, - org_id: int, - payments_config: PaymentsConfigUpdate, - current_user: PublicUser = Depends(get_current_user), - db_session: Session = Depends(get_db_session), -) -> PaymentsConfig: - return await update_payments_config(request, org_id, payments_config, current_user, db_session) - @router.delete("/{org_id}/config") async def api_delete_payments_config( request: Request, @@ -227,4 +219,31 @@ async def api_get_owned_courses( current_user: PublicUser = Depends(get_current_user), db_session: Session = Depends(get_db_session), ): - return await get_owned_courses(request, current_user, db_session) \ No newline at end of file + return await get_owned_courses(request, current_user, db_session) + +@router.put("/{org_id}/stripe/account") +async def api_update_stripe_account_id( + request: Request, + org_id: int, + stripe_account_id: str, + current_user: PublicUser = Depends(get_current_user), + db_session: Session = Depends(get_db_session), +): + return await update_stripe_account_id( + request, org_id, stripe_account_id, current_user, db_session + ) + +@router.post("/{org_id}/stripe/connect/link") +async def api_generate_stripe_connect_link( + request: Request, + org_id: int, + redirect_uri: str, + current_user: PublicUser = Depends(get_current_user), + db_session: Session = Depends(get_db_session), +): + """ + Generate a Stripe OAuth link for connecting a Stripe account + """ + return await generate_stripe_connect_link( + request, org_id, redirect_uri, current_user, db_session + ) diff --git a/apps/api/src/services/payments/payments_config.py b/apps/api/src/services/payments/payments_config.py index 074a68d8..16ba2fcc 100644 --- a/apps/api/src/services/payments/payments_config.py +++ b/apps/api/src/services/payments/payments_config.py @@ -1,8 +1,9 @@ +from typing import Literal from fastapi import HTTPException, Request from sqlmodel import Session, select from src.db.payments.payments import ( + PaymentProviderEnum, PaymentsConfig, - PaymentsConfigCreate, PaymentsConfigUpdate, PaymentsConfigRead, ) @@ -11,33 +12,45 @@ from src.db.organizations import Organization from src.services.orgs.orgs import rbac_check -async def create_payments_config( +async def init_payments_config( request: Request, org_id: int, - payments_config: PaymentsConfigCreate, + provider: Literal["stripe"], current_user: PublicUser | AnonymousUser, db_session: Session, ) -> PaymentsConfig: - # Check if organization exists - statement = select(Organization).where(Organization.id == org_id) - org = db_session.exec(statement).first() + # Validate organization exists + org = db_session.exec( + select(Organization).where(Organization.id == org_id) + ).first() if not org: raise HTTPException(status_code=404, detail="Organization not found") - # RBAC check + # Verify permissions await rbac_check(request, org.org_uuid, current_user, "create", db_session) - # Check if payments config already exists for this organization - statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id) - existing_config = db_session.exec(statement).first() + # Check for existing config + existing_config = db_session.exec( + select(PaymentsConfig).where(PaymentsConfig.org_id == org_id) + ).first() + if existing_config: raise HTTPException( status_code=409, - detail="Payments config already exists for this organization", + detail="Payments config already exists for this organization" ) - # Create new payments config - new_config = PaymentsConfig(**payments_config.model_dump(), org_id=org_id) + # Initialize new config + new_config = PaymentsConfig( + org_id=org_id, + provider=PaymentProviderEnum.STRIPE, + provider_config={ + "onboarding_completed": False, + "stripe_account_id": "" + } + ) + + # Save to database db_session.add(new_config) db_session.commit() db_session.refresh(new_config) @@ -71,7 +84,7 @@ async def update_payments_config( request: Request, org_id: int, payments_config: PaymentsConfigUpdate, - current_user: PublicUser | AnonymousUser, + current_user: PublicUser | AnonymousUser | InternalUser, db_session: Session, ) -> PaymentsConfig: # Check if organization exists diff --git a/apps/api/src/services/payments/payments_products.py b/apps/api/src/services/payments/payments_products.py index ce2dede1..81874e80 100644 --- a/apps/api/src/services/payments/payments_products.py +++ b/apps/api/src/services/payments/payments_products.py @@ -15,7 +15,7 @@ from src.db.organizations import Organization from src.services.orgs.orgs import rbac_check from datetime import datetime -from src.services.payments.stripe import archive_stripe_product, create_stripe_product, update_stripe_product +from src.services.payments.payments_stripe import archive_stripe_product, create_stripe_product, update_stripe_product async def create_payments_product( request: Request, @@ -33,12 +33,15 @@ async def create_payments_product( # RBAC check await rbac_check(request, org.org_uuid, current_user, "create", db_session) - # Check if payments config exists and has a valid id + # Check if payments config exists, has a valid id, and is active statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id) config = db_session.exec(statement).first() if not config or config.id is None: raise HTTPException(status_code=404, detail="Valid payments config not found") + if not config.active: + raise HTTPException(status_code=400, detail="Payments config is not active") + # Create new payments product new_product = PaymentsProduct(**payments_product.model_dump(), org_id=org_id, payments_config_id=config.id) new_product.creation_date = datetime.now() diff --git a/apps/api/src/services/payments/payments_stripe.py b/apps/api/src/services/payments/payments_stripe.py new file mode 100644 index 00000000..65c17ddc --- /dev/null +++ b/apps/api/src/services/payments/payments_stripe.py @@ -0,0 +1,429 @@ +import logging +from typing import Literal +from fastapi import HTTPException, Request +from sqlmodel import Session +import stripe +from config.config import get_learnhouse_config +from src.db.payments.payments import PaymentsConfigUpdate, PaymentsConfig +from src.db.payments.payments_products import ( + PaymentPriceTypeEnum, + PaymentProductTypeEnum, + PaymentsProduct, +) +from src.db.payments.payments_users import PaymentStatusEnum +from src.db.users import AnonymousUser, InternalUser, PublicUser +from src.services.payments.payments_config import ( + get_payments_config, + update_payments_config, +) +from sqlmodel import select + +from src.services.payments.payments_users import ( + create_payment_user, + delete_payment_user, +) + + +async def get_stripe_connected_account_id( + request: Request, + org_id: int, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +): + # Get payments config + payments_config = await get_payments_config(request, org_id, current_user, db_session) + + return payments_config[0].provider_config.get("stripe_account_id") + + +async def get_stripe_credentials( + request: Request, + org_id: int, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +): + # Get payments config from config file + learnhouse_config = get_learnhouse_config() + + if not learnhouse_config.payments_config.stripe.stripe_secret_key: + raise HTTPException(status_code=400, detail="Stripe secret key not configured") + + if not learnhouse_config.payments_config.stripe.stripe_publishable_key: + raise HTTPException( + status_code=400, detail="Stripe publishable key not configured" + ) + + return { + "stripe_secret_key": learnhouse_config.payments_config.stripe.stripe_secret_key, + "stripe_publishable_key": learnhouse_config.payments_config.stripe.stripe_publishable_key, + } + + +async def create_stripe_product( + request: Request, + org_id: int, + product_data: PaymentsProduct, + current_user: PublicUser | AnonymousUser, + db_session: Session, +): + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + + # Set the Stripe API key using the credentials + stripe.api_key = creds.get("stripe_secret_key") + + # Prepare default_price_data based on price_type + if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE: + default_price_data = { + "currency": product_data.currency, + "custom_unit_amount": { + "enabled": True, + "minimum": int(product_data.amount * 100), # Convert to cents + }, + } + else: + default_price_data = { + "currency": product_data.currency, + "unit_amount": int(product_data.amount * 100), # Convert to cents + } + + if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION: + default_price_data["recurring"] = {"interval": "month"} + + stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session) + + product = stripe.Product.create( + name=product_data.name, + description=product_data.description or "", + marketing_features=[ + {"name": benefit.strip()} + for benefit in product_data.benefits.split(",") + if benefit.strip() + ], + default_price_data=default_price_data, # type: ignore + stripe_account=stripe_acc_id, + ) + + return product + + +async def archive_stripe_product( + request: Request, + org_id: int, + product_id: str, + current_user: PublicUser | AnonymousUser, + db_session: Session, +): + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + + # Set the Stripe API key using the credentials + stripe.api_key = creds.get("stripe_secret_key") + + try: + # Archive the product in Stripe + archived_product = stripe.Product.modify(product_id, active=False) + + return archived_product + except stripe.StripeError as e: + print(f"Error archiving Stripe product: {str(e)}") + raise HTTPException( + status_code=400, detail=f"Error archiving Stripe product: {str(e)}" + ) + + +async def update_stripe_product( + request: Request, + org_id: int, + product_id: str, + product_data: PaymentsProduct, + current_user: PublicUser | AnonymousUser, + db_session: Session, +): + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + + # Set the Stripe API key using the credentials + stripe.api_key = creds.get("stripe_secret_key") + + try: + # Create new price based on price_type + if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE: + new_price_data = { + "currency": product_data.currency, + "product": product_id, + "custom_unit_amount": { + "enabled": True, + "minimum": int(product_data.amount * 100), # Convert to cents + }, + } + else: + new_price_data = { + "currency": product_data.currency, + "unit_amount": int(product_data.amount * 100), # Convert to cents + "product": product_id, + } + + if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION: + new_price_data["recurring"] = {"interval": "month"} + + new_price = stripe.Price.create(**new_price_data) + + # Prepare the update data + update_data = { + "name": product_data.name, + "description": product_data.description or "", + "metadata": {"benefits": product_data.benefits}, + "marketing_features": [ + {"name": benefit.strip()} + for benefit in product_data.benefits.split(",") + if benefit.strip() + ], + "default_price": new_price.id, + } + + # Update the product in Stripe + updated_product = stripe.Product.modify(product_id, **update_data) + + # Archive all existing prices for the product + existing_prices = stripe.Price.list(product=product_id, active=True) + for price in existing_prices: + if price.id != new_price.id: + stripe.Price.modify(price.id, active=False) + + return updated_product + except stripe.StripeError as e: + raise HTTPException( + status_code=400, detail=f"Error updating Stripe product: {str(e)}" + ) + + +async def create_checkout_session( + request: Request, + org_id: int, + product_id: int, + redirect_uri: str, + current_user: PublicUser | AnonymousUser, + db_session: Session, +): + # Get Stripe credentials + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + stripe.api_key = creds.get("stripe_secret_key") + + # Get product details + statement = select(PaymentsProduct).where( + PaymentsProduct.id == product_id, PaymentsProduct.org_id == org_id + ) + product = db_session.exec(statement).first() + + if not product: + raise HTTPException(status_code=404, detail="Product not found") + + success_url = redirect_uri + cancel_url = redirect_uri + + # Get the default price for the product + stripe_product = stripe.Product.retrieve(product.provider_product_id) + line_items = [{"price": stripe_product.default_price, "quantity": 1}] + + stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session) + + # Create or retrieve Stripe customer + try: + customers = stripe.Customer.list( + email=current_user.email, stripe_account=stripe_acc_id + ) + if customers.data: + customer = customers.data[0] + else: + customer = stripe.Customer.create( + email=current_user.email, + metadata={ + "user_id": str(current_user.id), + "org_id": str(org_id), + }, + stripe_account=stripe_acc_id, + ) + + # Create initial payment user with pending status + payment_user = await create_payment_user( + request=request, + org_id=org_id, + user_id=current_user.id, + product_id=product_id, + status=PaymentStatusEnum.PENDING, + provider_data=customer, + current_user=InternalUser(), + db_session=db_session, + ) + + if not payment_user: + raise HTTPException(status_code=400, detail="Error creating payment user") + + except stripe.StripeError as e: + # Clean up payment user if customer creation fails + if payment_user and payment_user.id: + await delete_payment_user( + request, org_id, payment_user.id, InternalUser(), db_session + ) + raise HTTPException( + status_code=400, detail=f"Error creating/retrieving customer: {str(e)}" + ) + + # Create checkout session with customer + try: + checkout_session_params = { + "success_url": success_url, + "cancel_url": cancel_url, + "mode": ( + "payment" + if product.product_type == PaymentProductTypeEnum.ONE_TIME + else "subscription" + ), + "line_items": line_items, + "customer": customer.id, + "metadata": { + "product_id": str(product.id), + "payment_user_id": str(payment_user.id), + }, + "stripe_account": stripe_acc_id, + } + + # Add payment_intent_data only for one-time payments + if product.product_type == PaymentProductTypeEnum.ONE_TIME: + checkout_session_params["payment_intent_data"] = { + "metadata": { + "product_id": str(product.id), + "payment_user_id": str(payment_user.id), + } + } + # Add subscription_data for subscription payments + else: + checkout_session_params["subscription_data"] = { + "metadata": { + "product_id": str(product.id), + "payment_user_id": str(payment_user.id), + } + } + + checkout_session = stripe.checkout.Session.create(**checkout_session_params) + + return {"checkout_url": checkout_session.url, "session_id": checkout_session.id} + + except stripe.StripeError as e: + # Clean up payment user if checkout session creation fails + if payment_user and payment_user.id: + await delete_payment_user( + request, org_id, payment_user.id, InternalUser(), db_session + ) + logging.error(f"Error creating checkout session: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + + +async def generate_stripe_connect_link( + request: Request, + org_id: int, + redirect_uri: str, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +): + """ + Generate a Stripe OAuth link for connecting a Stripe account + """ + # Get credentials + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + stripe.api_key = creds.get("stripe_secret_key") + + # Get config + learnhouse_config = get_learnhouse_config() + + # Get client id + stripe_acc_id = await get_stripe_connected_account_id(request, org_id, current_user, db_session) + if not stripe_acc_id: + raise HTTPException(status_code=400, detail="No Stripe account ID found for this organization") + + # Generate OAuth link + connect_link = stripe.AccountLink.create( + account=stripe_acc_id, + type="account_onboarding", + return_url=redirect_uri, + refresh_url=redirect_uri, + ) + + return {"connect_url": connect_link.url} + +async def create_stripe_account( + request: Request, + org_id: int, + type: Literal["standard"], # Only standard is supported for now, we'll see if we need express later + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +): + # Get credentials + creds = await get_stripe_credentials(request, org_id, current_user, db_session) + stripe.api_key = creds.get("stripe_secret_key") + + # Get existing payments config + statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id) + existing_config = db_session.exec(statement).first() + + if existing_config and existing_config.provider_config.get("stripe_account_id"): + raise HTTPException( + status_code=400, + detail="A Stripe Express account is already linked to this organization" + ) + + # Create Stripe account + stripe_account = stripe.Account.create( + type="standard", + capabilities={ + "card_payments": {"requested": True}, + "transfers": {"requested": True}, + }, + ) + + # Update payments config for the org + await update_payments_config( + request, + org_id, + PaymentsConfigUpdate( + enabled=True, + provider_config={"stripe_account_id": stripe_account.id} + ), + current_user, + db_session, + ) + + return stripe_account + + +async def update_stripe_account_id( + request: Request, + org_id: int, + stripe_account_id: str, + current_user: PublicUser | AnonymousUser | InternalUser, + db_session: Session, +): + """ + Update the Stripe account ID for an organization + """ + # Get existing payments config + statement = select(PaymentsConfig).where(PaymentsConfig.org_id == org_id) + existing_config = db_session.exec(statement).first() + + if not existing_config: + raise HTTPException( + status_code=404, + detail="No payments configuration found for this organization" + ) + + # Update payments config with new stripe account id + await update_payments_config( + request, + org_id, + PaymentsConfigUpdate( + enabled=True, + provider_config={"stripe_account_id": stripe_account_id} + ), + current_user, + db_session, + ) + + return {"message": "Stripe account ID updated successfully"} diff --git a/apps/api/src/services/payments/stripe.py b/apps/api/src/services/payments/stripe.py deleted file mode 100644 index b7c8643d..00000000 --- a/apps/api/src/services/payments/stripe.py +++ /dev/null @@ -1,272 +0,0 @@ -import logging -from fastapi import HTTPException, Request -from sqlmodel import Session -import stripe -from src.db.payments.payments_products import PaymentPriceTypeEnum, PaymentProductTypeEnum, PaymentsProduct -from src.db.payments.payments_users import PaymentStatusEnum -from src.db.users import AnonymousUser, InternalUser, PublicUser -from src.services.payments.payments_config import get_payments_config -from sqlmodel import select - -from src.services.payments.payments_users import create_payment_user, delete_payment_user - -async def get_stripe_credentials( - request: Request, - org_id: int, - current_user: PublicUser | AnonymousUser | InternalUser, - db_session: Session, -): - configs = await get_payments_config(request, org_id, current_user, db_session) - - if len(configs) == 0: - raise HTTPException(status_code=404, detail="Payments config not found") - if len(configs) > 1: - raise HTTPException( - status_code=400, detail="Organization has multiple payments configs" - ) - config = configs[0] - if config.provider != "stripe": - raise HTTPException( - status_code=400, detail="Payments config is not a Stripe config" - ) - - # Get provider config - credentials = config.provider_config - - return credentials - -async def create_stripe_product( - request: Request, - org_id: int, - product_data: PaymentsProduct, - current_user: PublicUser | AnonymousUser, - db_session: Session, -): - creds = await get_stripe_credentials(request, org_id, current_user, db_session) - - # Set the Stripe API key using the credentials - stripe.api_key = creds.get('stripe_secret_key') - - # Prepare default_price_data based on price_type - if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE: - default_price_data = { - "currency": product_data.currency, - "custom_unit_amount": { - "enabled": True, - "minimum": int(product_data.amount * 100), # Convert to cents - } - } - else: - default_price_data = { - "currency": product_data.currency, - "unit_amount": int(product_data.amount * 100) # Convert to cents - } - - if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION: - default_price_data["recurring"] = {"interval": "month"} - - product = stripe.Product.create( - name=product_data.name, - description=product_data.description or "", - marketing_features=[{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()], - default_price_data=default_price_data # type: ignore - ) - - return product - -async def archive_stripe_product( - request: Request, - org_id: int, - product_id: str, - current_user: PublicUser | AnonymousUser, - db_session: Session, -): - creds = await get_stripe_credentials(request, org_id, current_user, db_session) - - # Set the Stripe API key using the credentials - stripe.api_key = creds.get('stripe_secret_key') - - try: - # Archive the product in Stripe - archived_product = stripe.Product.modify(product_id, active=False) - - return archived_product - except stripe.StripeError as e: - print(f"Error archiving Stripe product: {str(e)}") - raise HTTPException(status_code=400, detail=f"Error archiving Stripe product: {str(e)}") - -async def update_stripe_product( - request: Request, - org_id: int, - product_id: str, - product_data: PaymentsProduct, - current_user: PublicUser | AnonymousUser, - db_session: Session, -): - creds = await get_stripe_credentials(request, org_id, current_user, db_session) - - # Set the Stripe API key using the credentials - stripe.api_key = creds.get('stripe_secret_key') - - try: - # Create new price based on price_type - if product_data.price_type == PaymentPriceTypeEnum.CUSTOMER_CHOICE: - new_price_data = { - "currency": product_data.currency, - "product": product_id, - "custom_unit_amount": { - "enabled": True, - "minimum": int(product_data.amount * 100), # Convert to cents - } - } - else: - new_price_data = { - "currency": product_data.currency, - "unit_amount": int(product_data.amount * 100), # Convert to cents - "product": product_id, - } - - if product_data.product_type == PaymentProductTypeEnum.SUBSCRIPTION: - new_price_data["recurring"] = {"interval": "month"} - - new_price = stripe.Price.create(**new_price_data) - - # Prepare the update data - update_data = { - "name": product_data.name, - "description": product_data.description or "", - "metadata": {"benefits": product_data.benefits}, - "marketing_features": [{"name": benefit.strip()} for benefit in product_data.benefits.split(",") if benefit.strip()], - "default_price": new_price.id - } - - # Update the product in Stripe - updated_product = stripe.Product.modify(product_id, **update_data) - - # Archive all existing prices for the product - existing_prices = stripe.Price.list(product=product_id, active=True) - for price in existing_prices: - if price.id != new_price.id: - stripe.Price.modify(price.id, active=False) - - return updated_product - except stripe.StripeError as e: - raise HTTPException(status_code=400, detail=f"Error updating Stripe product: {str(e)}") - -async def create_checkout_session( - request: Request, - org_id: int, - product_id: int, - redirect_uri: str, - current_user: PublicUser | AnonymousUser, - db_session: Session, -): - # Get Stripe credentials - creds = await get_stripe_credentials(request, org_id, current_user, db_session) - stripe.api_key = creds.get('stripe_secret_key') - - # Get product details - statement = select(PaymentsProduct).where( - PaymentsProduct.id == product_id, - PaymentsProduct.org_id == org_id - ) - product = db_session.exec(statement).first() - - if not product: - raise HTTPException(status_code=404, detail="Product not found") - - - success_url = redirect_uri - cancel_url = redirect_uri - - # Get the default price for the product - stripe_product = stripe.Product.retrieve(product.provider_product_id) - line_items = [{ - "price": stripe_product.default_price, - "quantity": 1 - }] - - # Create or retrieve Stripe customer - try: - customers = stripe.Customer.list(email=current_user.email) - if customers.data: - customer = customers.data[0] - else: - customer = stripe.Customer.create( - email=current_user.email, - metadata={ - "user_id": str(current_user.id), - "org_id": str(org_id), - } - ) - - # Create initial payment user with pending status - payment_user = await create_payment_user( - request=request, - org_id=org_id, - user_id=current_user.id, - product_id=product_id, - status=PaymentStatusEnum.PENDING, - provider_data=customer, - current_user=InternalUser(), - db_session=db_session - ) - - if not payment_user: - raise HTTPException(status_code=400, detail="Error creating payment user") - - except stripe.StripeError as e: - # Clean up payment user if customer creation fails - if payment_user and payment_user.id: - await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session) - raise HTTPException(status_code=400, detail=f"Error creating/retrieving customer: {str(e)}") - - # Create checkout session with customer - try: - checkout_session_params = { - "success_url": success_url, - "cancel_url": cancel_url, - "mode": 'payment' if product.product_type == PaymentProductTypeEnum.ONE_TIME else 'subscription', - "line_items": line_items, - "customer": customer.id, - "metadata": { - "product_id": str(product.id), - "payment_user_id": str(payment_user.id) - } - } - - # Add payment_intent_data only for one-time payments - if product.product_type == PaymentProductTypeEnum.ONE_TIME: - checkout_session_params["payment_intent_data"] = { - "metadata": { - "product_id": str(product.id), - "payment_user_id": str(payment_user.id) - } - } - # Add subscription_data for subscription payments - else: - checkout_session_params["subscription_data"] = { - "metadata": { - "product_id": str(product.id), - "payment_user_id": str(payment_user.id) - } - } - - checkout_session = stripe.checkout.Session.create(**checkout_session_params) - - return { - "checkout_url": checkout_session.url, - "session_id": checkout_session.id - } - - except stripe.StripeError as e: - # Clean up payment user if checkout session creation fails - if payment_user and payment_user.id: - await delete_payment_user(request, org_id, payment_user.id, InternalUser(), db_session) - logging.error(f"Error creating checkout session: {str(e)}") - raise HTTPException(status_code=400, detail=str(e)) - - - - - diff --git a/apps/api/src/services/payments/payments_webhook.py b/apps/api/src/services/payments/webhooks/payments_connected_webhook.py similarity index 98% rename from apps/api/src/services/payments/payments_webhook.py rename to apps/api/src/services/payments/webhooks/payments_connected_webhook.py index 0f79f265..e1e849e2 100644 --- a/apps/api/src/services/payments/payments_webhook.py +++ b/apps/api/src/services/payments/webhooks/payments_connected_webhook.py @@ -7,7 +7,7 @@ from src.db.payments.payments_users import PaymentStatusEnum from src.db.payments.payments_products import PaymentsProduct from src.db.users import InternalUser, User from src.services.payments.payments_users import update_payment_user_status -from src.services.payments.stripe import get_stripe_credentials +from src.services.payments.payments_stripe import get_stripe_credentials logger = logging.getLogger(__name__) diff --git a/apps/web/components/Dashboard/Payments/PaymentsConfigurationPage.tsx b/apps/web/components/Dashboard/Payments/PaymentsConfigurationPage.tsx index 25561ef2..acd09f0e 100644 --- a/apps/web/components/Dashboard/Payments/PaymentsConfigurationPage.tsx +++ b/apps/web/components/Dashboard/Payments/PaymentsConfigurationPage.tsx @@ -3,18 +3,21 @@ import React, { useState, useEffect } from 'react'; import { useOrg } from '@components/Contexts/OrgContext'; import { SiStripe } from '@icons-pack/react-simple-icons' import { useLHSession } from '@components/Contexts/LHSessionContext'; -import { getPaymentConfigs, createPaymentConfig, updatePaymentConfig, deletePaymentConfig } from '@services/payments/payments'; +import { getPaymentConfigs, initializePaymentConfig, updatePaymentConfig, deletePaymentConfig, updateStripeAccountID, getStripeOnboardingLink } from '@services/payments/payments'; import FormLayout, { ButtonBlack, Input, Textarea, FormField, FormLabelAndMessage, Flex } from '@components/StyledElements/Form/Form'; -import { Check, Edit, Trash2 } from 'lucide-react'; +import { AlertTriangle, BarChart2, Check, Coins, CreditCard, Edit, ExternalLink, Info, Loader2, RefreshCcw, Trash2 } from 'lucide-react'; import toast from 'react-hot-toast'; import useSWR, { mutate } from 'swr'; import Modal from '@components/StyledElements/Modal/Modal'; import ConfirmationModal from '@components/StyledElements/ConfirmationModal/ConfirmationModal'; import { Button } from '@components/ui/button'; +import { Alert, AlertDescription, AlertTitle } from '@components/ui/alert'; +import { useRouter } from 'next/navigation'; const PaymentsConfigurationPage: React.FC = () => { const org = useOrg() as any; const session = useLHSession() as any; + const router = useRouter(); const access_token = session?.data?.tokens?.access_token; const { data: paymentConfigs, error, isLoading } = useSWR( () => (org && access_token ? [`/payments/${org.id}/config`, access_token] : null), @@ -23,16 +26,21 @@ const PaymentsConfigurationPage: React.FC = () => { const stripeConfig = paymentConfigs?.find((config: any) => config.provider === 'stripe'); const [isModalOpen, setIsModalOpen] = useState(false); + const [isOnboarding, setIsOnboarding] = useState(false); + const [isOnboardingLoading, setIsOnboardingLoading] = useState(false); const enableStripe = async () => { try { + setIsOnboarding(true); const newConfig = { provider: 'stripe', enabled: true }; - const config = await createPaymentConfig(org.id, newConfig, access_token); + const config = await initializePaymentConfig(org.id, newConfig, 'stripe', access_token); toast.success('Stripe enabled successfully'); mutate([`/payments/${org.id}/config`, access_token]); } catch (error) { console.error('Error enabling Stripe:', error); toast.error('Failed to enable Stripe'); + } finally { + setIsOnboarding(false); } }; @@ -51,6 +59,19 @@ const PaymentsConfigurationPage: React.FC = () => { } }; + const handleStripeOnboarding = async () => { + try { + setIsOnboardingLoading(true); + const { connect_url } = await getStripeOnboardingLink(org.id, access_token, window.location.href); + router.push(connect_url); + } catch (error) { + console.error('Error getting onboarding link:', error); + toast.error('Failed to start Stripe onboarding'); + } finally { + setIsOnboardingLoading(false); + } + }; + if (isLoading) { return