refactor: improve app initialization and update rate limiter tests for consistency

This commit is contained in:
Your Name
2026-03-30 00:12:34 +02:00
parent 5ac1d58742
commit 4ac4bf4e42
2 changed files with 32 additions and 31 deletions

View File

@@ -1,4 +1,5 @@
"""Flask Application Factory.""" """Flask Application Factory."""
import os import os
from flask import Flask, jsonify from flask import Flask, jsonify
@@ -11,7 +12,12 @@ from app.services.ai_cost_service import init_ai_cost_db
from app.services.site_assistant_service import init_site_assistant_db from app.services.site_assistant_service import init_site_assistant_db
from app.services.contact_service import init_contact_db from app.services.contact_service import init_contact_db
from app.services.stripe_service import init_stripe_db from app.services.stripe_service import init_stripe_db
from app.utils.csrf import CSRFError, apply_csrf_cookie, should_enforce_csrf, validate_csrf_request from app.utils.csrf import (
CSRFError,
apply_csrf_cookie,
should_enforce_csrf,
validate_csrf_request,
)
def _init_sentry(app): def _init_sentry(app):
@@ -35,13 +41,15 @@ def _init_sentry(app):
app.logger.warning("sentry-sdk not installed — monitoring disabled.") app.logger.warning("sentry-sdk not installed — monitoring disabled.")
def create_app(config_name=None): def create_app(config_name=None, config_overrides=None):
"""Create and configure the Flask application.""" """Create and configure the Flask application."""
if config_name is None: if config_name is None:
config_name = os.getenv("FLASK_ENV", "development") config_name = os.getenv("FLASK_ENV", "development")
app = Flask(__name__) app = Flask(__name__)
app.config.from_object(config[config_name]) app.config.from_object(config[config_name])
if config_overrides:
app.config.update(config_overrides)
# Initialize Sentry early # Initialize Sentry early
_init_sentry(app) _init_sentry(app)

View File

@@ -1,6 +1,8 @@
"""Tests for rate limiting middleware.""" """Tests for rate limiting middleware."""
import pytest import pytest
from app import create_app from app import create_app
from tests.conftest import CSRFTestClient
@pytest.fixture @pytest.fixture
@@ -11,33 +13,24 @@ def rate_limited_app(tmp_path):
never throttled. Here we force the extension's internal flag back to never throttled. Here we force the extension's internal flag back to
True *after* init_app so the decorator limits are enforced. True *after* init_app so the decorator limits are enforced.
""" """
app = create_app('testing') app = create_app(
app.config.update({ "testing",
'TESTING': True, {
'RATELIMIT_STORAGE_URI': 'memory://', "TESTING": True,
'UPLOAD_FOLDER': str(tmp_path / 'uploads'), "RATELIMIT_ENABLED": True,
'OUTPUT_FOLDER': str(tmp_path / 'outputs'), "RATELIMIT_STORAGE_URI": "memory://",
}) "UPLOAD_FOLDER": str(tmp_path / "uploads"),
"OUTPUT_FOLDER": str(tmp_path / "outputs"),
},
)
app.test_client_class = CSRFTestClient
import os import os
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True)
# flask-limiter 3.x returns from init_app immediately when os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
# RATELIMIT_ENABLED=False (TestingConfig default), so `initialized` os.makedirs(app.config["OUTPUT_FOLDER"], exist_ok=True)
# stays False and no limits are enforced. We override the config key
# and call init_app a SECOND time so the extension fully initialises.
# It is safe to call twice — flask-limiter guards against duplicate
# before_request hook registration via app.extensions["limiter"].
from app.extensions import limiter as _limiter
app.config['RATELIMIT_ENABLED'] = True
_limiter.init_app(app) # second call — now RATELIMIT_ENABLED=True
yield app yield app
# Restore so other tests are unaffected
_limiter.enabled = False
_limiter.initialized = False
@pytest.fixture @pytest.fixture
def rate_limited_client(rate_limited_app): def rate_limited_client(rate_limited_app):
@@ -48,12 +41,12 @@ class TestRateLimiter:
def test_health_endpoint_not_rate_limited(self, client): def test_health_endpoint_not_rate_limited(self, client):
"""Health endpoint should handle many rapid requests.""" """Health endpoint should handle many rapid requests."""
for _ in range(20): for _ in range(20):
response = client.get('/api/health') response = client.get("/api/health")
assert response.status_code == 200 assert response.status_code == 200
def test_rate_limit_header_present(self, client): def test_rate_limit_header_present(self, client):
"""Response should include a valid HTTP status code.""" """Response should include a valid HTTP status code."""
response = client.get('/api/health') response = client.get("/api/health")
assert response.status_code == 200 assert response.status_code == 200
@@ -68,7 +61,7 @@ class TestRateLimitEnforcement:
""" """
blocked = False blocked = False
for i in range(15): for i in range(15):
r = rate_limited_client.post('/api/compress/pdf') r = rate_limited_client.post("/api/compress/pdf")
if r.status_code == 429: if r.status_code == 429:
blocked = True blocked = True
break break
@@ -81,7 +74,7 @@ class TestRateLimitEnforcement:
"""POST /api/convert/pdf-to-word is also rate-limited.""" """POST /api/convert/pdf-to-word is also rate-limited."""
blocked = False blocked = False
for _ in range(15): for _ in range(15):
r = rate_limited_client.post('/api/convert/pdf-to-word') r = rate_limited_client.post("/api/convert/pdf-to-word")
if r.status_code == 429: if r.status_code == 429:
blocked = True blocked = True
break break
@@ -94,8 +87,8 @@ class TestRateLimitEnforcement:
""" """
# Exhaust compress limit # Exhaust compress limit
for _ in range(15): for _ in range(15):
rate_limited_client.post('/api/compress/pdf') rate_limited_client.post("/api/compress/pdf")
# Health should still respond normally # Health should still respond normally
r = rate_limited_client.get('/api/health') r = rate_limited_client.get("/api/health")
assert r.status_code == 200 assert r.status_code == 200