diff --git a/backend/alembic.ini b/backend/alembic.ini index d174d12..e60e36b 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -68,18 +68,17 @@ sqlalchemy.url = driver://user:pass@localhost/dbname # on newly generated revision scripts. See the documentation for further # detail and examples -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME -hooks = black isort -black.type = console_scripts -black.entrypoint = black -black.options = REVISION_SCRIPT_FILENAME -isort.type = console_scripts -isort.entrypoint = isort -isort.options = --profile black REVISION_SCRIPT_FILENAME +hooks = ruff, ruff_format + +# lint with attempts to fix using ruff +ruff.type = exec +ruff.executable = ruff +ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# format using ruff +ruff_format.type = exec +ruff_format.executable = ruff +ruff_format.options = format REVISION_SCRIPT_FILENAME # Logging configuration diff --git a/backend/api/constants.py b/backend/api/constants.py index 465a079..d80e5bd 100644 --- a/backend/api/constants.py +++ b/backend/api/constants.py @@ -2,6 +2,7 @@ import logging import os import tempfile +import uuid from dataclasses import dataclass, field from pathlib import Path @@ -21,45 +22,106 @@ class BackendConf: Backend configuration, read from environment variables and set default values. """ - logger: logging.Logger = field(init=False) + # Configuration + project_expire_after: datetime.timedelta = datetime.timedelta(days=7) + project_quota: int = 0 + chunk_size: int = 1024 # reading/writing received files + illustration_quota: int = 0 + api_version_prefix: str = "/v1" # our API + + # Database + postgres_uri: str = os.getenv("POSTGRES_URI") or "nodb" + + # Scheduler process + redis_uri: str = os.getenv("REDIS_URI") or "redis://localhost:6379/0" + channel_name: str = os.getenv("CHANNEL_NAME") or "s3_upload" - # Mandatory configurations - postgres_uri = os.getenv("POSTGRES_URI", "nodb") - s3_url_with_credentials = os.getenv("S3_URL_WITH_CREDENTIALS") - private_salt = os.getenv("PRIVATE_SALT") + # Transient (on host disk) Storage + transient_storage_path: Path = Path() - # Optional configuration. - s3_max_tries = int(os.getenv("S3_MAX_TRIES", "3")) - s3_retry_wait = humanfriendly.parse_timespan(os.getenv("S3_RETRY_TIMES", "10s")) - s3_deletion_delay = datetime.timedelta( + # S3 Storage + s3_url_with_credentials: str = os.getenv("S3_URL_WITH_CREDENTIALS") or "" + s3_max_tries: int = int(os.getenv("S3_MAX_TRIES", "3")) + s3_retry_wait: int = int( + humanfriendly.parse_timespan(os.getenv("S3_RETRY_TIMES") or "10s") + ) + s3_deletion_delay: datetime.timedelta = datetime.timedelta( hours=int(os.getenv("S3_REMOVE_DELETEDUPLOADING_AFTER_HOURS", "12")) ) - transient_storage_path = Path( - os.getenv("TRANSIENT_STORAGE_PATH", tempfile.gettempdir()) - ).resolve() - redis_uri = os.getenv("REDIS_URI", "redis://localhost:6379/0") - channel_name = os.getenv("CHANNEL_NAME", "s3_upload") + private_salt = os.getenv( + "PRIVATE_SALT", uuid.uuid4().hex + ) # used to make S3 keys unguessable + + # Cookies cookie_domain = os.getenv("COOKIE_DOMAIN", None) cookie_expiration_days = int(os.getenv("COOKIE_EXPIRATION_DAYS", "30")) - project_quota = humanfriendly.parse_size(os.getenv("PROJECT_QUOTA", "100MB")) - chunk_size = humanfriendly.parse_size(os.getenv("CHUNK_SIZE", "2MiB")) - illustration_quota = humanfriendly.parse_size( - os.getenv("ILLUSTRATION_QUOTA", "2MiB") + authentication_cookie_name: str = "user_id" + + # Deployment + public_url: str = os.getenv("PUBLIC_URL") or "http://localhost" + download_url: str = ( + os.getenv("DOWNLOAD_URL") + or "https://s3.us-west-1.wasabisys.com/org-kiwix-zimit/zim" ) allowed_origins = os.getenv( "ALLOWED_ORIGINS", "http://localhost", ).split("|") - authentication_cookie_name: str = "user_id" - api_version_prefix = "/v1" - project_expire_after = datetime.timedelta(days=7) + # Zimfarm (3rd party API creating ZIMs and calling back with feedback) + zimfarm_api_url: str = ( + os.getenv("ZIMFARM_API_URL") or "https://api.farm.zimit.kiwix.org/v1" + ) + zimfarm_username: str = os.getenv("ZIMFARM_API_USERNAME") or "" + zimfarm_password: str = os.getenv("ZIMFARM_API_PASSWORD") or "" + zimfarm_nautilus_image: str = ( + os.getenv("ZIMFARM_NAUTILUS_IMAGE") or "ghcr.io/openzim/nautilus:latest" + ) + zimfarm_task_cpu: int = int(os.getenv("ZIMFARM_TASK_CPU") or "3") + zimfarm_task_memory: int = 0 + zimfarm_task_disk: int = 0 + zimfarm_callback_base_url = os.getenv("ZIMFARM_CALLBACK_BASE_URL", "") + zimfarm_callback_token = os.getenv("ZIMFARM_CALLBACK_TOKEN", uuid.uuid4().hex) + zimfarm_task_worker: str = os.getenv("ZIMFARM_TASK_WORKDER") or "-" + zimfarm_request_timeout_sec: int = 10 + + # Mailgun (3rd party API to send emails) + mailgun_api_url: str = os.getenv("MAILGUN_API_URL") or "" + mailgun_api_key: str = os.getenv("MAILGUN_API_KEY") or "" + mailgun_from: str = os.getenv("MAILGUN_FROM") or "Nautilus ZIM" + mailgun_request_timeout_sec: int = 10 + + logger: logging.Logger = field(init=False) def __post_init__(self): self.logger = logging.getLogger(Path(__file__).parent.name) self.transient_storage_path.mkdir(exist_ok=True) self.job_retry = Retry(max=self.s3_max_tries, interval=int(self.s3_retry_wait)) + self.transient_storage_path = Path( + os.getenv("TRANSIENT_STORAGE_PATH") or tempfile.gettempdir() + ).resolve() + + self.project_quota = humanfriendly.parse_size( + os.getenv("PROJECT_QUOTA") or "100MB" + ) + + self.chunk_size = humanfriendly.parse_size(os.getenv("CHUNK_SIZE", "2MiB")) + + self.illustration_quota = humanfriendly.parse_size( + os.getenv("ILLUSTRATION_QUOTA", "2MiB") + ) + + self.zimfarm_task_memory = humanfriendly.parse_size( + os.getenv("ZIMFARM_TASK_MEMORY") or "1000MiB" + ) + self.zimfarm_task_disk = humanfriendly.parse_size( + os.getenv("ZIMFARM_TASK_DISK") or "200MiB" + ) + + if not self.zimfarm_callback_base_url: + self.zimfarm_callback_base_url = f"{self.zimfarm_api_url}/requests/hook" + constants = BackendConf() logger = constants.logger diff --git a/backend/api/database/models.py b/backend/api/database/models.py index 42ac4a3..ad13783 100644 --- a/backend/api/database/models.py +++ b/backend/api/database/models.py @@ -137,6 +137,7 @@ class Archive(Base): filesize: Mapped[int | None] created_on: Mapped[datetime] requested_on: Mapped[datetime | None] + completed_on: Mapped[datetime | None] download_url: Mapped[str | None] collection_json_path: Mapped[str | None] status: Mapped[str] diff --git a/backend/api/database/utils.py b/backend/api/database/utils.py new file mode 100644 index 0000000..d7972d4 --- /dev/null +++ b/backend/api/database/utils.py @@ -0,0 +1,28 @@ +from uuid import UUID + +from sqlalchemy import select + +from api.database import Session as DBSession +from api.database.models import File, Project + + +def get_file_by_id(file_id: UUID) -> File: + """Get File instance by its id.""" + with DBSession.begin() as session: + stmt = select(File).where(File.id == file_id) + file = session.execute(stmt).scalar() + if not file: + raise ValueError(f"File not found: {file_id}") + session.expunge(file) + return file + + +def get_project_by_id(project_id: UUID) -> Project: + """Get Project instance by its id.""" + with DBSession.begin() as session: + stmt = select(Project).where(Project.id == project_id) + project = session.execute(stmt).scalar() + if not project: + raise ValueError(f"Project not found: {project_id}") + session.expunge(project) + return project diff --git a/backend/api/email.py b/backend/api/email.py new file mode 100644 index 0000000..0ec984f --- /dev/null +++ b/backend/api/email.py @@ -0,0 +1,75 @@ +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import humanfriendly +import requests +from jinja2 import Environment, FileSystemLoader, select_autoescape +from werkzeug.datastructures import MultiDict + +from api.constants import constants, logger +from api.database.models import Archive + +jinja_env = Environment( + loader=FileSystemLoader("templates"), + autoescape=select_autoescape(["html", "txt"]), +) +jinja_env.filters["short_id"] = lambda value: str(value)[:5] +jinja_env.filters["format_size"] = lambda value: humanfriendly.format_size( + value, binary=True +) + + +def send_email_via_mailgun( + to: Iterable[str] | str, + subject: str, + contents: str, + cc: Iterable[str] | None = None, + bcc: Iterable[str] | None = None, + attachments: Iterable[Path] | None = None, +) -> str: + if not constants.mailgun_api_url or not constants.mailgun_api_key: + logger.warn(f"Mailgun not configured, ignoring email request to: {to!s}") + return "" + + values = [ + ("from", constants.mailgun_from), + ("subject", subject), + ("html", contents), + ] + + values += [("to", list(to) if isinstance(to, Iterable) else [to])] + values += [("cc", list(cc) if isinstance(cc, Iterable) else [cc])] + values += [("bcc", list(bcc) if isinstance(bcc, Iterable) else [bcc])] + data = MultiDict(values) + + try: + resp = requests.post( + url=f"{constants.mailgun_api_url}/messages", + auth=("api", constants.mailgun_api_key), + data=data, + files=( + [ + ("attachment", (fpath.name, open(fpath, "rb").read())) + for fpath in attachments + ] + if attachments + else [] + ), + timeout=constants.mailgun_request_timeout_sec, + ) + resp.raise_for_status() + except Exception as exc: + logger.error(f"Failed to send mailgun notif: {exc}") + logger.exception(exc) + return resp.json().get("id") or resp.text + + +def get_context(task: dict[str, Any], archive: Archive): + """Jinja context dict for email notifications""" + return { + "base_url": constants.public_url, + "download_url": constants.download_url, + "task": task, + "archive": archive, + } diff --git a/backend/api/files.py b/backend/api/files.py new file mode 100644 index 0000000..6029feb --- /dev/null +++ b/backend/api/files.py @@ -0,0 +1,67 @@ +import hashlib +from collections.abc import Iterator +from pathlib import Path +from typing import BinaryIO +from uuid import UUID + +from api.constants import constants +from api.database import get_local_fpath_for + + +def calculate_file_size(file: BinaryIO) -> int: + """Calculate the size of a file chunk by chunk""" + size = 0 + for chunk in read_file_in_chunks(file): + size += len(chunk) + return size + + +def read_file_in_chunks( + reader: BinaryIO, chunk_size=constants.chunk_size +) -> Iterator[bytes]: + """Read Big file chunk by chunk. Default chunk size is 2k""" + while True: + chunk = reader.read(chunk_size) + if not chunk: + break + yield chunk + reader.seek(0) + + +def save_file(file: BinaryIO, file_name: str, project_id: UUID) -> Path: + """Saves a binary file to a specific location and returns its path.""" + fpath = get_local_fpath_for(file_name, project_id) + if not fpath.is_file(): + with open(fpath, "wb") as file_object: + for chunk in read_file_in_chunks(file): + file_object.write(chunk) + return fpath + + +def generate_file_hash(file: BinaryIO) -> str: + """Generate sha256 hash of a file, optimized for large files""" + hasher = hashlib.sha256() + for chunk in read_file_in_chunks(file): + hasher.update(chunk) + return hasher.hexdigest() + + +def normalize_filename(filename: str) -> str: + """filesystem (ext4,apfs,hfs+,ntfs,exfat) and S3 compliant filename""" + + normalized = str(filename) + + # we replace / with __ as it would have a meaning + replacements = (("/", "__"),) + for pattern, repl in replacements: + normalized = filename.replace(pattern, repl) + + # other prohibited chars are removed (mostly for Windows context) + removals = ["\\", ":", "*", "?", '"', "<", ">", "|"] + [ + chr(idx) for idx in range(1, 32) + ] + for char in removals: + normalized.replace(char, "") + + # ext4/exfat has a 255B filename limit (s3 is 1KiB) + return normalized.encode("utf-8")[:255].decode("utf-8") diff --git a/backend/api/routes/__init__.py b/backend/api/routes/__init__.py index 88f22d8..0d19d94 100644 --- a/backend/api/routes/__init__.py +++ b/backend/api/routes/__init__.py @@ -1,8 +1,5 @@ -import hashlib -from collections.abc import Iterator from http import HTTPStatus -from pathlib import Path -from typing import Annotated, BinaryIO +from typing import Annotated from uuid import UUID from fastapi import Cookie, Depends, HTTPException, Response @@ -10,7 +7,7 @@ from sqlalchemy.orm import Session from api.constants import constants -from api.database import gen_session, get_local_fpath_for +from api.database import gen_session from api.database.models import Project, User @@ -56,62 +53,3 @@ async def validated_project( if not project: raise HTTPException(HTTPStatus.NOT_FOUND, f"Project not found: {project_id}") return project - - -def calculate_file_size(file: BinaryIO) -> int: - """Calculate the size of a file chunk by chunk""" - size = 0 - for chunk in read_file_in_chunks(file): - size += len(chunk) - return size - - -def read_file_in_chunks( - reader: BinaryIO, chunk_size=constants.chunk_size -) -> Iterator[bytes]: - """Read Big file chunk by chunk. Default chunk size is 2k""" - while True: - chunk = reader.read(chunk_size) - if not chunk: - break - yield chunk - reader.seek(0) - - -def save_file(file: BinaryIO, file_name: str, project_id: UUID) -> Path: - """Saves a binary file to a specific location and returns its path.""" - fpath = get_local_fpath_for(file_name, project_id) - if not fpath.is_file(): - with open(fpath, "wb") as file_object: - for chunk in read_file_in_chunks(file): - file_object.write(chunk) - return fpath - - -def generate_file_hash(file: BinaryIO) -> str: - """Generate sha256 hash of a file, optimized for large files""" - hasher = hashlib.sha256() - for chunk in read_file_in_chunks(file): - hasher.update(chunk) - return hasher.hexdigest() - - -def normalize_filename(filename: str) -> str: - """filesystem (ext4,apfs,hfs+,ntfs,exfat) and S3 compliant filename""" - - normalized = str(filename) - - # we replace / with __ as it would have a meaning - replacements = (("/", "__"),) - for pattern, repl in replacements: - normalized = filename.replace(pattern, repl) - - # other prohibited chars are removed (mostly for Windows context) - removals = ["\\", ":", "*", "?", '"', "<", ">", "|"] + [ - chr(idx) for idx in range(1, 32) - ] - for char in removals: - normalized.replace(char, "") - - # ext4/exfat has a 255B filename limit (s3 is 1KiB) - return normalized.encode("utf-8")[:255].decode("utf-8") diff --git a/backend/api/routes/archives.py b/backend/api/routes/archives.py index 60fa09a..d559315 100644 --- a/backend/api/routes/archives.py +++ b/backend/api/routes/archives.py @@ -1,9 +1,10 @@ import base64 import datetime import io +import json from enum import Enum from http import HTTPStatus -from typing import Any +from typing import Any, BinaryIO from uuid import UUID import zimscraperlib.image @@ -11,17 +12,22 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter from sqlalchemy import select, update from sqlalchemy.orm import Session +from sqlalchemy.sql.base import Executable as ExecutableStatement from zimscraperlib import filesystem -from api.constants import constants +from api.constants import constants, logger from api.database import gen_session from api.database.models import Archive, Project -from api.routes import ( +from api.email import get_context, jinja_env, send_email_via_mailgun +from api.files import ( calculate_file_size, + generate_file_hash, normalize_filename, read_file_in_chunks, - validated_project, ) +from api.routes import validated_project +from api.s3 import s3_file_key, s3_storage +from api.zimfarm import RequestSchema, WebhookPayload, request_task router = APIRouter() @@ -44,7 +50,7 @@ class ArchiveConfig(BaseModel): name: str | None publisher: str | None creator: str | None - languages: list[str] | None + languages: str | None tags: list[str] | None filename: str @@ -200,3 +206,185 @@ async def upload_illustration( new_config["illustration"] = base64.b64encode(dst.getvalue()).decode("utf-8") stmt = update(Archive).filter_by(id=archive.id).values(config=new_config) session.execute(stmt) + + +def gen_collection_for(project: Project) -> tuple[list[dict[str, Any]], BinaryIO, str]: + collection = [] + # project = get_project_by_id(project_id) + for file in project.files: + entry = {} + if file.title: + entry["title"] = file.title + if file.description: + entry["title"] = file.description + if file.authors: + entry["authors"] = ", ".join(file.authors) + entry["files"] = [ + { + "uri": f"{constants.download_url}/{s3_file_key(project.id, file.hash)}", + "filename": file.filename, + } + ] + collection.append(entry) + + file = io.BytesIO() + file.write(json.dumps(collection, indent=2, ensure_ascii=False).encode("UTF-8")) + file.seek(0) + + digest = generate_file_hash(file) + + return collection, file, digest + + +def get_collection_key(project_id: UUID, collection_hash: str) -> str: + # using .json suffix (for now) so we can debug live URLs in-browser + return f"{s3_file_key(project_id=project_id, file_hash=collection_hash)}.json" + + +def upload_collection_to_s3(project: Project, collection_file: BinaryIO, s3_key: str): + + try: + if s3_storage.storage.has_object(s3_key): + logger.debug(f"Object `{s3_key}` already in S3… weird but OK") + return + logger.debug(f"Uploading collection to `{s3_key}`") + s3_storage.storage.upload_fileobj(fileobj=collection_file, key=s3_key) + s3_storage.storage.set_object_autodelete_on(s3_key, project.expire_on) + except Exception as exc: + logger.error(f"Collection failed to upload to s3 `{s3_key}`: {exc}") + raise exc + + +@router.post( + "/{project_id}/archives/{archive_id}/request", status_code=HTTPStatus.CREATED +) +async def request_archive( + archive: Archive = Depends(validated_archive), + project: Project = Depends(validated_project), + session: Session = Depends(gen_session), +): + if archive.status != ArchiveStatus.PENDING: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="Non-pending archive cannot be requested", + ) + + # gen collection and stream + collection, collection_file, collection_hash = gen_collection_for(project=project) + collection_key = get_collection_key( + project_id=archive.project_id, collection_hash=collection_hash + ) + + # upload it to S3 + upload_collection_to_s3( + project=project, + collection_file=collection_file, + s3_key=collection_key, + ) + + # Everything's on S3, prepare and submit a ZF request + request_def = RequestSchema( + collection_url=f"{constants.download_url}/{collection_key}", + name=archive.config["name"], + title=archive.config["title"], + description=archive.config["description"], + long_description=archive.config["long_description"], + language=archive.config["language"], + creator=archive.config["creator"], + publisher=archive.config["publisher"], + tags=archive.config["tags"], + main_logo_url=None, + illustration_url=f"{constants.download_url}/{collection_key}", + ) + task_id = request_task( + archive_id=archive.id, request_def=request_def, email=archive.email + ) + + # request new statis in DB (requested with the ZF ID) + stmt = ( + update(Archive) + .filter_by(id=archive.id) + .values( + requested_on=datetime.datetime.now(tz=datetime.UTC), + collection_json_path=collection_key, + status=ArchiveStatus.REQUESTED, + zimfarm_task_id=task_id, + ) + ) + session.execute(stmt) + + +@router.post("/{project_id}/archives/{archive_id}/hook", status_code=HTTPStatus.CREATED) +async def record_task_feedback( + payload: WebhookPayload, + archive: Archive = Depends(validated_archive), + session: Session = Depends(gen_session), + token: str = "", + target: str = "", +): + + # we require a `token` arg equal to a setting string so we can ensure + # hook requests are from know senders. + # otherwises exposes us to spam abuse + if token != constants.zimfarm_callback_token: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail="Identify via proper token to use hook", + ) + + # discard statuses we're not interested in + if payload.status not in ("requested", "succeeded", "failed", "canceled"): + return {"status": "success"} + + # record task request results to DB + stmt: ExecutableStatement | None = None + if payload.status == "succeeded": + try: + # should we check for file["status"] == "uploaded"? + file: dict = next(iter(payload.files.values())) + filesize = file["size"] + completed_on = datetime.datetime.fromisoformat(file["uploaded_timestamp"]) + download_url = ( + f"{constants.download_url}/" + f"{payload.config['warehouse_path']}/" + f"{file['name']}" + ) + status = ArchiveStatus.READY + except Exception as exc: + logger.error(f"Failed to parse callback payload: {exc!s}") + payload.status = "failed" + else: + stmt = ( + update(Archive) + .filter_by(id=archive.id) + .values( + filesize=filesize, + completed_on=completed_on, + download_url=download_url, + status=status, + ) + ) + if payload.status in ("failed", "canceled"): + stmt = ( + update(Archive).filter_by(id=archive.id).values(status=ArchiveStatus.FAILED) + ) + if stmt is not None: + try: + session.execute(stmt) + session.commit() + except Exception as exc: + logger.error( + "Failed to update Archive with FAILED status {archive.id}: {exc!s}" + ) + logger.exception(exc) + + # ensure we have a target otherwise there's no point in preparing an email + if not target: + return {"status": "success"} + + context = get_context(task=payload.dict(), archive=archive) + subject = jinja_env.get_template("email_subject.txt").render(**context) + body = jinja_env.get_template("email_body.html").render(**context) + send_email_via_mailgun(target, subject, body) + + return {"status": "success"} diff --git a/backend/api/routes/files.py b/backend/api/routes/files.py index a4c7931..a3a6a53 100644 --- a/backend/api/routes/files.py +++ b/backend/api/routes/files.py @@ -1,5 +1,4 @@ import datetime -import hashlib from enum import Enum from http import HTTPStatus from uuid import UUID @@ -14,13 +13,10 @@ from api.database import Session as DBSession from api.database import gen_session from api.database.models import File, Project -from api.routes import ( - calculate_file_size, - generate_file_hash, - save_file, - validated_project, -) -from api.s3 import s3_storage +from api.database.utils import get_file_by_id, get_project_by_id +from api.files import calculate_file_size, generate_file_hash, save_file +from api.routes import validated_project +from api.s3 import s3_file_key, s3_storage from api.store import task_queue router = APIRouter() @@ -116,12 +112,6 @@ def validate_project_quota(file_size: int, project: Project): ) -def s3_file_key(project_id: UUID, file_hash: str) -> str: - """Generate s3 file key.""" - to_be_hashed_str = f"{project_id}-{file_hash}-{constants.private_salt}" - return hashlib.sha256(bytes(to_be_hashed_str, "utf-8")).hexdigest() - - def update_file_status_and_path(file: File, status: str, path: str): """Update file's Status and Path.""" with DBSession.begin() as session: @@ -140,28 +130,6 @@ def update_file_path(file: File, path: str): update_file_status_and_path(file, file.status, path) -def get_file_by_id(file_id: UUID) -> File: - """Get File instance by its id.""" - with DBSession.begin() as session: - stmt = select(File).where(File.id == file_id) - file = session.execute(stmt).scalar() - if not file: - raise ValueError(f"File not found: {file_id}") - session.expunge(file) - return file - - -def get_project_by_id(project_id: UUID) -> Project: - """Get Project instance by its id.""" - with DBSession.begin() as session: - stmt = select(Project).where(Project.id == project_id) - project = session.execute(stmt).scalar() - if not project: - raise ValueError(f"Project not found: {project_id}") - session.expunge(project) - return project - - def upload_file_to_s3(new_file_id: UUID): """Update local file to S3 storage and update file status""" new_file = get_file_by_id(new_file_id) diff --git a/backend/api/routes/projects.py b/backend/api/routes/projects.py index 9eb2ef5..411749c 100644 --- a/backend/api/routes/projects.py +++ b/backend/api/routes/projects.py @@ -49,6 +49,7 @@ async def create_project( config={}, filesize=None, requested_on=None, + completed_on=None, download_url=None, collection_json_path=None, zimfarm_task_id=None, diff --git a/backend/api/s3.py b/backend/api/s3.py index 61d19be..a91fe6e 100644 --- a/backend/api/s3.py +++ b/backend/api/s3.py @@ -1,3 +1,6 @@ +import hashlib +from uuid import UUID + from kiwixstorage import KiwixStorage from api.constants import constants, logger @@ -32,3 +35,12 @@ def storage(self): s3_storage = S3Storage() + + +def s3_file_key(project_id: UUID, file_hash: str) -> str: + """S3 key for a Project's File""" + digest = hashlib.sha256( + bytes(f"{project_id}-{file_hash}-{constants.private_salt}", "utf-8") + ).hexdigest() + # using project_id/ pattern to ease browsing bucket for objects + return f"{project_id}/{digest}" diff --git a/backend/api/templates/email_body.html b/backend/api/templates/email_body.html new file mode 100644 index 0000000..7621ed7 --- /dev/null +++ b/backend/api/templates/email_body.html @@ -0,0 +1,22 @@ + + +{% if task.status == "requested" %} +

Zim requested!

+

Your Zim request of a Nautilus ZIM for “{{ task.config.flags.title }}” has been recorded.

+

We'll send you another email once your Zim file is ready to download.

+{% endif %} + +{% if task.status == "succeeded" %} +

Zim is ready!

+

Your Zim request of a Nautilus ZIM for “{{ task.config.flags.title }}” has completed.

+

Here it is:

+{% if task.files %}{% endif %} +{% endif %} + +{% if task.status in ("failed", "canceled") %}

Your ZIM request of a Nautilus ZIM for “{{ task.config.flags.title }}” has failed!

+

We are really sorry.

+

Please double check your inputs and try again. If it fails again, please contact-us

{% endif %} + + diff --git a/backend/api/templates/email_subject.txt b/backend/api/templates/email_subject.txt new file mode 100644 index 0000000..7d2b0cc --- /dev/null +++ b/backend/api/templates/email_subject.txt @@ -0,0 +1 @@ +Nautilus archive “{{ archive.config.title }}” {{ task.status }} diff --git a/backend/api/zimfarm.py b/backend/api/zimfarm.py new file mode 100644 index 0000000..870f177 --- /dev/null +++ b/backend/api/zimfarm.py @@ -0,0 +1,322 @@ +import datetime +import json +import logging +from dataclasses import dataclass +from http import HTTPStatus +from typing import Any, NamedTuple +from uuid import UUID, uuid4 + +import requests +from pydantic import BaseModel + +from api.constants import constants + +GET = "GET" +POST = "POST" +PATCH = "PATCH" +DELETE = "DELETE" + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class RequestSchema: + """Flags sent to ZF for the schedule/task""" + + collection_url: str + name: str + title: str + description: str + long_description: str | None + language: str + creator: str + publisher: str + tags: list[str] + main_logo_url: str | None + illustration_url: str + + +class WebhookPayload(BaseModel): + """Webhook payload sent by ZF""" + + _id: str + status: str + timestamp: dict + schedule_name: str + worker_name: str + updated_at: str + config: dict + original_schedule_name: str + events: list[dict] + debug: dict + requested_by: str + canceled_by: str + container: str + priority: int + notification: dict + files: dict[str, dict] + upload: dict + + +class TokenData: + """In-memory persistence of ZF credentials""" + + ACCESS_TOKEN: str = "" + ACCESS_TOKEN_EXPIRY: datetime.datetime = datetime.datetime( + 2000, 1, 1, tzinfo=datetime.UTC + ) + REFRESH_TOKEN: str = "" + REFRESH_TOKEN_EXPIRY: datetime.datetime = datetime.datetime( + 2000, 1, 1, tzinfo=datetime.UTC + ) + + +class ZimfarmAPIError(Exception): + def __init__(self, message: str, code: int = -1) -> None: + super().__init__(message) + self.code = code + + def __str__(self): + if self.code: + return f"HTTP {self.code}: {', '.join(self.args)}" + return ", ".join(self.args) + + +class ZimfarmResponse(NamedTuple): + succeeded: bool + code: int + data: str | dict[str, Any] + + +def get_url(path: str) -> str: + return "/".join([constants.zimfarm_api_url, path[1:] if path[0] == "/" else path]) + + +def get_token_headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Token {token}", + "Content-type": "application/json", + } + + +def get_token(username: str, password: str) -> tuple[str, str]: + req = requests.post( + url=get_url("/auth/authorize"), + headers={ + "username": username, + "password": password, + "Content-type": "application/json", + }, + timeout=constants.zimfarm_request_timeout_sec, + ) + req.raise_for_status() + return req.json().get("access_token", ""), req.json().get("refresh_token", "") + + +def authenticate(*, force: bool = False): + if ( + not force + and TokenData.ACCESS_TOKEN + and TokenData.ACCESS_TOKEN_EXPIRY + > datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta(minutes=2) + ): + return + + logger.debug(f"authenticate() with {force=}") + + try: + access_token, refresh_token = get_token( + username=constants.zimfarm_username, password=constants.zimfarm_password + ) + except Exception: + TokenData.ACCESS_TOKEN = TokenData.REFRESH_TOKEN = "" + TokenData.ACCESS_TOKEN_EXPIRY = datetime.datetime = datetime.datetime( + 2000, 1, 1, tzinfo=datetime.UTC + ) + else: + TokenData.ACCESS_TOKEN, TokenData.REFRESH_TOKEN = access_token, refresh_token + TokenData.ACCESS_TOKEN_EXPIRY = datetime.datetime.now( + tz=datetime.UTC + ) + datetime.timedelta(minutes=59) + TokenData.REFRESH_TOKEN_EXPIRY = datetime.datetime.now( + tz=datetime.UTC + ) + datetime.timedelta(days=29) + + +def auth_required(func): + def wrapper(*args, **kwargs): + authenticate() + return func(*args, **kwargs) + + return wrapper + + +@auth_required +def query_api( + method: str, + path: str, + payload: dict[str, str | list[str]] | None = None, + params: dict[str, str] | None = None, +) -> ZimfarmResponse: + func = { + GET: requests.get, + POST: requests.post, + PATCH: requests.patch, + DELETE: requests.delete, + }.get(method.upper(), requests.get) + try: + req = func( + url=get_url(path), + headers=get_token_headers(TokenData.ACCESS_TOKEN), + json=payload, + params=params, + timeout=constants.zimfarm_request_timeout_sec, + ) + except Exception as exc: + logger.exception(exc) + return ZimfarmResponse(False, 900, f"ConnectionError -- {exc!s}") + + try: + resp = req.json() if req.text else {} + except json.JSONDecodeError: + return ZimfarmResponse( + False, + req.status_code, + f"ResponseError (not JSON): -- {req.text}", + ) + except Exception as exc: + return ZimfarmResponse( + False, + req.status_code, + f"ResponseError -- {exc!s} -- {req.text}", + ) + + if ( + req.status_code >= HTTPStatus.OK + and req.status_code < HTTPStatus.MULTIPLE_CHOICES + ): + return ZimfarmResponse(True, req.status_code, resp) + + # Unauthorised error: attempt to re-auth as scheduler might have restarted? + if req.status_code == HTTPStatus.UNAUTHORIZED: + authenticate(force=True) + + reason = resp["error"] if "error" in resp else str(resp) + if "error_description" in resp: + reason = f"{reason}: {resp['error_description']}" + return ZimfarmResponse(False, req.status_code, reason) + + +@auth_required +def test_connection(): + return query_api(GET, "/auth/test") + + +def request_task( + archive_id: UUID, request_def: RequestSchema, email: str | None +) -> str: + ident = uuid4().hex + + flags = { + "collection": request_def.collection_url, + "name": request_def.name, + "output": "/output", + "zim_file": f"nautilus_{archive_id}_{ident}.zim", + "language": request_def.language, + "title": request_def.title, + "description": request_def.description, + "creator": request_def.creator, + "publisher": request_def.publisher, + "tags": request_def.tags, + "main_logo": request_def.main_logo_url, + "favicon": request_def.illustration_url, + } + + config = { + "task_name": "nautilus", + "warehouse_path": "/other", + "image": { + "name": constants.zimfarm_nautilus_image.split(":")[0], + "tag": constants.zimfarm_nautilus_image.split(":")[1], + }, + "resources": { + "cpu": constants.zimfarm_task_cpu, + "memory": constants.zimfarm_task_memory, + "disk": constants.zimfarm_task_disk, + }, + "platform": None, + "monitor": False, + "flags": flags, + } + + # gen schedule name + schedule_name = f"nautilus_{archive_id}_{ident}" + # create schedule payload + payload = { + "name": schedule_name, + "language": {"code": "eng", "name_en": "English", "name_native": "English"}, + "category": "other", + "periodicity": "manually", + "tags": [], + "enabled": True, + "config": config, + } + + # add notification callback if email supplied + if email: + url = ( + f"{constants.zimfarm_callback_base_url}" + f"?token={constants.zimfarm_callback_token}&target={email}" + ) + payload.update( + { + "notification": { + "requested": {"webhook": [url]}, + "ended": {"webhook": [url]}, + } + } + ) + + # create a unique schedule for that request on the zimfarm + success, status, resp = query_api("POST", "/schedules/", payload=payload) + if not success: + logger.error(f"Unable to create schedule via HTTP {status}: {resp}") + message = f"Unable to create schedule via HTTP {status}: {resp}" + if status == HTTPStatus.BAD_REQUEST: + # if Zimfarm replied this is a bad request, then this is most probably + # a bad request due to user input so we can track it like a bad request + raise ZimfarmAPIError(message, status) + else: + # otherwise, this is most probably an internal problem in our systems + raise ZimfarmAPIError(message, status) + + # request a task for that newly created schedule + success, status, resp = query_api( + "POST", + "/requested-tasks/", + payload={ + "schedule_names": [schedule_name], + "worker": constants.zimfarm_task_worker, + }, + ) + if not success: + logger.error(f"Unable to request {schedule_name} via HTTP {status}: {resp}") + raise ZimfarmAPIError(f"Unable to request schedule: {resp}", status) + + if not isinstance(resp, dict): + raise ZimfarmAPIError(f"response is unexpected format ({type(resp)})") + + try: + task_id = resp["requested"].pop() + if not task_id: + raise ValueError(f"task_id is empty? `{task_id}`") + except Exception as exc: + raise ZimfarmAPIError(f"Couldn't retrieve requested task id: {exc!s}") from exc + + # remove newly created schedule (not needed anymore) + success, status, resp = query_api("DELETE", f"/schedules/{schedule_name}") + if not success: + logger.error( + f"Unable to remove schedule {schedule_name} via HTTP {status}: {resp}" + ) + return str(task_id) diff --git a/backend/migrations/versions/be9763d49e5f_archives_completed_on.py b/backend/migrations/versions/be9763d49e5f_archives_completed_on.py new file mode 100644 index 0000000..782d85a --- /dev/null +++ b/backend/migrations/versions/be9763d49e5f_archives_completed_on.py @@ -0,0 +1,28 @@ +"""archives-completed-on + +Revision ID: be9763d49e5f +Revises: 4d766c1cc6a3 +Create Date: 2024-06-06 18:04:36.401616 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "be9763d49e5f" +down_revision = "4d766c1cc6a3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("archive", sa.Column("completed_on", sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("archive", "completed_on") + # ### end Alembic commands ### diff --git a/backend/pyproject.toml b/backend/pyproject.toml index f1e474f..d4ccba1 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "zimscraperlib==3.3.2", "humanfriendly==10.0", "rq==1.16.2", + "werkzeug==3.0.3", ] dynamic = ["authors", "license", "version", "urls"] @@ -39,7 +40,7 @@ lint = [ "ruff==0.4.3", ] check = [ - "pyright==1.1.362", + "pyright==1.1.366", ] test = [ "pytest==8.2.0", @@ -47,6 +48,7 @@ test = [ "pytest-mock==3.14.0", ] dev = [ + "ipython==8.25.0", "pre-commit==3.7.0", "nautilus-api[scripts]", "nautilus-api[lint]", @@ -198,6 +200,7 @@ ban-relative-imports = "all" [tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] +"**/migrations/**/*" = ["F401", "ISC001"] [tool.pytest.ini_options] minversion = "7.3" diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 772ac7b..1343938 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -9,8 +9,8 @@ from api.database import Session from api.database.models import Archive, File, Project, User from api.entrypoint import app +from api.files import save_file from api.routes.archives import ArchiveStatus -from api.routes.files import save_file @pytest.fixture() @@ -161,6 +161,7 @@ def archive_id(test_archive_name, project_id): config={"filename": test_archive_name}, filesize=None, requested_on=None, + completed_on=None, download_url=None, collection_json_path=None, zimfarm_task_id=None, diff --git a/backend/tests/routes/test_archives.py b/backend/tests/routes/test_archives.py index b02d6fb..eae4d78 100644 --- a/backend/tests/routes/test_archives.py +++ b/backend/tests/routes/test_archives.py @@ -69,7 +69,7 @@ def test_update_archive_correct_data(logged_in_client, project_id, archive_id): "name": "test_name", "publisher": "test_publisher", "creator": "test_creator", - "languages": ["en"], + "languages": "en", "tags": ["test_tags"], }, } @@ -91,7 +91,7 @@ def test_update_archive_correct_data(logged_in_client, project_id, archive_id): assert json_result.get("config").get("name") == "test_name" assert json_result.get("config").get("publisher") == "test_publisher" assert json_result.get("config").get("creator") == "test_creator" - assert json_result.get("config").get("languages")[0] == "en" + assert json_result.get("config").get("languages") == "en" assert json_result.get("config").get("tags")[0] == "test_tags" @@ -105,7 +105,7 @@ def test_update_archive_wrong_id(logged_in_client, project_id, missing_archive_i "name": "test_name", "publisher": "test_publisher", "creator": "test_creator", - "languages": ["en"], + "languages": "en", "tags": ["test_tags"], }, }