Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
"""File validation utilities — multi-layer security checks."""
|
||||
import os
|
||||
|
||||
try:
|
||||
import magic
|
||||
HAS_MAGIC = True
|
||||
except (ImportError, OSError):
|
||||
HAS_MAGIC = False
|
||||
import os
|
||||
|
||||
from flask import current_app
|
||||
from werkzeug.utils import secure_filename
|
||||
@@ -45,30 +40,60 @@ def validate_file(
|
||||
if not file_storage or file_storage.filename == "":
|
||||
raise FileValidationError("No file provided.")
|
||||
|
||||
filename = secure_filename(file_storage.filename)
|
||||
if not filename:
|
||||
raise FileValidationError("Invalid filename.")
|
||||
raw_filename = str(file_storage.filename).strip()
|
||||
if not raw_filename:
|
||||
raise FileValidationError("No file provided.")
|
||||
|
||||
# Layer 2: Check file extension against whitelist
|
||||
ext = _get_extension(filename)
|
||||
filename = secure_filename(raw_filename)
|
||||
allowed_extensions = config.get("ALLOWED_EXTENSIONS", {})
|
||||
|
||||
if allowed_types:
|
||||
valid_extensions = {k: v for k, v in allowed_extensions.items() if k in allowed_types}
|
||||
valid_extensions = {
|
||||
k: v for k, v in allowed_extensions.items() if k in allowed_types
|
||||
}
|
||||
else:
|
||||
valid_extensions = allowed_extensions
|
||||
|
||||
# Layer 2: Reject clearly invalid extensions before touching file streams.
|
||||
ext = _get_extension(raw_filename) or _get_extension(filename)
|
||||
if ext and ext not in valid_extensions:
|
||||
raise FileValidationError(
|
||||
f"File type '.{ext}' is not allowed. "
|
||||
f"Allowed types: {', '.join(valid_extensions.keys())}"
|
||||
)
|
||||
|
||||
# Layer 3: Check basic file size and header first so we can recover
|
||||
# from malformed filenames like ".pdf" or "." using content sniffing.
|
||||
file_storage.seek(0, os.SEEK_END)
|
||||
file_size = file_storage.tell()
|
||||
file_storage.seek(0)
|
||||
|
||||
if file_size == 0:
|
||||
raise FileValidationError("File is empty.")
|
||||
|
||||
file_header = file_storage.read(8192)
|
||||
file_storage.seek(0)
|
||||
|
||||
detected_mime = _detect_mime(file_header)
|
||||
|
||||
if not ext:
|
||||
ext = _infer_extension_from_content(
|
||||
file_header, detected_mime, valid_extensions
|
||||
)
|
||||
|
||||
if raw_filename.startswith(".") and not _get_extension(filename):
|
||||
filename = ""
|
||||
|
||||
if not filename:
|
||||
filename = f"upload.{ext}" if ext else "upload"
|
||||
|
||||
if ext not in valid_extensions:
|
||||
raise FileValidationError(
|
||||
f"File type '.{ext}' is not allowed. "
|
||||
f"Allowed types: {', '.join(valid_extensions.keys())}"
|
||||
)
|
||||
|
||||
# Layer 3: Check file size against type-specific limits
|
||||
file_storage.seek(0, os.SEEK_END)
|
||||
file_size = file_storage.tell()
|
||||
file_storage.seek(0)
|
||||
|
||||
# Layer 4: Check file size against type-specific limits
|
||||
size_limits = size_limit_overrides or config.get("FILE_SIZE_LIMITS", {})
|
||||
max_size = size_limits.get(ext, 20 * 1024 * 1024) # Default 20MB
|
||||
|
||||
@@ -78,15 +103,8 @@ def validate_file(
|
||||
f"File too large. Maximum size for .{ext} files is {max_mb:.0f}MB."
|
||||
)
|
||||
|
||||
if file_size == 0:
|
||||
raise FileValidationError("File is empty.")
|
||||
|
||||
# Layer 4: Check MIME type using magic bytes (if libmagic is available)
|
||||
file_header = file_storage.read(8192)
|
||||
file_storage.seek(0)
|
||||
|
||||
if HAS_MAGIC:
|
||||
detected_mime = magic.from_buffer(file_header, mime=True)
|
||||
# Layer 5: Check MIME type using magic bytes (if libmagic is available)
|
||||
if detected_mime:
|
||||
expected_mimes = valid_extensions.get(ext, [])
|
||||
|
||||
if detected_mime not in expected_mimes:
|
||||
@@ -95,7 +113,7 @@ def validate_file(
|
||||
f"Detected type: {detected_mime}"
|
||||
)
|
||||
|
||||
# Layer 5: Additional content checks for specific types
|
||||
# Layer 6: Additional content checks for specific types
|
||||
if ext == "pdf":
|
||||
_check_pdf_safety(file_header)
|
||||
|
||||
@@ -104,9 +122,52 @@ def validate_file(
|
||||
|
||||
def _get_extension(filename: str) -> str:
|
||||
"""Extract and normalize file extension."""
|
||||
if "." not in filename:
|
||||
filename = str(filename or "").strip()
|
||||
if not filename or "." not in filename:
|
||||
return ""
|
||||
return filename.rsplit(".", 1)[1].lower()
|
||||
stem, ext = filename.rsplit(".", 1)
|
||||
if not ext:
|
||||
return ""
|
||||
if not stem and filename.startswith("."):
|
||||
return ext.lower()
|
||||
return ext.lower()
|
||||
|
||||
|
||||
def _detect_mime(file_header: bytes) -> str | None:
|
||||
"""Detect MIME type lazily so environments without libmagic stay usable."""
|
||||
try:
|
||||
import magic as magic_module
|
||||
except (ImportError, OSError):
|
||||
return None
|
||||
|
||||
try:
|
||||
return magic_module.from_buffer(file_header, mime=True)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _infer_extension_from_content(
|
||||
file_header: bytes,
|
||||
detected_mime: str | None,
|
||||
valid_extensions: dict[str, list[str]],
|
||||
) -> str:
|
||||
"""Infer a safe extension from MIME type or common signatures."""
|
||||
if detected_mime:
|
||||
for ext, mimes in valid_extensions.items():
|
||||
if detected_mime in mimes:
|
||||
return ext
|
||||
|
||||
signature_map = {
|
||||
b"%PDF": "pdf",
|
||||
b"\x89PNG\r\n\x1a\n": "png",
|
||||
b"\xff\xd8\xff": "jpg",
|
||||
b"RIFF": "webp",
|
||||
}
|
||||
for signature, ext in signature_map.items():
|
||||
if file_header.startswith(signature) and ext in valid_extensions:
|
||||
return ext
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _check_pdf_safety(file_header: bytes):
|
||||
|
||||
Reference in New Issue
Block a user