Refactor configuration handling and improve error management across services; normalize placeholder values for SMTP and Stripe configurations; enhance local storage fallback logic in StorageService; add tests for new behaviors and edge cases.

This commit is contained in:
Your Name
2026-03-26 14:15:10 +02:00
parent 688d411537
commit bc8a5dc290
19 changed files with 423 additions and 95 deletions

View File

@@ -16,13 +16,13 @@ CELERY_BROKER_URL=redis://redis:6379/0
CELERY_RESULT_BACKEND=redis://redis:6379/1 CELERY_RESULT_BACKEND=redis://redis:6379/1
# OpenRouter AI # OpenRouter AI
OPENROUTER_API_KEY=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 OPENROUTER_API_KEY=
OPENROUTER_MODEL=nvidia/nemotron-3-super-120b-a12b:free OPENROUTER_MODEL=nvidia/nemotron-3-super-120b-a12b:free
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1/chat/completions OPENROUTER_BASE_URL=https://openrouter.ai/api/v1/chat/completions
# AWS S3 # AWS S3
AWS_ACCESS_KEY_ID=your-access-key AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=your-secret-key AWS_SECRET_ACCESS_KEY=
AWS_S3_BUCKET=dociva-temp-files AWS_S3_BUCKET=dociva-temp-files
AWS_S3_REGION=eu-west-1 AWS_S3_REGION=eu-west-1
@@ -31,31 +31,32 @@ MAX_CONTENT_LENGTH_MB=50
UPLOAD_FOLDER=/tmp/uploads UPLOAD_FOLDER=/tmp/uploads
OUTPUT_FOLDER=/tmp/outputs OUTPUT_FOLDER=/tmp/outputs
FILE_EXPIRY_SECONDS=1800 FILE_EXPIRY_SECONDS=1800
STORAGE_ALLOW_LOCAL_FALLBACK=true
DATABASE_PATH=/app/data/dociva.db DATABASE_PATH=/app/data/dociva.db
# CORS # CORS
CORS_ORIGINS=https://dociva.io,https://www.dociva.io CORS_ORIGINS=https://dociva.io,https://www.dociva.io
# SMTP (Password reset + contact notifications) # SMTP (Password reset + contact notifications)
SMTP_HOST=smtp.your-provider.com SMTP_HOST=
SMTP_PORT=587 SMTP_PORT=587
SMTP_USER=noreply@dociva.io SMTP_USER=
SMTP_PASSWORD=replace-with-smtp-password SMTP_PASSWORD=
SMTP_FROM=noreply@dociva.io SMTP_FROM=noreply@dociva.io
SMTP_USE_TLS=true SMTP_USE_TLS=true
# Stripe Payments # Stripe Payments
STRIPE_SECRET_KEY=sk_test_XXXXXXXXXXXXXXXXXXXXXXXX STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=whsec_XXXXXXXXXXXXXXXXXXXXXXXX STRIPE_WEBHOOK_SECRET=
STRIPE_PRICE_ID_PRO_MONTHLY=price_XXXXXXXXXXXXXXXX STRIPE_PRICE_ID_PRO_MONTHLY=
STRIPE_PRICE_ID_PRO_YEARLY=price_XXXXXXXXXXXXXXXX STRIPE_PRICE_ID_PRO_YEARLY=
# Sentry Error Monitoring # Sentry Error Monitoring
SENTRY_DSN= SENTRY_DSN=
SENTRY_ENVIRONMENT=production SENTRY_ENVIRONMENT=production
# PostgreSQL (production) — leave empty to use SQLite # PostgreSQL (production) — leave empty to use SQLite
DATABASE_URL=sqlite3 /app/data/dociva.db DATABASE_URL=
POSTGRES_DB=dociva POSTGRES_DB=dociva
POSTGRES_USER=dociva POSTGRES_USER=dociva
POSTGRES_PASSWORD=replace-with-strong-postgres-password POSTGRES_PASSWORD=replace-with-strong-postgres-password

View File

@@ -8,6 +8,8 @@ from app.services.stripe_service import (
create_checkout_session, create_checkout_session,
create_portal_session, create_portal_session,
handle_webhook_event, handle_webhook_event,
get_stripe_price_id,
is_stripe_configured,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,11 +33,9 @@ def checkout():
data = request.get_json(silent=True) or {} data = request.get_json(silent=True) or {}
billing = data.get("billing", "monthly") billing = data.get("billing", "monthly")
monthly_price = current_app.config.get("STRIPE_PRICE_ID_PRO_MONTHLY", "") price_id = get_stripe_price_id(billing)
yearly_price = current_app.config.get("STRIPE_PRICE_ID_PRO_YEARLY", "")
price_id = yearly_price if billing == "yearly" and yearly_price else monthly_price
if not price_id: if not is_stripe_configured() or not price_id:
return jsonify({"error": "Payment is not configured yet."}), 503 return jsonify({"error": "Payment is not configured yet."}), 503
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173") frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173")
@@ -62,6 +62,9 @@ def portal():
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173") frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173")
return_url = f"{frontend_url}/account" return_url = f"{frontend_url}/account"
if not is_stripe_configured():
return jsonify({"error": "Payment is not configured yet."}), 503
try: try:
url = create_portal_session(user_id, return_url) url = create_portal_session(user_id, return_url)
except Exception as e: except Exception as e:

View File

@@ -3,7 +3,10 @@ import json
import logging import logging
import requests import requests
from app.services.openrouter_config_service import get_openrouter_settings from app.services.openrouter_config_service import (
extract_openrouter_text,
get_openrouter_settings,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -72,12 +75,7 @@ def chat_about_flowchart(message: str, flow_data: dict | None = None) -> dict:
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
reply = ( reply = extract_openrouter_text(data)
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
)
if not reply: if not reply:
reply = "I couldn't generate a response. Please try again." reply = "I couldn't generate a response. Please try again."

View File

@@ -63,7 +63,10 @@ def save_message(name: str, email: str, category: str, subject: str, message: st
conn.close() conn.close()
# Send notification email to admin # Send notification email to admin
admin_email = current_app.config.get("SMTP_FROM", "noreply@dociva.io") admin_emails = tuple(current_app.config.get("INTERNAL_ADMIN_EMAILS", ()))
admin_email = admin_emails[0] if admin_emails else current_app.config.get(
"SMTP_FROM", "noreply@dociva.io"
)
try: try:
send_email( send_email(
to=admin_email, to=admin_email,

View File

@@ -6,16 +6,27 @@ from email.mime.multipart import MIMEMultipart
from flask import current_app from flask import current_app
from app.utils.config_placeholders import normalize_optional_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_smtp_config() -> dict: def _get_smtp_config() -> dict:
"""Read SMTP settings from Flask config.""" """Read SMTP settings from Flask config."""
return { return {
"host": current_app.config.get("SMTP_HOST", ""), "host": normalize_optional_config(
current_app.config.get("SMTP_HOST", ""),
("your-provider", "replace-with"),
),
"port": current_app.config.get("SMTP_PORT", 587), "port": current_app.config.get("SMTP_PORT", 587),
"user": current_app.config.get("SMTP_USER", ""), "user": normalize_optional_config(
"password": current_app.config.get("SMTP_PASSWORD", ""), current_app.config.get("SMTP_USER", ""),
("replace-with",),
),
"password": normalize_optional_config(
current_app.config.get("SMTP_PASSWORD", ""),
("replace-with",),
),
"from_addr": current_app.config.get("SMTP_FROM", "noreply@dociva.io"), "from_addr": current_app.config.get("SMTP_FROM", "noreply@dociva.io"),
"use_tls": current_app.config.get("SMTP_USE_TLS", True), "use_tls": current_app.config.get("SMTP_USE_TLS", True),
} }

View File

@@ -18,6 +18,42 @@ class OpenRouterSettings:
base_url: str base_url: str
def extract_openrouter_text(payload: dict) -> str:
"""Extract assistant text from OpenRouter/OpenAI-style payloads safely."""
choices = payload.get("choices") or []
if not choices:
return ""
message = choices[0].get("message") or {}
content = message.get("content")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
text_parts: list[str] = []
for item in content:
if isinstance(item, str):
if item.strip():
text_parts.append(item.strip())
continue
if not isinstance(item, dict):
continue
if isinstance(item.get("text"), str) and item["text"].strip():
text_parts.append(item["text"].strip())
continue
nested_text = item.get("content")
if item.get("type") == "text" and isinstance(nested_text, str) and nested_text.strip():
text_parts.append(nested_text.strip())
return "\n".join(text_parts).strip()
return ""
def _load_dotenv_settings() -> dict[str, str]: def _load_dotenv_settings() -> dict[str, str]:
"""Read .env values directly so workers can recover from blank in-app config.""" """Read .env values directly so workers can recover from blank in-app config."""
service_dir = os.path.abspath(os.path.dirname(__file__)) service_dir = os.path.abspath(os.path.dirname(__file__))

View File

@@ -4,7 +4,10 @@ import logging
import requests import requests
from app.services.openrouter_config_service import get_openrouter_settings from app.services.openrouter_config_service import (
extract_openrouter_text,
get_openrouter_settings,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,6 +37,12 @@ def _extract_text_from_pdf(input_path: str, max_pages: int = 50) -> str:
from PyPDF2 import PdfReader from PyPDF2 import PdfReader
reader = PdfReader(input_path) reader = PdfReader(input_path)
if reader.is_encrypted and reader.decrypt("") == 0:
raise PdfAiError(
"This PDF is password-protected. Please unlock it first.",
error_code="PDF_ENCRYPTED",
)
pages = reader.pages[:max_pages] pages = reader.pages[:max_pages]
texts = [] texts = []
for i, page in enumerate(pages): for i, page in enumerate(pages):
@@ -41,6 +50,8 @@ def _extract_text_from_pdf(input_path: str, max_pages: int = 50) -> str:
if text.strip(): if text.strip():
texts.append(f"[Page {i + 1}]\n{text}") texts.append(f"[Page {i + 1}]\n{text}")
return "\n\n".join(texts) return "\n\n".join(texts)
except PdfAiError:
raise
except Exception as e: except Exception as e:
raise PdfAiError( raise PdfAiError(
"Failed to extract text from PDF.", "Failed to extract text from PDF.",
@@ -98,29 +109,31 @@ def _call_openrouter(
timeout=60, timeout=60,
) )
if response.status_code == 401: status_code = getattr(response, "status_code", 200)
if status_code == 401:
logger.error("OpenRouter API key is invalid or expired (401).") logger.error("OpenRouter API key is invalid or expired (401).")
raise PdfAiError( raise PdfAiError(
"AI features are temporarily unavailable due to a configuration issue. Our team has been notified.", "AI features are temporarily unavailable due to a configuration issue. Our team has been notified.",
error_code="OPENROUTER_UNAUTHORIZED", error_code="OPENROUTER_UNAUTHORIZED",
) )
if response.status_code == 402: if status_code == 402:
logger.error("OpenRouter account has insufficient credits (402).") logger.error("OpenRouter account has insufficient credits (402).")
raise PdfAiError( raise PdfAiError(
"AI processing credits have been exhausted. Please try again later.", "AI processing credits have been exhausted. Please try again later.",
error_code="OPENROUTER_INSUFFICIENT_CREDITS", error_code="OPENROUTER_INSUFFICIENT_CREDITS",
) )
if response.status_code == 429: if status_code == 429:
logger.warning("OpenRouter rate limit reached (429).") logger.warning("OpenRouter rate limit reached (429).")
raise PdfAiError( raise PdfAiError(
"AI service is experiencing high demand. Please wait a moment and try again.", "AI service is experiencing high demand. Please wait a moment and try again.",
error_code="OPENROUTER_RATE_LIMIT", error_code="OPENROUTER_RATE_LIMIT",
) )
if response.status_code >= 500: if status_code >= 500:
logger.error("OpenRouter server error (%s).", response.status_code) logger.error("OpenRouter server error (%s).", status_code)
raise PdfAiError( raise PdfAiError(
"AI service provider is experiencing issues. Please try again shortly.", "AI service provider is experiencing issues. Please try again shortly.",
error_code="OPENROUTER_SERVER_ERROR", error_code="OPENROUTER_SERVER_ERROR",
@@ -139,12 +152,7 @@ def _call_openrouter(
detail=error_msg, detail=error_msg,
) )
reply = ( reply = extract_openrouter_text(data)
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
)
if not reply: if not reply:
raise PdfAiError( raise PdfAiError(

View File

@@ -9,7 +9,10 @@ from datetime import datetime, timezone
import requests import requests
from flask import current_app from flask import current_app
from app.services.openrouter_config_service import get_openrouter_settings from app.services.openrouter_config_service import (
extract_openrouter_text,
get_openrouter_settings,
)
from app.services.ai_cost_service import AiBudgetExceededError, check_ai_budget, log_ai_usage from app.services.ai_cost_service import AiBudgetExceededError, check_ai_budget, log_ai_usage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -465,12 +468,7 @@ def _request_ai_reply(
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
reply = ( reply = extract_openrouter_text(data)
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
)
if not reply: if not reply:
raise RuntimeError("Assistant returned an empty reply.") raise RuntimeError("Assistant returned an empty reply.")

View File

@@ -5,14 +5,32 @@ import logging
from flask import current_app from flask import current_app
from app.utils.config_placeholders import normalize_optional_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _resolved_s3_settings() -> tuple[str, str, str]:
"""Return sanitized S3 credentials, treating copied sample values as blank."""
key = normalize_optional_config(
current_app.config.get("AWS_ACCESS_KEY_ID"),
("your-access-key", "replace-with"),
)
secret = normalize_optional_config(
current_app.config.get("AWS_SECRET_ACCESS_KEY"),
("your-secret-key", "replace-with"),
)
bucket = normalize_optional_config(
current_app.config.get("AWS_S3_BUCKET"),
("your-bucket-name", "replace-with"),
)
return key, secret, bucket
def _is_s3_configured() -> bool: def _is_s3_configured() -> bool:
"""Check if AWS S3 credentials are provided.""" """Check if AWS S3 credentials are provided."""
key = current_app.config.get("AWS_ACCESS_KEY_ID") key, secret, bucket = _resolved_s3_settings()
secret = current_app.config.get("AWS_SECRET_ACCESS_KEY") return bool(key and secret and bucket)
return bool(key and secret and key.strip() and secret.strip())
class StorageService: class StorageService:
@@ -25,22 +43,60 @@ class StorageService:
def use_s3(self) -> bool: def use_s3(self) -> bool:
return _is_s3_configured() return _is_s3_configured()
@property
def allow_local_fallback(self) -> bool:
value = current_app.config.get("STORAGE_ALLOW_LOCAL_FALLBACK", True)
if isinstance(value, bool):
return value
return str(value).strip().lower() != "false"
@property @property
def client(self): def client(self):
"""Lazy-initialize S3 client (only when S3 is configured).""" """Lazy-initialize S3 client (only when S3 is configured)."""
if self._client is None: if self._client is None:
import boto3 import boto3
key, secret, _ = _resolved_s3_settings()
self._client = boto3.client( self._client = boto3.client(
"s3", "s3",
region_name=current_app.config["AWS_S3_REGION"], region_name=current_app.config["AWS_S3_REGION"],
aws_access_key_id=current_app.config["AWS_ACCESS_KEY_ID"], aws_access_key_id=key,
aws_secret_access_key=current_app.config["AWS_SECRET_ACCESS_KEY"], aws_secret_access_key=secret,
) )
return self._client return self._client
@property @property
def bucket(self): def bucket(self):
return current_app.config["AWS_S3_BUCKET"] _, _, bucket = _resolved_s3_settings()
return bucket
def _local_key(self, task_id: str, filename: str, folder: str = "outputs") -> str:
return f"{folder}/{task_id}/{filename}"
def _local_destination(self, task_id: str, filename: str) -> str:
output_dir = current_app.config["OUTPUT_FOLDER"]
dest_dir = os.path.join(output_dir, task_id)
os.makedirs(dest_dir, exist_ok=True)
return os.path.join(dest_dir, filename)
def _store_locally(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
"""Copy a generated file into the app's local download storage."""
filename = os.path.basename(local_path)
dest_path = self._local_destination(task_id, filename)
if os.path.abspath(local_path) != os.path.abspath(dest_path):
shutil.copy2(local_path, dest_path)
logger.info("[Local] Stored file: %s", dest_path)
return self._local_key(task_id, filename, folder=folder)
def _resolve_local_path(self, storage_key: str) -> str | None:
parts = [part for part in storage_key.strip("/").split("/") if part]
if len(parts) < 3:
return None
task_id = parts[1]
filename = parts[-1]
return os.path.join(current_app.config["OUTPUT_FOLDER"], task_id, filename)
def upload_file(self, local_path: str, task_id: str, folder: str = "outputs") -> str: def upload_file(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
""" """
@@ -53,7 +109,7 @@ class StorageService:
S3 key or local relative path (used as identifier) S3 key or local relative path (used as identifier)
""" """
filename = os.path.basename(local_path) filename = os.path.basename(local_path)
key = f"{folder}/{task_id}/{filename}" key = self._local_key(task_id, filename, folder=folder)
if self.use_s3: if self.use_s3:
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
@@ -61,19 +117,16 @@ class StorageService:
self.client.upload_file(local_path, self.bucket, key) self.client.upload_file(local_path, self.bucket, key)
return key return key
except ClientError as e: except ClientError as e:
raise RuntimeError(f"Failed to upload file to S3: {e}") if not self.allow_local_fallback:
else: raise RuntimeError(f"Failed to upload file to S3: {e}") from e
# Local mode — keep file in the outputs directory
output_dir = current_app.config["OUTPUT_FOLDER"]
dest_dir = os.path.join(output_dir, task_id)
os.makedirs(dest_dir, exist_ok=True)
dest_path = os.path.join(dest_dir, filename)
if os.path.abspath(local_path) != os.path.abspath(dest_path): logger.exception(
shutil.copy2(local_path, dest_path) "S3 upload failed for %s. Falling back to local storage.",
key,
)
return self._store_locally(local_path, task_id, folder=folder)
logger.info(f"[Local] Stored file: {dest_path}") return self._store_locally(local_path, task_id, folder=folder)
return key
def generate_presigned_url( def generate_presigned_url(
self, s3_key: str, expiry: int | None = None, original_filename: str | None = None self, s3_key: str, expiry: int | None = None, original_filename: str | None = None
@@ -84,6 +137,14 @@ class StorageService:
S3 mode: presigned URL. S3 mode: presigned URL.
Local mode: /api/download/<task_id>/<filename> Local mode: /api/download/<task_id>/<filename>
""" """
local_path = self._resolve_local_path(s3_key)
if local_path and os.path.isfile(local_path):
parts = [part for part in s3_key.strip("/").split("/") if part]
task_id = parts[1]
filename = parts[-1]
download_name = original_filename or filename
return f"/api/download/{task_id}/{filename}?name={download_name}"
if self.use_s3: if self.use_s3:
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
if expiry is None: if expiry is None:
@@ -108,20 +169,21 @@ class StorageService:
raise RuntimeError(f"Failed to generate presigned URL: {e}") raise RuntimeError(f"Failed to generate presigned URL: {e}")
else: else:
# Local mode — return path to Flask download route # Local mode — return path to Flask download route
parts = s3_key.strip("/").split("/") parts = [part for part in s3_key.strip("/").split("/") if part]
# key = "outputs/<task_id>/<filename>" task_id = parts[1] if len(parts) >= 3 else parts[0]
if len(parts) >= 3: filename = parts[-1]
task_id = parts[1]
filename = parts[2]
else:
task_id = parts[0]
filename = parts[-1]
download_name = original_filename or filename download_name = original_filename or filename
return f"/api/download/{task_id}/{filename}?name={download_name}" return f"/api/download/{task_id}/{filename}?name={download_name}"
def delete_file(self, s3_key: str): def delete_file(self, s3_key: str):
"""Delete a file from S3 (no-op in local mode).""" """Delete a file from S3 (no-op in local mode)."""
local_path = self._resolve_local_path(s3_key)
if local_path and os.path.isfile(local_path):
try:
os.remove(local_path)
except OSError:
logger.warning("Failed to delete local fallback file: %s", local_path)
if self.use_s3: if self.use_s3:
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
try: try:
@@ -131,6 +193,10 @@ class StorageService:
def file_exists(self, s3_key: str) -> bool: def file_exists(self, s3_key: str) -> bool:
"""Check if a file exists.""" """Check if a file exists."""
local_path = self._resolve_local_path(s3_key)
if local_path and os.path.isfile(local_path):
return True
if self.use_s3: if self.use_s3:
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
try: try:
@@ -138,16 +204,8 @@ class StorageService:
return True return True
except ClientError: except ClientError:
return False return False
else:
parts = s3_key.strip("/").split("/") return False
if len(parts) >= 3:
task_id = parts[1]
filename = parts[2]
else:
task_id = parts[0]
filename = parts[-1]
output_dir = current_app.config["OUTPUT_FOLDER"]
return os.path.isfile(os.path.join(output_dir, task_id, filename))
# Singleton instance # Singleton instance

View File

@@ -1,18 +1,47 @@
"""Stripe payment service — checkout sessions, webhooks, and subscription management.""" """Stripe payment service — checkout sessions, webhooks, and subscription management."""
import logging import logging
import os
import stripe import stripe
from flask import current_app from flask import current_app
from app.services.account_service import update_user_plan, get_user_by_id, _connect, _utc_now from app.services.account_service import update_user_plan, _connect, _utc_now
from app.utils.config_placeholders import normalize_optional_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_stripe_secret_key() -> str:
"""Return the configured Stripe secret key, ignoring copied sample values."""
return normalize_optional_config(
current_app.config.get("STRIPE_SECRET_KEY", ""),
("replace-with",),
)
def get_stripe_price_id(billing: str = "monthly") -> str:
"""Return the configured Stripe price id for the requested billing cycle."""
monthly = normalize_optional_config(
current_app.config.get("STRIPE_PRICE_ID_PRO_MONTHLY", ""),
("replace-with",),
)
yearly = normalize_optional_config(
current_app.config.get("STRIPE_PRICE_ID_PRO_YEARLY", ""),
("replace-with",),
)
if billing == "yearly" and yearly:
return yearly
return monthly
def is_stripe_configured() -> bool:
"""Return True when billing has a usable secret key and at least one price id."""
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")))
def _init_stripe(): def _init_stripe():
"""Configure stripe with the app's secret key.""" """Configure stripe with the app's secret key."""
stripe.api_key = current_app.config.get("STRIPE_SECRET_KEY", "") stripe.api_key = get_stripe_secret_key()
def _ensure_stripe_columns(): def _ensure_stripe_columns():
@@ -109,7 +138,10 @@ def create_portal_session(user_id: int, return_url: str) -> str:
def handle_webhook_event(payload: bytes, sig_header: str) -> dict: def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
"""Process a Stripe webhook event. Returns a status dict.""" """Process a Stripe webhook event. Returns a status dict."""
webhook_secret = current_app.config.get("STRIPE_WEBHOOK_SECRET", "") webhook_secret = normalize_optional_config(
current_app.config.get("STRIPE_WEBHOOK_SECRET", ""),
("replace-with",),
)
if not webhook_secret: if not webhook_secret:
logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.") logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.")
return {"status": "ignored", "reason": "no webhook secret"} return {"status": "ignored", "reason": "no webhook secret"}

View File

@@ -0,0 +1,31 @@
"""Helpers for treating sample config values as missing runtime configuration."""
import re
_MASKED_SEQUENCE_RE = re.compile(r"(x{6,}|\*{4,})", re.IGNORECASE)
def normalize_optional_config(
value: str | None,
placeholder_markers: tuple[str, ...] = (),
) -> str:
"""Return a stripped config value, or blank when it still looks like a sample."""
normalized = str(value or "").strip()
if not normalized:
return ""
lowered = normalized.lower()
if any(marker.lower() in lowered for marker in placeholder_markers if marker):
return ""
if _MASKED_SEQUENCE_RE.search(normalized):
return ""
return normalized
def has_real_config(
value: str | None,
placeholder_markers: tuple[str, ...] = (),
) -> bool:
"""Return True when the value is present and not an obvious placeholder."""
return bool(normalize_optional_config(value, placeholder_markers))

View File

@@ -70,8 +70,17 @@ def cleanup_task_files(task_id: str, keep_outputs: bool = False):
if os.path.exists(upload_task_dir): if os.path.exists(upload_task_dir):
shutil.rmtree(upload_task_dir, ignore_errors=True) shutil.rmtree(upload_task_dir, ignore_errors=True)
# Only clean outputs when using S3 (files already uploaded to S3) # Preserve local outputs whenever local fallback is enabled so download links remain valid.
if not keep_outputs: preserve_outputs = keep_outputs
if not preserve_outputs:
try:
from app.services.storage_service import storage
preserve_outputs = storage.allow_local_fallback
except Exception:
preserve_outputs = False
if not preserve_outputs:
output_task_dir = os.path.join(output_dir, task_id) output_task_dir = os.path.join(output_dir, task_id)
if os.path.exists(output_task_dir): if os.path.exists(output_task_dir):
shutil.rmtree(output_task_dir, ignore_errors=True) shutil.rmtree(output_task_dir, ignore_errors=True)

View File

@@ -37,6 +37,9 @@ class BaseConfig:
UPLOAD_FOLDER = _env_or_default("UPLOAD_FOLDER", "/tmp/uploads") UPLOAD_FOLDER = _env_or_default("UPLOAD_FOLDER", "/tmp/uploads")
OUTPUT_FOLDER = _env_or_default("OUTPUT_FOLDER", "/tmp/outputs") OUTPUT_FOLDER = _env_or_default("OUTPUT_FOLDER", "/tmp/outputs")
FILE_EXPIRY_SECONDS = int(os.getenv("FILE_EXPIRY_SECONDS", 1800)) FILE_EXPIRY_SECONDS = int(os.getenv("FILE_EXPIRY_SECONDS", 1800))
STORAGE_ALLOW_LOCAL_FALLBACK = os.getenv(
"STORAGE_ALLOW_LOCAL_FALLBACK", "true"
).lower() == "true"
DATABASE_PATH = _env_or_default( DATABASE_PATH = _env_or_default(
"DATABASE_PATH", os.path.join(BASE_DIR, "data", "dociva.db") "DATABASE_PATH", os.path.join(BASE_DIR, "data", "dociva.db")
) )

View File

@@ -18,6 +18,7 @@ ffmpeg-python>=0.2,<1.0
# PDF Processing # PDF Processing
PyPDF2>=3.0,<4.0 PyPDF2>=3.0,<4.0
pycryptodome>=3.20,<4.0
reportlab>=4.0,<5.0 reportlab>=4.0,<5.0
pdf2image>=1.16,<2.0 pdf2image>=1.16,<2.0

View File

@@ -0,0 +1,19 @@
"""Tests for SMTP configuration normalization."""
from app.services.email_service import send_email
def test_placeholder_smtp_host_is_treated_as_unconfigured(app, monkeypatch):
"""A copied sample SMTP host should not trigger a network call."""
with app.app_context():
app.config.update({
"SMTP_HOST": "smtp.your-provider.com",
"SMTP_PORT": 587,
"SMTP_USER": "noreply@dociva.io",
"SMTP_PASSWORD": "replace-with-smtp-password",
})
def fail_if_called(*args, **kwargs):
raise AssertionError("SMTP should not be contacted for placeholder config")
monkeypatch.setattr("smtplib.SMTP", fail_if_called)
assert send_email("user@example.com", "Subject", "<p>Body</p>") is False

View File

@@ -1,6 +1,10 @@
"""Tests for shared OpenRouter configuration resolution across AI services.""" """Tests for shared OpenRouter configuration resolution across AI services."""
from app.services.openrouter_config_service import get_openrouter_settings from app.services.openrouter_config_service import (
LEGACY_SAMPLE_OPENROUTER_API_KEY,
extract_openrouter_text,
get_openrouter_settings,
)
from app.services.pdf_ai_service import _call_openrouter from app.services.pdf_ai_service import _call_openrouter
from app.services.site_assistant_service import _request_ai_reply from app.services.site_assistant_service import _request_ai_reply
@@ -85,7 +89,7 @@ class TestOpenRouterConfigService:
monkeypatch.setattr( monkeypatch.setattr(
'app.services.openrouter_config_service._load_dotenv_settings', 'app.services.openrouter_config_service._load_dotenv_settings',
lambda: { lambda: {
'OPENROUTER_API_KEY': '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', 'OPENROUTER_API_KEY': LEGACY_SAMPLE_OPENROUTER_API_KEY,
}, },
) )
@@ -95,6 +99,27 @@ class TestOpenRouterConfigService:
assert settings.api_key == '' assert settings.api_key == ''
def test_extract_openrouter_text_supports_string_and_list_content(self):
assert extract_openrouter_text({
'choices': [{'message': {'content': ' plain text reply '}}],
}) == 'plain text reply'
assert extract_openrouter_text({
'choices': [{
'message': {
'content': [
{'type': 'text', 'text': 'First part'},
{'type': 'text', 'content': 'Second part'},
None,
],
},
}],
}) == 'First part\nSecond part'
assert extract_openrouter_text({
'choices': [{'message': {'content': None}}],
}) == ''
class TestAiServicesUseSharedConfig: class TestAiServicesUseSharedConfig:
def test_pdf_ai_uses_flask_config(self, app, monkeypatch): def test_pdf_ai_uses_flask_config(self, app, monkeypatch):

View File

@@ -0,0 +1,24 @@
"""Service-level tests for PDF AI helpers."""
import pytest
from app.services.pdf_ai_service import PdfAiError, _extract_text_from_pdf
def test_extract_text_from_pdf_rejects_password_protected_documents(monkeypatch):
"""Password-protected PDFs should surface a specific actionable error."""
class FakeReader:
def __init__(self, input_path):
self.is_encrypted = True
self.pages = []
def decrypt(self, password):
return 0
monkeypatch.setattr("PyPDF2.PdfReader", FakeReader)
with pytest.raises(PdfAiError) as exc:
_extract_text_from_pdf("/tmp/protected.pdf")
assert exc.value.error_code == "PDF_ENCRYPTED"
assert "unlock" in exc.value.user_message.lower()

View File

@@ -1,5 +1,6 @@
"""Tests for storage service — local mode (S3 not configured in tests).""" """Tests for storage service — local mode (S3 not configured in tests)."""
import os import os
from unittest.mock import Mock
from app.services.storage_service import StorageService from app.services.storage_service import StorageService
@@ -54,3 +55,46 @@ class TestStorageServiceLocal:
f.write('test') f.write('test')
assert svc.file_exists(f'outputs/{task_id}/test.pdf') is True assert svc.file_exists(f'outputs/{task_id}/test.pdf') is True
def test_placeholder_s3_credentials_disable_s3(self, app):
"""Copied sample AWS credentials should not activate S3 mode."""
with app.app_context():
app.config.update({
'AWS_ACCESS_KEY_ID': 'your-access-key',
'AWS_SECRET_ACCESS_KEY': 'your-secret-key',
'AWS_S3_BUCKET': 'dociva-temp-files',
})
svc = StorageService()
assert svc.use_s3 is False
def test_upload_falls_back_to_local_when_s3_upload_fails(self, app, monkeypatch):
"""A broken S3 upload should still preserve a working local download."""
with app.app_context():
app.config.update({
'AWS_ACCESS_KEY_ID': 'real-looking-key',
'AWS_SECRET_ACCESS_KEY': 'real-looking-secret',
'AWS_S3_BUCKET': 'dociva-temp-files',
'STORAGE_ALLOW_LOCAL_FALLBACK': True,
})
svc = StorageService()
task_id = 's3-fallback-test'
input_path = '/tmp/test_storage_fallback.pdf'
with open(input_path, 'wb') as f:
f.write(b'%PDF-1.4 fallback')
class DummyClientError(Exception):
pass
failing_client = Mock()
failing_client.upload_file.side_effect = DummyClientError('boom')
monkeypatch.setattr('botocore.exceptions.ClientError', DummyClientError)
monkeypatch.setattr(StorageService, 'client', property(lambda self: failing_client))
key = svc.upload_file(input_path, task_id)
url = svc.generate_presigned_url(key, original_filename='fallback.pdf')
assert key == f'outputs/{task_id}/test_storage_fallback.pdf'
assert svc.file_exists(key) is True
assert '/api/download/s3-fallback-test/test_storage_fallback.pdf' in url
os.unlink(input_path)

View File

@@ -32,10 +32,34 @@ class TestStripeRoutes:
}) })
assert response.status_code == 503 assert response.status_code == 503
def test_checkout_placeholder_config_returns_503(self, client, app):
"""Copied sample Stripe values should be treated as not configured."""
self._login(client, email="stripe-placeholder@test.com")
app.config.update({
"STRIPE_SECRET_KEY": "sk_test_XXXXXXXXXXXXXXXXXXXXXXXX",
"STRIPE_PRICE_ID_PRO_MONTHLY": "price_XXXXXXXXXXXXXXXX",
"STRIPE_PRICE_ID_PRO_YEARLY": "price_XXXXXXXXXXXXXXXX",
})
response = client.post("/api/stripe/create-checkout-session", json={
"billing": "monthly",
})
assert response.status_code == 503
def test_portal_requires_auth(self, client): def test_portal_requires_auth(self, client):
response = client.post("/api/stripe/create-portal-session") response = client.post("/api/stripe/create-portal-session")
assert response.status_code == 401 assert response.status_code == 401
def test_portal_placeholder_config_returns_503(self, client, app):
"""Portal access should not attempt Stripe calls when config is only sample data."""
self._login(client, email="stripe-portal@test.com")
app.config.update({
"STRIPE_SECRET_KEY": "sk_test_XXXXXXXXXXXXXXXXXXXXXXXX",
"STRIPE_PRICE_ID_PRO_MONTHLY": "price_XXXXXXXXXXXXXXXX",
"STRIPE_PRICE_ID_PRO_YEARLY": "price_XXXXXXXXXXXXXXXX",
})
response = client.post("/api/stripe/create-portal-session")
assert response.status_code == 503
def test_webhook_missing_signature(self, client): def test_webhook_missing_signature(self, client):
"""Webhook without config returns ignored status.""" """Webhook without config returns ignored status."""
response = client.post( response = client.post(