260 lines
8.6 KiB
Python
260 lines
8.6 KiB
Python
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
|
import logging
|
|
|
|
import stripe
|
|
from flask import current_app
|
|
|
|
from app.services.account_service import update_user_plan, _connect, _utc_now
|
|
from app.utils.config_placeholders import normalize_optional_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_stripe_secret_key() -> str:
|
|
"""Return the configured Stripe secret key, ignoring copied sample values."""
|
|
return normalize_optional_config(
|
|
current_app.config.get("STRIPE_SECRET_KEY", ""),
|
|
("replace-with",),
|
|
)
|
|
|
|
|
|
def get_stripe_price_id(billing: str = "monthly") -> str:
|
|
"""Return the configured Stripe price id for the requested billing cycle."""
|
|
monthly = normalize_optional_config(
|
|
current_app.config.get("STRIPE_PRICE_ID_PRO_MONTHLY", ""),
|
|
("replace-with",),
|
|
)
|
|
yearly = normalize_optional_config(
|
|
current_app.config.get("STRIPE_PRICE_ID_PRO_YEARLY", ""),
|
|
("replace-with",),
|
|
)
|
|
|
|
if billing == "yearly" and yearly:
|
|
return yearly
|
|
return monthly
|
|
|
|
|
|
def is_stripe_configured() -> bool:
|
|
"""Return True when billing has a usable secret key and at least one price id."""
|
|
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")))
|
|
|
|
|
|
def _init_stripe():
|
|
"""Configure stripe with the app's secret key."""
|
|
stripe.api_key = get_stripe_secret_key()
|
|
|
|
|
|
def _ensure_stripe_columns():
|
|
"""Add stripe_customer_id and stripe_subscription_id columns if missing."""
|
|
conn = _connect()
|
|
try:
|
|
# Check that users table exists before altering it
|
|
table_exists = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
|
).fetchone()
|
|
if table_exists is None:
|
|
return
|
|
|
|
cols = [row["name"] for row in conn.execute("PRAGMA table_info(users)").fetchall()]
|
|
if "stripe_customer_id" not in cols:
|
|
conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT")
|
|
if "stripe_subscription_id" not in cols:
|
|
conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT")
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def init_stripe_db():
|
|
"""Initialize stripe-related DB columns."""
|
|
_ensure_stripe_columns()
|
|
|
|
|
|
def _get_or_create_customer(user_id: int) -> str:
|
|
"""Get existing Stripe customer or create one."""
|
|
_init_stripe()
|
|
conn = _connect()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT email, stripe_customer_id FROM users WHERE id = ?",
|
|
(user_id,),
|
|
).fetchone()
|
|
finally:
|
|
conn.close()
|
|
|
|
if row is None:
|
|
raise ValueError("User not found.")
|
|
|
|
if row["stripe_customer_id"]:
|
|
return row["stripe_customer_id"]
|
|
|
|
# Create new Stripe customer
|
|
customer = stripe.Customer.create(
|
|
email=row["email"],
|
|
metadata={"user_id": str(user_id)},
|
|
)
|
|
|
|
conn = _connect()
|
|
try:
|
|
conn.execute(
|
|
"UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?",
|
|
(customer.id, _utc_now(), user_id),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
return customer.id
|
|
|
|
|
|
def create_checkout_session(user_id: int, price_id: str, success_url: str, cancel_url: str) -> str:
|
|
"""Create a Stripe Checkout Session and return the URL."""
|
|
_init_stripe()
|
|
customer_id = _get_or_create_customer(user_id)
|
|
|
|
session = stripe.checkout.Session.create(
|
|
customer=customer_id,
|
|
payment_method_types=["card"],
|
|
line_items=[{"price": price_id, "quantity": 1}],
|
|
mode="subscription",
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
metadata={"user_id": str(user_id)},
|
|
)
|
|
return session.url
|
|
|
|
|
|
def create_portal_session(user_id: int, return_url: str) -> str:
|
|
"""Create a Stripe Customer Portal session for managing subscriptions."""
|
|
_init_stripe()
|
|
customer_id = _get_or_create_customer(user_id)
|
|
|
|
session = stripe.billing_portal.Session.create(
|
|
customer=customer_id,
|
|
return_url=return_url,
|
|
)
|
|
return session.url
|
|
|
|
|
|
def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
|
|
"""Process a Stripe webhook event. Returns a status dict."""
|
|
webhook_secret = normalize_optional_config(
|
|
current_app.config.get("STRIPE_WEBHOOK_SECRET", ""),
|
|
("replace-with",),
|
|
)
|
|
if not webhook_secret:
|
|
logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.")
|
|
return {"status": "ignored", "reason": "no webhook secret"}
|
|
|
|
try:
|
|
event = stripe.Webhook.construct_event(payload, sig_header, webhook_secret)
|
|
except stripe.SignatureVerificationError:
|
|
logger.warning("Stripe webhook signature verification failed.")
|
|
return {"status": "error", "reason": "signature_failed"}
|
|
except ValueError:
|
|
logger.warning("Invalid Stripe webhook payload.")
|
|
return {"status": "error", "reason": "invalid_payload"}
|
|
|
|
event_type = event["type"]
|
|
data_object = event["data"]["object"]
|
|
|
|
if event_type == "checkout.session.completed":
|
|
_handle_checkout_completed(data_object)
|
|
elif event_type == "customer.subscription.updated":
|
|
_handle_subscription_updated(data_object)
|
|
elif event_type == "customer.subscription.deleted":
|
|
_handle_subscription_deleted(data_object)
|
|
elif event_type == "invoice.payment_failed":
|
|
_handle_payment_failed(data_object)
|
|
|
|
return {"status": "ok", "event_type": event_type}
|
|
|
|
|
|
def _find_user_by_customer_id(customer_id: str) -> dict | None:
|
|
"""Find user by Stripe customer ID."""
|
|
conn = _connect()
|
|
try:
|
|
row = conn.execute(
|
|
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?",
|
|
(customer_id,),
|
|
).fetchone()
|
|
finally:
|
|
conn.close()
|
|
return dict(row) if row else None
|
|
|
|
|
|
def _handle_checkout_completed(session: dict):
|
|
"""Handle successful checkout — activate Pro plan."""
|
|
customer_id = session.get("customer")
|
|
subscription_id = session.get("subscription")
|
|
user_id = session.get("metadata", {}).get("user_id")
|
|
|
|
if user_id:
|
|
conn = _connect()
|
|
try:
|
|
conn.execute(
|
|
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
|
|
(subscription_id, _utc_now(), int(user_id)),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
logger.info("User %s upgraded to Pro via checkout.", user_id)
|
|
elif customer_id:
|
|
user = _find_user_by_customer_id(customer_id)
|
|
if user:
|
|
conn = _connect()
|
|
try:
|
|
conn.execute(
|
|
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
|
|
(subscription_id, _utc_now(), user["id"]),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
logger.info("User %s upgraded to Pro via checkout (customer match).", user["id"])
|
|
|
|
|
|
def _handle_subscription_updated(subscription: dict):
|
|
"""Handle subscription changes (upgrade/downgrade)."""
|
|
customer_id = subscription.get("customer")
|
|
status = subscription.get("status")
|
|
user = _find_user_by_customer_id(customer_id)
|
|
if not user:
|
|
return
|
|
|
|
if status in ("active", "trialing"):
|
|
update_user_plan(user["id"], "pro")
|
|
logger.info("User %s subscription active — Pro plan.", user["id"])
|
|
elif status in ("past_due", "unpaid"):
|
|
logger.warning("User %s subscription %s.", user["id"], status)
|
|
elif status in ("canceled", "incomplete_expired"):
|
|
update_user_plan(user["id"], "free")
|
|
logger.info("User %s subscription ended — Free plan.", user["id"])
|
|
|
|
|
|
def _handle_subscription_deleted(subscription: dict):
|
|
"""Handle subscription cancellation."""
|
|
customer_id = subscription.get("customer")
|
|
user = _find_user_by_customer_id(customer_id)
|
|
if user:
|
|
update_user_plan(user["id"], "free")
|
|
conn = _connect()
|
|
try:
|
|
conn.execute(
|
|
"UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?",
|
|
(_utc_now(), user["id"]),
|
|
)
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
logger.info("User %s subscription deleted — downgraded to Free.", user["id"])
|
|
|
|
|
|
def _handle_payment_failed(invoice: dict):
|
|
"""Log payment failures."""
|
|
customer_id = invoice.get("customer")
|
|
user = _find_user_by_customer_id(customer_id)
|
|
if user:
|
|
logger.warning("Payment failed for user %s (customer %s).", user["id"], customer_id)
|