diff --git a/backend/app/routes/tasks.py b/backend/app/routes/tasks.py index b2bddc0..74b07bc 100644 --- a/backend/app/routes/tasks.py +++ b/backend/app/routes/tasks.py @@ -1,9 +1,12 @@ """Task status polling endpoint.""" +from urllib.parse import urlparse + from flask import Blueprint, jsonify, request from celery.result import AsyncResult from app.extensions import celery from app.middleware.rate_limiter import limiter +from app.services.account_service import has_task_access, record_usage_event from app.services.policy_service import ( PolicyError, assert_api_task_access, @@ -11,10 +14,46 @@ from app.services.policy_service import ( resolve_api_actor, resolve_web_actor, ) +from app.utils.auth import remember_task_access tasks_bp = Blueprint("tasks", __name__) +def _extract_download_task_id(download_url: str | None) -> str | None: + """Return the local download identifier embedded in one download URL.""" + if not download_url: + return None + + path_parts = [part for part in urlparse(download_url).path.split("/") if part] + if len(path_parts) >= 4 and path_parts[0] == "api" and path_parts[1] == "download": + return path_parts[2] + + return None + + +def _remember_download_alias(actor, download_task_id: str | None): + """Grant access to one local download identifier returned after task success.""" + if not download_task_id: + return + + remember_task_access(download_task_id) + + if actor.user_id is None: + return + + if has_task_access(actor.user_id, actor.source, download_task_id): + return + + record_usage_event( + user_id=actor.user_id, + api_key_id=actor.api_key_id, + source=actor.source, + tool="download", + task_id=download_task_id, + event_type="download_alias", + ) + + @tasks_bp.route("//status", methods=["GET"]) @limiter.limit("300/minute", override_defaults=True) def get_task_status(task_id: str): @@ -50,6 +89,7 @@ def get_task_status(task_id: str): elif result.state == "SUCCESS": task_result = result.result or {} + _remember_download_alias(actor, _extract_download_task_id(task_result.get("download_url"))) response["result"] = task_result elif result.state == "FAILURE": diff --git a/backend/app/services/account_service.py b/backend/app/services/account_service.py index 88af715..f2d6e9f 100644 --- a/backend/app/services/account_service.py +++ b/backend/app/services/account_service.py @@ -669,7 +669,8 @@ def has_task_access(user_id: int, source: str, task_id: str) -> bool: """ SELECT 1 FROM usage_events - WHERE user_id = ? AND source = ? AND task_id = ? AND event_type = 'accepted' + WHERE user_id = ? AND source = ? AND task_id = ? + AND event_type IN ('accepted', 'download_alias') LIMIT 1 """, (user_id, source, task_id), diff --git a/backend/celerybeat-schedule b/backend/celerybeat-schedule index 4f922dd..065fd14 100644 Binary files a/backend/celerybeat-schedule and b/backend/celerybeat-schedule differ diff --git a/backend/tests/test_tasks_route.py b/backend/tests/test_tasks_route.py index 9c6e5f6..1869f1e 100644 --- a/backend/tests/test_tasks_route.py +++ b/backend/tests/test_tasks_route.py @@ -1,6 +1,7 @@ """Tests for task status polling route.""" from unittest.mock import patch, MagicMock +from app.services.account_service import create_user, has_task_access from app.utils.auth import TASK_ACCESS_SESSION_KEY @@ -62,6 +63,30 @@ class TestTaskStatus: assert data['result']['status'] == 'completed' assert 'download_url' in data['result'] + with client.session_transaction() as session: + assert 'task-id' in session[TASK_ACCESS_SESSION_KEY] + + def test_success_task_persists_download_alias_for_authenticated_user(self, client): + """Should persist download aliases for logged-in users as authorized task ids.""" + user = create_user('tasks-route@example.com', 'secretpass123') + mock_result = MagicMock() + mock_result.state = 'SUCCESS' + mock_result.result = { + 'status': 'completed', + 'download_url': '/api/download/local-download-id/output.pdf', + 'filename': 'output.pdf', + } + + with client.session_transaction() as session: + session['user_id'] = user['id'] + session[TASK_ACCESS_SESSION_KEY] = ['success-id'] + + with patch('app.routes.tasks.AsyncResult', return_value=mock_result): + response = client.get('/api/tasks/success-id/status') + + assert response.status_code == 200 + assert has_task_access(user['id'], 'web', 'local-download-id') is True + def test_failure_task(self, client, monkeypatch): """Should return FAILURE state with error message.""" mock_result = MagicMock() diff --git a/frontend/src/components/shared/FileUploader.tsx b/frontend/src/components/shared/FileUploader.tsx index 228489c..220f01a 100644 --- a/frontend/src/components/shared/FileUploader.tsx +++ b/frontend/src/components/shared/FileUploader.tsx @@ -1,5 +1,5 @@ -import { useCallback } from 'react'; -import { useDropzone, type Accept } from 'react-dropzone'; +import { useState, useCallback } from 'react'; +import { useDropzone, type Accept, type FileRejection } from 'react-dropzone'; import { useTranslation } from 'react-i18next'; import { Upload, File, X } from 'lucide-react'; import { formatFileSize } from '@/utils/textTools'; @@ -37,9 +37,11 @@ export default function FileUploader({ acceptLabel, }: FileUploaderProps) { const { t } = useTranslation(); + const [sizeError, setSizeError] = useState(null); const onDrop = useCallback( (acceptedFiles: File[]) => { + setSizeError(null); if (acceptedFiles.length > 0) { onFileSelect(acceptedFiles[0]); } @@ -47,8 +49,19 @@ export default function FileUploader({ [onFileSelect] ); + const onDropRejected = useCallback( + (rejectedFiles: FileRejection[]) => { + const code = rejectedFiles[0]?.errors[0]?.code; + if (code === 'file-too-large') { + setSizeError(t('errors.fileTooLarge', { size: maxSizeMB })); + } + }, + [maxSizeMB, t] + ); + const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop, + onDropRejected, accept, maxFiles: 1, maxSize: maxSizeMB * 1024 * 1024, @@ -122,9 +135,9 @@ export default function FileUploader({ )} {/* Error */} - {error && ( + {(sizeError || error) && (
-

{error}

+

⚠️ {sizeError || error}

)}