refactor: improve app initialization and update rate limiter tests for consistency
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Flask Application Factory."""
|
||||
|
||||
import os
|
||||
|
||||
from flask import Flask, jsonify
|
||||
@@ -11,7 +12,12 @@ from app.services.ai_cost_service import init_ai_cost_db
|
||||
from app.services.site_assistant_service import init_site_assistant_db
|
||||
from app.services.contact_service import init_contact_db
|
||||
from app.services.stripe_service import init_stripe_db
|
||||
from app.utils.csrf import CSRFError, apply_csrf_cookie, should_enforce_csrf, validate_csrf_request
|
||||
from app.utils.csrf import (
|
||||
CSRFError,
|
||||
apply_csrf_cookie,
|
||||
should_enforce_csrf,
|
||||
validate_csrf_request,
|
||||
)
|
||||
|
||||
|
||||
def _init_sentry(app):
|
||||
@@ -35,13 +41,15 @@ def _init_sentry(app):
|
||||
app.logger.warning("sentry-sdk not installed — monitoring disabled.")
|
||||
|
||||
|
||||
def create_app(config_name=None):
|
||||
def create_app(config_name=None, config_overrides=None):
|
||||
"""Create and configure the Flask application."""
|
||||
if config_name is None:
|
||||
config_name = os.getenv("FLASK_ENV", "development")
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(config[config_name])
|
||||
if config_overrides:
|
||||
app.config.update(config_overrides)
|
||||
|
||||
# Initialize Sentry early
|
||||
_init_sentry(app)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
import pytest
|
||||
from app import create_app
|
||||
from tests.conftest import CSRFTestClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -11,33 +13,24 @@ def rate_limited_app(tmp_path):
|
||||
never throttled. Here we force the extension's internal flag back to
|
||||
True *after* init_app so the decorator limits are enforced.
|
||||
"""
|
||||
app = create_app('testing')
|
||||
app.config.update({
|
||||
'TESTING': True,
|
||||
'RATELIMIT_STORAGE_URI': 'memory://',
|
||||
'UPLOAD_FOLDER': str(tmp_path / 'uploads'),
|
||||
'OUTPUT_FOLDER': str(tmp_path / 'outputs'),
|
||||
})
|
||||
app = create_app(
|
||||
"testing",
|
||||
{
|
||||
"TESTING": True,
|
||||
"RATELIMIT_ENABLED": True,
|
||||
"RATELIMIT_STORAGE_URI": "memory://",
|
||||
"UPLOAD_FOLDER": str(tmp_path / "uploads"),
|
||||
"OUTPUT_FOLDER": str(tmp_path / "outputs"),
|
||||
},
|
||||
)
|
||||
app.test_client_class = CSRFTestClient
|
||||
import os
|
||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||
os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True)
|
||||
|
||||
# flask-limiter 3.x returns from init_app immediately when
|
||||
# RATELIMIT_ENABLED=False (TestingConfig default), so `initialized`
|
||||
# stays False and no limits are enforced. We override the config key
|
||||
# and call init_app a SECOND time so the extension fully initialises.
|
||||
# It is safe to call twice — flask-limiter guards against duplicate
|
||||
# before_request hook registration via app.extensions["limiter"].
|
||||
from app.extensions import limiter as _limiter
|
||||
app.config['RATELIMIT_ENABLED'] = True
|
||||
_limiter.init_app(app) # second call — now RATELIMIT_ENABLED=True
|
||||
os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
|
||||
os.makedirs(app.config["OUTPUT_FOLDER"], exist_ok=True)
|
||||
|
||||
yield app
|
||||
|
||||
# Restore so other tests are unaffected
|
||||
_limiter.enabled = False
|
||||
_limiter.initialized = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limited_client(rate_limited_app):
|
||||
@@ -48,12 +41,12 @@ class TestRateLimiter:
|
||||
def test_health_endpoint_not_rate_limited(self, client):
|
||||
"""Health endpoint should handle many rapid requests."""
|
||||
for _ in range(20):
|
||||
response = client.get('/api/health')
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_rate_limit_header_present(self, client):
|
||||
"""Response should include a valid HTTP status code."""
|
||||
response = client.get('/api/health')
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -68,7 +61,7 @@ class TestRateLimitEnforcement:
|
||||
"""
|
||||
blocked = False
|
||||
for i in range(15):
|
||||
r = rate_limited_client.post('/api/compress/pdf')
|
||||
r = rate_limited_client.post("/api/compress/pdf")
|
||||
if r.status_code == 429:
|
||||
blocked = True
|
||||
break
|
||||
@@ -81,7 +74,7 @@ class TestRateLimitEnforcement:
|
||||
"""POST /api/convert/pdf-to-word is also rate-limited."""
|
||||
blocked = False
|
||||
for _ in range(15):
|
||||
r = rate_limited_client.post('/api/convert/pdf-to-word')
|
||||
r = rate_limited_client.post("/api/convert/pdf-to-word")
|
||||
if r.status_code == 429:
|
||||
blocked = True
|
||||
break
|
||||
@@ -94,8 +87,8 @@ class TestRateLimitEnforcement:
|
||||
"""
|
||||
# Exhaust compress limit
|
||||
for _ in range(15):
|
||||
rate_limited_client.post('/api/compress/pdf')
|
||||
rate_limited_client.post("/api/compress/pdf")
|
||||
|
||||
# Health should still respond normally
|
||||
r = rate_limited_client.get('/api/health')
|
||||
r = rate_limited_client.get("/api/health")
|
||||
assert r.status_code == 200
|
||||
Reference in New Issue
Block a user