Refactor configuration handling and improve error management across services; normalize placeholder values for SMTP and Stripe configurations; enhance local storage fallback logic in StorageService; add tests for new behaviors and edge cases.
This commit is contained in:
@@ -3,7 +3,10 @@ import json
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from app.services.openrouter_config_service import get_openrouter_settings
|
||||
from app.services.openrouter_config_service import (
|
||||
extract_openrouter_text,
|
||||
get_openrouter_settings,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,12 +75,7 @@ def chat_about_flowchart(message: str, flow_data: dict | None = None) -> dict:
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
reply = (
|
||||
data.get("choices", [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
.strip()
|
||||
)
|
||||
reply = extract_openrouter_text(data)
|
||||
|
||||
if not reply:
|
||||
reply = "I couldn't generate a response. Please try again."
|
||||
|
||||
@@ -63,7 +63,10 @@ def save_message(name: str, email: str, category: str, subject: str, message: st
|
||||
conn.close()
|
||||
|
||||
# Send notification email to admin
|
||||
admin_email = current_app.config.get("SMTP_FROM", "noreply@dociva.io")
|
||||
admin_emails = tuple(current_app.config.get("INTERNAL_ADMIN_EMAILS", ()))
|
||||
admin_email = admin_emails[0] if admin_emails else current_app.config.get(
|
||||
"SMTP_FROM", "noreply@dociva.io"
|
||||
)
|
||||
try:
|
||||
send_email(
|
||||
to=admin_email,
|
||||
|
||||
@@ -6,16 +6,27 @@ from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from app.utils.config_placeholders import normalize_optional_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_smtp_config() -> dict:
|
||||
"""Read SMTP settings from Flask config."""
|
||||
return {
|
||||
"host": current_app.config.get("SMTP_HOST", ""),
|
||||
"host": normalize_optional_config(
|
||||
current_app.config.get("SMTP_HOST", ""),
|
||||
("your-provider", "replace-with"),
|
||||
),
|
||||
"port": current_app.config.get("SMTP_PORT", 587),
|
||||
"user": current_app.config.get("SMTP_USER", ""),
|
||||
"password": current_app.config.get("SMTP_PASSWORD", ""),
|
||||
"user": normalize_optional_config(
|
||||
current_app.config.get("SMTP_USER", ""),
|
||||
("replace-with",),
|
||||
),
|
||||
"password": normalize_optional_config(
|
||||
current_app.config.get("SMTP_PASSWORD", ""),
|
||||
("replace-with",),
|
||||
),
|
||||
"from_addr": current_app.config.get("SMTP_FROM", "noreply@dociva.io"),
|
||||
"use_tls": current_app.config.get("SMTP_USE_TLS", True),
|
||||
}
|
||||
|
||||
@@ -18,6 +18,42 @@ class OpenRouterSettings:
|
||||
base_url: str
|
||||
|
||||
|
||||
def extract_openrouter_text(payload: dict) -> str:
|
||||
"""Extract assistant text from OpenRouter/OpenAI-style payloads safely."""
|
||||
choices = payload.get("choices") or []
|
||||
if not choices:
|
||||
return ""
|
||||
|
||||
message = choices[0].get("message") or {}
|
||||
content = message.get("content")
|
||||
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
text_parts.append(item.strip())
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if isinstance(item.get("text"), str) and item["text"].strip():
|
||||
text_parts.append(item["text"].strip())
|
||||
continue
|
||||
|
||||
nested_text = item.get("content")
|
||||
if item.get("type") == "text" and isinstance(nested_text, str) and nested_text.strip():
|
||||
text_parts.append(nested_text.strip())
|
||||
|
||||
return "\n".join(text_parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _load_dotenv_settings() -> dict[str, str]:
|
||||
"""Read .env values directly so workers can recover from blank in-app config."""
|
||||
service_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
@@ -97,4 +133,4 @@ def get_openrouter_settings() -> OpenRouterSettings:
|
||||
dotenv_settings.get("OPENROUTER_BASE_URL", DEFAULT_OPENROUTER_BASE_URL),
|
||||
default=DEFAULT_OPENROUTER_BASE_URL,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4,7 +4,10 @@ import logging
|
||||
|
||||
import requests
|
||||
|
||||
from app.services.openrouter_config_service import get_openrouter_settings
|
||||
from app.services.openrouter_config_service import (
|
||||
extract_openrouter_text,
|
||||
get_openrouter_settings,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +37,12 @@ def _extract_text_from_pdf(input_path: str, max_pages: int = 50) -> str:
|
||||
from PyPDF2 import PdfReader
|
||||
|
||||
reader = PdfReader(input_path)
|
||||
if reader.is_encrypted and reader.decrypt("") == 0:
|
||||
raise PdfAiError(
|
||||
"This PDF is password-protected. Please unlock it first.",
|
||||
error_code="PDF_ENCRYPTED",
|
||||
)
|
||||
|
||||
pages = reader.pages[:max_pages]
|
||||
texts = []
|
||||
for i, page in enumerate(pages):
|
||||
@@ -41,6 +50,8 @@ def _extract_text_from_pdf(input_path: str, max_pages: int = 50) -> str:
|
||||
if text.strip():
|
||||
texts.append(f"[Page {i + 1}]\n{text}")
|
||||
return "\n\n".join(texts)
|
||||
except PdfAiError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise PdfAiError(
|
||||
"Failed to extract text from PDF.",
|
||||
@@ -98,29 +109,31 @@ def _call_openrouter(
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
status_code = getattr(response, "status_code", 200)
|
||||
|
||||
if status_code == 401:
|
||||
logger.error("OpenRouter API key is invalid or expired (401).")
|
||||
raise PdfAiError(
|
||||
"AI features are temporarily unavailable due to a configuration issue. Our team has been notified.",
|
||||
error_code="OPENROUTER_UNAUTHORIZED",
|
||||
)
|
||||
|
||||
if response.status_code == 402:
|
||||
if status_code == 402:
|
||||
logger.error("OpenRouter account has insufficient credits (402).")
|
||||
raise PdfAiError(
|
||||
"AI processing credits have been exhausted. Please try again later.",
|
||||
error_code="OPENROUTER_INSUFFICIENT_CREDITS",
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
if status_code == 429:
|
||||
logger.warning("OpenRouter rate limit reached (429).")
|
||||
raise PdfAiError(
|
||||
"AI service is experiencing high demand. Please wait a moment and try again.",
|
||||
error_code="OPENROUTER_RATE_LIMIT",
|
||||
)
|
||||
|
||||
if response.status_code >= 500:
|
||||
logger.error("OpenRouter server error (%s).", response.status_code)
|
||||
if status_code >= 500:
|
||||
logger.error("OpenRouter server error (%s).", status_code)
|
||||
raise PdfAiError(
|
||||
"AI service provider is experiencing issues. Please try again shortly.",
|
||||
error_code="OPENROUTER_SERVER_ERROR",
|
||||
@@ -139,12 +152,7 @@ def _call_openrouter(
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
reply = (
|
||||
data.get("choices", [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
.strip()
|
||||
)
|
||||
reply = extract_openrouter_text(data)
|
||||
|
||||
if not reply:
|
||||
raise PdfAiError(
|
||||
|
||||
@@ -9,7 +9,10 @@ from datetime import datetime, timezone
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from app.services.openrouter_config_service import get_openrouter_settings
|
||||
from app.services.openrouter_config_service import (
|
||||
extract_openrouter_text,
|
||||
get_openrouter_settings,
|
||||
)
|
||||
from app.services.ai_cost_service import AiBudgetExceededError, check_ai_budget, log_ai_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -465,12 +468,7 @@ def _request_ai_reply(
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
reply = (
|
||||
data.get("choices", [{}])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
.strip()
|
||||
)
|
||||
reply = extract_openrouter_text(data)
|
||||
if not reply:
|
||||
raise RuntimeError("Assistant returned an empty reply.")
|
||||
|
||||
@@ -593,4 +591,4 @@ def _fallback_reply(message: str, tool_slug: str) -> str:
|
||||
|
||||
def _response_model_name() -> str:
|
||||
settings = get_openrouter_settings()
|
||||
return settings.model if settings.api_key else "fallback"
|
||||
return settings.model if settings.api_key else "fallback"
|
||||
|
||||
@@ -5,14 +5,32 @@ import logging
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from app.utils.config_placeholders import normalize_optional_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolved_s3_settings() -> tuple[str, str, str]:
|
||||
"""Return sanitized S3 credentials, treating copied sample values as blank."""
|
||||
key = normalize_optional_config(
|
||||
current_app.config.get("AWS_ACCESS_KEY_ID"),
|
||||
("your-access-key", "replace-with"),
|
||||
)
|
||||
secret = normalize_optional_config(
|
||||
current_app.config.get("AWS_SECRET_ACCESS_KEY"),
|
||||
("your-secret-key", "replace-with"),
|
||||
)
|
||||
bucket = normalize_optional_config(
|
||||
current_app.config.get("AWS_S3_BUCKET"),
|
||||
("your-bucket-name", "replace-with"),
|
||||
)
|
||||
return key, secret, bucket
|
||||
|
||||
|
||||
def _is_s3_configured() -> bool:
|
||||
"""Check if AWS S3 credentials are provided."""
|
||||
key = current_app.config.get("AWS_ACCESS_KEY_ID")
|
||||
secret = current_app.config.get("AWS_SECRET_ACCESS_KEY")
|
||||
return bool(key and secret and key.strip() and secret.strip())
|
||||
key, secret, bucket = _resolved_s3_settings()
|
||||
return bool(key and secret and bucket)
|
||||
|
||||
|
||||
class StorageService:
|
||||
@@ -25,22 +43,60 @@ class StorageService:
|
||||
def use_s3(self) -> bool:
|
||||
return _is_s3_configured()
|
||||
|
||||
@property
|
||||
def allow_local_fallback(self) -> bool:
|
||||
value = current_app.config.get("STORAGE_ALLOW_LOCAL_FALLBACK", True)
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).strip().lower() != "false"
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy-initialize S3 client (only when S3 is configured)."""
|
||||
if self._client is None:
|
||||
import boto3
|
||||
key, secret, _ = _resolved_s3_settings()
|
||||
self._client = boto3.client(
|
||||
"s3",
|
||||
region_name=current_app.config["AWS_S3_REGION"],
|
||||
aws_access_key_id=current_app.config["AWS_ACCESS_KEY_ID"],
|
||||
aws_secret_access_key=current_app.config["AWS_SECRET_ACCESS_KEY"],
|
||||
aws_access_key_id=key,
|
||||
aws_secret_access_key=secret,
|
||||
)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def bucket(self):
|
||||
return current_app.config["AWS_S3_BUCKET"]
|
||||
_, _, bucket = _resolved_s3_settings()
|
||||
return bucket
|
||||
|
||||
def _local_key(self, task_id: str, filename: str, folder: str = "outputs") -> str:
|
||||
return f"{folder}/{task_id}/{filename}"
|
||||
|
||||
def _local_destination(self, task_id: str, filename: str) -> str:
|
||||
output_dir = current_app.config["OUTPUT_FOLDER"]
|
||||
dest_dir = os.path.join(output_dir, task_id)
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
return os.path.join(dest_dir, filename)
|
||||
|
||||
def _store_locally(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
|
||||
"""Copy a generated file into the app's local download storage."""
|
||||
filename = os.path.basename(local_path)
|
||||
dest_path = self._local_destination(task_id, filename)
|
||||
|
||||
if os.path.abspath(local_path) != os.path.abspath(dest_path):
|
||||
shutil.copy2(local_path, dest_path)
|
||||
|
||||
logger.info("[Local] Stored file: %s", dest_path)
|
||||
return self._local_key(task_id, filename, folder=folder)
|
||||
|
||||
def _resolve_local_path(self, storage_key: str) -> str | None:
|
||||
parts = [part for part in storage_key.strip("/").split("/") if part]
|
||||
if len(parts) < 3:
|
||||
return None
|
||||
|
||||
task_id = parts[1]
|
||||
filename = parts[-1]
|
||||
return os.path.join(current_app.config["OUTPUT_FOLDER"], task_id, filename)
|
||||
|
||||
def upload_file(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
|
||||
"""
|
||||
@@ -53,7 +109,7 @@ class StorageService:
|
||||
S3 key or local relative path (used as identifier)
|
||||
"""
|
||||
filename = os.path.basename(local_path)
|
||||
key = f"{folder}/{task_id}/{filename}"
|
||||
key = self._local_key(task_id, filename, folder=folder)
|
||||
|
||||
if self.use_s3:
|
||||
from botocore.exceptions import ClientError
|
||||
@@ -61,19 +117,16 @@ class StorageService:
|
||||
self.client.upload_file(local_path, self.bucket, key)
|
||||
return key
|
||||
except ClientError as e:
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}")
|
||||
else:
|
||||
# Local mode — keep file in the outputs directory
|
||||
output_dir = current_app.config["OUTPUT_FOLDER"]
|
||||
dest_dir = os.path.join(output_dir, task_id)
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
dest_path = os.path.join(dest_dir, filename)
|
||||
if not self.allow_local_fallback:
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
if os.path.abspath(local_path) != os.path.abspath(dest_path):
|
||||
shutil.copy2(local_path, dest_path)
|
||||
logger.exception(
|
||||
"S3 upload failed for %s. Falling back to local storage.",
|
||||
key,
|
||||
)
|
||||
return self._store_locally(local_path, task_id, folder=folder)
|
||||
|
||||
logger.info(f"[Local] Stored file: {dest_path}")
|
||||
return key
|
||||
return self._store_locally(local_path, task_id, folder=folder)
|
||||
|
||||
def generate_presigned_url(
|
||||
self, s3_key: str, expiry: int | None = None, original_filename: str | None = None
|
||||
@@ -84,6 +137,14 @@ class StorageService:
|
||||
S3 mode: presigned URL.
|
||||
Local mode: /api/download/<task_id>/<filename>
|
||||
"""
|
||||
local_path = self._resolve_local_path(s3_key)
|
||||
if local_path and os.path.isfile(local_path):
|
||||
parts = [part for part in s3_key.strip("/").split("/") if part]
|
||||
task_id = parts[1]
|
||||
filename = parts[-1]
|
||||
download_name = original_filename or filename
|
||||
return f"/api/download/{task_id}/{filename}?name={download_name}"
|
||||
|
||||
if self.use_s3:
|
||||
from botocore.exceptions import ClientError
|
||||
if expiry is None:
|
||||
@@ -108,20 +169,21 @@ class StorageService:
|
||||
raise RuntimeError(f"Failed to generate presigned URL: {e}")
|
||||
else:
|
||||
# Local mode — return path to Flask download route
|
||||
parts = s3_key.strip("/").split("/")
|
||||
# key = "outputs/<task_id>/<filename>"
|
||||
if len(parts) >= 3:
|
||||
task_id = parts[1]
|
||||
filename = parts[2]
|
||||
else:
|
||||
task_id = parts[0]
|
||||
filename = parts[-1]
|
||||
|
||||
parts = [part for part in s3_key.strip("/").split("/") if part]
|
||||
task_id = parts[1] if len(parts) >= 3 else parts[0]
|
||||
filename = parts[-1]
|
||||
download_name = original_filename or filename
|
||||
return f"/api/download/{task_id}/{filename}?name={download_name}"
|
||||
|
||||
def delete_file(self, s3_key: str):
|
||||
"""Delete a file from S3 (no-op in local mode)."""
|
||||
local_path = self._resolve_local_path(s3_key)
|
||||
if local_path and os.path.isfile(local_path):
|
||||
try:
|
||||
os.remove(local_path)
|
||||
except OSError:
|
||||
logger.warning("Failed to delete local fallback file: %s", local_path)
|
||||
|
||||
if self.use_s3:
|
||||
from botocore.exceptions import ClientError
|
||||
try:
|
||||
@@ -131,6 +193,10 @@ class StorageService:
|
||||
|
||||
def file_exists(self, s3_key: str) -> bool:
|
||||
"""Check if a file exists."""
|
||||
local_path = self._resolve_local_path(s3_key)
|
||||
if local_path and os.path.isfile(local_path):
|
||||
return True
|
||||
|
||||
if self.use_s3:
|
||||
from botocore.exceptions import ClientError
|
||||
try:
|
||||
@@ -138,16 +204,8 @@ class StorageService:
|
||||
return True
|
||||
except ClientError:
|
||||
return False
|
||||
else:
|
||||
parts = s3_key.strip("/").split("/")
|
||||
if len(parts) >= 3:
|
||||
task_id = parts[1]
|
||||
filename = parts[2]
|
||||
else:
|
||||
task_id = parts[0]
|
||||
filename = parts[-1]
|
||||
output_dir = current_app.config["OUTPUT_FOLDER"]
|
||||
return os.path.isfile(os.path.join(output_dir, task_id, filename))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance
|
||||
|
||||
@@ -1,18 +1,47 @@
|
||||
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
import stripe
|
||||
from flask import current_app
|
||||
|
||||
from app.services.account_service import update_user_plan, get_user_by_id, _connect, _utc_now
|
||||
from app.services.account_service import update_user_plan, _connect, _utc_now
|
||||
from app.utils.config_placeholders import normalize_optional_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_stripe_secret_key() -> str:
|
||||
"""Return the configured Stripe secret key, ignoring copied sample values."""
|
||||
return normalize_optional_config(
|
||||
current_app.config.get("STRIPE_SECRET_KEY", ""),
|
||||
("replace-with",),
|
||||
)
|
||||
|
||||
|
||||
def get_stripe_price_id(billing: str = "monthly") -> str:
|
||||
"""Return the configured Stripe price id for the requested billing cycle."""
|
||||
monthly = normalize_optional_config(
|
||||
current_app.config.get("STRIPE_PRICE_ID_PRO_MONTHLY", ""),
|
||||
("replace-with",),
|
||||
)
|
||||
yearly = normalize_optional_config(
|
||||
current_app.config.get("STRIPE_PRICE_ID_PRO_YEARLY", ""),
|
||||
("replace-with",),
|
||||
)
|
||||
|
||||
if billing == "yearly" and yearly:
|
||||
return yearly
|
||||
return monthly
|
||||
|
||||
|
||||
def is_stripe_configured() -> bool:
|
||||
"""Return True when billing has a usable secret key and at least one price id."""
|
||||
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")))
|
||||
|
||||
|
||||
def _init_stripe():
|
||||
"""Configure stripe with the app's secret key."""
|
||||
stripe.api_key = current_app.config.get("STRIPE_SECRET_KEY", "")
|
||||
stripe.api_key = get_stripe_secret_key()
|
||||
|
||||
|
||||
def _ensure_stripe_columns():
|
||||
@@ -109,7 +138,10 @@ def create_portal_session(user_id: int, return_url: str) -> str:
|
||||
|
||||
def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
|
||||
"""Process a Stripe webhook event. Returns a status dict."""
|
||||
webhook_secret = current_app.config.get("STRIPE_WEBHOOK_SECRET", "")
|
||||
webhook_secret = normalize_optional_config(
|
||||
current_app.config.get("STRIPE_WEBHOOK_SECRET", ""),
|
||||
("replace-with",),
|
||||
)
|
||||
if not webhook_secret:
|
||||
logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.")
|
||||
return {"status": "ignored", "reason": "no webhook secret"}
|
||||
|
||||
Reference in New Issue
Block a user