fix: Update stripe_service.py to use database abstraction layer
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
||||
|
||||
import logging
|
||||
|
||||
import stripe
|
||||
from flask import current_app
|
||||
|
||||
from app.services.account_service import update_user_plan, _connect, _utc_now
|
||||
from app.services.account_service import update_user_plan, _utc_now
|
||||
from app.utils.database import db_connection, execute_query, is_postgres, row_to_dict
|
||||
from app.utils.config_placeholders import normalize_optional_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,7 +38,10 @@ def get_stripe_price_id(billing: str = "monthly") -> str:
|
||||
|
||||
def is_stripe_configured() -> bool:
|
||||
"""Return True when billing has a usable secret key and at least one price id."""
|
||||
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")))
|
||||
return bool(
|
||||
get_stripe_secret_key()
|
||||
and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly"))
|
||||
)
|
||||
|
||||
|
||||
def _init_stripe():
|
||||
@@ -46,23 +51,35 @@ def _init_stripe():
|
||||
|
||||
def _ensure_stripe_columns():
|
||||
"""Add stripe_customer_id and stripe_subscription_id columns if missing."""
|
||||
conn = _connect()
|
||||
try:
|
||||
# Check that users table exists before altering it
|
||||
table_exists = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
).fetchone()
|
||||
if table_exists is None:
|
||||
return
|
||||
with db_connection() as conn:
|
||||
if is_postgres():
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')"
|
||||
)
|
||||
if not cursor.fetchone()[0]:
|
||||
return
|
||||
cursor.execute(
|
||||
"SELECT column_name FROM information_schema.columns WHERE table_name = 'users'"
|
||||
)
|
||||
cols = [row[0] for row in cursor.fetchall()]
|
||||
else:
|
||||
table_exists = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||
).fetchone()
|
||||
if table_exists is None:
|
||||
return
|
||||
cols = [
|
||||
row["name"]
|
||||
for row in conn.execute("PRAGMA table_info(users)").fetchall()
|
||||
]
|
||||
|
||||
cols = [row["name"] for row in conn.execute("PRAGMA table_info(users)").fetchall()]
|
||||
if "stripe_customer_id" not in cols:
|
||||
conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT")
|
||||
sql = "ALTER TABLE users ADD COLUMN stripe_customer_id TEXT"
|
||||
execute_query(conn, sql)
|
||||
if "stripe_subscription_id" not in cols:
|
||||
conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT")
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
sql = "ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT"
|
||||
execute_query(conn, sql)
|
||||
|
||||
|
||||
def init_stripe_db():
|
||||
@@ -73,41 +90,40 @@ def init_stripe_db():
|
||||
def _get_or_create_customer(user_id: int) -> str:
|
||||
"""Get existing Stripe customer or create one."""
|
||||
_init_stripe()
|
||||
conn = _connect()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT email, stripe_customer_id FROM users WHERE id = ?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
with db_connection() as conn:
|
||||
sql = (
|
||||
"SELECT email, stripe_customer_id FROM users WHERE id = %s"
|
||||
if is_postgres()
|
||||
else "SELECT email, stripe_customer_id FROM users WHERE id = ?"
|
||||
)
|
||||
cursor = execute_query(conn, sql, (user_id,))
|
||||
row = row_to_dict(cursor.fetchone())
|
||||
|
||||
if row is None:
|
||||
raise ValueError("User not found.")
|
||||
|
||||
if row["stripe_customer_id"]:
|
||||
if row.get("stripe_customer_id"):
|
||||
return row["stripe_customer_id"]
|
||||
|
||||
# Create new Stripe customer
|
||||
customer = stripe.Customer.create(
|
||||
email=row["email"],
|
||||
metadata={"user_id": str(user_id)},
|
||||
)
|
||||
|
||||
conn = _connect()
|
||||
try:
|
||||
conn.execute(
|
||||
"UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?",
|
||||
(customer.id, _utc_now(), user_id),
|
||||
with db_connection() as conn:
|
||||
sql = (
|
||||
"UPDATE users SET stripe_customer_id = %s, updated_at = %s WHERE id = %s"
|
||||
if is_postgres()
|
||||
else "UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?"
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
execute_query(conn, sql, (customer.id, _utc_now(), user_id))
|
||||
|
||||
return customer.id
|
||||
|
||||
|
||||
def create_checkout_session(user_id: int, price_id: str, success_url: str, cancel_url: str) -> str:
|
||||
def create_checkout_session(
|
||||
user_id: int, price_id: str, success_url: str, cancel_url: str
|
||||
) -> str:
|
||||
"""Create a Stripe Checkout Session and return the URL."""
|
||||
_init_stripe()
|
||||
customer_id = _get_or_create_customer(user_id)
|
||||
@@ -172,15 +188,15 @@ def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
|
||||
|
||||
def _find_user_by_customer_id(customer_id: str) -> dict | None:
|
||||
"""Find user by Stripe customer ID."""
|
||||
conn = _connect()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?",
|
||||
(customer_id,),
|
||||
).fetchone()
|
||||
finally:
|
||||
conn.close()
|
||||
return dict(row) if row else None
|
||||
with db_connection() as conn:
|
||||
sql = (
|
||||
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = %s"
|
||||
if is_postgres()
|
||||
else "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?"
|
||||
)
|
||||
cursor = execute_query(conn, sql, (customer_id,))
|
||||
row = row_to_dict(cursor.fetchone())
|
||||
return row
|
||||
|
||||
|
||||
def _handle_checkout_completed(session: dict):
|
||||
@@ -190,29 +206,27 @@ def _handle_checkout_completed(session: dict):
|
||||
user_id = session.get("metadata", {}).get("user_id")
|
||||
|
||||
if user_id:
|
||||
conn = _connect()
|
||||
try:
|
||||
conn.execute(
|
||||
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
|
||||
(subscription_id, _utc_now(), int(user_id)),
|
||||
with db_connection() as conn:
|
||||
sql = (
|
||||
"UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s"
|
||||
if is_postgres()
|
||||
else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?"
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
execute_query(conn, sql, (subscription_id, _utc_now(), int(user_id)))
|
||||
logger.info("User %s upgraded to Pro via checkout.", user_id)
|
||||
elif customer_id:
|
||||
user = _find_user_by_customer_id(customer_id)
|
||||
if user:
|
||||
conn = _connect()
|
||||
try:
|
||||
conn.execute(
|
||||
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
|
||||
(subscription_id, _utc_now(), user["id"]),
|
||||
with db_connection() as conn:
|
||||
sql = (
|
||||
"UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s"
|
||||
if is_postgres()
|
||||
else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?"
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
logger.info("User %s upgraded to Pro via checkout (customer match).", user["id"])
|
||||
execute_query(conn, sql, (subscription_id, _utc_now(), user["id"]))
|
||||
logger.info(
|
||||
"User %s upgraded to Pro via checkout (customer match).", user["id"]
|
||||
)
|
||||
|
||||
|
||||
def _handle_subscription_updated(subscription: dict):
|
||||
@@ -239,15 +253,13 @@ def _handle_subscription_deleted(subscription: dict):
|
||||
user = _find_user_by_customer_id(customer_id)
|
||||
if user:
|
||||
update_user_plan(user["id"], "free")
|
||||
conn = _connect()
|
||||
try:
|
||||
conn.execute(
|
||||
"UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?",
|
||||
(_utc_now(), user["id"]),
|
||||
with db_connection() as conn:
|
||||
sql = (
|
||||
"UPDATE users SET stripe_subscription_id = NULL, updated_at = %s WHERE id = %s"
|
||||
if is_postgres()
|
||||
else "UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?"
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
execute_query(conn, sql, (_utc_now(), user["id"]))
|
||||
logger.info("User %s subscription deleted — downgraded to Free.", user["id"])
|
||||
|
||||
|
||||
@@ -256,4 +268,6 @@ def _handle_payment_failed(invoice: dict):
|
||||
customer_id = invoice.get("customer")
|
||||
user = _find_user_by_customer_id(customer_id)
|
||||
if user:
|
||||
logger.warning("Payment failed for user %s (customer %s).", user["id"], customer_id)
|
||||
logger.warning(
|
||||
"Payment failed for user %s (customer %s).", user["id"], customer_id
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user