fix: Update stripe_service.py to use database abstraction layer
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import stripe
|
import stripe
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
from app.services.account_service import update_user_plan, _connect, _utc_now
|
from app.services.account_service import update_user_plan, _utc_now
|
||||||
|
from app.utils.database import db_connection, execute_query, is_postgres, row_to_dict
|
||||||
from app.utils.config_placeholders import normalize_optional_config
|
from app.utils.config_placeholders import normalize_optional_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -36,7 +38,10 @@ def get_stripe_price_id(billing: str = "monthly") -> str:
|
|||||||
|
|
||||||
def is_stripe_configured() -> bool:
|
def is_stripe_configured() -> bool:
|
||||||
"""Return True when billing has a usable secret key and at least one price id."""
|
"""Return True when billing has a usable secret key and at least one price id."""
|
||||||
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")))
|
return bool(
|
||||||
|
get_stripe_secret_key()
|
||||||
|
and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _init_stripe():
|
def _init_stripe():
|
||||||
@@ -46,23 +51,35 @@ def _init_stripe():
|
|||||||
|
|
||||||
def _ensure_stripe_columns():
|
def _ensure_stripe_columns():
|
||||||
"""Add stripe_customer_id and stripe_subscription_id columns if missing."""
|
"""Add stripe_customer_id and stripe_subscription_id columns if missing."""
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
if is_postgres():
|
||||||
# Check that users table exists before altering it
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')"
|
||||||
|
)
|
||||||
|
if not cursor.fetchone()[0]:
|
||||||
|
return
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT column_name FROM information_schema.columns WHERE table_name = 'users'"
|
||||||
|
)
|
||||||
|
cols = [row[0] for row in cursor.fetchall()]
|
||||||
|
else:
|
||||||
table_exists = conn.execute(
|
table_exists = conn.execute(
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
|
||||||
).fetchone()
|
).fetchone()
|
||||||
if table_exists is None:
|
if table_exists is None:
|
||||||
return
|
return
|
||||||
|
cols = [
|
||||||
|
row["name"]
|
||||||
|
for row in conn.execute("PRAGMA table_info(users)").fetchall()
|
||||||
|
]
|
||||||
|
|
||||||
cols = [row["name"] for row in conn.execute("PRAGMA table_info(users)").fetchall()]
|
|
||||||
if "stripe_customer_id" not in cols:
|
if "stripe_customer_id" not in cols:
|
||||||
conn.execute("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT")
|
sql = "ALTER TABLE users ADD COLUMN stripe_customer_id TEXT"
|
||||||
|
execute_query(conn, sql)
|
||||||
if "stripe_subscription_id" not in cols:
|
if "stripe_subscription_id" not in cols:
|
||||||
conn.execute("ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT")
|
sql = "ALTER TABLE users ADD COLUMN stripe_subscription_id TEXT"
|
||||||
conn.commit()
|
execute_query(conn, sql)
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def init_stripe_db():
|
def init_stripe_db():
|
||||||
@@ -73,41 +90,40 @@ def init_stripe_db():
|
|||||||
def _get_or_create_customer(user_id: int) -> str:
|
def _get_or_create_customer(user_id: int) -> str:
|
||||||
"""Get existing Stripe customer or create one."""
|
"""Get existing Stripe customer or create one."""
|
||||||
_init_stripe()
|
_init_stripe()
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
sql = (
|
||||||
row = conn.execute(
|
"SELECT email, stripe_customer_id FROM users WHERE id = %s"
|
||||||
"SELECT email, stripe_customer_id FROM users WHERE id = ?",
|
if is_postgres()
|
||||||
(user_id,),
|
else "SELECT email, stripe_customer_id FROM users WHERE id = ?"
|
||||||
).fetchone()
|
)
|
||||||
finally:
|
cursor = execute_query(conn, sql, (user_id,))
|
||||||
conn.close()
|
row = row_to_dict(cursor.fetchone())
|
||||||
|
|
||||||
if row is None:
|
if row is None:
|
||||||
raise ValueError("User not found.")
|
raise ValueError("User not found.")
|
||||||
|
|
||||||
if row["stripe_customer_id"]:
|
if row.get("stripe_customer_id"):
|
||||||
return row["stripe_customer_id"]
|
return row["stripe_customer_id"]
|
||||||
|
|
||||||
# Create new Stripe customer
|
|
||||||
customer = stripe.Customer.create(
|
customer = stripe.Customer.create(
|
||||||
email=row["email"],
|
email=row["email"],
|
||||||
metadata={"user_id": str(user_id)},
|
metadata={"user_id": str(user_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
sql = (
|
||||||
conn.execute(
|
"UPDATE users SET stripe_customer_id = %s, updated_at = %s WHERE id = %s"
|
||||||
"UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?",
|
if is_postgres()
|
||||||
(customer.id, _utc_now(), user_id),
|
else "UPDATE users SET stripe_customer_id = ?, updated_at = ? WHERE id = ?"
|
||||||
)
|
)
|
||||||
conn.commit()
|
execute_query(conn, sql, (customer.id, _utc_now(), user_id))
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return customer.id
|
return customer.id
|
||||||
|
|
||||||
|
|
||||||
def create_checkout_session(user_id: int, price_id: str, success_url: str, cancel_url: str) -> str:
|
def create_checkout_session(
|
||||||
|
user_id: int, price_id: str, success_url: str, cancel_url: str
|
||||||
|
) -> str:
|
||||||
"""Create a Stripe Checkout Session and return the URL."""
|
"""Create a Stripe Checkout Session and return the URL."""
|
||||||
_init_stripe()
|
_init_stripe()
|
||||||
customer_id = _get_or_create_customer(user_id)
|
customer_id = _get_or_create_customer(user_id)
|
||||||
@@ -172,15 +188,15 @@ def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
|
|||||||
|
|
||||||
def _find_user_by_customer_id(customer_id: str) -> dict | None:
|
def _find_user_by_customer_id(customer_id: str) -> dict | None:
|
||||||
"""Find user by Stripe customer ID."""
|
"""Find user by Stripe customer ID."""
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
sql = (
|
||||||
row = conn.execute(
|
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = %s"
|
||||||
"SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?",
|
if is_postgres()
|
||||||
(customer_id,),
|
else "SELECT id, email, plan, created_at FROM users WHERE stripe_customer_id = ?"
|
||||||
).fetchone()
|
)
|
||||||
finally:
|
cursor = execute_query(conn, sql, (customer_id,))
|
||||||
conn.close()
|
row = row_to_dict(cursor.fetchone())
|
||||||
return dict(row) if row else None
|
return row
|
||||||
|
|
||||||
|
|
||||||
def _handle_checkout_completed(session: dict):
|
def _handle_checkout_completed(session: dict):
|
||||||
@@ -190,29 +206,27 @@ def _handle_checkout_completed(session: dict):
|
|||||||
user_id = session.get("metadata", {}).get("user_id")
|
user_id = session.get("metadata", {}).get("user_id")
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
sql = (
|
||||||
conn.execute(
|
"UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s"
|
||||||
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
|
if is_postgres()
|
||||||
(subscription_id, _utc_now(), int(user_id)),
|
else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?"
|
||||||
)
|
)
|
||||||
conn.commit()
|
execute_query(conn, sql, (subscription_id, _utc_now(), int(user_id)))
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
logger.info("User %s upgraded to Pro via checkout.", user_id)
|
logger.info("User %s upgraded to Pro via checkout.", user_id)
|
||||||
elif customer_id:
|
elif customer_id:
|
||||||
user = _find_user_by_customer_id(customer_id)
|
user = _find_user_by_customer_id(customer_id)
|
||||||
if user:
|
if user:
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
sql = (
|
||||||
conn.execute(
|
"UPDATE users SET plan = 'pro', stripe_subscription_id = %s, updated_at = %s WHERE id = %s"
|
||||||
"UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?",
|
if is_postgres()
|
||||||
(subscription_id, _utc_now(), user["id"]),
|
else "UPDATE users SET plan = 'pro', stripe_subscription_id = ?, updated_at = ? WHERE id = ?"
|
||||||
|
)
|
||||||
|
execute_query(conn, sql, (subscription_id, _utc_now(), user["id"]))
|
||||||
|
logger.info(
|
||||||
|
"User %s upgraded to Pro via checkout (customer match).", user["id"]
|
||||||
)
|
)
|
||||||
conn.commit()
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
logger.info("User %s upgraded to Pro via checkout (customer match).", user["id"])
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_subscription_updated(subscription: dict):
|
def _handle_subscription_updated(subscription: dict):
|
||||||
@@ -239,15 +253,13 @@ def _handle_subscription_deleted(subscription: dict):
|
|||||||
user = _find_user_by_customer_id(customer_id)
|
user = _find_user_by_customer_id(customer_id)
|
||||||
if user:
|
if user:
|
||||||
update_user_plan(user["id"], "free")
|
update_user_plan(user["id"], "free")
|
||||||
conn = _connect()
|
with db_connection() as conn:
|
||||||
try:
|
sql = (
|
||||||
conn.execute(
|
"UPDATE users SET stripe_subscription_id = NULL, updated_at = %s WHERE id = %s"
|
||||||
"UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?",
|
if is_postgres()
|
||||||
(_utc_now(), user["id"]),
|
else "UPDATE users SET stripe_subscription_id = NULL, updated_at = ? WHERE id = ?"
|
||||||
)
|
)
|
||||||
conn.commit()
|
execute_query(conn, sql, (_utc_now(), user["id"]))
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
logger.info("User %s subscription deleted — downgraded to Free.", user["id"])
|
logger.info("User %s subscription deleted — downgraded to Free.", user["id"])
|
||||||
|
|
||||||
|
|
||||||
@@ -256,4 +268,6 @@ def _handle_payment_failed(invoice: dict):
|
|||||||
customer_id = invoice.get("customer")
|
customer_id = invoice.get("customer")
|
||||||
user = _find_user_by_customer_id(customer_id)
|
user = _find_user_by_customer_id(customer_id)
|
||||||
if user:
|
if user:
|
||||||
logger.warning("Payment failed for user %s (customer %s).", user["id"], customer_id)
|
logger.warning(
|
||||||
|
"Payment failed for user %s (customer %s).", user["id"], customer_id
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user