fix: Update stripe_service.py to use database abstraction layer

This commit is contained in:
Your Name
2026-03-31 22:03:19 +02:00
parent 030418f6db
commit 892db08d22

View File

@@ -1,10 +1,12 @@
"""Stripe payment service — checkout sessions, webhooks, and subscription management.""" """Stripe payment service — checkout sessions, webhooks, and subscription management."""
import logging import logging
import stripe import stripe
from flask import current_app from flask import current_app
from app.services.account_service import update_user_plan, _connect, _utc_now from app.services.account_service import update_user_plan, _utc_now
from app.utils.database import db_connection, execute_query, is_postgres, row_to_dict
from app.utils.config_placeholders import normalize_optional_config from app.utils.config_placeholders import normalize_optional_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,7 +38,10 @@ def get_stripe_price_id(billing: str = "monthly") -> str:
def is_stripe_configured() -> bool: def is_stripe_configured() -> bool:
"""Return True when billing has a usable secret key and at least one price id.""" """Return True when billing has a usable secret key and at least one price id."""
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly"))) return bool(
get_stripe_secret_key()
and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly"))
)
def _init_stripe(): def _init_stripe():
@@ -46,23 +51,35 @@ def _init_stripe():
def _ensure_stripe_columns(): def _ensure_stripe_columns():
"""Add stripe_customer_id and stripe_subscription_id columns if missing.""" """Add stripe_customer_id and stripe_subscription_id columns if missing."""
conn = _connect() with db_connection() as conn:
try: if is_postgres():
# Check that users table exists before altering it cursor = conn.cursor()
cursor.execute(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')"
)
if not cursor.fetchone()[0]:
return
cursor.execute(
"SELECT column_name FROM information_schema.columns WHERE table_name = 'users'"
)
cols = [row[0] for row in cursor.fetchall()]
else:
table_exists = conn.execute( table_exists = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'" "SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
).fetchone() ).fetchone()
if table_exists is None: if table_exists is None:
return return
cols = [
row["name"]
for row in conn.execute("PRAGMA table_info(users)").fetchall()
]
cols = [row["name"] for row in conn.execute("PRAGMA table_info(users)").fetchall()]
if "stripe_customer_id" not in cols: if "stripe_customer_id" not in cols:
conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT") sql = "ALTER TABLE users ADD COLUMN stripe_customer_id TEXT"
execute_query(conn, sql)
if "stripe_subscription_id" not in cols: if "stripe_subscription_id" not in cols:
conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT") sql = "ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT"
conn.commit() execute_query(conn, sql)
finally:
conn.close()
def init_stripe_db(): def init_stripe_db():
@@ -73,41 +90,40 @@ def init_stripe_db():
def _get_or_create_customer(user_id: int) -> str: def _get_or_create_customer(user_id: int) -> str:
"""Get existing Stripe customer or create one.""" """Get existing Stripe customer or create one."""
_init_stripe() _init_stripe()
conn = _connect() with db_connection() as conn:
try: sql = (
row = conn.execute( "SELECT email, stripe_customer_id FROM users WHERE id = %s"
"SELECT email, stripe_customer_id FROM users WHERE id = ?", if is_postgres()
(user_id,), else "SELECT email, stripe_customer_id FROM users WHERE id = ?"
).fetchone() )
finally: cursor = execute_query(conn, sql, (user_id,))
conn.close() row = row_to_dict(cursor.fetchone())
if row is None: if row is None:
raise ValueError("User not found.") raise ValueError("User not found.")
if row["stripe_customer_id"]: if row.get("stripe_customer_id"):
return row["stripe_customer_id"] return row["stripe_customer_id"]
# Create new Stripe customer
customer = stripe.Customer.create( customer = stripe.Customer.create(
email=row["email"], email=row["email"],
metadata={"user_id": str(user_id)}, metadata={"user_id": str(user_id)},
) )
conn = _connect() with db_connection() as conn:
try: sql = (
conn.execute( "UPDATE users SET stripe_customer_id = %s, updated_at = %s WHERE id = %s"
"UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?", if is_postgres()
(customer.id, _utc_now(), user_id), else "UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?"
) )
conn.commit() execute_query(conn, sql, (customer.id, _utc_now(), user_id))
finally:
conn.close()
return customer.id return customer.id
def create_checkout_session(user_id: int, price_id: str, success_url: str, cancel_url: str) -> str: def create_checkout_session(
user_id: int, price_id: str, success_url: str, cancel_url: str
) -> str:
"""Create a Stripe Checkout Session and return the URL.""" """Create a Stripe Checkout Session and return the URL."""
_init_stripe() _init_stripe()
customer_id = _get_or_create_customer(user_id) customer_id = _get_or_create_customer(user_id)
@@ -172,15 +188,15 @@ def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
def _find_user_by_customer_id(customer_id: str) -> dict | None: def _find_user_by_customer_id(customer_id: str) -> dict | None:
"""Find user by Stripe customer ID.""" """Find user by Stripe customer ID."""
conn = _connect() with db_connection() as conn:
try: sql = (
row = conn.execute( "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = %s"
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?", if is_postgres()
(customer_id,), else "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?"
).fetchone() )
finally: cursor = execute_query(conn, sql, (customer_id,))
conn.close() row = row_to_dict(cursor.fetchone())
return dict(row) if row else None return row
def _handle_checkout_completed(session: dict): def _handle_checkout_completed(session: dict):
@@ -190,29 +206,27 @@ def _handle_checkout_completed(session: dict):
user_id = session.get("metadata", {}).get("user_id") user_id = session.get("metadata", {}).get("user_id")
if user_id: if user_id:
conn = _connect() with db_connection() as conn:
try: sql = (
conn.execute( "UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s"
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?", if is_postgres()
(subscription_id, _utc_now(), int(user_id)), else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?"
) )
conn.commit() execute_query(conn, sql, (subscription_id, _utc_now(), int(user_id)))
finally:
conn.close()
logger.info("User %s upgraded to Pro via checkout.", user_id) logger.info("User %s upgraded to Pro via checkout.", user_id)
elif customer_id: elif customer_id:
user = _find_user_by_customer_id(customer_id) user = _find_user_by_customer_id(customer_id)
if user: if user:
conn = _connect() with db_connection() as conn:
try: sql = (
conn.execute( "UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s"
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?", if is_postgres()
(subscription_id, _utc_now(), user["id"]), else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?"
)
execute_query(conn, sql, (subscription_id, _utc_now(), user["id"]))
logger.info(
"User %s upgraded to Pro via checkout (customer match).", user["id"]
) )
conn.commit()
finally:
conn.close()
logger.info("User %s upgraded to Pro via checkout (customer match).", user["id"])
def _handle_subscription_updated(subscription: dict): def _handle_subscription_updated(subscription: dict):
@@ -239,15 +253,13 @@ def _handle_subscription_deleted(subscription: dict):
user = _find_user_by_customer_id(customer_id) user = _find_user_by_customer_id(customer_id)
if user: if user:
update_user_plan(user["id"], "free") update_user_plan(user["id"], "free")
conn = _connect() with db_connection() as conn:
try: sql = (
conn.execute( "UPDATE users SET stripe_subscription_id = NULL, updated_at = %s WHERE id = %s"
"UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?", if is_postgres()
(_utc_now(), user["id"]), else "UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?"
) )
conn.commit() execute_query(conn, sql, (_utc_now(), user["id"]))
finally:
conn.close()
logger.info("User %s subscription deleted — downgraded to Free.", user["id"]) logger.info("User %s subscription deleted — downgraded to Free.", user["id"])
@@ -256,4 +268,6 @@ def _handle_payment_failed(invoice: dict):
customer_id = invoice.get("customer") customer_id = invoice.get("customer")
user = _find_user_by_customer_id(customer_id) user = _find_user_by_customer_id(customer_id)
if user: if user:
logger.warning("Payment failed for user %s (customer %s).", user["id"], customer_id) logger.warning(
"Payment failed for user %s (customer %s).", user["id"], customer_id
)