diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 4edb772..689db5d 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -1,4 +1,5 @@ """Flask Application Factory.""" + import os from flask import Flask, jsonify @@ -11,7 +12,12 @@ from app.services.ai_cost_service import init_ai_cost_db from app.services.site_assistant_service import init_site_assistant_db from app.services.contact_service import init_contact_db from app.services.stripe_service import init_stripe_db -from app.utils.csrf import CSRFError, apply_csrf_cookie, should_enforce_csrf, validate_csrf_request +from app.utils.csrf import ( + CSRFError, + apply_csrf_cookie, + should_enforce_csrf, + validate_csrf_request, +) def _init_sentry(app): @@ -35,13 +41,15 @@ def _init_sentry(app): app.logger.warning("sentry-sdk not installed — monitoring disabled.") -def create_app(config_name=None): +def create_app(config_name=None, config_overrides=None): """Create and configure the Flask application.""" if config_name is None: config_name = os.getenv("FLASK_ENV", "development") app = Flask(__name__) app.config.from_object(config[config_name]) + if config_overrides: + app.config.update(config_overrides) # Initialize Sentry early _init_sentry(app) diff --git a/backend/tests/test_rate_limiter.py b/backend/tests/test_rate_limiter.py index 35f4c31..526f224 100644 --- a/backend/tests/test_rate_limiter.py +++ b/backend/tests/test_rate_limiter.py @@ -1,6 +1,8 @@ """Tests for rate limiting middleware.""" + import pytest from app import create_app +from tests.conftest import CSRFTestClient @pytest.fixture @@ -11,33 +13,24 @@ def rate_limited_app(tmp_path): never throttled. Here we force the extension's internal flag back to True *after* init_app so the decorator limits are enforced. """ - app = create_app('testing') - app.config.update({ - 'TESTING': True, - 'RATELIMIT_STORAGE_URI': 'memory://', - 'UPLOAD_FOLDER': str(tmp_path / 'uploads'), - 'OUTPUT_FOLDER': str(tmp_path / 'outputs'), - }) + app = create_app( + "testing", + { + "TESTING": True, + "RATELIMIT_ENABLED": True, + "RATELIMIT_STORAGE_URI": "memory://", + "UPLOAD_FOLDER": str(tmp_path / "uploads"), + "OUTPUT_FOLDER": str(tmp_path / "outputs"), + }, + ) + app.test_client_class = CSRFTestClient import os - os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) - os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True) - # flask-limiter 3.x returns from init_app immediately when - # RATELIMIT_ENABLED=False (TestingConfig default), so `initialized` - # stays False and no limits are enforced. We override the config key - # and call init_app a SECOND time so the extension fully initialises. - # It is safe to call twice — flask-limiter guards against duplicate - # before_request hook registration via app.extensions["limiter"]. - from app.extensions import limiter as _limiter - app.config['RATELIMIT_ENABLED'] = True - _limiter.init_app(app) # second call — now RATELIMIT_ENABLED=True + os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True) + os.makedirs(app.config["OUTPUT_FOLDER"], exist_ok=True) yield app - # Restore so other tests are unaffected - _limiter.enabled = False - _limiter.initialized = False - @pytest.fixture def rate_limited_client(rate_limited_app): @@ -48,12 +41,12 @@ class TestRateLimiter: def test_health_endpoint_not_rate_limited(self, client): """Health endpoint should handle many rapid requests.""" for _ in range(20): - response = client.get('/api/health') + response = client.get("/api/health") assert response.status_code == 200 def test_rate_limit_header_present(self, client): """Response should include a valid HTTP status code.""" - response = client.get('/api/health') + response = client.get("/api/health") assert response.status_code == 200 @@ -68,7 +61,7 @@ class TestRateLimitEnforcement: """ blocked = False for i in range(15): - r = rate_limited_client.post('/api/compress/pdf') + r = rate_limited_client.post("/api/compress/pdf") if r.status_code == 429: blocked = True break @@ -81,7 +74,7 @@ class TestRateLimitEnforcement: """POST /api/convert/pdf-to-word is also rate-limited.""" blocked = False for _ in range(15): - r = rate_limited_client.post('/api/convert/pdf-to-word') + r = rate_limited_client.post("/api/convert/pdf-to-word") if r.status_code == 429: blocked = True break @@ -94,8 +87,8 @@ class TestRateLimitEnforcement: """ # Exhaust compress limit for _ in range(15): - rate_limited_client.post('/api/compress/pdf') + rate_limited_client.post("/api/compress/pdf") # Health should still respond normally - r = rate_limited_client.get('/api/health') - assert r.status_code == 200 \ No newline at end of file + r = rate_limited_client.get("/api/health") + assert r.status_code == 200