Refactor configuration handling and improve error management across services; normalize placeholder values for SMTP and Stripe configurations; enhance local storage fallback logic in StorageService; add tests for new behaviors and edge cases.
This commit is contained in:
23
.env.example
23
.env.example
@@ -16,13 +16,13 @@ CELERY_BROKER_URL=redis://redis:6379/0
|
|||||||
CELERY_RESULT_BACKEND=redis://redis:6379/1
|
CELERY_RESULT_BACKEND=redis://redis:6379/1
|
||||||
|
|
||||||
# OpenRouter AI
|
# OpenRouter AI
|
||||||
OPENROUTER_API_KEY=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
|
OPENROUTER_API_KEY=
|
||||||
OPENROUTER_MODEL=nvidia/nemotron-3-super-120b-a12b:free
|
OPENROUTER_MODEL=nvidia/nemotron-3-super-120b-a12b:free
|
||||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1/chat/completions
|
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1/chat/completions
|
||||||
|
|
||||||
# AWS S3
|
# AWS S3
|
||||||
AWS_ACCESS_KEY_ID=your-access-key
|
AWS_ACCESS_KEY_ID=
|
||||||
AWS_SECRET_ACCESS_KEY=your-secret-key
|
AWS_SECRET_ACCESS_KEY=
|
||||||
AWS_S3_BUCKET=dociva-temp-files
|
AWS_S3_BUCKET=dociva-temp-files
|
||||||
AWS_S3_REGION=eu-west-1
|
AWS_S3_REGION=eu-west-1
|
||||||
|
|
||||||
@@ -31,31 +31,32 @@ MAX_CONTENT_LENGTH_MB=50
|
|||||||
UPLOAD_FOLDER=/tmp/uploads
|
UPLOAD_FOLDER=/tmp/uploads
|
||||||
OUTPUT_FOLDER=/tmp/outputs
|
OUTPUT_FOLDER=/tmp/outputs
|
||||||
FILE_EXPIRY_SECONDS=1800
|
FILE_EXPIRY_SECONDS=1800
|
||||||
|
STORAGE_ALLOW_LOCAL_FALLBACK=true
|
||||||
DATABASE_PATH=/app/data/dociva.db
|
DATABASE_PATH=/app/data/dociva.db
|
||||||
|
|
||||||
# CORS
|
# CORS
|
||||||
CORS_ORIGINS=https://dociva.io,https://www.dociva.io
|
CORS_ORIGINS=https://dociva.io,https://www.dociva.io
|
||||||
|
|
||||||
# SMTP (Password reset + contact notifications)
|
# SMTP (Password reset + contact notifications)
|
||||||
SMTP_HOST=smtp.your-provider.com
|
SMTP_HOST=
|
||||||
SMTP_PORT=587
|
SMTP_PORT=587
|
||||||
SMTP_USER=noreply@dociva.io
|
SMTP_USER=
|
||||||
SMTP_PASSWORD=replace-with-smtp-password
|
SMTP_PASSWORD=
|
||||||
SMTP_FROM=noreply@dociva.io
|
SMTP_FROM=noreply@dociva.io
|
||||||
SMTP_USE_TLS=true
|
SMTP_USE_TLS=true
|
||||||
|
|
||||||
# Stripe Payments
|
# Stripe Payments
|
||||||
STRIPE_SECRET_KEY=sk_test_XXXXXXXXXXXXXXXXXXXXXXXX
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=whsec_XXXXXXXXXXXXXXXXXXXXXXXX
|
STRIPE_WEBHOOK_SECRET=
|
||||||
STRIPE_PRICE_ID_PRO_MONTHLY=price_XXXXXXXXXXXXXXXX
|
STRIPE_PRICE_ID_PRO_MONTHLY=
|
||||||
STRIPE_PRICE_ID_PRO_YEARLY=price_XXXXXXXXXXXXXXXX
|
STRIPE_PRICE_ID_PRO_YEARLY=
|
||||||
|
|
||||||
# Sentry Error Monitoring
|
# Sentry Error Monitoring
|
||||||
SENTRY_DSN=
|
SENTRY_DSN=
|
||||||
SENTRY_ENVIRONMENT=production
|
SENTRY_ENVIRONMENT=production
|
||||||
|
|
||||||
# PostgreSQL (production) — leave empty to use SQLite
|
# PostgreSQL (production) — leave empty to use SQLite
|
||||||
DATABASE_URL=sqlite3 /app/data/dociva.db
|
DATABASE_URL=
|
||||||
POSTGRES_DB=dociva
|
POSTGRES_DB=dociva
|
||||||
POSTGRES_USER=dociva
|
POSTGRES_USER=dociva
|
||||||
POSTGRES_PASSWORD=replace-with-strong-postgres-password
|
POSTGRES_PASSWORD=replace-with-strong-postgres-password
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from app.services.stripe_service import (
|
|||||||
create_checkout_session,
|
create_checkout_session,
|
||||||
create_portal_session,
|
create_portal_session,
|
||||||
handle_webhook_event,
|
handle_webhook_event,
|
||||||
|
get_stripe_price_id,
|
||||||
|
is_stripe_configured,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -31,11 +33,9 @@ def checkout():
|
|||||||
data = request.get_json(silent=True) or {}
|
data = request.get_json(silent=True) or {}
|
||||||
billing = data.get("billing", "monthly")
|
billing = data.get("billing", "monthly")
|
||||||
|
|
||||||
monthly_price = current_app.config.get("STRIPE_PRICE_ID_PRO_MONTHLY", "")
|
price_id = get_stripe_price_id(billing)
|
||||||
yearly_price = current_app.config.get("STRIPE_PRICE_ID_PRO_YEARLY", "")
|
|
||||||
price_id = yearly_price if billing == "yearly" and yearly_price else monthly_price
|
|
||||||
|
|
||||||
if not price_id:
|
if not is_stripe_configured() or not price_id:
|
||||||
return jsonify({"error": "Payment is not configured yet."}), 503
|
return jsonify({"error": "Payment is not configured yet."}), 503
|
||||||
|
|
||||||
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173")
|
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173")
|
||||||
@@ -62,6 +62,9 @@ def portal():
|
|||||||
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173")
|
frontend_url = current_app.config.get("FRONTEND_URL", "http://localhost:5173")
|
||||||
return_url = f"{frontend_url}/account"
|
return_url = f"{frontend_url}/account"
|
||||||
|
|
||||||
|
if not is_stripe_configured():
|
||||||
|
return jsonify({"error": "Payment is not configured yet."}), 503
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = create_portal_session(user_id, return_url)
|
url = create_portal_session(user_id, return_url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from app.services.openrouter_config_service import get_openrouter_settings
|
from app.services.openrouter_config_service import (
|
||||||
|
extract_openrouter_text,
|
||||||
|
get_openrouter_settings,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -72,12 +75,7 @@ def chat_about_flowchart(message: str, flow_data: dict | None = None) -> dict:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
reply = (
|
reply = extract_openrouter_text(data)
|
||||||
data.get("choices", [{}])[0]
|
|
||||||
.get("message", {})
|
|
||||||
.get("content", "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not reply:
|
if not reply:
|
||||||
reply = "I couldn't generate a response. Please try again."
|
reply = "I couldn't generate a response. Please try again."
|
||||||
|
|||||||
@@ -63,7 +63,10 @@ def save_message(name: str, email: str, category: str, subject: str, message: st
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# Send notification email to admin
|
# Send notification email to admin
|
||||||
admin_email = current_app.config.get("SMTP_FROM", "noreply@dociva.io")
|
admin_emails = tuple(current_app.config.get("INTERNAL_ADMIN_EMAILS", ()))
|
||||||
|
admin_email = admin_emails[0] if admin_emails else current_app.config.get(
|
||||||
|
"SMTP_FROM", "noreply@dociva.io"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
send_email(
|
send_email(
|
||||||
to=admin_email,
|
to=admin_email,
|
||||||
|
|||||||
@@ -6,16 +6,27 @@ from email.mime.multipart import MIMEMultipart
|
|||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
|
from app.utils.config_placeholders import normalize_optional_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_smtp_config() -> dict:
|
def _get_smtp_config() -> dict:
|
||||||
"""Read SMTP settings from Flask config."""
|
"""Read SMTP settings from Flask config."""
|
||||||
return {
|
return {
|
||||||
"host": current_app.config.get("SMTP_HOST", ""),
|
"host": normalize_optional_config(
|
||||||
|
current_app.config.get("SMTP_HOST", ""),
|
||||||
|
("your-provider", "replace-with"),
|
||||||
|
),
|
||||||
"port": current_app.config.get("SMTP_PORT", 587),
|
"port": current_app.config.get("SMTP_PORT", 587),
|
||||||
"user": current_app.config.get("SMTP_USER", ""),
|
"user": normalize_optional_config(
|
||||||
"password": current_app.config.get("SMTP_PASSWORD", ""),
|
current_app.config.get("SMTP_USER", ""),
|
||||||
|
("replace-with",),
|
||||||
|
),
|
||||||
|
"password": normalize_optional_config(
|
||||||
|
current_app.config.get("SMTP_PASSWORD", ""),
|
||||||
|
("replace-with",),
|
||||||
|
),
|
||||||
"from_addr": current_app.config.get("SMTP_FROM", "noreply@dociva.io"),
|
"from_addr": current_app.config.get("SMTP_FROM", "noreply@dociva.io"),
|
||||||
"use_tls": current_app.config.get("SMTP_USE_TLS", True),
|
"use_tls": current_app.config.get("SMTP_USE_TLS", True),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,42 @@ class OpenRouterSettings:
|
|||||||
base_url: str
|
base_url: str
|
||||||
|
|
||||||
|
|
||||||
|
def extract_openrouter_text(payload: dict) -> str:
|
||||||
|
"""Extract assistant text from OpenRouter/OpenAI-style payloads safely."""
|
||||||
|
choices = payload.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
message = choices[0].get("message") or {}
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
if item.strip():
|
||||||
|
text_parts.append(item.strip())
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(item.get("text"), str) and item["text"].strip():
|
||||||
|
text_parts.append(item["text"].strip())
|
||||||
|
continue
|
||||||
|
|
||||||
|
nested_text = item.get("content")
|
||||||
|
if item.get("type") == "text" and isinstance(nested_text, str) and nested_text.strip():
|
||||||
|
text_parts.append(nested_text.strip())
|
||||||
|
|
||||||
|
return "\n".join(text_parts).strip()
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _load_dotenv_settings() -> dict[str, str]:
|
def _load_dotenv_settings() -> dict[str, str]:
|
||||||
"""Read .env values directly so workers can recover from blank in-app config."""
|
"""Read .env values directly so workers can recover from blank in-app config."""
|
||||||
service_dir = os.path.abspath(os.path.dirname(__file__))
|
service_dir = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ import logging
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from app.services.openrouter_config_service import get_openrouter_settings
|
from app.services.openrouter_config_service import (
|
||||||
|
extract_openrouter_text,
|
||||||
|
get_openrouter_settings,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -34,6 +37,12 @@ def _extract_text_from_pdf(input_path: str, max_pages: int = 50) -> str:
|
|||||||
from PyPDF2 import PdfReader
|
from PyPDF2 import PdfReader
|
||||||
|
|
||||||
reader = PdfReader(input_path)
|
reader = PdfReader(input_path)
|
||||||
|
if reader.is_encrypted and reader.decrypt("") == 0:
|
||||||
|
raise PdfAiError(
|
||||||
|
"This PDF is password-protected. Please unlock it first.",
|
||||||
|
error_code="PDF_ENCRYPTED",
|
||||||
|
)
|
||||||
|
|
||||||
pages = reader.pages[:max_pages]
|
pages = reader.pages[:max_pages]
|
||||||
texts = []
|
texts = []
|
||||||
for i, page in enumerate(pages):
|
for i, page in enumerate(pages):
|
||||||
@@ -41,6 +50,8 @@ def _extract_text_from_pdf(input_path: str, max_pages: int = 50) -> str:
|
|||||||
if text.strip():
|
if text.strip():
|
||||||
texts.append(f"[Page {i + 1}]\n{text}")
|
texts.append(f"[Page {i + 1}]\n{text}")
|
||||||
return "\n\n".join(texts)
|
return "\n\n".join(texts)
|
||||||
|
except PdfAiError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise PdfAiError(
|
raise PdfAiError(
|
||||||
"Failed to extract text from PDF.",
|
"Failed to extract text from PDF.",
|
||||||
@@ -98,29 +109,31 @@ def _call_openrouter(
|
|||||||
timeout=60,
|
timeout=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 401:
|
status_code = getattr(response, "status_code", 200)
|
||||||
|
|
||||||
|
if status_code == 401:
|
||||||
logger.error("OpenRouter API key is invalid or expired (401).")
|
logger.error("OpenRouter API key is invalid or expired (401).")
|
||||||
raise PdfAiError(
|
raise PdfAiError(
|
||||||
"AI features are temporarily unavailable due to a configuration issue. Our team has been notified.",
|
"AI features are temporarily unavailable due to a configuration issue. Our team has been notified.",
|
||||||
error_code="OPENROUTER_UNAUTHORIZED",
|
error_code="OPENROUTER_UNAUTHORIZED",
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 402:
|
if status_code == 402:
|
||||||
logger.error("OpenRouter account has insufficient credits (402).")
|
logger.error("OpenRouter account has insufficient credits (402).")
|
||||||
raise PdfAiError(
|
raise PdfAiError(
|
||||||
"AI processing credits have been exhausted. Please try again later.",
|
"AI processing credits have been exhausted. Please try again later.",
|
||||||
error_code="OPENROUTER_INSUFFICIENT_CREDITS",
|
error_code="OPENROUTER_INSUFFICIENT_CREDITS",
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 429:
|
if status_code == 429:
|
||||||
logger.warning("OpenRouter rate limit reached (429).")
|
logger.warning("OpenRouter rate limit reached (429).")
|
||||||
raise PdfAiError(
|
raise PdfAiError(
|
||||||
"AI service is experiencing high demand. Please wait a moment and try again.",
|
"AI service is experiencing high demand. Please wait a moment and try again.",
|
||||||
error_code="OPENROUTER_RATE_LIMIT",
|
error_code="OPENROUTER_RATE_LIMIT",
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code >= 500:
|
if status_code >= 500:
|
||||||
logger.error("OpenRouter server error (%s).", response.status_code)
|
logger.error("OpenRouter server error (%s).", status_code)
|
||||||
raise PdfAiError(
|
raise PdfAiError(
|
||||||
"AI service provider is experiencing issues. Please try again shortly.",
|
"AI service provider is experiencing issues. Please try again shortly.",
|
||||||
error_code="OPENROUTER_SERVER_ERROR",
|
error_code="OPENROUTER_SERVER_ERROR",
|
||||||
@@ -139,12 +152,7 @@ def _call_openrouter(
|
|||||||
detail=error_msg,
|
detail=error_msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
reply = (
|
reply = extract_openrouter_text(data)
|
||||||
data.get("choices", [{}])[0]
|
|
||||||
.get("message", {})
|
|
||||||
.get("content", "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not reply:
|
if not reply:
|
||||||
raise PdfAiError(
|
raise PdfAiError(
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ from datetime import datetime, timezone
|
|||||||
import requests
|
import requests
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
from app.services.openrouter_config_service import get_openrouter_settings
|
from app.services.openrouter_config_service import (
|
||||||
|
extract_openrouter_text,
|
||||||
|
get_openrouter_settings,
|
||||||
|
)
|
||||||
from app.services.ai_cost_service import AiBudgetExceededError, check_ai_budget, log_ai_usage
|
from app.services.ai_cost_service import AiBudgetExceededError, check_ai_budget, log_ai_usage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -465,12 +468,7 @@ def _request_ai_reply(
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
reply = (
|
reply = extract_openrouter_text(data)
|
||||||
data.get("choices", [{}])[0]
|
|
||||||
.get("message", {})
|
|
||||||
.get("content", "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
if not reply:
|
if not reply:
|
||||||
raise RuntimeError("Assistant returned an empty reply.")
|
raise RuntimeError("Assistant returned an empty reply.")
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,32 @@ import logging
|
|||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
|
from app.utils.config_placeholders import normalize_optional_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolved_s3_settings() -> tuple[str, str, str]:
|
||||||
|
"""Return sanitized S3 credentials, treating copied sample values as blank."""
|
||||||
|
key = normalize_optional_config(
|
||||||
|
current_app.config.get("AWS_ACCESS_KEY_ID"),
|
||||||
|
("your-access-key", "replace-with"),
|
||||||
|
)
|
||||||
|
secret = normalize_optional_config(
|
||||||
|
current_app.config.get("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
("your-secret-key", "replace-with"),
|
||||||
|
)
|
||||||
|
bucket = normalize_optional_config(
|
||||||
|
current_app.config.get("AWS_S3_BUCKET"),
|
||||||
|
("your-bucket-name", "replace-with"),
|
||||||
|
)
|
||||||
|
return key, secret, bucket
|
||||||
|
|
||||||
|
|
||||||
def _is_s3_configured() -> bool:
|
def _is_s3_configured() -> bool:
|
||||||
"""Check if AWS S3 credentials are provided."""
|
"""Check if AWS S3 credentials are provided."""
|
||||||
key = current_app.config.get("AWS_ACCESS_KEY_ID")
|
key, secret, bucket = _resolved_s3_settings()
|
||||||
secret = current_app.config.get("AWS_SECRET_ACCESS_KEY")
|
return bool(key and secret and bucket)
|
||||||
return bool(key and secret and key.strip() and secret.strip())
|
|
||||||
|
|
||||||
|
|
||||||
class StorageService:
|
class StorageService:
|
||||||
@@ -25,22 +43,60 @@ class StorageService:
|
|||||||
def use_s3(self) -> bool:
|
def use_s3(self) -> bool:
|
||||||
return _is_s3_configured()
|
return _is_s3_configured()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_local_fallback(self) -> bool:
|
||||||
|
value = current_app.config.get("STORAGE_ALLOW_LOCAL_FALLBACK", True)
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return str(value).strip().lower() != "false"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self):
|
||||||
"""Lazy-initialize S3 client (only when S3 is configured)."""
|
"""Lazy-initialize S3 client (only when S3 is configured)."""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
import boto3
|
import boto3
|
||||||
|
key, secret, _ = _resolved_s3_settings()
|
||||||
self._client = boto3.client(
|
self._client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
region_name=current_app.config["AWS_S3_REGION"],
|
region_name=current_app.config["AWS_S3_REGION"],
|
||||||
aws_access_key_id=current_app.config["AWS_ACCESS_KEY_ID"],
|
aws_access_key_id=key,
|
||||||
aws_secret_access_key=current_app.config["AWS_SECRET_ACCESS_KEY"],
|
aws_secret_access_key=secret,
|
||||||
)
|
)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def bucket(self):
|
def bucket(self):
|
||||||
return current_app.config["AWS_S3_BUCKET"]
|
_, _, bucket = _resolved_s3_settings()
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
def _local_key(self, task_id: str, filename: str, folder: str = "outputs") -> str:
|
||||||
|
return f"{folder}/{task_id}/{filename}"
|
||||||
|
|
||||||
|
def _local_destination(self, task_id: str, filename: str) -> str:
|
||||||
|
output_dir = current_app.config["OUTPUT_FOLDER"]
|
||||||
|
dest_dir = os.path.join(output_dir, task_id)
|
||||||
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
|
return os.path.join(dest_dir, filename)
|
||||||
|
|
||||||
|
def _store_locally(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
|
||||||
|
"""Copy a generated file into the app's local download storage."""
|
||||||
|
filename = os.path.basename(local_path)
|
||||||
|
dest_path = self._local_destination(task_id, filename)
|
||||||
|
|
||||||
|
if os.path.abspath(local_path) != os.path.abspath(dest_path):
|
||||||
|
shutil.copy2(local_path, dest_path)
|
||||||
|
|
||||||
|
logger.info("[Local] Stored file: %s", dest_path)
|
||||||
|
return self._local_key(task_id, filename, folder=folder)
|
||||||
|
|
||||||
|
def _resolve_local_path(self, storage_key: str) -> str | None:
|
||||||
|
parts = [part for part in storage_key.strip("/").split("/") if part]
|
||||||
|
if len(parts) < 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task_id = parts[1]
|
||||||
|
filename = parts[-1]
|
||||||
|
return os.path.join(current_app.config["OUTPUT_FOLDER"], task_id, filename)
|
||||||
|
|
||||||
def upload_file(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
|
def upload_file(self, local_path: str, task_id: str, folder: str = "outputs") -> str:
|
||||||
"""
|
"""
|
||||||
@@ -53,7 +109,7 @@ class StorageService:
|
|||||||
S3 key or local relative path (used as identifier)
|
S3 key or local relative path (used as identifier)
|
||||||
"""
|
"""
|
||||||
filename = os.path.basename(local_path)
|
filename = os.path.basename(local_path)
|
||||||
key = f"{folder}/{task_id}/{filename}"
|
key = self._local_key(task_id, filename, folder=folder)
|
||||||
|
|
||||||
if self.use_s3:
|
if self.use_s3:
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
@@ -61,19 +117,16 @@ class StorageService:
|
|||||||
self.client.upload_file(local_path, self.bucket, key)
|
self.client.upload_file(local_path, self.bucket, key)
|
||||||
return key
|
return key
|
||||||
except ClientError as e:
|
except ClientError as e:
|
||||||
raise RuntimeError(f"Failed to upload file to S3: {e}")
|
if not self.allow_local_fallback:
|
||||||
else:
|
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||||
# Local mode — keep file in the outputs directory
|
|
||||||
output_dir = current_app.config["OUTPUT_FOLDER"]
|
|
||||||
dest_dir = os.path.join(output_dir, task_id)
|
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
|
||||||
dest_path = os.path.join(dest_dir, filename)
|
|
||||||
|
|
||||||
if os.path.abspath(local_path) != os.path.abspath(dest_path):
|
logger.exception(
|
||||||
shutil.copy2(local_path, dest_path)
|
"S3 upload failed for %s. Falling back to local storage.",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return self._store_locally(local_path, task_id, folder=folder)
|
||||||
|
|
||||||
logger.info(f"[Local] Stored file: {dest_path}")
|
return self._store_locally(local_path, task_id, folder=folder)
|
||||||
return key
|
|
||||||
|
|
||||||
def generate_presigned_url(
|
def generate_presigned_url(
|
||||||
self, s3_key: str, expiry: int | None = None, original_filename: str | None = None
|
self, s3_key: str, expiry: int | None = None, original_filename: str | None = None
|
||||||
@@ -84,6 +137,14 @@ class StorageService:
|
|||||||
S3 mode: presigned URL.
|
S3 mode: presigned URL.
|
||||||
Local mode: /api/download/<task_id>/<filename>
|
Local mode: /api/download/<task_id>/<filename>
|
||||||
"""
|
"""
|
||||||
|
local_path = self._resolve_local_path(s3_key)
|
||||||
|
if local_path and os.path.isfile(local_path):
|
||||||
|
parts = [part for part in s3_key.strip("/").split("/") if part]
|
||||||
|
task_id = parts[1]
|
||||||
|
filename = parts[-1]
|
||||||
|
download_name = original_filename or filename
|
||||||
|
return f"/api/download/{task_id}/{filename}?name={download_name}"
|
||||||
|
|
||||||
if self.use_s3:
|
if self.use_s3:
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
if expiry is None:
|
if expiry is None:
|
||||||
@@ -108,20 +169,21 @@ class StorageService:
|
|||||||
raise RuntimeError(f"Failed to generate presigned URL: {e}")
|
raise RuntimeError(f"Failed to generate presigned URL: {e}")
|
||||||
else:
|
else:
|
||||||
# Local mode — return path to Flask download route
|
# Local mode — return path to Flask download route
|
||||||
parts = s3_key.strip("/").split("/")
|
parts = [part for part in s3_key.strip("/").split("/") if part]
|
||||||
# key = "outputs/<task_id>/<filename>"
|
task_id = parts[1] if len(parts) >= 3 else parts[0]
|
||||||
if len(parts) >= 3:
|
filename = parts[-1]
|
||||||
task_id = parts[1]
|
|
||||||
filename = parts[2]
|
|
||||||
else:
|
|
||||||
task_id = parts[0]
|
|
||||||
filename = parts[-1]
|
|
||||||
|
|
||||||
download_name = original_filename or filename
|
download_name = original_filename or filename
|
||||||
return f"/api/download/{task_id}/{filename}?name={download_name}"
|
return f"/api/download/{task_id}/{filename}?name={download_name}"
|
||||||
|
|
||||||
def delete_file(self, s3_key: str):
|
def delete_file(self, s3_key: str):
|
||||||
"""Delete a file from S3 (no-op in local mode)."""
|
"""Delete a file from S3 (no-op in local mode)."""
|
||||||
|
local_path = self._resolve_local_path(s3_key)
|
||||||
|
if local_path and os.path.isfile(local_path):
|
||||||
|
try:
|
||||||
|
os.remove(local_path)
|
||||||
|
except OSError:
|
||||||
|
logger.warning("Failed to delete local fallback file: %s", local_path)
|
||||||
|
|
||||||
if self.use_s3:
|
if self.use_s3:
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
try:
|
try:
|
||||||
@@ -131,6 +193,10 @@ class StorageService:
|
|||||||
|
|
||||||
def file_exists(self, s3_key: str) -> bool:
|
def file_exists(self, s3_key: str) -> bool:
|
||||||
"""Check if a file exists."""
|
"""Check if a file exists."""
|
||||||
|
local_path = self._resolve_local_path(s3_key)
|
||||||
|
if local_path and os.path.isfile(local_path):
|
||||||
|
return True
|
||||||
|
|
||||||
if self.use_s3:
|
if self.use_s3:
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
try:
|
try:
|
||||||
@@ -138,16 +204,8 @@ class StorageService:
|
|||||||
return True
|
return True
|
||||||
except ClientError:
|
except ClientError:
|
||||||
return False
|
return False
|
||||||
else:
|
|
||||||
parts = s3_key.strip("/").split("/")
|
return False
|
||||||
if len(parts) >= 3:
|
|
||||||
task_id = parts[1]
|
|
||||||
filename = parts[2]
|
|
||||||
else:
|
|
||||||
task_id = parts[0]
|
|
||||||
filename = parts[-1]
|
|
||||||
output_dir = current_app.config["OUTPUT_FOLDER"]
|
|
||||||
return os.path.isfile(os.path.join(output_dir, task_id, filename))
|
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
# Singleton instance
|
||||||
|
|||||||
@@ -1,18 +1,47 @@
|
|||||||
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
"""Stripe payment service — checkout sessions, webhooks, and subscription management."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
import stripe
|
import stripe
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
from app.services.account_service import update_user_plan, get_user_by_id, _connect, _utc_now
|
from app.services.account_service import update_user_plan, _connect, _utc_now
|
||||||
|
from app.utils.config_placeholders import normalize_optional_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stripe_secret_key() -> str:
|
||||||
|
"""Return the configured Stripe secret key, ignoring copied sample values."""
|
||||||
|
return normalize_optional_config(
|
||||||
|
current_app.config.get("STRIPE_SECRET_KEY", ""),
|
||||||
|
("replace-with",),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_stripe_price_id(billing: str = "monthly") -> str:
|
||||||
|
"""Return the configured Stripe price id for the requested billing cycle."""
|
||||||
|
monthly = normalize_optional_config(
|
||||||
|
current_app.config.get("STRIPE_PRICE_ID_PRO_MONTHLY", ""),
|
||||||
|
("replace-with",),
|
||||||
|
)
|
||||||
|
yearly = normalize_optional_config(
|
||||||
|
current_app.config.get("STRIPE_PRICE_ID_PRO_YEARLY", ""),
|
||||||
|
("replace-with",),
|
||||||
|
)
|
||||||
|
|
||||||
|
if billing == "yearly" and yearly:
|
||||||
|
return yearly
|
||||||
|
return monthly
|
||||||
|
|
||||||
|
|
||||||
|
def is_stripe_configured() -> bool:
|
||||||
|
"""Return True when billing has a usable secret key and at least one price id."""
|
||||||
|
return bool(get_stripe_secret_key() and (get_stripe_price_id("monthly") or get_stripe_price_id("yearly")))
|
||||||
|
|
||||||
|
|
||||||
def _init_stripe():
|
def _init_stripe():
|
||||||
"""Configure stripe with the app's secret key."""
|
"""Configure stripe with the app's secret key."""
|
||||||
stripe.api_key = current_app.config.get("STRIPE_SECRET_KEY", "")
|
stripe.api_key = get_stripe_secret_key()
|
||||||
|
|
||||||
|
|
||||||
def _ensure_stripe_columns():
|
def _ensure_stripe_columns():
|
||||||
@@ -109,7 +138,10 @@ def create_portal_session(user_id: int, return_url: str) -> str:
|
|||||||
|
|
||||||
def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
|
def handle_webhook_event(payload: bytes, sig_header: str) -> dict:
|
||||||
"""Process a Stripe webhook event. Returns a status dict."""
|
"""Process a Stripe webhook event. Returns a status dict."""
|
||||||
webhook_secret = current_app.config.get("STRIPE_WEBHOOK_SECRET", "")
|
webhook_secret = normalize_optional_config(
|
||||||
|
current_app.config.get("STRIPE_WEBHOOK_SECRET", ""),
|
||||||
|
("replace-with",),
|
||||||
|
)
|
||||||
if not webhook_secret:
|
if not webhook_secret:
|
||||||
logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.")
|
logger.warning("STRIPE_WEBHOOK_SECRET not configured — ignoring webhook.")
|
||||||
return {"status": "ignored", "reason": "no webhook secret"}
|
return {"status": "ignored", "reason": "no webhook secret"}
|
||||||
|
|||||||
31
backend/app/utils/config_placeholders.py
Normal file
31
backend/app/utils/config_placeholders.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Helpers for treating sample config values as missing runtime configuration."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
_MASKED_SEQUENCE_RE = re.compile(r"(x{6,}|\*{4,})", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_optional_config(
|
||||||
|
value: str | None,
|
||||||
|
placeholder_markers: tuple[str, ...] = (),
|
||||||
|
) -> str:
|
||||||
|
"""Return a stripped config value, or blank when it still looks like a sample."""
|
||||||
|
normalized = str(value or "").strip()
|
||||||
|
if not normalized:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lowered = normalized.lower()
|
||||||
|
if any(marker.lower() in lowered for marker in placeholder_markers if marker):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if _MASKED_SEQUENCE_RE.search(normalized):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def has_real_config(
|
||||||
|
value: str | None,
|
||||||
|
placeholder_markers: tuple[str, ...] = (),
|
||||||
|
) -> bool:
|
||||||
|
"""Return True when the value is present and not an obvious placeholder."""
|
||||||
|
return bool(normalize_optional_config(value, placeholder_markers))
|
||||||
@@ -70,8 +70,17 @@ def cleanup_task_files(task_id: str, keep_outputs: bool = False):
|
|||||||
if os.path.exists(upload_task_dir):
|
if os.path.exists(upload_task_dir):
|
||||||
shutil.rmtree(upload_task_dir, ignore_errors=True)
|
shutil.rmtree(upload_task_dir, ignore_errors=True)
|
||||||
|
|
||||||
# Only clean outputs when using S3 (files already uploaded to S3)
|
# Preserve local outputs whenever local fallback is enabled so download links remain valid.
|
||||||
if not keep_outputs:
|
preserve_outputs = keep_outputs
|
||||||
|
if not preserve_outputs:
|
||||||
|
try:
|
||||||
|
from app.services.storage_service import storage
|
||||||
|
|
||||||
|
preserve_outputs = storage.allow_local_fallback
|
||||||
|
except Exception:
|
||||||
|
preserve_outputs = False
|
||||||
|
|
||||||
|
if not preserve_outputs:
|
||||||
output_task_dir = os.path.join(output_dir, task_id)
|
output_task_dir = os.path.join(output_dir, task_id)
|
||||||
if os.path.exists(output_task_dir):
|
if os.path.exists(output_task_dir):
|
||||||
shutil.rmtree(output_task_dir, ignore_errors=True)
|
shutil.rmtree(output_task_dir, ignore_errors=True)
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ class BaseConfig:
|
|||||||
UPLOAD_FOLDER = _env_or_default("UPLOAD_FOLDER", "/tmp/uploads")
|
UPLOAD_FOLDER = _env_or_default("UPLOAD_FOLDER", "/tmp/uploads")
|
||||||
OUTPUT_FOLDER = _env_or_default("OUTPUT_FOLDER", "/tmp/outputs")
|
OUTPUT_FOLDER = _env_or_default("OUTPUT_FOLDER", "/tmp/outputs")
|
||||||
FILE_EXPIRY_SECONDS = int(os.getenv("FILE_EXPIRY_SECONDS", 1800))
|
FILE_EXPIRY_SECONDS = int(os.getenv("FILE_EXPIRY_SECONDS", 1800))
|
||||||
|
STORAGE_ALLOW_LOCAL_FALLBACK = os.getenv(
|
||||||
|
"STORAGE_ALLOW_LOCAL_FALLBACK", "true"
|
||||||
|
).lower() == "true"
|
||||||
DATABASE_PATH = _env_or_default(
|
DATABASE_PATH = _env_or_default(
|
||||||
"DATABASE_PATH", os.path.join(BASE_DIR, "data", "dociva.db")
|
"DATABASE_PATH", os.path.join(BASE_DIR, "data", "dociva.db")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ ffmpeg-python>=0.2,<1.0
|
|||||||
|
|
||||||
# PDF Processing
|
# PDF Processing
|
||||||
PyPDF2>=3.0,<4.0
|
PyPDF2>=3.0,<4.0
|
||||||
|
pycryptodome>=3.20,<4.0
|
||||||
reportlab>=4.0,<5.0
|
reportlab>=4.0,<5.0
|
||||||
pdf2image>=1.16,<2.0
|
pdf2image>=1.16,<2.0
|
||||||
|
|
||||||
|
|||||||
19
backend/tests/test_email_service.py
Normal file
19
backend/tests/test_email_service.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Tests for SMTP configuration normalization."""
|
||||||
|
from app.services.email_service import send_email
|
||||||
|
|
||||||
|
|
||||||
|
def test_placeholder_smtp_host_is_treated_as_unconfigured(app, monkeypatch):
|
||||||
|
"""A copied sample SMTP host should not trigger a network call."""
|
||||||
|
with app.app_context():
|
||||||
|
app.config.update({
|
||||||
|
"SMTP_HOST": "smtp.your-provider.com",
|
||||||
|
"SMTP_PORT": 587,
|
||||||
|
"SMTP_USER": "noreply@dociva.io",
|
||||||
|
"SMTP_PASSWORD": "replace-with-smtp-password",
|
||||||
|
})
|
||||||
|
|
||||||
|
def fail_if_called(*args, **kwargs):
|
||||||
|
raise AssertionError("SMTP should not be contacted for placeholder config")
|
||||||
|
|
||||||
|
monkeypatch.setattr("smtplib.SMTP", fail_if_called)
|
||||||
|
assert send_email("user@example.com", "Subject", "<p>Body</p>") is False
|
||||||
@@ -1,6 +1,10 @@
|
|||||||
"""Tests for shared OpenRouter configuration resolution across AI services."""
|
"""Tests for shared OpenRouter configuration resolution across AI services."""
|
||||||
|
|
||||||
from app.services.openrouter_config_service import get_openrouter_settings
|
from app.services.openrouter_config_service import (
|
||||||
|
LEGACY_SAMPLE_OPENROUTER_API_KEY,
|
||||||
|
extract_openrouter_text,
|
||||||
|
get_openrouter_settings,
|
||||||
|
)
|
||||||
from app.services.pdf_ai_service import _call_openrouter
|
from app.services.pdf_ai_service import _call_openrouter
|
||||||
from app.services.site_assistant_service import _request_ai_reply
|
from app.services.site_assistant_service import _request_ai_reply
|
||||||
|
|
||||||
@@ -85,7 +89,7 @@ class TestOpenRouterConfigService:
|
|||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
'app.services.openrouter_config_service._load_dotenv_settings',
|
'app.services.openrouter_config_service._load_dotenv_settings',
|
||||||
lambda: {
|
lambda: {
|
||||||
'OPENROUTER_API_KEY': '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',
|
'OPENROUTER_API_KEY': LEGACY_SAMPLE_OPENROUTER_API_KEY,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,6 +99,27 @@ class TestOpenRouterConfigService:
|
|||||||
|
|
||||||
assert settings.api_key == ''
|
assert settings.api_key == ''
|
||||||
|
|
||||||
|
def test_extract_openrouter_text_supports_string_and_list_content(self):
|
||||||
|
assert extract_openrouter_text({
|
||||||
|
'choices': [{'message': {'content': ' plain text reply '}}],
|
||||||
|
}) == 'plain text reply'
|
||||||
|
|
||||||
|
assert extract_openrouter_text({
|
||||||
|
'choices': [{
|
||||||
|
'message': {
|
||||||
|
'content': [
|
||||||
|
{'type': 'text', 'text': 'First part'},
|
||||||
|
{'type': 'text', 'content': 'Second part'},
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}) == 'First part\nSecond part'
|
||||||
|
|
||||||
|
assert extract_openrouter_text({
|
||||||
|
'choices': [{'message': {'content': None}}],
|
||||||
|
}) == ''
|
||||||
|
|
||||||
|
|
||||||
class TestAiServicesUseSharedConfig:
|
class TestAiServicesUseSharedConfig:
|
||||||
def test_pdf_ai_uses_flask_config(self, app, monkeypatch):
|
def test_pdf_ai_uses_flask_config(self, app, monkeypatch):
|
||||||
|
|||||||
24
backend/tests/test_pdf_ai_service.py
Normal file
24
backend/tests/test_pdf_ai_service.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Service-level tests for PDF AI helpers."""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.pdf_ai_service import PdfAiError, _extract_text_from_pdf
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_text_from_pdf_rejects_password_protected_documents(monkeypatch):
|
||||||
|
"""Password-protected PDFs should surface a specific actionable error."""
|
||||||
|
|
||||||
|
class FakeReader:
|
||||||
|
def __init__(self, input_path):
|
||||||
|
self.is_encrypted = True
|
||||||
|
self.pages = []
|
||||||
|
|
||||||
|
def decrypt(self, password):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
monkeypatch.setattr("PyPDF2.PdfReader", FakeReader)
|
||||||
|
|
||||||
|
with pytest.raises(PdfAiError) as exc:
|
||||||
|
_extract_text_from_pdf("/tmp/protected.pdf")
|
||||||
|
|
||||||
|
assert exc.value.error_code == "PDF_ENCRYPTED"
|
||||||
|
assert "unlock" in exc.value.user_message.lower()
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for storage service — local mode (S3 not configured in tests)."""
|
"""Tests for storage service — local mode (S3 not configured in tests)."""
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from app.services.storage_service import StorageService
|
from app.services.storage_service import StorageService
|
||||||
|
|
||||||
@@ -54,3 +55,46 @@ class TestStorageServiceLocal:
|
|||||||
f.write('test')
|
f.write('test')
|
||||||
|
|
||||||
assert svc.file_exists(f'outputs/{task_id}/test.pdf') is True
|
assert svc.file_exists(f'outputs/{task_id}/test.pdf') is True
|
||||||
|
|
||||||
|
def test_placeholder_s3_credentials_disable_s3(self, app):
|
||||||
|
"""Copied sample AWS credentials should not activate S3 mode."""
|
||||||
|
with app.app_context():
|
||||||
|
app.config.update({
|
||||||
|
'AWS_ACCESS_KEY_ID': 'your-access-key',
|
||||||
|
'AWS_SECRET_ACCESS_KEY': 'your-secret-key',
|
||||||
|
'AWS_S3_BUCKET': 'dociva-temp-files',
|
||||||
|
})
|
||||||
|
svc = StorageService()
|
||||||
|
assert svc.use_s3 is False
|
||||||
|
|
||||||
|
def test_upload_falls_back_to_local_when_s3_upload_fails(self, app, monkeypatch):
|
||||||
|
"""A broken S3 upload should still preserve a working local download."""
|
||||||
|
with app.app_context():
|
||||||
|
app.config.update({
|
||||||
|
'AWS_ACCESS_KEY_ID': 'real-looking-key',
|
||||||
|
'AWS_SECRET_ACCESS_KEY': 'real-looking-secret',
|
||||||
|
'AWS_S3_BUCKET': 'dociva-temp-files',
|
||||||
|
'STORAGE_ALLOW_LOCAL_FALLBACK': True,
|
||||||
|
})
|
||||||
|
svc = StorageService()
|
||||||
|
task_id = 's3-fallback-test'
|
||||||
|
input_path = '/tmp/test_storage_fallback.pdf'
|
||||||
|
with open(input_path, 'wb') as f:
|
||||||
|
f.write(b'%PDF-1.4 fallback')
|
||||||
|
|
||||||
|
class DummyClientError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
failing_client = Mock()
|
||||||
|
failing_client.upload_file.side_effect = DummyClientError('boom')
|
||||||
|
monkeypatch.setattr('botocore.exceptions.ClientError', DummyClientError)
|
||||||
|
monkeypatch.setattr(StorageService, 'client', property(lambda self: failing_client))
|
||||||
|
|
||||||
|
key = svc.upload_file(input_path, task_id)
|
||||||
|
url = svc.generate_presigned_url(key, original_filename='fallback.pdf')
|
||||||
|
|
||||||
|
assert key == f'outputs/{task_id}/test_storage_fallback.pdf'
|
||||||
|
assert svc.file_exists(key) is True
|
||||||
|
assert '/api/download/s3-fallback-test/test_storage_fallback.pdf' in url
|
||||||
|
|
||||||
|
os.unlink(input_path)
|
||||||
|
|||||||
@@ -32,10 +32,34 @@ class TestStripeRoutes:
|
|||||||
})
|
})
|
||||||
assert response.status_code == 503
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
def test_checkout_placeholder_config_returns_503(self, client, app):
|
||||||
|
"""Copied sample Stripe values should be treated as not configured."""
|
||||||
|
self._login(client, email="stripe-placeholder@test.com")
|
||||||
|
app.config.update({
|
||||||
|
"STRIPE_SECRET_KEY": "sk_test_XXXXXXXXXXXXXXXXXXXXXXXX",
|
||||||
|
"STRIPE_PRICE_ID_PRO_MONTHLY": "price_XXXXXXXXXXXXXXXX",
|
||||||
|
"STRIPE_PRICE_ID_PRO_YEARLY": "price_XXXXXXXXXXXXXXXX",
|
||||||
|
})
|
||||||
|
response = client.post("/api/stripe/create-checkout-session", json={
|
||||||
|
"billing": "monthly",
|
||||||
|
})
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
def test_portal_requires_auth(self, client):
|
def test_portal_requires_auth(self, client):
|
||||||
response = client.post("/api/stripe/create-portal-session")
|
response = client.post("/api/stripe/create-portal-session")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_portal_placeholder_config_returns_503(self, client, app):
|
||||||
|
"""Portal access should not attempt Stripe calls when config is only sample data."""
|
||||||
|
self._login(client, email="stripe-portal@test.com")
|
||||||
|
app.config.update({
|
||||||
|
"STRIPE_SECRET_KEY": "sk_test_XXXXXXXXXXXXXXXXXXXXXXXX",
|
||||||
|
"STRIPE_PRICE_ID_PRO_MONTHLY": "price_XXXXXXXXXXXXXXXX",
|
||||||
|
"STRIPE_PRICE_ID_PRO_YEARLY": "price_XXXXXXXXXXXXXXXX",
|
||||||
|
})
|
||||||
|
response = client.post("/api/stripe/create-portal-session")
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
def test_webhook_missing_signature(self, client):
|
def test_webhook_missing_signature(self, client):
|
||||||
"""Webhook without config returns ignored status."""
|
"""Webhook without config returns ignored status."""
|
||||||
response = client.post(
|
response = client.post(
|
||||||
|
|||||||
Reference in New Issue
Block a user