refactor: improve app initialization and update rate limiter tests for consistency
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
"""Flask Application Factory."""
|
"""Flask Application Factory."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from flask import Flask, jsonify
|
from flask import Flask, jsonify
|
||||||
@@ -11,7 +12,12 @@ from app.services.ai_cost_service import init_ai_cost_db
|
|||||||
from app.services.site_assistant_service import init_site_assistant_db
|
from app.services.site_assistant_service import init_site_assistant_db
|
||||||
from app.services.contact_service import init_contact_db
|
from app.services.contact_service import init_contact_db
|
||||||
from app.services.stripe_service import init_stripe_db
|
from app.services.stripe_service import init_stripe_db
|
||||||
from app.utils.csrf import CSRFError, apply_csrf_cookie, should_enforce_csrf, validate_csrf_request
|
from app.utils.csrf import (
|
||||||
|
CSRFError,
|
||||||
|
apply_csrf_cookie,
|
||||||
|
should_enforce_csrf,
|
||||||
|
validate_csrf_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _init_sentry(app):
|
def _init_sentry(app):
|
||||||
@@ -35,13 +41,15 @@ def _init_sentry(app):
|
|||||||
app.logger.warning("sentry-sdk not installed — monitoring disabled.")
|
app.logger.warning("sentry-sdk not installed — monitoring disabled.")
|
||||||
|
|
||||||
|
|
||||||
def create_app(config_name=None):
|
def create_app(config_name=None, config_overrides=None):
|
||||||
"""Create and configure the Flask application."""
|
"""Create and configure the Flask application."""
|
||||||
if config_name is None:
|
if config_name is None:
|
||||||
config_name = os.getenv("FLASK_ENV", "development")
|
config_name = os.getenv("FLASK_ENV", "development")
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config.from_object(config[config_name])
|
app.config.from_object(config[config_name])
|
||||||
|
if config_overrides:
|
||||||
|
app.config.update(config_overrides)
|
||||||
|
|
||||||
# Initialize Sentry early
|
# Initialize Sentry early
|
||||||
_init_sentry(app)
|
_init_sentry(app)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Tests for rate limiting middleware."""
|
"""Tests for rate limiting middleware."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from app import create_app
|
from app import create_app
|
||||||
|
from tests.conftest import CSRFTestClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -11,33 +13,24 @@ def rate_limited_app(tmp_path):
|
|||||||
never throttled. Here we force the extension's internal flag back to
|
never throttled. Here we force the extension's internal flag back to
|
||||||
True *after* init_app so the decorator limits are enforced.
|
True *after* init_app so the decorator limits are enforced.
|
||||||
"""
|
"""
|
||||||
app = create_app('testing')
|
app = create_app(
|
||||||
app.config.update({
|
"testing",
|
||||||
'TESTING': True,
|
{
|
||||||
'RATELIMIT_STORAGE_URI': 'memory://',
|
"TESTING": True,
|
||||||
'UPLOAD_FOLDER': str(tmp_path / 'uploads'),
|
"RATELIMIT_ENABLED": True,
|
||||||
'OUTPUT_FOLDER': str(tmp_path / 'outputs'),
|
"RATELIMIT_STORAGE_URI": "memory://",
|
||||||
})
|
"UPLOAD_FOLDER": str(tmp_path / "uploads"),
|
||||||
|
"OUTPUT_FOLDER": str(tmp_path / "outputs"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.test_client_class = CSRFTestClient
|
||||||
import os
|
import os
|
||||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
||||||
os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True)
|
|
||||||
|
|
||||||
# flask-limiter 3.x returns from init_app immediately when
|
os.makedirs(app.config["UPLOAD_FOLDER"], exist_ok=True)
|
||||||
# RATELIMIT_ENABLED=False (TestingConfig default), so `initialized`
|
os.makedirs(app.config["OUTPUT_FOLDER"], exist_ok=True)
|
||||||
# stays False and no limits are enforced. We override the config key
|
|
||||||
# and call init_app a SECOND time so the extension fully initialises.
|
|
||||||
# It is safe to call twice — flask-limiter guards against duplicate
|
|
||||||
# before_request hook registration via app.extensions["limiter"].
|
|
||||||
from app.extensions import limiter as _limiter
|
|
||||||
app.config['RATELIMIT_ENABLED'] = True
|
|
||||||
_limiter.init_app(app) # second call — now RATELIMIT_ENABLED=True
|
|
||||||
|
|
||||||
yield app
|
yield app
|
||||||
|
|
||||||
# Restore so other tests are unaffected
|
|
||||||
_limiter.enabled = False
|
|
||||||
_limiter.initialized = False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def rate_limited_client(rate_limited_app):
|
def rate_limited_client(rate_limited_app):
|
||||||
@@ -48,12 +41,12 @@ class TestRateLimiter:
|
|||||||
def test_health_endpoint_not_rate_limited(self, client):
|
def test_health_endpoint_not_rate_limited(self, client):
|
||||||
"""Health endpoint should handle many rapid requests."""
|
"""Health endpoint should handle many rapid requests."""
|
||||||
for _ in range(20):
|
for _ in range(20):
|
||||||
response = client.get('/api/health')
|
response = client.get("/api/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
def test_rate_limit_header_present(self, client):
|
def test_rate_limit_header_present(self, client):
|
||||||
"""Response should include a valid HTTP status code."""
|
"""Response should include a valid HTTP status code."""
|
||||||
response = client.get('/api/health')
|
response = client.get("/api/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +61,7 @@ class TestRateLimitEnforcement:
|
|||||||
"""
|
"""
|
||||||
blocked = False
|
blocked = False
|
||||||
for i in range(15):
|
for i in range(15):
|
||||||
r = rate_limited_client.post('/api/compress/pdf')
|
r = rate_limited_client.post("/api/compress/pdf")
|
||||||
if r.status_code == 429:
|
if r.status_code == 429:
|
||||||
blocked = True
|
blocked = True
|
||||||
break
|
break
|
||||||
@@ -81,7 +74,7 @@ class TestRateLimitEnforcement:
|
|||||||
"""POST /api/convert/pdf-to-word is also rate-limited."""
|
"""POST /api/convert/pdf-to-word is also rate-limited."""
|
||||||
blocked = False
|
blocked = False
|
||||||
for _ in range(15):
|
for _ in range(15):
|
||||||
r = rate_limited_client.post('/api/convert/pdf-to-word')
|
r = rate_limited_client.post("/api/convert/pdf-to-word")
|
||||||
if r.status_code == 429:
|
if r.status_code == 429:
|
||||||
blocked = True
|
blocked = True
|
||||||
break
|
break
|
||||||
@@ -94,8 +87,8 @@ class TestRateLimitEnforcement:
|
|||||||
"""
|
"""
|
||||||
# Exhaust compress limit
|
# Exhaust compress limit
|
||||||
for _ in range(15):
|
for _ in range(15):
|
||||||
rate_limited_client.post('/api/compress/pdf')
|
rate_limited_client.post("/api/compress/pdf")
|
||||||
|
|
||||||
# Health should still respond normally
|
# Health should still respond normally
|
||||||
r = rate_limited_client.get('/api/health')
|
r = rate_limited_client.get("/api/health")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
Reference in New Issue
Block a user