diff --git a/shhh/__init__.py b/shhh/__init__.py index 8e10cb46..e94f36fe 100644 --- a/shhh/__init__.py +++ b/shhh/__init__.py @@ -1 +1 @@ -__version__ = "3.0.4" +__version__ = "3.0.5" diff --git a/shhh/api/api.py b/shhh/api/api.py index bb20376d..614c5e37 100644 --- a/shhh/api/api.py +++ b/shhh/api/api.py @@ -1,93 +1,39 @@ import functools -import re -from typing import Mapping +from http import HTTPStatus +from typing import NoReturn -from flask import (Blueprint, - Request, - Response, - current_app as app, - make_response) +from flask import Blueprint, Response, make_response from flask.views import MethodView -from marshmallow import Schema, ValidationError, fields, pre_load, validate +from marshmallow import ValidationError from webargs.flaskparser import abort, parser, use_kwargs from shhh.api import handlers -from shhh.constants import (DEFAULT_EXPIRATION_TIME_VALUE, - DEFAULT_READ_TRIES_VALUE, - EXPIRATION_TIME_VALUES, - READ_TRIES_VALUES) +from shhh.api.schemas import CallableResponse, ReadRequest, WriteRequest -api = Blueprint("api", __name__, url_prefix="/api") - -json = functools.partial(use_kwargs, location="json") -query = functools.partial(use_kwargs, location="query") - - -class _ReadSchema(Schema): - external_id = fields.Str(required=True) - passphrase = fields.Str(required=True) - - -def _passphrase_validator(passphrase: str) -> None: - regex = re.compile(r"^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9]).{8,}$") - if regex.search(passphrase): - return - raise ValidationError( - "Sorry, your passphrase is too weak. It needs minimum 8 characters, " - "with 1 number and 1 uppercase.") +def _handle(response: CallableResponse, code: HTTPStatus) -> Response: + return make_response(response(), code) -def _secret_validator(secret: str) -> None: - max_length = app.config["SHHH_SECRET_MAX_LENGTH"] - if len(secret) <= max_length: - return - raise ValidationError( - f"The secret should not exceed {max_length} characters") - -class _CreateSchema(Schema): - passphrase = fields.Str(required=True, validate=_passphrase_validator) - secret = fields.Str(required=True, validate=_secret_validator) - expire = fields.Str(dump_default=DEFAULT_EXPIRATION_TIME_VALUE, - validate=validate.OneOf( - EXPIRATION_TIME_VALUES.values())) - tries = fields.Int(dump_default=DEFAULT_READ_TRIES_VALUE, - validate=validate.OneOf(READ_TRIES_VALUES)) - - @pre_load - def secret_sanitise_newline(self, data, **kwargs): - if isinstance(data.get("secret"), str): - data["secret"] = "\n".join(data["secret"].splitlines()) - return data +@parser.error_handler +def handle_parsing_error(err: ValidationError, *args, **kwargs) -> NoReturn: + abort(_handle(*handlers.parse_error(err))) -@parser.error_handler -def handle_parsing_error(err: ValidationError, - req: Request, - schema: Schema, - *, - error_status_code: int, - error_headers: Mapping[str, str]): - response, code = handlers.parse_error(err) - return abort(make_response(response.make(), code)) +body = functools.partial(use_kwargs, location="json") +query = functools.partial(use_kwargs, location="query") class Api(MethodView): - @query(_ReadSchema()) - def get(self, external_id: str, passphrase: str) -> Response: - response, code = handlers.read_secret(external_id, passphrase) - return make_response(response.make(), code) + @query(ReadRequest()) + def get(self, *args, **kwargs) -> Response: + return _handle(*handlers.read(*args, **kwargs)) - @json(_CreateSchema()) - def post(self, - passphrase: str, - secret: str, - expire: str = DEFAULT_EXPIRATION_TIME_VALUE, - tries: int = DEFAULT_READ_TRIES_VALUE) -> Response: - response, code = handlers.write_secret(passphrase, secret, - expire, tries) - return make_response(response.make(), code) + @body(WriteRequest()) + def post(self, *args, **kwargs) -> Response: + return _handle(*handlers.write(*args, **kwargs)) +api = Blueprint("api", __name__, url_prefix="/api") api.add_url_rule("/secret", view_func=Api.as_view("secret")) diff --git a/shhh/api/handlers.py b/shhh/api/handlers.py index 3d3fbcfd..9f1ea640 100644 --- a/shhh/api/handlers.py +++ b/shhh/api/handlers.py @@ -5,11 +5,11 @@ from marshmallow import ValidationError from sqlalchemy.orm.exc import NoResultFound -from shhh.api.responses import (ErrorResponse, - Message, - ReadResponse, - Status, - WriteResponse) +from shhh.api.schemas import (ErrorResponse, + Message, + ReadResponse, + Status, + WriteResponse) from shhh.constants import ClientType from shhh.domain import model from shhh.extensions import db @@ -17,8 +17,7 @@ @db_liveness_ping(ClientType.WEB) -def read_secret(external_id: str, - passphrase: str) -> tuple[ReadResponse, HTTPStatus]: +def read(external_id: str, passphrase: str) -> tuple[ReadResponse, HTTPStatus]: try: secret = db.session.query(model.Secret).filter( model.Secret.has_external_id(external_id)).one() @@ -55,16 +54,17 @@ def read_secret(external_id: str, @db_liveness_ping(ClientType.WEB) -def write_secret(passphrase: str, message: str, expire_code: str, - tries: int) -> tuple[WriteResponse, HTTPStatus]: - secret = model.Secret.encrypt(message=message, - passphrase=passphrase, - expire_code=expire_code, - tries=tries) - db.session.add(secret) +def write(passphrase: str, secret: str, expire: str, + tries: int) -> tuple[WriteResponse, HTTPStatus]: + encrypted_secret = model.Secret.encrypt(message=secret, + passphrase=passphrase, + expire_code=expire, + tries=tries) + db.session.add(encrypted_secret) db.session.commit() - app.logger.info("%s created", str(secret)) - return (WriteResponse(secret.external_id, secret.expires_on_text), + app.logger.info("%s created", str(encrypted_secret)) + return (WriteResponse(encrypted_secret.external_id, + encrypted_secret.expires_on_text), HTTPStatus.CREATED) diff --git a/shhh/api/responses.py b/shhh/api/responses.py deleted file mode 100644 index d1777065..00000000 --- a/shhh/api/responses.py +++ /dev/null @@ -1,63 +0,0 @@ -from dataclasses import dataclass, field, fields -from enum import Enum -from urllib.parse import urljoin - -from flask import Response, current_app as app, jsonify, request, url_for - - -class Status(str, Enum): - CREATED = "created" - SUCCESS = "success" - EXPIRED = "expired" - INVALID = "invalid" - ERROR = "error" - - -class Message(str, Enum): - NOT_FOUND = ("Sorry, we can't find a secret, it has expired, been deleted " - "or has already been read.") - EXCEEDED = ("The passphrase is not valid. You've exceeded the number of " - "tries and the secret has been deleted.") - INVALID = ("Sorry, the passphrase is not valid. Number of tries " - "remaining: {remaining}") - CREATED = "Secret successfully created." - UNEXPECTED = "An unexpected error has occurred, please try again." - - -@dataclass -class BaseResponse: - - def make(self) -> Response: - return jsonify( - {"response": { - f.name: getattr(self, f.name) - for f in fields(self) - }}) - - -@dataclass -class ReadResponse(BaseResponse): - status: Status - msg: str - - -@dataclass -class WriteResponse(BaseResponse): - external_id: str - expires_on: str - link: str = field(init=False) - status: Status = Status.CREATED - details: Message = Message.CREATED - - def __post_init__(self): - self.link = urljoin(request.url_root, - url_for("web.read", external_id=self.external_id)) - if host_config := app.config["SHHH_HOST"]: - self.link = urljoin( - host_config, url_for("web.read", external_id=self.external_id)) - - -@dataclass -class ErrorResponse(BaseResponse): - details: str - status: Status = Status.ERROR diff --git a/shhh/api/schemas.py b/shhh/api/schemas.py new file mode 100644 index 00000000..a0e23141 --- /dev/null +++ b/shhh/api/schemas.py @@ -0,0 +1,95 @@ +import re +from dataclasses import dataclass, field, fields as dfields +from urllib.parse import urljoin + +from flask import Response, current_app as app, jsonify, request, url_for +from marshmallow import Schema, ValidationError, fields, pre_load, validate + +from shhh.constants import (DEFAULT_EXPIRATION_TIME_VALUE, + DEFAULT_READ_TRIES_VALUE, + EXPIRATION_TIME_VALUES, + READ_TRIES_VALUES, + Message, + Status) + + +class ReadRequest(Schema): + """Schema for inbound read requests.""" + external_id = fields.Str(required=True) + passphrase = fields.Str(required=True) + + +def _passphrase_validator(passphrase: str) -> None: + regex = re.compile(r"^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9]).{8,}$") + if not regex.search(passphrase): + raise ValidationError("Sorry, your passphrase is too weak. It needs " + "minimum 8 characters, with 1 number and 1 " + "uppercase.") + + +def _secret_validator(secret: str) -> None: + max_length = app.config["SHHH_SECRET_MAX_LENGTH"] + if len(secret) > max_length: + raise ValidationError(f"The secret should not exceed {max_length} " + "characters.") + + +class WriteRequest(Schema): + """Schema for inbound write requests.""" + passphrase = fields.Str(required=True, validate=_passphrase_validator) + secret = fields.Str(required=True, validate=_secret_validator) + expire = fields.Str(load_default=DEFAULT_EXPIRATION_TIME_VALUE, + validate=validate.OneOf( + EXPIRATION_TIME_VALUES.values())) + tries = fields.Int(load_default=DEFAULT_READ_TRIES_VALUE, + validate=validate.OneOf(READ_TRIES_VALUES)) + + @pre_load + def secret_sanitise_newline(self, data, **kwargs): + if isinstance(data.get("secret"), str): + data["secret"] = "\n".join(data["secret"].splitlines()) + return data + + +@dataclass +class CallableResponse: + + def __call__(self) -> Response: + return jsonify({ + "response": { + f.name: getattr(self, f.name) + for f in dfields(self) + } + }) + + +@dataclass +class ReadResponse(CallableResponse): + """Schema for outbound read responses.""" + status: Status + msg: str + + +def _build_link_url(external_id: str) -> str: + root_host = app.config.get("SHHH_HOST") or request.url_root + return urljoin(root_host, url_for("web.read", external_id=external_id)) + + +@dataclass +class WriteResponse(CallableResponse): + """Schema for outbound write responses.""" + external_id: str + expires_on: str + link: str = field(init=False) + status: Status = Status.CREATED + details: Message = Message.CREATED + + def __post_init__(self): + self.link = _build_link_url(self.external_id) + + +@dataclass +class ErrorResponse(CallableResponse): + """Schema for outbound error responses.""" + details: str + status: Status = Status.ERROR diff --git a/shhh/config.py b/shhh/config.py index 5fc89d48..e9c1708e 100644 --- a/shhh/config.py +++ b/shhh/config.py @@ -67,6 +67,7 @@ class TestConfig(DefaultConfig): SQLALCHEMY_DATABASE_URI = "sqlite://" SHHH_HOST = "http://test.test" + SHHH_SECRET_MAX_LENGTH = 20 SHHH_DB_LIVENESS_RETRY_COUNT = 1 SHHH_DB_LIVENESS_SLEEP_INTERVAL = 0.1 diff --git a/shhh/constants.py b/shhh/constants.py index ca9c5e76..7cc22404 100644 --- a/shhh/constants.py +++ b/shhh/constants.py @@ -23,3 +23,22 @@ class EnvConfig(str, Enum): DEV_DOCKER = "dev-docker" HEROKU = "heroku" PRODUCTION = "production" + + +class Status(str, Enum): + CREATED = "created" + SUCCESS = "success" + EXPIRED = "expired" + INVALID = "invalid" + ERROR = "error" + + +class Message(str, Enum): + NOT_FOUND = ("Sorry, we can't find a secret, it has expired, been deleted " + "or has already been read.") + EXCEEDED = ("The passphrase is not valid. You've exceeded the number of " + "tries and the secret has been deleted.") + INVALID = ("Sorry, the passphrase is not valid. Number of tries " + "remaining: {remaining}") + CREATED = "Secret successfully created." + UNEXPECTED = "An unexpected error has occurred, please try again." diff --git a/shhh/liveness.py b/shhh/liveness.py index 4e9acca0..11a42e7f 100644 --- a/shhh/liveness.py +++ b/shhh/liveness.py @@ -6,8 +6,8 @@ from flask import Flask, Response, abort, current_app as app, make_response from sqlalchemy import text -from shhh.api.responses import ErrorResponse, Message -from shhh.constants import ClientType +from shhh.api.schemas import ErrorResponse +from shhh.constants import ClientType, Message from shhh.extensions import db, scheduler logger = logging.getLogger(__name__) @@ -74,8 +74,7 @@ def _check_web_liveness(f: Callable[..., RT], *args, return f(*args, **kwargs) response = ErrorResponse(Message.UNEXPECTED) - return abort(make_response(response.make(), - HTTPStatus.SERVICE_UNAVAILABLE)) + abort(make_response(response(), HTTPStatus.SERVICE_UNAVAILABLE)) def _check_liveness(client_type: ClientType, diff --git a/tests/test_api.py b/tests/test_api.py index 4cdb78e2..569a7c1a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,7 +5,7 @@ import pytest from flask import url_for -from shhh.api.responses import Message, Status +from shhh.constants import Message, Status from shhh.domain import model from shhh.extensions import db @@ -89,6 +89,17 @@ def test_api_post_weak_passphrase(app, post_payload, passphrase): "characters, with 1 number and 1 uppercase.") +def test_api_post_secret_too_long(app, post_payload): + post_payload["secret"] = "MoreThan20Characters!" + with app.test_request_context(), app.test_client() as test_client: + response = test_client.post(url_for("api.secret"), json=post_payload) + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + data = response.get_json() + assert data["response"]["status"] == Status.ERROR + assert data["response"]["details"] == ("The secret should not exceed " + "20 characters.") + + def test_api_get_wrong_passphrase(app, secret): with app.test_request_context(), app.test_client() as test_client: response = test_client.get(