Files
SaaS-PDF/backend/app/services/stripe_service.py

228 lines
7.6 KiB
Python

"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
import logging
import os
import stripe
from flask import current_app
from app.services.account_service import update_user_plan, get_user_by_id, _connect, _utc_now
logger = logging.getLogger(__name__)
def _init_stripe():
"""Configure stripe with the app's secret key."""
stripe.api_key = current_app.config.get("STRIPE_SECRET_KEY", "")
def _ensure_stripe_columns():
"""Add stripe_customer_id and stripe_subscription_id columns if missing."""
conn = _connect()
try:
# Check that users table exists before altering it
table_exists = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
).fetchone()
if table_exists is None:
return
cols = [row["name"] for row in conn.execute("PRAGMA table_info(users)").fetchall()]
if "stripe_customer_id" not in cols:
conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT")
if "stripe_subscription_id" not in cols:
conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT")
conn.commit()
finally:
conn.close()
def init_stripe_db():
"""Initialize stripe-related DB columns."""
_ensure_stripe_columns()
def _get_or_create_customer(user_id: int) -> str:
"""Get existing Stripe customer or create one."""
_init_stripe()
conn = _connect()
try:
row = conn.execute(
"SELECT email, stripe_customer_id FROM users WHERE id = ?",
(user_id,),
).fetchone()
finally:
conn.close()
if row is None:
raise ValueError("User not found.")
if row["stripe_customer_id"]:
return row["stripe_customer_id"]
# Create new Stripe customer
customer = stripe.Customer.create(
email=row["email"],
metadata={"user_id": str(user_id)},
)
conn = _connect()
try:
conn.execute(
"UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?",
(customer.id, _utc_now(), user_id),
)
conn.commit()
finally:
conn.close()
return customer.id
def create_checkout_session(user_id: int, price_id: str, success_url: str, cancel_url: str) -> str:
"""Create a Stripe Checkout Session and return the URL."""
_init_stripe()
customer_id = _get_or_create_customer(user_id)
session = stripe.checkout.Session.create(
customer=customer_id,
payment_method_types=["card"],
line_items=[{"price": price_id, "quantity": 1}],
mode="subscription",
success_url=success_url,
cancel_url=cancel_url,
metadata={"user_id": str(user_id)},
)
return session.url
def create_portal_session(user_id: int, return_url: str) -> str:
"""Create a Stripe Customer Portal session for managing subscriptions."""
_init_stripe()
customer_id = _get_or_create_customer(user_id)
session = stripe.billing_portal.Session.create(
customer=customer_id,
return_url=return_url,
)
return session.url
def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
"""Process a Stripe webhook event. Returns a status dict."""
webhook_secret = current_app.config.get("STRIPE_WEBHOOK_SECRET", "")
if not webhook_secret:
logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.")
return {"status": "ignored", "reason": "no webhook secret"}
try:
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
except stripe.SignatureVerificationError:
logger.warning("Stripe webhook signature verification failed.")
return {"status": "error", "reason": "signature_failed"}
except ValueError:
logger.warning("Invalid Stripe webhook payload.")
return {"status": "error", "reason": "invalid_payload"}
event_type = event["type"]
data_object = event["data"]["object"]
if event_type == "checkout.session.completed":
_handle_checkout_completed(data_object)
elif event_type == "customer.subscription.updated":
_handle_subscription_updated(data_object)
elif event_type == "customer.subscription.deleted":
_handle_subscription_deleted(data_object)
elif event_type == "invoice.payment_failed":
_handle_payment_failed(data_object)
return {"status": "ok", "event_type": event_type}
def _find_user_by_customer_id(customer_id: str) -> dict | None:
"""Find user by Stripe customer ID."""
conn = _connect()
try:
row = conn.execute(
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?",
(customer_id,),
).fetchone()
finally:
conn.close()
return dict(row) if row else None
def _handle_checkout_completed(session: dict):
"""Handle successful checkout — activate Pro plan."""
customer_id = session.get("customer")
subscription_id = session.get("subscription")
user_id = session.get("metadata", {}).get("user_id")
if user_id:
conn = _connect()
try:
conn.execute(
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
(subscription_id, _utc_now(), int(user_id)),
)
conn.commit()
finally:
conn.close()
logger.info("User %s upgraded to Pro via checkout.", user_id)
elif customer_id:
user = _find_user_by_customer_id(customer_id)
if user:
conn = _connect()
try:
conn.execute(
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
(subscription_id, _utc_now(), user["id"]),
)
conn.commit()
finally:
conn.close()
logger.info("User %s upgraded to Pro via checkout (customer match).", user["id"])
def _handle_subscription_updated(subscription: dict):
"""Handle subscription changes (upgrade/downgrade)."""
customer_id = subscription.get("customer")
status = subscription.get("status")
user = _find_user_by_customer_id(customer_id)
if not user:
return
if status in ("active", "trialing"):
update_user_plan(user["id"], "pro")
logger.info("User %s subscription active — Pro plan.", user["id"])
elif status in ("past_due", "unpaid"):
logger.warning("User %s subscription %s.", user["id"], status)
elif status in ("canceled", "incomplete_expired"):
update_user_plan(user["id"], "free")
logger.info("User %s subscription ended — Free plan.", user["id"])
def _handle_subscription_deleted(subscription: dict):
"""Handle subscription cancellation."""
customer_id = subscription.get("customer")
user = _find_user_by_customer_id(customer_id)
if user:
update_user_plan(user["id"], "free")
conn = _connect()
try:
conn.execute(
"UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?",
(_utc_now(), user["id"]),
)
conn.commit()
finally:
conn.close()
logger.info("User %s subscription deleted — downgraded to Free.", user["id"])
def _handle_payment_failed(invoice: dict):
"""Log payment failures."""
customer_id = invoice.get("customer")
user = _find_user_by_customer_id(customer_id)
if user:
logger.warning("Payment failed for user %s (customer %s).", user["id"], customer_id)