Skip to content

Commit

Permalink
Improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
smallwat3r committed Sep 12, 2023
1 parent 7a5695a commit d06588a
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
23 changes: 15 additions & 8 deletions shhh/api/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import re
import functools
import re
from typing import Mapping

from flask import Blueprint, Response, make_response
from flask import current_app as app
from flask import (Blueprint,
Request,
Response,
current_app as app,
make_response)
from flask.views import MethodView
from marshmallow import Schema, fields, pre_load, validate
from marshmallow import ValidationError
from marshmallow import Schema, ValidationError, fields, pre_load, validate
from webargs.flaskparser import abort, parser, use_kwargs

from shhh.api import handlers
Expand Down Expand Up @@ -52,15 +55,19 @@ class _CreateSchema(Schema):
validate=validate.OneOf(READ_TRIES_VALUES))

@pre_load
def secret_sanitise_newline(self, data: dict, **kwargs) -> dict[str, str]:
def secret_sanitise_newline(self, data, **kwargs):
if isinstance(data.get("secret"), str):
data["secret"] = "\n".join(data["secret"].splitlines())
return data


@parser.error_handler
def handle_parsing_error(err, req, schema, *, error_status_code,
error_headers):
def handle_parsing_error(err: ValidationError,
req: Request,
schema: Schema,
*,
error_status_code: int,
error_headers: Mapping[str, str]):
response, code = handlers.parse_error(err)
return abort(make_response(response.make(), code))

Expand Down
7 changes: 5 additions & 2 deletions shhh/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cryptography.fernet import InvalidToken
from flask import current_app as app
from marshmallow import ValidationError
from sqlalchemy.orm.exc import NoResultFound

from shhh.api.responses import (ErrorResponse,
Expand Down Expand Up @@ -67,9 +68,11 @@ def write_secret(passphrase: str, message: str, expire_code: str,
HTTPStatus.CREATED)


def parse_error(errors) -> tuple[ErrorResponse, HTTPStatus]:
def parse_error(
error_exc: ValidationError) -> tuple[ErrorResponse, HTTPStatus]:
messages = error_exc.normalized_messages()
error = ""
for source in ("json", "query"):
for _, message in errors.messages.get(source, {}).items():
for _, message in messages.get(source, {}).items():
error += f"{message[0]} "
return ErrorResponse(error.strip()), HTTPStatus.UNPROCESSABLE_ENTITY
4 changes: 2 additions & 2 deletions shhh/api/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from urllib.parse import urljoin

from flask import current_app as app, jsonify, request, url_for
from flask import Response, current_app as app, jsonify, request, url_for


class Status(str, Enum):
Expand All @@ -27,7 +27,7 @@ class Message(str, Enum):
@dataclass
class BaseResponse:

def make(self):
def make(self) -> Response:
return jsonify(
{"response": {
f.name: getattr(self, f.name)
Expand Down
3 changes: 2 additions & 1 deletion shhh/scheduler/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Iterable

from shhh.constants import ClientType
from shhh.domain import model
Expand All @@ -19,7 +20,7 @@ def delete_expired_records() -> None:
len(expired_secrets))


def _delete_records(records: list[model.Secret]) -> None:
def _delete_records(records: Iterable[model.Secret]) -> None:
for record in records:
db.session.delete(record)
db.session.commit()
4 changes: 1 addition & 3 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def test_api_get_exceeded_tries(app, secret):
def test_api_message_expired(app):
with app.test_request_context(), app.test_client() as test_client:
response = test_client.get(
url_for("api.secret",
external_id="123456",
passphrase="Hello123"))
url_for("api.secret", external_id="123456", passphrase="Hello123"))
assert response.status_code == HTTPStatus.NOT_FOUND
data = response.get_json()
assert data["response"]["status"] == Status.EXPIRED
Expand Down

0 comments on commit d06588a

Please sign in to comment.