diff --git a/.github/workflows/Tests.yaml b/.github/workflows/Tests.yaml index 24e9709..ec6dc36 100644 --- a/.github/workflows/Tests.yaml +++ b/.github/workflows/Tests.yaml @@ -28,7 +28,7 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python }} + - name: Set up Python uses: actions/setup-python@v4 with: python-version-file: "backend/pyproject.toml" @@ -51,11 +51,9 @@ jobs: run: inv coverage --args "-vvv" - name: Upload coverage to Codecov - if: matrix.python == '3.11' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: - root_dir: backend - working-directory: backend + directory: backend fail_ci_if_error: true token: ${{ secrets.CODECOV_TOKEN }} diff --git a/backend/README.md b/backend/README.md index 07f1313..af8a827 100644 --- a/backend/README.md +++ b/backend/README.md @@ -1,31 +1,27 @@ -# Contribution - -## Dependencies -```bash -# Install all the dependencies. -pipenv sync -# Update dependencies. -pipenv install -``` +# backend -## Development +Leverages great things to achieve great results -If you want to link to Postgresql, create the `.env` file and set the `POSTGRES_URI` environment variable in it, example: +[![CodeFactor](https://www.codefactor.io/repository/github/openzim/nautilus-webui/badge)](https://www.codefactor.io/repository/github/openzim/nautilus-webui) +[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) +[![codecov](https://codecov.io/gh/openzim/nautilus-webui/branch/main/graph/badge.svg)](https://codecov.io/gh/openzim/nautilus-webui) +![Python Version from PEP 621 TOML](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fgithub.com%2Fopenzim%2Fnautilus-webui%2Fraw%2Fmain%2Fbackend%2Fpyproject.toml) -```env -POSTGRES_URI=postgresql+psycopg://username:password@host/database -``` -Dev commands: -```bash -# Init database -pipenv run init -# Start FastAPI -pipenv run start -# Run tests -pipenv run tests -# Format code -pipenv run format -# Check format. -pipenv run format:check +## Usage + +**CAUTION**: this is not a standalone, installable Python package. + +- It's the backend of a web service that is intended to be deployed using OCI images. +- See the sample Composefile in the dev folder of the repository. +- It has external dependencies (including [S3 Storage](https://wasabi.com/), [Mailgun](https://www.mailgun.com/) account and a full-fledged [Zimfarm](https://github.com/openzim/zimfarm). +- It **must be configured** via environment variables (see `constants.py` and Compose's Envfile) +- There is no CHANGELOG nor release management. Production is tied to CD on `main` branch. + +```sh +❯ hatch run serve ``` + +nautilus-webui backend adheres to openZIM's [Contribution Guidelines](https://github.com/openzim/overview/wiki/Contributing). + +nautilus-webui backend has implemented openZIM's [Python bootstrap, conventions and policies](https://github.com/openzim/_python-bootstrap/docs/Policy.md) **v1.0.1**. diff --git a/backend/alembic.ini b/backend/alembic.ini index d174d12..e60e36b 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -68,18 +68,17 @@ sqlalchemy.url = driver://user:pass@localhost/dbname # on newly generated revision scripts. See the documentation for further # detail and examples -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME -hooks = black isort -black.type = console_scripts -black.entrypoint = black -black.options = REVISION_SCRIPT_FILENAME -isort.type = console_scripts -isort.entrypoint = isort -isort.options = --profile black REVISION_SCRIPT_FILENAME +hooks = ruff, ruff_format + +# lint with attempts to fix using ruff +ruff.type = exec +ruff.executable = ruff +ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# format using ruff +ruff_format.type = exec +ruff_format.executable = ruff +ruff_format.options = format REVISION_SCRIPT_FILENAME # Logging configuration diff --git a/backend/api/constants.py b/backend/api/constants.py index 465a079..d80e5bd 100644 --- a/backend/api/constants.py +++ b/backend/api/constants.py @@ -2,6 +2,7 @@ import logging import os import tempfile +import uuid from dataclasses import dataclass, field from pathlib import Path @@ -21,45 +22,106 @@ class BackendConf: Backend configuration, read from environment variables and set default values. """ - logger: logging.Logger = field(init=False) + # Configuration + project_expire_after: datetime.timedelta = datetime.timedelta(days=7) + project_quota: int = 0 + chunk_size: int = 1024 # reading/writing received files + illustration_quota: int = 0 + api_version_prefix: str = "/v1" # our API + + # Database + postgres_uri: str = os.getenv("POSTGRES_URI") or "nodb" + + # Scheduler process + redis_uri: str = os.getenv("REDIS_URI") or "redis://localhost:6379/0" + channel_name: str = os.getenv("CHANNEL_NAME") or "s3_upload" - # Mandatory configurations - postgres_uri = os.getenv("POSTGRES_URI", "nodb") - s3_url_with_credentials = os.getenv("S3_URL_WITH_CREDENTIALS") - private_salt = os.getenv("PRIVATE_SALT") + # Transient (on host disk) Storage + transient_storage_path: Path = Path() - # Optional configuration. - s3_max_tries = int(os.getenv("S3_MAX_TRIES", "3")) - s3_retry_wait = humanfriendly.parse_timespan(os.getenv("S3_RETRY_TIMES", "10s")) - s3_deletion_delay = datetime.timedelta( + # S3 Storage + s3_url_with_credentials: str = os.getenv("S3_URL_WITH_CREDENTIALS") or "" + s3_max_tries: int = int(os.getenv("S3_MAX_TRIES", "3")) + s3_retry_wait: int = int( + humanfriendly.parse_timespan(os.getenv("S3_RETRY_TIMES") or "10s") + ) + s3_deletion_delay: datetime.timedelta = datetime.timedelta( hours=int(os.getenv("S3_REMOVE_DELETEDUPLOADING_AFTER_HOURS", "12")) ) - transient_storage_path = Path( - os.getenv("TRANSIENT_STORAGE_PATH", tempfile.gettempdir()) - ).resolve() - redis_uri = os.getenv("REDIS_URI", "redis://localhost:6379/0") - channel_name = os.getenv("CHANNEL_NAME", "s3_upload") + private_salt = os.getenv( + "PRIVATE_SALT", uuid.uuid4().hex + ) # used to make S3 keys unguessable + + # Cookies cookie_domain = os.getenv("COOKIE_DOMAIN", None) cookie_expiration_days = int(os.getenv("COOKIE_EXPIRATION_DAYS", "30")) - project_quota = humanfriendly.parse_size(os.getenv("PROJECT_QUOTA", "100MB")) - chunk_size = humanfriendly.parse_size(os.getenv("CHUNK_SIZE", "2MiB")) - illustration_quota = humanfriendly.parse_size( - os.getenv("ILLUSTRATION_QUOTA", "2MiB") + authentication_cookie_name: str = "user_id" + + # Deployment + public_url: str = os.getenv("PUBLIC_URL") or "http://localhost" + download_url: str = ( + os.getenv("DOWNLOAD_URL") + or "https://s3.us-west-1.wasabisys.com/org-kiwix-zimit/zim" ) allowed_origins = os.getenv( "ALLOWED_ORIGINS", "http://localhost", ).split("|") - authentication_cookie_name: str = "user_id" - api_version_prefix = "/v1" - project_expire_after = datetime.timedelta(days=7) + # Zimfarm (3rd party API creating ZIMs and calling back with feedback) + zimfarm_api_url: str = ( + os.getenv("ZIMFARM_API_URL") or "https://api.farm.zimit.kiwix.org/v1" + ) + zimfarm_username: str = os.getenv("ZIMFARM_API_USERNAME") or "" + zimfarm_password: str = os.getenv("ZIMFARM_API_PASSWORD") or "" + zimfarm_nautilus_image: str = ( + os.getenv("ZIMFARM_NAUTILUS_IMAGE") or "ghcr.io/openzim/nautilus:latest" + ) + zimfarm_task_cpu: int = int(os.getenv("ZIMFARM_TASK_CPU") or "3") + zimfarm_task_memory: int = 0 + zimfarm_task_disk: int = 0 + zimfarm_callback_base_url = os.getenv("ZIMFARM_CALLBACK_BASE_URL", "") + zimfarm_callback_token = os.getenv("ZIMFARM_CALLBACK_TOKEN", uuid.uuid4().hex) + zimfarm_task_worker: str = os.getenv("ZIMFARM_TASK_WORKDER") or "-" + zimfarm_request_timeout_sec: int = 10 + + # Mailgun (3rd party API to send emails) + mailgun_api_url: str = os.getenv("MAILGUN_API_URL") or "" + mailgun_api_key: str = os.getenv("MAILGUN_API_KEY") or "" + mailgun_from: str = os.getenv("MAILGUN_FROM") or "Nautilus ZIM" + mailgun_request_timeout_sec: int = 10 + + logger: logging.Logger = field(init=False) def __post_init__(self): self.logger = logging.getLogger(Path(__file__).parent.name) self.transient_storage_path.mkdir(exist_ok=True) self.job_retry = Retry(max=self.s3_max_tries, interval=int(self.s3_retry_wait)) + self.transient_storage_path = Path( + os.getenv("TRANSIENT_STORAGE_PATH") or tempfile.gettempdir() + ).resolve() + + self.project_quota = humanfriendly.parse_size( + os.getenv("PROJECT_QUOTA") or "100MB" + ) + + self.chunk_size = humanfriendly.parse_size(os.getenv("CHUNK_SIZE", "2MiB")) + + self.illustration_quota = humanfriendly.parse_size( + os.getenv("ILLUSTRATION_QUOTA", "2MiB") + ) + + self.zimfarm_task_memory = humanfriendly.parse_size( + os.getenv("ZIMFARM_TASK_MEMORY") or "1000MiB" + ) + self.zimfarm_task_disk = humanfriendly.parse_size( + os.getenv("ZIMFARM_TASK_DISK") or "200MiB" + ) + + if not self.zimfarm_callback_base_url: + self.zimfarm_callback_base_url = f"{self.zimfarm_api_url}/requests/hook" + constants = BackendConf() logger = constants.logger diff --git a/backend/api/database/__init__.py b/backend/api/database/__init__.py index 6d936d4..9e0d0d6 100644 --- a/backend/api/database/__init__.py +++ b/backend/api/database/__init__.py @@ -1,7 +1,8 @@ from collections.abc import Generator from uuid import UUID -from bson.json_util import DEFAULT_JSON_OPTIONS, dumps, loads +import pydantic_core +from bson.json_util import DEFAULT_JSON_OPTIONS, loads from sqlalchemy import create_engine from sqlalchemy.orm import Session as OrmSession from sqlalchemy.orm import sessionmaker @@ -25,7 +26,7 @@ def my_loads(s, *args, **kwargs): bind=create_engine( constants.postgres_uri, echo=False, - json_serializer=dumps, # use bson serializer to handle datetime naively + json_serializer=pydantic_core.to_json, json_deserializer=my_loads, # use custom bson deserializer for same reason ) ) diff --git a/backend/api/database/models.py b/backend/api/database/models.py index 42ac4a3..81ef3d4 100644 --- a/backend/api/database/models.py +++ b/backend/api/database/models.py @@ -1,8 +1,9 @@ from datetime import datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar from uuid import UUID -from sqlalchemy import DateTime, ForeignKey, String, text +from pydantic import BaseModel +from sqlalchemy import DateTime, ForeignKey, String, text, types from sqlalchemy.dialects.postgresql import ARRAY, JSONB from sqlalchemy.orm import ( DeclarativeBase, @@ -12,9 +13,74 @@ relationship, ) from sqlalchemy.sql.schema import MetaData +from zimscraperlib.zim.metadata import ( + validate_description, + validate_language, + validate_required_values, + validate_tags, + validate_title, +) from api.database import get_local_fpath_for +T = TypeVar("T", bound="ArchiveConfig") + + +class ArchiveConfig(BaseModel): + title: str + description: str + name: str + publisher: str + creator: str + languages: str + tags: list[str] + illustration: str + filename: str + + @classmethod + def init_with(cls: type[T], filename: str, **kwargs) -> T: + default = {"tags": []} + data: dict = {key: default.get(key, "") for key in cls.model_fields.keys()} + data.update({"filename": filename}) + if kwargs: + data.update(kwargs) + return cls.model_validate(data) + + def is_ready(self) -> bool: + try: + for key in self.model_fields.keys(): + validate_required_values(key.title(), getattr(self, key, "")) + validate_title("Title", self.title) + validate_description("Description", self.description) + validate_language("Language", self.languages) + validate_tags("Tags", self.tags) + + except ValueError: + return False + return True + + +class ArchiveConfigType(types.TypeDecorator): + cache_ok = True + impl = JSONB + + def process_bind_param(self, value, dialect): # noqa: ARG002 + if isinstance(value, ArchiveConfig): + return value.model_dump() + if isinstance(value, dict): + return value + return dict(value) if value else {} + + def process_result_value(self, value, dialect) -> ArchiveConfig: # noqa: ARG002 + if isinstance(value, ArchiveConfig): + return value + return ArchiveConfig.model_validate(dict(value) if value else {}) + + def coerce_compared_value(self, op, value): + return self.impl.coerce_compared_value( + op, value + ) # pyright: ignore [reportCallIssue] + class Base(MappedAsDataclass, DeclarativeBase): # This map details the specific transformation of types between Python and @@ -22,6 +88,7 @@ class Base(MappedAsDataclass, DeclarativeBase): # type has to be used or when we want to ensure a specific setting (like the # timezone below) type_annotation_map: ClassVar = { + ArchiveConfig: ArchiveConfigType, dict[str, Any]: JSONB, # transform Python Dict[str, Any] into PostgreSQL JSONB list[dict[str, Any]]: JSONB, datetime: DateTime( @@ -137,9 +204,10 @@ class Archive(Base): filesize: Mapped[int | None] created_on: Mapped[datetime] requested_on: Mapped[datetime | None] + completed_on: Mapped[datetime | None] download_url: Mapped[str | None] collection_json_path: Mapped[str | None] status: Mapped[str] zimfarm_task_id: Mapped[UUID | None] email: Mapped[str | None] - config: Mapped[dict[str, Any]] + config: Mapped[ArchiveConfig] diff --git a/backend/api/database/utils.py b/backend/api/database/utils.py new file mode 100644 index 0000000..d7972d4 --- /dev/null +++ b/backend/api/database/utils.py @@ -0,0 +1,28 @@ +from uuid import UUID + +from sqlalchemy import select + +from api.database import Session as DBSession +from api.database.models import File, Project + + +def get_file_by_id(file_id: UUID) -> File: + """Get File instance by its id.""" + with DBSession.begin() as session: + stmt = select(File).where(File.id == file_id) + file = session.execute(stmt).scalar() + if not file: + raise ValueError(f"File not found: {file_id}") + session.expunge(file) + return file + + +def get_project_by_id(project_id: UUID) -> Project: + """Get Project instance by its id.""" + with DBSession.begin() as session: + stmt = select(Project).where(Project.id == project_id) + project = session.execute(stmt).scalar() + if not project: + raise ValueError(f"Project not found: {project_id}") + session.expunge(project) + return project diff --git a/backend/api/email.py b/backend/api/email.py new file mode 100644 index 0000000..0ec984f --- /dev/null +++ b/backend/api/email.py @@ -0,0 +1,75 @@ +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import humanfriendly +import requests +from jinja2 import Environment, FileSystemLoader, select_autoescape +from werkzeug.datastructures import MultiDict + +from api.constants import constants, logger +from api.database.models import Archive + +jinja_env = Environment( + loader=FileSystemLoader("templates"), + autoescape=select_autoescape(["html", "txt"]), +) +jinja_env.filters["short_id"] = lambda value: str(value)[:5] +jinja_env.filters["format_size"] = lambda value: humanfriendly.format_size( + value, binary=True +) + + +def send_email_via_mailgun( + to: Iterable[str] | str, + subject: str, + contents: str, + cc: Iterable[str] | None = None, + bcc: Iterable[str] | None = None, + attachments: Iterable[Path] | None = None, +) -> str: + if not constants.mailgun_api_url or not constants.mailgun_api_key: + logger.warn(f"Mailgun not configured, ignoring email request to: {to!s}") + return "" + + values = [ + ("from", constants.mailgun_from), + ("subject", subject), + ("html", contents), + ] + + values += [("to", list(to) if isinstance(to, Iterable) else [to])] + values += [("cc", list(cc) if isinstance(cc, Iterable) else [cc])] + values += [("bcc", list(bcc) if isinstance(bcc, Iterable) else [bcc])] + data = MultiDict(values) + + try: + resp = requests.post( + url=f"{constants.mailgun_api_url}/messages", + auth=("api", constants.mailgun_api_key), + data=data, + files=( + [ + ("attachment", (fpath.name, open(fpath, "rb").read())) + for fpath in attachments + ] + if attachments + else [] + ), + timeout=constants.mailgun_request_timeout_sec, + ) + resp.raise_for_status() + except Exception as exc: + logger.error(f"Failed to send mailgun notif: {exc}") + logger.exception(exc) + return resp.json().get("id") or resp.text + + +def get_context(task: dict[str, Any], archive: Archive): + """Jinja context dict for email notifications""" + return { + "base_url": constants.public_url, + "download_url": constants.download_url, + "task": task, + "archive": archive, + } diff --git a/backend/api/files.py b/backend/api/files.py new file mode 100644 index 0000000..6029feb --- /dev/null +++ b/backend/api/files.py @@ -0,0 +1,67 @@ +import hashlib +from collections.abc import Iterator +from pathlib import Path +from typing import BinaryIO +from uuid import UUID + +from api.constants import constants +from api.database import get_local_fpath_for + + +def calculate_file_size(file: BinaryIO) -> int: + """Calculate the size of a file chunk by chunk""" + size = 0 + for chunk in read_file_in_chunks(file): + size += len(chunk) + return size + + +def read_file_in_chunks( + reader: BinaryIO, chunk_size=constants.chunk_size +) -> Iterator[bytes]: + """Read Big file chunk by chunk. Default chunk size is 2k""" + while True: + chunk = reader.read(chunk_size) + if not chunk: + break + yield chunk + reader.seek(0) + + +def save_file(file: BinaryIO, file_name: str, project_id: UUID) -> Path: + """Saves a binary file to a specific location and returns its path.""" + fpath = get_local_fpath_for(file_name, project_id) + if not fpath.is_file(): + with open(fpath, "wb") as file_object: + for chunk in read_file_in_chunks(file): + file_object.write(chunk) + return fpath + + +def generate_file_hash(file: BinaryIO) -> str: + """Generate sha256 hash of a file, optimized for large files""" + hasher = hashlib.sha256() + for chunk in read_file_in_chunks(file): + hasher.update(chunk) + return hasher.hexdigest() + + +def normalize_filename(filename: str) -> str: + """filesystem (ext4,apfs,hfs+,ntfs,exfat) and S3 compliant filename""" + + normalized = str(filename) + + # we replace / with __ as it would have a meaning + replacements = (("/", "__"),) + for pattern, repl in replacements: + normalized = filename.replace(pattern, repl) + + # other prohibited chars are removed (mostly for Windows context) + removals = ["\\", ":", "*", "?", '"', "<", ">", "|"] + [ + chr(idx) for idx in range(1, 32) + ] + for char in removals: + normalized.replace(char, "") + + # ext4/exfat has a 255B filename limit (s3 is 1KiB) + return normalized.encode("utf-8")[:255].decode("utf-8") diff --git a/backend/api/routes/__init__.py b/backend/api/routes/__init__.py index 88f22d8..0d19d94 100644 --- a/backend/api/routes/__init__.py +++ b/backend/api/routes/__init__.py @@ -1,8 +1,5 @@ -import hashlib -from collections.abc import Iterator from http import HTTPStatus -from pathlib import Path -from typing import Annotated, BinaryIO +from typing import Annotated from uuid import UUID from fastapi import Cookie, Depends, HTTPException, Response @@ -10,7 +7,7 @@ from sqlalchemy.orm import Session from api.constants import constants -from api.database import gen_session, get_local_fpath_for +from api.database import gen_session from api.database.models import Project, User @@ -56,62 +53,3 @@ async def validated_project( if not project: raise HTTPException(HTTPStatus.NOT_FOUND, f"Project not found: {project_id}") return project - - -def calculate_file_size(file: BinaryIO) -> int: - """Calculate the size of a file chunk by chunk""" - size = 0 - for chunk in read_file_in_chunks(file): - size += len(chunk) - return size - - -def read_file_in_chunks( - reader: BinaryIO, chunk_size=constants.chunk_size -) -> Iterator[bytes]: - """Read Big file chunk by chunk. Default chunk size is 2k""" - while True: - chunk = reader.read(chunk_size) - if not chunk: - break - yield chunk - reader.seek(0) - - -def save_file(file: BinaryIO, file_name: str, project_id: UUID) -> Path: - """Saves a binary file to a specific location and returns its path.""" - fpath = get_local_fpath_for(file_name, project_id) - if not fpath.is_file(): - with open(fpath, "wb") as file_object: - for chunk in read_file_in_chunks(file): - file_object.write(chunk) - return fpath - - -def generate_file_hash(file: BinaryIO) -> str: - """Generate sha256 hash of a file, optimized for large files""" - hasher = hashlib.sha256() - for chunk in read_file_in_chunks(file): - hasher.update(chunk) - return hasher.hexdigest() - - -def normalize_filename(filename: str) -> str: - """filesystem (ext4,apfs,hfs+,ntfs,exfat) and S3 compliant filename""" - - normalized = str(filename) - - # we replace / with __ as it would have a meaning - replacements = (("/", "__"),) - for pattern, repl in replacements: - normalized = filename.replace(pattern, repl) - - # other prohibited chars are removed (mostly for Windows context) - removals = ["\\", ":", "*", "?", '"', "<", ">", "|"] + [ - chr(idx) for idx in range(1, 32) - ] - for char in removals: - normalized.replace(char, "") - - # ext4/exfat has a 255B filename limit (s3 is 1KiB) - return normalized.encode("utf-8")[:255].decode("utf-8") diff --git a/backend/api/routes/archives.py b/backend/api/routes/archives.py index 60fa09a..786b180 100644 --- a/backend/api/routes/archives.py +++ b/backend/api/routes/archives.py @@ -1,9 +1,10 @@ import base64 import datetime import io +import json from enum import Enum from http import HTTPStatus -from typing import Any +from typing import Any, BinaryIO from uuid import UUID import zimscraperlib.image @@ -11,17 +12,22 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter from sqlalchemy import select, update from sqlalchemy.orm import Session +from sqlalchemy.sql.base import Executable as ExecutableStatement from zimscraperlib import filesystem -from api.constants import constants +from api.constants import constants, logger from api.database import gen_session -from api.database.models import Archive, Project -from api.routes import ( +from api.database.models import Archive, ArchiveConfig, Project +from api.email import get_context, jinja_env, send_email_via_mailgun +from api.files import ( calculate_file_size, + generate_file_hash, normalize_filename, read_file_in_chunks, - validated_project, ) +from api.routes import validated_project +from api.s3 import s3_file_key, s3_storage +from api.zimfarm import RequestSchema, WebhookPayload, request_task router = APIRouter() @@ -38,17 +44,6 @@ class ArchiveStatus(str, Enum): FAILED = "FAILED" -class ArchiveConfig(BaseModel): - title: str | None - description: str | None - name: str | None - publisher: str | None - creator: str | None - languages: list[str] | None - tags: list[str] | None - filename: str - - class ArchiveRequest(BaseModel): email: str | None config: ArchiveConfig @@ -66,7 +61,7 @@ class ArchiveModel(BaseModel): download_url: str | None status: str email: str | None - config: dict[str, Any] + config: ArchiveConfig model_config = ConfigDict(from_attributes=True) @@ -108,14 +103,15 @@ async def update_archive( session: Session = Depends(gen_session), ): """Update a metadata of a archive""" - config = archive_request.config.model_dump() - config["filename"] = normalize_filename(config["filename"]) + archive_request.config.filename = normalize_filename( + archive_request.config.filename + ) stmt = ( update(Archive) .filter_by(id=archive.id) .values( email=archive_request.email, - config=archive_request.config.model_dump(), + config=archive_request.config, ) ) session.execute(stmt) @@ -196,7 +192,201 @@ async def upload_illustration( detail="Illustration cannot be resized", ) from exc else: - new_config = archive.config - new_config["illustration"] = base64.b64encode(dst.getvalue()).decode("utf-8") - stmt = update(Archive).filter_by(id=archive.id).values(config=new_config) + archive.config.illustration = base64.b64encode(dst.getvalue()).decode("utf-8") + stmt = update(Archive).filter_by(id=archive.id).values(config=archive.config) session.execute(stmt) + + +def gen_collection_for(project: Project) -> tuple[list[dict[str, Any]], BinaryIO, str]: + collection = [] + # project = get_project_by_id(project_id) + for file in project.files: + entry = {} + if file.title: + entry["title"] = file.title + if file.description: + entry["title"] = file.description + if file.authors: + entry["authors"] = ", ".join(file.authors) + entry["files"] = [ + { + "uri": f"{constants.download_url}/{s3_file_key(project.id, file.hash)}", + "filename": file.filename, + } + ] + collection.append(entry) + + file = io.BytesIO() + file.write(json.dumps(collection, indent=2, ensure_ascii=False).encode("UTF-8")) + file.seek(0) + + digest = generate_file_hash(file) + + return collection, file, digest + + +def get_collection_key(project_id: UUID, collection_hash: str) -> str: + # using .json suffix (for now) so we can debug live URLs in-browser + return f"{s3_file_key(project_id=project_id, file_hash=collection_hash)}.json" + + +def upload_collection_to_s3(project: Project, collection_file: BinaryIO, s3_key: str): + + try: + if s3_storage.storage.has_object(s3_key): + logger.debug(f"Object `{s3_key}` already in S3… weird but OK") + return + logger.debug(f"Uploading collection to `{s3_key}`") + s3_storage.storage.upload_fileobj(fileobj=collection_file, key=s3_key) + s3_storage.storage.set_object_autodelete_on(s3_key, project.expire_on) + except Exception as exc: + logger.error(f"Collection failed to upload to s3 `{s3_key}`: {exc}") + raise exc + + +@router.post( + "/{project_id}/archives/{archive_id}/request", status_code=HTTPStatus.CREATED +) +async def request_archive( + archive: Archive = Depends(validated_archive), + project: Project = Depends(validated_project), + session: Session = Depends(gen_session), +): + if archive.status != ArchiveStatus.PENDING: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Non-pending archive cannot be requested", + ) + + if not archive.config.is_ready(): + raise HTTPException( + status_code=HTTPStatus.CONFLICT, + detail="Project is not ready (Archive config missing mandatory metadata)", + ) + + # this should guard the creation of Archive instead !! + if not project.expire_on: + raise HTTPException( + status_code=HTTPStatus.CONFLICT, + detail="Project is not ready (no archive or no files)", + ) + + # gen collection and stream + collection, collection_file, collection_hash = gen_collection_for(project=project) + collection_key = get_collection_key( + project_id=archive.project_id, collection_hash=collection_hash + ) + + # upload it to S3 + upload_collection_to_s3( + project=project, + collection_file=collection_file, + s3_key=collection_key, + ) + + # Everything's on S3, prepare and submit a ZF request + request_def = RequestSchema( + collection_url=f"{constants.download_url}/{collection_key}", + name=archive.config.name, + title=archive.config.title, + description=archive.config.description, + long_description=None, + language=archive.config.languages, + creator=archive.config.creator, + publisher=archive.config.publisher, + tags=archive.config.tags, + main_logo_url=None, + illustration_url=f"{constants.download_url}/{collection_key}", + ) + task_id = request_task( + archive_id=archive.id, request_def=request_def, email=archive.email + ) + + # request new statis in DB (requested with the ZF ID) + stmt = ( + update(Archive) + .filter_by(id=archive.id) + .values( + requested_on=datetime.datetime.now(tz=datetime.UTC), + collection_json_path=collection_key, + status=ArchiveStatus.REQUESTED, + zimfarm_task_id=task_id, + ) + ) + session.execute(stmt) + + +@router.post("/{project_id}/archives/{archive_id}/hook", status_code=HTTPStatus.CREATED) +async def record_task_feedback( + payload: WebhookPayload, + archive: Archive = Depends(validated_archive), + session: Session = Depends(gen_session), + token: str = "", + target: str = "", +): + + # we require a `token` arg equal to a setting string so we can ensure + # hook requests are from know senders. + # otherwises exposes us to spam abuse + if token != constants.zimfarm_callback_token: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Identify via proper token to use hook", + ) + + # discard statuses we're not interested in + if payload.status not in ("requested", "succeeded", "failed", "canceled"): + return {"status": "success"} + + # record task request results to DB + stmt: ExecutableStatement | None = None + if payload.status == "succeeded": + try: + # should we check for file["status"] == "uploaded"? + file: dict = next(iter(payload.files.values())) + filesize = file["size"] + completed_on = datetime.datetime.fromisoformat(file["uploaded_timestamp"]) + download_url = ( + f"{constants.download_url}/" + f"{payload.config['warehouse_path']}/" + f"{file['name']}" + ) + status = ArchiveStatus.READY + except Exception as exc: + logger.error(f"Failed to parse callback payload: {exc!s}") + payload.status = "failed" + else: + stmt = ( + update(Archive) + .filter_by(id=archive.id) + .values( + filesize=filesize, + completed_on=completed_on, + download_url=download_url, + status=status, + ) + ) + if payload.status in ("failed", "canceled"): + stmt = ( + update(Archive).filter_by(id=archive.id).values(status=ArchiveStatus.FAILED) + ) + if stmt is not None: + try: + session.execute(stmt) + session.commit() + except Exception as exc: + logger.error( + "Failed to update Archive with FAILED status {archive.id}: {exc!s}" + ) + logger.exception(exc) + + # ensure we have a target otherwise there's no point in preparing an email + if not target: + return {"status": "success"} + + context = get_context(task=payload.model_dump(), archive=archive) + subject = jinja_env.get_template("email_subject.txt").render(**context) + body = jinja_env.get_template("email_body.html").render(**context) + send_email_via_mailgun(target, subject, body) + + return {"status": "success"} diff --git a/backend/api/routes/files.py b/backend/api/routes/files.py index a4c7931..a3a6a53 100644 --- a/backend/api/routes/files.py +++ b/backend/api/routes/files.py @@ -1,5 +1,4 @@ import datetime -import hashlib from enum import Enum from http import HTTPStatus from uuid import UUID @@ -14,13 +13,10 @@ from api.database import Session as DBSession from api.database import gen_session from api.database.models import File, Project -from api.routes import ( - calculate_file_size, - generate_file_hash, - save_file, - validated_project, -) -from api.s3 import s3_storage +from api.database.utils import get_file_by_id, get_project_by_id +from api.files import calculate_file_size, generate_file_hash, save_file +from api.routes import validated_project +from api.s3 import s3_file_key, s3_storage from api.store import task_queue router = APIRouter() @@ -116,12 +112,6 @@ def validate_project_quota(file_size: int, project: Project): ) -def s3_file_key(project_id: UUID, file_hash: str) -> str: - """Generate s3 file key.""" - to_be_hashed_str = f"{project_id}-{file_hash}-{constants.private_salt}" - return hashlib.sha256(bytes(to_be_hashed_str, "utf-8")).hexdigest() - - def update_file_status_and_path(file: File, status: str, path: str): """Update file's Status and Path.""" with DBSession.begin() as session: @@ -140,28 +130,6 @@ def update_file_path(file: File, path: str): update_file_status_and_path(file, file.status, path) -def get_file_by_id(file_id: UUID) -> File: - """Get File instance by its id.""" - with DBSession.begin() as session: - stmt = select(File).where(File.id == file_id) - file = session.execute(stmt).scalar() - if not file: - raise ValueError(f"File not found: {file_id}") - session.expunge(file) - return file - - -def get_project_by_id(project_id: UUID) -> Project: - """Get Project instance by its id.""" - with DBSession.begin() as session: - stmt = select(Project).where(Project.id == project_id) - project = session.execute(stmt).scalar() - if not project: - raise ValueError(f"Project not found: {project_id}") - session.expunge(project) - return project - - def upload_file_to_s3(new_file_id: UUID): """Update local file to S3 storage and update file status""" new_file = get_file_by_id(new_file_id) diff --git a/backend/api/routes/projects.py b/backend/api/routes/projects.py index 9eb2ef5..a91bb9e 100644 --- a/backend/api/routes/projects.py +++ b/backend/api/routes/projects.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import Session from api.database import gen_session -from api.database.models import Archive, Project, User +from api.database.models import Archive, ArchiveConfig, Project, User from api.routes import validated_project, validated_user from api.routes.archives import ArchiveStatus @@ -46,9 +46,10 @@ async def create_project( new_archive = Archive( created_on=now, status=ArchiveStatus.PENDING, - config={}, + config=ArchiveConfig.init_with(filename="-"), filesize=None, requested_on=None, + completed_on=None, download_url=None, collection_json_path=None, zimfarm_task_id=None, diff --git a/backend/api/s3.py b/backend/api/s3.py index 61d19be..a91fe6e 100644 --- a/backend/api/s3.py +++ b/backend/api/s3.py @@ -1,3 +1,6 @@ +import hashlib +from uuid import UUID + from kiwixstorage import KiwixStorage from api.constants import constants, logger @@ -32,3 +35,12 @@ def storage(self): s3_storage = S3Storage() + + +def s3_file_key(project_id: UUID, file_hash: str) -> str: + """S3 key for a Project's File""" + digest = hashlib.sha256( + bytes(f"{project_id}-{file_hash}-{constants.private_salt}", "utf-8") + ).hexdigest() + # using project_id/ pattern to ease browsing bucket for objects + return f"{project_id}/{digest}" diff --git a/backend/api/templates/email_body.html b/backend/api/templates/email_body.html new file mode 100644 index 0000000..7621ed7 --- /dev/null +++ b/backend/api/templates/email_body.html @@ -0,0 +1,22 @@ + +
+{% if task.status == "requested" %} +Your Zim request of a Nautilus ZIM for “{{ task.config.flags.title }}” has been recorded.
+We'll send you another email once your Zim file is ready to download.
+{% endif %} + +{% if task.status == "succeeded" %} +Your Zim request of a Nautilus ZIM for “{{ task.config.flags.title }}” has completed.
+Here it is:
+{% if task.files %}We are really sorry.
+Please double check your inputs and try again. If it fails again, please contact-us
{% endif %} + + diff --git a/backend/api/templates/email_subject.txt b/backend/api/templates/email_subject.txt new file mode 100644 index 0000000..7d2b0cc --- /dev/null +++ b/backend/api/templates/email_subject.txt @@ -0,0 +1 @@ +Nautilus archive “{{ archive.config.title }}” {{ task.status }} diff --git a/backend/api/zimfarm.py b/backend/api/zimfarm.py new file mode 100644 index 0000000..a03660d --- /dev/null +++ b/backend/api/zimfarm.py @@ -0,0 +1,322 @@ +import datetime +import json +import logging +from dataclasses import dataclass +from http import HTTPStatus +from typing import Any, NamedTuple +from uuid import UUID, uuid4 + +import requests +from pydantic import BaseModel + +from api.constants import constants + +GET = "GET" +POST = "POST" +PATCH = "PATCH" +DELETE = "DELETE" + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class RequestSchema: + """Flags sent to ZF for the schedule/task""" + + collection_url: str + name: str + title: str + description: str + long_description: str | None + language: str + creator: str + publisher: str + tags: list[str] + main_logo_url: str | None + illustration_url: str + + +class WebhookPayload(BaseModel): + """Webhook payload sent by ZF""" + + _id: str + status: str + timestamp: dict + schedule_name: str + worker_name: str + updated_at: str + config: dict + original_schedule_name: str + events: list[dict] + debug: dict + requested_by: str + canceled_by: str + container: str + priority: int + notification: dict + files: dict[str, dict] + upload: dict + + +class TokenData: + """In-memory persistence of ZF credentials""" + + ACCESS_TOKEN: str = "" + ACCESS_TOKEN_EXPIRY: datetime.datetime = datetime.datetime( + 2000, 1, 1, tzinfo=datetime.UTC + ) + REFRESH_TOKEN: str = "" + REFRESH_TOKEN_EXPIRY: datetime.datetime = datetime.datetime( + 2000, 1, 1, tzinfo=datetime.UTC + ) + + +class ZimfarmAPIError(Exception): + def __init__(self, message: str, code: int = -1) -> None: + super().__init__(message) + self.code = code + + def __str__(self): + if self.code: + return f"HTTP {self.code}: {', '.join(self.args)}" + return ", ".join(self.args) + + +class ZimfarmResponse(NamedTuple): + succeeded: bool + code: int + data: str | dict[str, Any] + + +def get_url(path: str) -> str: + return "/".join([constants.zimfarm_api_url, path[1:] if path[0] == "/" else path]) + + +def get_token_headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Token {token}", + "Content-type": "application/json", + } + + +def get_token(username: str, password: str) -> tuple[str, str]: + req = requests.post( + url=get_url("/auth/authorize"), + headers={ + "username": username, + "password": password, + "Content-type": "application/json", + }, + timeout=constants.zimfarm_request_timeout_sec, + ) + req.raise_for_status() + return req.json().get("access_token", ""), req.json().get("refresh_token", "") + + +def authenticate(*, force: bool = False): + if ( + not force + and TokenData.ACCESS_TOKEN + and TokenData.ACCESS_TOKEN_EXPIRY + > datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta(minutes=2) + ): + return + + logger.debug(f"authenticate() with {force=}") + + try: + access_token, refresh_token = get_token( + username=constants.zimfarm_username, password=constants.zimfarm_password + ) + except Exception: + TokenData.ACCESS_TOKEN = TokenData.REFRESH_TOKEN = "" + TokenData.ACCESS_TOKEN_EXPIRY = datetime.datetime = datetime.datetime( + 2000, 1, 1, tzinfo=datetime.UTC + ) + else: + TokenData.ACCESS_TOKEN, TokenData.REFRESH_TOKEN = access_token, refresh_token + TokenData.ACCESS_TOKEN_EXPIRY = datetime.datetime.now( + tz=datetime.UTC + ) + datetime.timedelta(minutes=59) + TokenData.REFRESH_TOKEN_EXPIRY = datetime.datetime.now( + tz=datetime.UTC + ) + datetime.timedelta(days=29) + + +def auth_required(func): + def wrapper(*args, **kwargs): + authenticate() + return func(*args, **kwargs) + + return wrapper + + +@auth_required +def query_api( + method: str, + path: str, + payload: dict[str, str | list[str]] | None = None, + params: dict[str, str] | None = None, +) -> ZimfarmResponse: + func = { + GET: requests.get, + POST: requests.post, + PATCH: requests.patch, + DELETE: requests.delete, + }.get(method.upper(), requests.get) + try: + req = func( + url=get_url(path), + headers=get_token_headers(TokenData.ACCESS_TOKEN), + json=payload, + params=params, + timeout=constants.zimfarm_request_timeout_sec, + ) + except Exception as exc: + logger.exception(exc) + return ZimfarmResponse(False, 900, f"ConnectionError -- {exc!s}") + + try: + resp = req.json() if req.text else {} + except json.JSONDecodeError: + return ZimfarmResponse( + False, + req.status_code, + f"ResponseError (not JSON): -- {req.text}", + ) + except Exception as exc: + return ZimfarmResponse( + False, + req.status_code, + f"ResponseError -- {exc!s} -- {req.text}", + ) + + if ( + req.status_code >= HTTPStatus.OK + and req.status_code < HTTPStatus.MULTIPLE_CHOICES + ): + return ZimfarmResponse(True, req.status_code, resp) + + # Unauthorised error: attempt to re-auth as scheduler might have restarted? + if req.status_code == HTTPStatus.UNAUTHORIZED: + authenticate(force=True) + + reason = resp["error"] if "error" in resp else str(resp) + if "error_description" in resp: + reason = f"{reason}: {resp['error_description']}" + return ZimfarmResponse(False, req.status_code, reason) + + +@auth_required +def test_connection(): + return query_api(GET, "/auth/test") + + +def request_task( + archive_id: UUID, request_def: RequestSchema, email: str | None +) -> UUID: + ident = uuid4().hex + + flags = { + "collection": request_def.collection_url, + "name": request_def.name, + "output": "/output", + "zim_file": f"nautilus_{archive_id}_{ident}.zim", + "language": request_def.language, + "title": request_def.title, + "description": request_def.description, + "creator": request_def.creator, + "publisher": request_def.publisher, + "tags": request_def.tags, + "main_logo": request_def.main_logo_url, + "favicon": request_def.illustration_url, + } + + config = { + "task_name": "nautilus", + "warehouse_path": "/other", + "image": { + "name": constants.zimfarm_nautilus_image.split(":")[0], + "tag": constants.zimfarm_nautilus_image.split(":")[1], + }, + "resources": { + "cpu": constants.zimfarm_task_cpu, + "memory": constants.zimfarm_task_memory, + "disk": constants.zimfarm_task_disk, + }, + "platform": None, + "monitor": False, + "flags": flags, + } + + # gen schedule name + schedule_name = f"nautilus_{archive_id}_{ident}" + # create schedule payload + payload = { + "name": schedule_name, + "language": {"code": "eng", "name_en": "English", "name_native": "English"}, + "category": "other", + "periodicity": "manually", + "tags": [], + "enabled": True, + "config": config, + } + + # add notification callback if email supplied + if email: + url = ( + f"{constants.zimfarm_callback_base_url}" + f"?token={constants.zimfarm_callback_token}&target={email}" + ) + payload.update( + { + "notification": { + "requested": {"webhook": [url]}, + "ended": {"webhook": [url]}, + } + } + ) + + # create a unique schedule for that request on the zimfarm + success, status, resp = query_api("POST", "/schedules/", payload=payload) + if not success: + logger.error(f"Unable to create schedule via HTTP {status}: {resp}") + message = f"Unable to create schedule via HTTP {status}: {resp}" + if status == HTTPStatus.BAD_REQUEST: + # if Zimfarm replied this is a bad request, then this is most probably + # a bad request due to user input so we can track it like a bad request + raise ZimfarmAPIError(message, status) + else: + # otherwise, this is most probably an internal problem in our systems + raise ZimfarmAPIError(message, status) + + # request a task for that newly created schedule + success, status, resp = query_api( + "POST", + "/requested-tasks/", + payload={ + "schedule_names": [schedule_name], + "worker": constants.zimfarm_task_worker, + }, + ) + if not success: + logger.error(f"Unable to request {schedule_name} via HTTP {status}: {resp}") + raise ZimfarmAPIError(f"Unable to request schedule: {resp}", status) + + if not isinstance(resp, dict): + raise ZimfarmAPIError(f"response is unexpected format ({type(resp)})") + + try: + task_id = resp["requested"].pop() + if not task_id: + raise ValueError(f"task_id is empty? `{task_id}`") + except Exception as exc: + raise ZimfarmAPIError(f"Couldn't retrieve requested task id: {exc!s}") from exc + + # remove newly created schedule (not needed anymore) + success, status, resp = query_api("DELETE", f"/schedules/{schedule_name}") + if not success: + logger.error( + f"Unable to remove schedule {schedule_name} via HTTP {status}: {resp}" + ) + return UUID(task_id) diff --git a/backend/migrations/versions/be9763d49e5f_archives_completed_on.py b/backend/migrations/versions/be9763d49e5f_archives_completed_on.py new file mode 100644 index 0000000..782d85a --- /dev/null +++ b/backend/migrations/versions/be9763d49e5f_archives_completed_on.py @@ -0,0 +1,28 @@ +"""archives-completed-on + +Revision ID: be9763d49e5f +Revises: 4d766c1cc6a3 +Create Date: 2024-06-06 18:04:36.401616 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "be9763d49e5f" +down_revision = "4d766c1cc6a3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("archive", sa.Column("completed_on", sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("archive", "completed_on") + # ### end Alembic commands ### diff --git a/backend/pyproject.toml b/backend/pyproject.toml index f1e474f..1e07ae8 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "zimscraperlib==3.3.2", "humanfriendly==10.0", "rq==1.16.2", + "werkzeug==3.0.3", ] dynamic = ["authors", "license", "version", "urls"] @@ -39,14 +40,17 @@ lint = [ "ruff==0.4.3", ] check = [ - "pyright==1.1.362", + "pyright==1.1.367", + "pytest == 8.2.0", # import pytest in tests ] test = [ "pytest==8.2.0", "coverage==7.5.1", "pytest-mock==3.14.0", + "trio == 0.25.1" ] dev = [ + "ipython==8.25.0", "pre-commit==3.7.0", "nautilus-api[scripts]", "nautilus-api[lint]", @@ -196,8 +200,9 @@ extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"] ban-relative-imports = "all" [tool.ruff.lint.per-file-ignores] -# Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +# Tests can use magic values, assertions, relative imports, print, and unused args (mock) +"tests/**/*" = ["PLR2004", "S101", "TID252","T201", "ARG001", "ARG002"] +"**/migrations/**/*" = ["F401", "ISC001"] [tool.pytest.ini_options] minversion = "7.3" diff --git a/backend/tasks.py b/backend/tasks.py index ed5e6c0..65e505e 100644 --- a/backend/tasks.py +++ b/backend/tasks.py @@ -24,6 +24,7 @@ def report_cov(ctx: Context): """report coverage""" ctx.run("coverage combine", warn=True, pty=use_pty) ctx.run("coverage report --show-missing", pty=use_pty) + ctx.run("coverage xml", pty=use_pty) @task(optional=["args"], help={"args": "pytest additional arguments"}) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 772ac7b..b30a9a8 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,16 +1,26 @@ import datetime import os +import urllib.parse +import uuid +from collections.abc import AsyncGenerator +from http import HTTPStatus from io import BytesIO from pathlib import Path +from typing import Any import pytest # pyright: ignore [reportMissingImports] -from fastapi.testclient import TestClient +import requests +from httpx import AsyncClient +from starlette.testclient import TestClient from api.database import Session -from api.database.models import Archive, File, Project, User +from api.database.models import Archive, ArchiveConfig, File, Project, User from api.entrypoint import app +from api.files import save_file from api.routes.archives import ArchiveStatus -from api.routes.files import save_file +from api.s3 import s3_storage + +pytestmark = pytest.mark.asyncio(scope="package") @pytest.fixture() @@ -33,6 +43,19 @@ def client(): return TestClient(app) +@pytest.fixture(scope="module") # pyright: ignore +async def aclient() -> AsyncGenerator[AsyncClient, Any]: + async with AsyncClient(app=app, base_url="http://localhost") as client: + yield client + + +@pytest.fixture() +async def alogged_in_client(user_id: str): + async with AsyncClient(app=app, base_url="http://localhost") as client: + client.cookies = {"user_id": str(user_id)} + yield client + + @pytest.fixture def non_existent_project_id(): return "94e430c6-8888-456a-9440-c10e4a04627c" @@ -48,23 +71,28 @@ def missing_user_cookie(missing_user_id): return {"user_id": missing_user_id} -@pytest.fixture +@pytest.fixture() def test_project_name(): return "test_project_name" -@pytest.fixture +@pytest.fixture() +def test_expiring_project_name(): + return "test_expiring_project_name" + + +@pytest.fixture() def test_archive_name(): return "test_archive_name.zim" -@pytest.fixture +@pytest.fixture() def missing_archive_id(): return "55a345a6-20d2-40a7-b85a-7ec37e55b986" @pytest.fixture() -def logged_in_client(client, user_id) -> str: +def logged_in_client(client, user_id: str) -> str: cookie = {"user_id": str(user_id)} client.cookies = cookie return client @@ -147,9 +175,34 @@ def project_id(test_project_name, user_id): created_id = new_project.id yield created_id with Session.begin() as session: - user = session.get(User, created_id) + project = session.get(Project, created_id) + if project: + session.delete(project) + + +@pytest.fixture() +def expiring_project_id(test_expiring_project_name, user_id): + now = datetime.datetime.now(datetime.UTC) + new_project = Project( + name=test_expiring_project_name, + created_on=now, + expire_on=now + datetime.timedelta(minutes=30), + files=[], + archives=[], + ) + with Session.begin() as session: + user = session.get(User, user_id) if user: - session.delete(user) + user.projects.append(new_project) + session.add(new_project) + session.flush() + session.refresh(new_project) + created_id = new_project.id + yield created_id + with Session.begin() as session: + project = session.get(Project, created_id) + if project: + session.delete(project) @pytest.fixture() @@ -158,9 +211,19 @@ def archive_id(test_archive_name, project_id): new_archive = Archive( created_on=now, status=ArchiveStatus.PENDING, - config={"filename": test_archive_name}, + config=ArchiveConfig.init_with( + filename=test_archive_name, + title="A Title", + description="A Description", + name="a_name", + creator="a creator", + publisher="a publisher", + languages="eng", + tags=[], + ), filesize=None, requested_on=None, + completed_on=None, download_url=None, collection_json_path=None, zimfarm_task_id=None, @@ -176,6 +239,135 @@ def archive_id(test_archive_name, project_id): created_id = new_archive.id yield created_id with Session.begin() as session: - archives = session.get(Archive, created_id) - if archives: - session.delete(archives) + archive = session.get(Archive, created_id) + if archive: + session.delete(archive) + + +@pytest.fixture() +def expiring_archive_id(test_archive_name, expiring_project_id): + now = datetime.datetime.now(datetime.UTC) + new_archive = Archive( + created_on=now, + status=ArchiveStatus.PENDING, + config=ArchiveConfig.init_with( + filename=test_archive_name, + title="A Title", + description="A Description", + name="a_name", + creator="a creator", + publisher="a publisher", + languages="eng", + tags=[], + ), + filesize=None, + requested_on=None, + completed_on=None, + download_url=None, + collection_json_path=None, + zimfarm_task_id=None, + email=None, + ) + with Session.begin() as session: + project = session.get(Project, expiring_project_id) + if project: + project.archives.append(new_archive) + session.add(new_archive) + session.flush() + session.refresh(new_archive) + created_id = new_archive.id + yield created_id + with Session.begin() as session: + archive = session.get(Archive, created_id) + if archive: + session.delete(archive) + + +class SuccessStorage: + + def upload_file(*args, **kwargs): ... + + def upload_fileobj(*args, **kwargs): ... + + def set_object_autodelete_on(*args, **kwargs): ... + + def has_object(*args, **kwargs): + return True + + def check_credentials(*args, **kwargs): + return True + + def delete_object(*args, **kwargs): ... + + +@pytest.fixture +def successful_s3_upload_file(monkeypatch): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + + monkeypatch.setattr(s3_storage, "_storage", SuccessStorage()) + yield True + + +class SuccessfulRequestResponse: + status_code = HTTPStatus.OK + text = "text" + + @staticmethod + def raise_for_status(): ... + + +class SuccessfulAuthResponse(SuccessfulRequestResponse): + @staticmethod + def json(): + return { + "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9." + "eyJpc3MiOiJkaXNwYXRjaGVyIiwiZXhwIj", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "aea891db-090b-4cbb-6qer-57c0928b42e6", + } + + +class ScheduleCreatedResponse(SuccessfulRequestResponse): + status_code = HTTPStatus.CREATED + + @staticmethod + def json(): + return {"_id": uuid.uuid4().hex} + + +class TaskRequestedResponse(SuccessfulRequestResponse): + status_code = HTTPStatus.CREATED + + @staticmethod + def json(): + return {"requested": [uuid.uuid4().hex]} + + +class ScheduleDeletedResponse(SuccessfulRequestResponse): + + @staticmethod + def json(): + return {} + + +@pytest.fixture +def successful_zimfarm_request_task(monkeypatch): + """Requests.get() mocked to return {'mock_key':'mock_response'}.""" + + def requests_post(**kwargs): + uri = urllib.parse.urlparse(kwargs.get("url")) + if uri.path == "/v1/auth/authorize": + return SuccessfulAuthResponse() + if uri.path == "/v1/schedules/": + return ScheduleCreatedResponse() + if uri.path == "/v1/requested-tasks/": + return TaskRequestedResponse() + raise ValueError(f"Unhandled {kwargs}") + + def requests_delete(*args, **kwargs): + return ScheduleDeletedResponse() + + monkeypatch.setattr(requests, "post", requests_post) + monkeypatch.setattr(requests, "delete", requests_delete) + yield True diff --git a/backend/tests/routes/test_archives.py b/backend/tests/routes/test_archives.py index b02d6fb..99da268 100644 --- a/backend/tests/routes/test_archives.py +++ b/backend/tests/routes/test_archives.py @@ -1,6 +1,9 @@ import uuid from http import HTTPStatus +import pytest +from httpx import AsyncClient + from api.constants import constants @@ -69,8 +72,9 @@ def test_update_archive_correct_data(logged_in_client, project_id, archive_id): "name": "test_name", "publisher": "test_publisher", "creator": "test_creator", - "languages": ["en"], + "languages": "en", "tags": ["test_tags"], + "illustration": "", }, } response = logged_in_client.patch( @@ -91,7 +95,7 @@ def test_update_archive_correct_data(logged_in_client, project_id, archive_id): assert json_result.get("config").get("name") == "test_name" assert json_result.get("config").get("publisher") == "test_publisher" assert json_result.get("config").get("creator") == "test_creator" - assert json_result.get("config").get("languages")[0] == "en" + assert json_result.get("config").get("languages") == "en" assert json_result.get("config").get("tags")[0] == "test_tags" @@ -105,7 +109,7 @@ def test_update_archive_wrong_id(logged_in_client, project_id, missing_archive_i "name": "test_name", "publisher": "test_publisher", "creator": "test_creator", - "languages": ["en"], + "languages": "en", "tags": ["test_tags"], }, } @@ -214,3 +218,30 @@ def test_upload_illustration_without_wrong_authorization( files=file, ) assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.anyio +async def test_request_archive_not_ready(alogged_in_client, project_id, archive_id): + response = await alogged_in_client.post( + f"{constants.api_version_prefix}/projects/" + f"{project_id}/archives/{archive_id}/request" + ) + assert response.status_code == HTTPStatus.CONFLICT + + +@pytest.mark.anyio +async def test_request_archive_ready( + alogged_in_client: AsyncClient, + archive_id, + project_id, + expiring_project_id, + expiring_archive_id, + successful_s3_upload_file, + successful_zimfarm_request_task, +): + + response = await alogged_in_client.post( + f"{constants.api_version_prefix}/projects/" + f"{expiring_project_id}/archives/{expiring_archive_id}/request" + ) + assert response.status_code == HTTPStatus.CREATED diff --git a/backend/tests/routes/test_projects.py b/backend/tests/routes/test_projects.py index 485ea3e..bd24749 100644 --- a/backend/tests/routes/test_projects.py +++ b/backend/tests/routes/test_projects.py @@ -28,6 +28,14 @@ def test_create_project_wrong_authorization(client, missing_user_cookie): assert response.status_code == HTTPStatus.UNAUTHORIZED +def test_get_all_projects_no_data(logged_in_client): + response = logged_in_client.get(f"{constants.api_version_prefix}/projects") + json_result = response.json() + assert response.status_code == HTTPStatus.OK + assert json_result is not None + assert len(json_result) == 0 + + def test_get_all_projects_correct_data(logged_in_client, project_id): response = logged_in_client.get(f"{constants.api_version_prefix}/projects") json_result = response.json()