From 892db08d220bc107a5e908a4f90fc89adbccc366 Mon Sep 17 00:00:00 2001 From: Your Name <119736744+aborayan2022@users.noreply.github.com> Date: Tue, 31 Mar 2026 22:03:19 +0200 Subject: [PATCH] fix: Update stripe_service.py to use database abstraction layer --- backend/app/services/stripe_service.py | 154 ++++++++++++++----------- 1 file changed, 84 insertions(+), 70 deletions(-) diff --git a/backend/app/services/stripe_service.py b/backend/app/services/stripe_service.py index 4a95176..f2ebc31 100644 --- a/backend/app/services/stripe_service.py +++ b/backend/app/services/stripe_service.py @@ -1,10 +1,12 @@ """Stripe payment service — checkout sessions, webhooks, and subscription management.""" + import logging import stripe from flask import current_app -from app.services.account_service import update_user_plan, _connect, _utc_now +from app.services.account_service import update_user_plan, _utc_now +from app.utils.database import db_connection, execute_query, is_postgres, row_to_dict from app.utils.config_placeholders import normalize_optional_config logger = logging.getLogger(__name__) @@ -36,7 +38,10 @@ def get_stripe_price_id(billing: str = "monthly") -> str: def is_stripe_configured() -> bool: """Return True when billing has a usable secret key and at least one price id.""" - return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly"))) + return bool( + get_stripe_secret_key() + and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")) + ) def _init_stripe(): @@ -46,23 +51,35 @@ def _init_stripe(): def _ensure_stripe_columns(): """Add stripe_customer_id and stripe_subscription_id columns if missing.""" - conn = _connect() - try: - # Check that users table exists before altering it - table_exists = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='users'" - ).fetchone() - if table_exists is None: - return + with db_connection() as conn: + if is_postgres(): + cursor = conn.cursor() + cursor.execute( + "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')" + ) + if not cursor.fetchone()[0]: + return + cursor.execute( + "SELECT column_name FROM information_schema.columns WHERE table_name = 'users'" + ) + cols = [row[0] for row in cursor.fetchall()] + else: + table_exists = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='users'" + ).fetchone() + if table_exists is None: + return + cols = [ + row["name"] + for row in conn.execute("PRAGMA table_info(users)").fetchall() + ] - cols = [row["name"] for row in conn.execute("PRAGMA table_info(users)").fetchall()] if "stripe_customer_id" not in cols: - conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT") + sql = "ALTER TABLE users ADD COLUMN stripe_customer_id TEXT" + execute_query(conn, sql) if "stripe_subscription_id" not in cols: - conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT") - conn.commit() - finally: - conn.close() + sql = "ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT" + execute_query(conn, sql) def init_stripe_db(): @@ -73,41 +90,40 @@ def init_stripe_db(): def _get_or_create_customer(user_id: int) -> str: """Get existing Stripe customer or create one.""" _init_stripe() - conn = _connect() - try: - row = conn.execute( - "SELECT email, stripe_customer_id FROM users WHERE id = ?", - (user_id,), - ).fetchone() - finally: - conn.close() + with db_connection() as conn: + sql = ( + "SELECT email, stripe_customer_id FROM users WHERE id = %s" + if is_postgres() + else "SELECT email, stripe_customer_id FROM users WHERE id = ?" + ) + cursor = execute_query(conn, sql, (user_id,)) + row = row_to_dict(cursor.fetchone()) if row is None: raise ValueError("User not found.") - if row["stripe_customer_id"]: + if row.get("stripe_customer_id"): return row["stripe_customer_id"] - # Create new Stripe customer customer = stripe.Customer.create( email=row["email"], metadata={"user_id": str(user_id)}, ) - conn = _connect() - try: - conn.execute( - "UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?", - (customer.id, _utc_now(), user_id), + with db_connection() as conn: + sql = ( + "UPDATE users SET stripe_customer_id = %s, updated_at = %s WHERE id = %s" + if is_postgres() + else "UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?" ) - conn.commit() - finally: - conn.close() + execute_query(conn, sql, (customer.id, _utc_now(), user_id)) return customer.id -def create_checkout_session(user_id: int, price_id: str, success_url: str, cancel_url: str) -> str: +def create_checkout_session( + user_id: int, price_id: str, success_url: str, cancel_url: str +) -> str: """Create a Stripe Checkout Session and return the URL.""" _init_stripe() customer_id = _get_or_create_customer(user_id) @@ -172,15 +188,15 @@ def handle_webhook_event(payload: bytes, sig_header: str) -> dict: def _find_user_by_customer_id(customer_id: str) -> dict | None: """Find user by Stripe customer ID.""" - conn = _connect() - try: - row = conn.execute( - "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?", - (customer_id,), - ).fetchone() - finally: - conn.close() - return dict(row) if row else None + with db_connection() as conn: + sql = ( + "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = %s" + if is_postgres() + else "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?" + ) + cursor = execute_query(conn, sql, (customer_id,)) + row = row_to_dict(cursor.fetchone()) + return row def _handle_checkout_completed(session: dict): @@ -190,29 +206,27 @@ def _handle_checkout_completed(session: dict): user_id = session.get("metadata", {}).get("user_id") if user_id: - conn = _connect() - try: - conn.execute( - "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?", - (subscription_id, _utc_now(), int(user_id)), + with db_connection() as conn: + sql = ( + "UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s" + if is_postgres() + else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?" ) - conn.commit() - finally: - conn.close() + execute_query(conn, sql, (subscription_id, _utc_now(), int(user_id))) logger.info("User %s upgraded to Pro via checkout.", user_id) elif customer_id: user = _find_user_by_customer_id(customer_id) if user: - conn = _connect() - try: - conn.execute( - "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?", - (subscription_id, _utc_now(), user["id"]), + with db_connection() as conn: + sql = ( + "UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s" + if is_postgres() + else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?" ) - conn.commit() - finally: - conn.close() - logger.info("User %s upgraded to Pro via checkout (customer match).", user["id"]) + execute_query(conn, sql, (subscription_id, _utc_now(), user["id"])) + logger.info( + "User %s upgraded to Pro via checkout (customer match).", user["id"] + ) def _handle_subscription_updated(subscription: dict): @@ -239,15 +253,13 @@ def _handle_subscription_deleted(subscription: dict): user = _find_user_by_customer_id(customer_id) if user: update_user_plan(user["id"], "free") - conn = _connect() - try: - conn.execute( - "UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?", - (_utc_now(), user["id"]), + with db_connection() as conn: + sql = ( + "UPDATE users SET stripe_subscription_id = NULL, updated_at = %s WHERE id = %s" + if is_postgres() + else "UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?" ) - conn.commit() - finally: - conn.close() + execute_query(conn, sql, (_utc_now(), user["id"])) logger.info("User %s subscription deleted — downgraded to Free.", user["id"]) @@ -256,4 +268,6 @@ def _handle_payment_failed(invoice: dict): customer_id = invoice.get("customer") user = _find_user_by_customer_id(customer_id) if user: - logger.warning("Payment failed for user %s (customer %s).", user["id"], customer_id) + logger.warning( + "Payment failed for user %s (customer %s).", user["id"], customer_id + )