Skip to content

Commit

Permalink
Make API layer thiner (#395)
Browse files Browse the repository at this point in the history
* Drop deprecated features

* Improve typing

* Bump version

* Edit README

* Make API thiner

* Add test for length validation

* Simplify using custom host
  • Loading branch information
smallwat3r authored Sep 12, 2023
1 parent 0e84366 commit 8b6057c
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 158 deletions.
2 changes: 1 addition & 1 deletion shhh/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.4"
__version__ = "3.0.5"
92 changes: 19 additions & 73 deletions shhh/api/api.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,39 @@
import functools
import re
from typing import Mapping
from http import HTTPStatus
from typing import NoReturn

from flask import (Blueprint,
Request,
Response,
current_app as app,
make_response)
from flask import Blueprint, Response, make_response
from flask.views import MethodView
from marshmallow import Schema, ValidationError, fields, pre_load, validate
from marshmallow import ValidationError
from webargs.flaskparser import abort, parser, use_kwargs

from shhh.api import handlers
from shhh.constants import (DEFAULT_EXPIRATION_TIME_VALUE,
DEFAULT_READ_TRIES_VALUE,
EXPIRATION_TIME_VALUES,
READ_TRIES_VALUES)
from shhh.api.schemas import CallableResponse, ReadRequest, WriteRequest

api = Blueprint("api", __name__, url_prefix="/api")

json = functools.partial(use_kwargs, location="json")
query = functools.partial(use_kwargs, location="query")


class _ReadSchema(Schema):
external_id = fields.Str(required=True)
passphrase = fields.Str(required=True)


def _passphrase_validator(passphrase: str) -> None:
regex = re.compile(r"^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9]).{8,}$")
if regex.search(passphrase):
return
raise ValidationError(
"Sorry, your passphrase is too weak. It needs minimum 8 characters, "
"with 1 number and 1 uppercase.")

def _handle(response: CallableResponse, code: HTTPStatus) -> Response:
return make_response(response(), code)

def _secret_validator(secret: str) -> None:
max_length = app.config["SHHH_SECRET_MAX_LENGTH"]
if len(secret) <= max_length:
return
raise ValidationError(
f"The secret should not exceed {max_length} characters")


class _CreateSchema(Schema):
passphrase = fields.Str(required=True, validate=_passphrase_validator)
secret = fields.Str(required=True, validate=_secret_validator)
expire = fields.Str(dump_default=DEFAULT_EXPIRATION_TIME_VALUE,
validate=validate.OneOf(
EXPIRATION_TIME_VALUES.values()))
tries = fields.Int(dump_default=DEFAULT_READ_TRIES_VALUE,
validate=validate.OneOf(READ_TRIES_VALUES))

@pre_load
def secret_sanitise_newline(self, data, **kwargs):
if isinstance(data.get("secret"), str):
data["secret"] = "\n".join(data["secret"].splitlines())
return data
@parser.error_handler
def handle_parsing_error(err: ValidationError, *args, **kwargs) -> NoReturn:
abort(_handle(*handlers.parse_error(err)))


@parser.error_handler
def handle_parsing_error(err: ValidationError,
req: Request,
schema: Schema,
*,
error_status_code: int,
error_headers: Mapping[str, str]):
response, code = handlers.parse_error(err)
return abort(make_response(response.make(), code))
body = functools.partial(use_kwargs, location="json")
query = functools.partial(use_kwargs, location="query")


class Api(MethodView):

@query(_ReadSchema())
def get(self, external_id: str, passphrase: str) -> Response:
response, code = handlers.read_secret(external_id, passphrase)
return make_response(response.make(), code)
@query(ReadRequest())
def get(self, *args, **kwargs) -> Response:
return _handle(*handlers.read(*args, **kwargs))

@json(_CreateSchema())
def post(self,
passphrase: str,
secret: str,
expire: str = DEFAULT_EXPIRATION_TIME_VALUE,
tries: int = DEFAULT_READ_TRIES_VALUE) -> Response:
response, code = handlers.write_secret(passphrase, secret,
expire, tries)
return make_response(response.make(), code)
@body(WriteRequest())
def post(self, *args, **kwargs) -> Response:
return _handle(*handlers.write(*args, **kwargs))


api = Blueprint("api", __name__, url_prefix="/api")
api.add_url_rule("/secret", view_func=Api.as_view("secret"))
32 changes: 16 additions & 16 deletions shhh/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@
from marshmallow import ValidationError
from sqlalchemy.orm.exc import NoResultFound

from shhh.api.responses import (ErrorResponse,
Message,
ReadResponse,
Status,
WriteResponse)
from shhh.api.schemas import (ErrorResponse,
Message,
ReadResponse,
Status,
WriteResponse)
from shhh.constants import ClientType
from shhh.domain import model
from shhh.extensions import db
from shhh.liveness import db_liveness_ping


@db_liveness_ping(ClientType.WEB)
def read_secret(external_id: str,
passphrase: str) -> tuple[ReadResponse, HTTPStatus]:
def read(external_id: str, passphrase: str) -> tuple[ReadResponse, HTTPStatus]:
try:
secret = db.session.query(model.Secret).filter(
model.Secret.has_external_id(external_id)).one()
Expand Down Expand Up @@ -55,16 +54,17 @@ def read_secret(external_id: str,


@db_liveness_ping(ClientType.WEB)
def write_secret(passphrase: str, message: str, expire_code: str,
tries: int) -> tuple[WriteResponse, HTTPStatus]:
secret = model.Secret.encrypt(message=message,
passphrase=passphrase,
expire_code=expire_code,
tries=tries)
db.session.add(secret)
def write(passphrase: str, secret: str, expire: str,
tries: int) -> tuple[WriteResponse, HTTPStatus]:
encrypted_secret = model.Secret.encrypt(message=secret,
passphrase=passphrase,
expire_code=expire,
tries=tries)
db.session.add(encrypted_secret)
db.session.commit()
app.logger.info("%s created", str(secret))
return (WriteResponse(secret.external_id, secret.expires_on_text),
app.logger.info("%s created", str(encrypted_secret))
return (WriteResponse(encrypted_secret.external_id,
encrypted_secret.expires_on_text),
HTTPStatus.CREATED)


Expand Down
63 changes: 0 additions & 63 deletions shhh/api/responses.py

This file was deleted.

95 changes: 95 additions & 0 deletions shhh/api/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import re
from dataclasses import dataclass, field, fields as dfields
from urllib.parse import urljoin

from flask import Response, current_app as app, jsonify, request, url_for
from marshmallow import Schema, ValidationError, fields, pre_load, validate

from shhh.constants import (DEFAULT_EXPIRATION_TIME_VALUE,
DEFAULT_READ_TRIES_VALUE,
EXPIRATION_TIME_VALUES,
READ_TRIES_VALUES,
Message,
Status)


class ReadRequest(Schema):
"""Schema for inbound read requests."""
external_id = fields.Str(required=True)
passphrase = fields.Str(required=True)


def _passphrase_validator(passphrase: str) -> None:
regex = re.compile(r"^(?=.*?[A-Z])(?=.*?[a-z])(?=.*?[0-9]).{8,}$")
if not regex.search(passphrase):
raise ValidationError("Sorry, your passphrase is too weak. It needs "
"minimum 8 characters, with 1 number and 1 "
"uppercase.")


def _secret_validator(secret: str) -> None:
max_length = app.config["SHHH_SECRET_MAX_LENGTH"]
if len(secret) > max_length:
raise ValidationError(f"The secret should not exceed {max_length} "
"characters.")


class WriteRequest(Schema):
"""Schema for inbound write requests."""
passphrase = fields.Str(required=True, validate=_passphrase_validator)
secret = fields.Str(required=True, validate=_secret_validator)
expire = fields.Str(load_default=DEFAULT_EXPIRATION_TIME_VALUE,
validate=validate.OneOf(
EXPIRATION_TIME_VALUES.values()))
tries = fields.Int(load_default=DEFAULT_READ_TRIES_VALUE,
validate=validate.OneOf(READ_TRIES_VALUES))

@pre_load
def secret_sanitise_newline(self, data, **kwargs):
if isinstance(data.get("secret"), str):
data["secret"] = "\n".join(data["secret"].splitlines())
return data


@dataclass
class CallableResponse:

def __call__(self) -> Response:
return jsonify({
"response": {
f.name: getattr(self, f.name)
for f in dfields(self)
}
})


@dataclass
class ReadResponse(CallableResponse):
"""Schema for outbound read responses."""
status: Status
msg: str


def _build_link_url(external_id: str) -> str:
root_host = app.config.get("SHHH_HOST") or request.url_root
return urljoin(root_host, url_for("web.read", external_id=external_id))


@dataclass
class WriteResponse(CallableResponse):
"""Schema for outbound write responses."""
external_id: str
expires_on: str
link: str = field(init=False)
status: Status = Status.CREATED
details: Message = Message.CREATED

def __post_init__(self):
self.link = _build_link_url(self.external_id)


@dataclass
class ErrorResponse(CallableResponse):
"""Schema for outbound error responses."""
details: str
status: Status = Status.ERROR
1 change: 1 addition & 0 deletions shhh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class TestConfig(DefaultConfig):
SQLALCHEMY_DATABASE_URI = "sqlite://"

SHHH_HOST = "http://test.test"
SHHH_SECRET_MAX_LENGTH = 20
SHHH_DB_LIVENESS_RETRY_COUNT = 1
SHHH_DB_LIVENESS_SLEEP_INTERVAL = 0.1

Expand Down
19 changes: 19 additions & 0 deletions shhh/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,22 @@ class EnvConfig(str, Enum):
DEV_DOCKER = "dev-docker"
HEROKU = "heroku"
PRODUCTION = "production"


class Status(str, Enum):
CREATED = "created"
SUCCESS = "success"
EXPIRED = "expired"
INVALID = "invalid"
ERROR = "error"


class Message(str, Enum):
NOT_FOUND = ("Sorry, we can't find a secret, it has expired, been deleted "
"or has already been read.")
EXCEEDED = ("The passphrase is not valid. You've exceeded the number of "
"tries and the secret has been deleted.")
INVALID = ("Sorry, the passphrase is not valid. Number of tries "
"remaining: {remaining}")
CREATED = "Secret successfully created."
UNEXPECTED = "An unexpected error has occurred, please try again."
7 changes: 3 additions & 4 deletions shhh/liveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from flask import Flask, Response, abort, current_app as app, make_response
from sqlalchemy import text

from shhh.api.responses import ErrorResponse, Message
from shhh.constants import ClientType
from shhh.api.schemas import ErrorResponse
from shhh.constants import ClientType, Message
from shhh.extensions import db, scheduler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,8 +74,7 @@ def _check_web_liveness(f: Callable[..., RT], *args,
return f(*args, **kwargs)

response = ErrorResponse(Message.UNEXPECTED)
return abort(make_response(response.make(),
HTTPStatus.SERVICE_UNAVAILABLE))
abort(make_response(response(), HTTPStatus.SERVICE_UNAVAILABLE))


def _check_liveness(client_type: ClientType,
Expand Down
Loading

0 comments on commit 8b6057c

Please sign in to comment.