Skip to content

Commit

Permalink
Better handler structure
Browse files Browse the repository at this point in the history
  • Loading branch information
smallwat3r committed Sep 30, 2023
1 parent 8bb0528 commit 9f65afc
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 69 deletions.
17 changes: 5 additions & 12 deletions shhh/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,23 @@
import functools
from typing import TYPE_CHECKING

from flask import Blueprint, make_response
from flask import Blueprint
from flask.views import MethodView
from webargs.flaskparser import abort, parser, use_kwargs

from shhh.api import handlers
from shhh.api.handlers import ErrorHandler, ReadHandler, WriteHandler
from shhh.api.schemas import ReadRequest, WriteRequest

if TYPE_CHECKING:
from http import HTTPStatus
from typing import NoReturn

Check warning on line 14 in shhh/api/api.py

View check run for this annotation

Codecov / codecov/patch

shhh/api/api.py#L14

Added line #L14 was not covered by tests

from flask import Response
from marshmallow import ValidationError

from shhh.api.schemas import CallableResponse


def _handle(response: CallableResponse, code: HTTPStatus) -> Response:
return make_response(response(), code)


Check warning on line 19 in shhh/api/api.py

View check run for this annotation

Codecov / codecov/patch

shhh/api/api.py#L16-L19

Added lines #L16 - L19 were not covered by tests
@parser.error_handler
def handle_parsing_error(err: ValidationError, *args, **kwargs) -> NoReturn:
abort(_handle(*handlers.parse_error(err)))
abort(ErrorHandler(err).make_response())


body = functools.partial(use_kwargs, location="json")
Expand All @@ -37,11 +30,11 @@ class Api(MethodView):

@query(ReadRequest())
def get(self, *args, **kwargs) -> Response:
return _handle(*handlers.read(*args, **kwargs))
return ReadHandler(*args, **kwargs).make_response()

@body(WriteRequest())
def post(self, *args, **kwargs) -> Response:
return _handle(*handlers.write(*args, **kwargs))
return WriteHandler(*args, **kwargs).make_response()


api = Blueprint("api", __name__, url_prefix="/api")
Expand Down
149 changes: 92 additions & 57 deletions shhh/api/handlers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from http import HTTPStatus
from typing import TYPE_CHECKING

from cryptography.fernet import InvalidToken
from flask import current_app as app
from flask import current_app as app, make_response
from sqlalchemy.orm.exc import NoResultFound

from shhh.api.schemas import ErrorResponse, ReadResponse, WriteResponse
Expand All @@ -14,66 +15,100 @@
from shhh.liveness import db_liveness_ping

if TYPE_CHECKING:
from flask import Response
from marshmallow import ValidationError

from shhh.api.schemas import CallableResponse

Check warning on line 21 in shhh/api/handlers.py

View check run for this annotation

Codecov / codecov/patch

shhh/api/handlers.py#L21

Added line #L21 was not covered by tests

@db_liveness_ping(ClientType.WEB)
def read(external_id: str, passphrase: str) -> tuple[ReadResponse, HTTPStatus]:
try:
secret = db.session.query(model.Secret).filter(
model.Secret.has_external_id(external_id)).one()
except NoResultFound:
return (ReadResponse(Status.EXPIRED, Message.NOT_FOUND),
HTTPStatus.NOT_FOUND)

try:
message = secret.decrypt(passphrase)
except InvalidToken:
remaining = secret.tries - 1
if remaining == 0:
# number of tries exceeded, delete secret
app.logger.info("%s tries to open secret exceeded", str(secret))
db.session.delete(secret)

class Handler(ABC):

@abstractmethod
def handle(self) -> tuple[CallableResponse, HTTPStatus]:
pass

Check warning on line 28 in shhh/api/handlers.py

View check run for this annotation

Codecov / codecov/patch

shhh/api/handlers.py#L28

Added line #L28 was not covered by tests

def make_response(self) -> Response:
response, code = self.handle()
return make_response(response(), code)


class ReadHandler(Handler):

def __init__(self, external_id: str, passphrase: str) -> None:
self.external_id = external_id
self.passphrase = passphrase

@db_liveness_ping(ClientType.WEB)
def handle(self) -> tuple[ReadResponse, HTTPStatus]:
try:
secret = db.session.query(model.Secret).filter(
model.Secret.has_external_id(self.external_id)).one()
except NoResultFound:
return (ReadResponse(Status.EXPIRED, Message.NOT_FOUND),
HTTPStatus.NOT_FOUND)

try:
message = secret.decrypt(self.passphrase)
except InvalidToken:
remaining = secret.tries - 1
if remaining == 0:
# number of tries exceeded, delete secret
app.logger.info("%s tries to open secret exceeded",
str(secret))
db.session.delete(secret)
db.session.commit()
return (ReadResponse(Status.INVALID, Message.EXCEEDED),
HTTPStatus.UNAUTHORIZED)

secret.tries = remaining
db.session.commit()
return (ReadResponse(Status.INVALID, Message.EXCEEDED),
app.logger.info(
"%s wrong passphrase used. Number of tries remaining: %s",
str(secret),
remaining)
return (ReadResponse(
Status.INVALID,
Message.INVALID.value.format(remaining=remaining)),
HTTPStatus.UNAUTHORIZED)

secret.tries = remaining
db.session.delete(secret)
db.session.commit()
app.logger.info(
"%s wrong passphrase used. Number of tries remaining: %s",
str(secret),
remaining)
return (ReadResponse(
Status.INVALID, Message.INVALID.value.format(remaining=remaining)),
HTTPStatus.UNAUTHORIZED)

db.session.delete(secret)
db.session.commit()
app.logger.info("%s was decrypted and deleted", str(secret))
return ReadResponse(Status.SUCCESS, message), HTTPStatus.OK


@db_liveness_ping(ClientType.WEB)
def write(passphrase: str, secret: str, expire: str,
tries: int) -> tuple[WriteResponse, HTTPStatus]:
encrypted_secret = model.Secret.encrypt(message=secret,
passphrase=passphrase,
expire_code=expire,
tries=tries)
db.session.add(encrypted_secret)
db.session.commit()
app.logger.info("%s created", str(encrypted_secret))
return (WriteResponse(encrypted_secret.external_id,
encrypted_secret.expires_on_text),
HTTPStatus.CREATED)


def parse_error(
error_exc: ValidationError) -> tuple[ErrorResponse, HTTPStatus]:
messages = error_exc.normalized_messages()
error = ""
for source in ("json", "query"):
for _, message in messages.get(source, {}).items():
error += f"{message[0]} "
return ErrorResponse(error.strip()), HTTPStatus.UNPROCESSABLE_ENTITY
app.logger.info("%s was decrypted and deleted", str(secret))
return ReadResponse(Status.SUCCESS, message), HTTPStatus.OK


class WriteHandler(Handler):

def __init__(self, passphrase: str, secret: str, expire: str,
tries: int) -> None:
self.passphrase = passphrase
self.secret = secret
self.expire = expire
self.tries = tries

@db_liveness_ping(ClientType.WEB)
def handle(self) -> tuple[WriteResponse, HTTPStatus]:
encrypted_secret = model.Secret.encrypt(message=self.secret,
passphrase=self.passphrase,
expire_code=self.expire,
tries=self.tries)
db.session.add(encrypted_secret)
db.session.commit()
app.logger.info("%s created", str(encrypted_secret))
return (WriteResponse(encrypted_secret.external_id,
encrypted_secret.expires_on_text),
HTTPStatus.CREATED)


class ErrorHandler(Handler):

def __init__(self, error_exc: ValidationError) -> None:
self.error_exc = error_exc

def handle(self) -> tuple[ErrorResponse, HTTPStatus]:
messages = self.error_exc.normalized_messages()
error = ""
for source in ("json", "query"):
for _, message in messages.get(source, {}).items():
error += f"{message[0]} "
return ErrorResponse(error.strip()), HTTPStatus.UNPROCESSABLE_ENTITY

0 comments on commit 9f65afc

Please sign in to comment.