From 57aee61b4c3a5e400e54462e1834cfb04dab202f Mon Sep 17 00:00:00 2001 From: Matthieu Petiteau Date: Fri, 5 Jun 2020 21:46:34 +0100 Subject: [PATCH] Add mypy typing --- Makefile | 5 +++++ README.md | 18 +++++------------- shhh/api/api.py | 10 ++++++++-- shhh/api/services.py | 27 ++++++++++++++------------- shhh/api/utils.py | 3 ++- shhh/api/validators.py | 16 ++++++++-------- shhh/config.py | 3 ++- shhh/extensions.py | 11 +++++++---- shhh/views.py | 12 +++++++----- test-requirements.txt | 1 + 10 files changed, 59 insertions(+), 47 deletions(-) diff --git a/Makefile b/Makefile index 5a216ec1..ab6804dc 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,8 @@ help: @echo " Run tests" @echo "make lint" @echo " Run pylint" + @echo "make mypy" + @echo " Run mypy" @echo "make secure" @echo " Run bandit" @@ -31,5 +33,8 @@ tests: env test-env lint: env test-env @pylint --rcfile=.pylintrc shhh +mypy: env test-env + @mypy --ignore-missing-imports shhh + secure: env test-env @bandit -r shhh diff --git a/README.md b/README.md index 31c6b472..766ab7b3 100644 --- a/README.md +++ b/README.md @@ -178,21 +178,13 @@ access: -## Run the tests +## Run the checks -You can run the tests using: ```sh -make tests -``` - -Run Pylint report using: -```sh -make lint -``` - -Run Bandit report using: -```sh -make secure +make tests # run tests +make lint # run pylint report +make secure # run bandit report +make mypy # run mypy report ``` ## Credits diff --git a/shhh/api/api.py b/shhh/api/api.py index 4b20fa08..70d4b318 100644 --- a/shhh/api/api.py +++ b/shhh/api/api.py @@ -1,5 +1,6 @@ # pylint: disable=no-self-use,too-many-arguments import functools +from typing import Dict, Tuple from marshmallow import Schema, fields, validates_schema from flask import Blueprint @@ -38,7 +39,12 @@ def haveibeenpwned_checker(self, data, **kwargs): class Create(Resource): """/api/c Create secret API.""" @json(CreateParams()) - def post(self, passphrase, secret, days=3, tries=5, haveibeenpwned=False): + def post(self, + passphrase: str, + secret: str, + days: int = 3, + tries: int = 5, + haveibeenpwned: bool = False) -> Tuple[Dict, int]: """Post request handler.""" response, code = create_secret(passphrase, secret, days, tries, haveibeenpwned) @@ -56,7 +62,7 @@ class ReadParams(Schema): class Read(Resource): """/api/r Read secret API.""" @query(ReadParams()) - def get(self, slug, passphrase): + def get(self, slug: str, passphrase: str) -> Tuple[Dict, int]: """Get request handler.""" response, code = read_secret(slug, passphrase) return {"response": response}, code diff --git a/shhh/api/services.py b/shhh/api/services.py index a5c7d581..2ef2cb9b 100644 --- a/shhh/api/services.py +++ b/shhh/api/services.py @@ -1,19 +1,18 @@ import html import secrets - from base64 import urlsafe_b64decode, urlsafe_b64encode from datetime import datetime, timedelta, timezone - -from flask import current_app as app -from flask import request +from typing import Dict, Tuple from cryptography.fernet import Fernet, InvalidToken from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from flask import current_app as app +from flask import request -from shhh.models import Entries from shhh.api.validators import Status +from shhh.models import Entries class Secret: @@ -21,11 +20,11 @@ class Secret: __slots__ = ("secret", "passphrase") - def __init__(self, secret, passphrase): + def __init__(self, secret: bytes, passphrase: str): self.secret = secret self.passphrase = passphrase - def __derive_key(self, salt, iterations): + def __derive_key(self, salt: bytes, iterations: int) -> bytes: """Derive a secret key from a given passphrase and salt.""" kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), length=32, @@ -34,7 +33,7 @@ def __derive_key(self, salt, iterations): backend=default_backend()) return urlsafe_b64encode(kdf.derive(self.passphrase.encode())) - def encrypt(self, iterations=100_000): + def encrypt(self, iterations: int = 100_000) -> bytes: """Encrypt secret.""" salt = secrets.token_bytes(16) key = self.__derive_key(salt, iterations) @@ -42,7 +41,7 @@ def encrypt(self, iterations=100_000): b"%b%b%b" % (salt, iterations.to_bytes(4, "big"), urlsafe_b64decode(Fernet(key).encrypt(self.secret)))) - def decrypt(self): + def decrypt(self) -> str: """Decrypt secret.""" decoded = urlsafe_b64decode(self.secret) salt, iteration, message = (decoded[:16], decoded[16:20], @@ -52,7 +51,7 @@ def decrypt(self): return Fernet(key).decrypt(message).decode("utf-8") -def _generate_unique_slug(): +def _generate_unique_slug() -> str: """Generates a unique slug link. This function will loop recursively on itself to make sure the slug @@ -65,7 +64,7 @@ def _generate_unique_slug(): return _generate_unique_slug() -def read_secret(slug, passphrase): +def read_secret(slug: str, passphrase: str) -> Tuple[Dict, int]: """Read a secret. Args: @@ -106,7 +105,8 @@ def read_secret(slug, passphrase): return dict(status=Status.SUCCESS.value, msg=html.escape(msg)), 200 -def create_secret(passphrase, secret, expire, tries, haveibeenpwned): +def create_secret(passphrase: str, secret: str, expire: int, tries: int, + haveibeenpwned: bool) -> Tuple[Dict, int]: """Create a secret. Args: @@ -137,4 +137,5 @@ def create_secret(passphrase, secret, expire, tries, haveibeenpwned): details="Secret successfully created.", slug=slug, link=f"{request.url_root}r/{slug}", - expires_on=f"{expiration_date.strftime('%Y-%m-%d at %H:%M')} {timez}"), 201 + expires_on=f"{expiration_date.strftime('%Y-%m-%d at %H:%M')} {timez}" + ), 201 diff --git a/shhh/api/utils.py b/shhh/api/utils.py index f721d25b..7bbd307f 100644 --- a/shhh/api/utils.py +++ b/shhh/api/utils.py @@ -1,9 +1,10 @@ import hashlib +from typing import Union import requests -def pwned_password(passphrase): +def pwned_password(passphrase: str) -> Union[int, bool]: """Check passphrase with Troy's Hunt haveibeenpwned API. Query the API to check if the passphrase has already been pwned in the diff --git a/shhh/api/validators.py b/shhh/api/validators.py index f6832318..2aed935a 100644 --- a/shhh/api/validators.py +++ b/shhh/api/validators.py @@ -27,7 +27,7 @@ def handle_parsing_error(err, req, schema, *, error_status_code, abort(422, response=dict(details=err.messages, status=Status.ERROR.value)) -def validate_strength(passphrase): +def validate_strength(passphrase: str) -> None: """Passphrase strength validation handler. Minimum 8 characters containing at least one number and one uppercase. @@ -41,13 +41,13 @@ def validate_strength(passphrase): "1 number and 1 uppercase.") -def validate_haveibeenpwned(passphrase): +def validate_haveibeenpwned(passphrase: str) -> None: """Validate passphrase against haveibeenpwned API.""" try: times_pwned = utils.pwned_password(passphrase) except Exception as err: # pylint: disable=broad-except app.logger.error(err) - times_pwned = None # don't break if service isn't reachable. + times_pwned = False # don't break if service isn't reachable. if times_pwned: raise ValidationError( @@ -55,7 +55,7 @@ def validate_haveibeenpwned(passphrase): "(haveibeenpwned.com), please chose another one.") -def validate_secret(secret): +def validate_secret(secret: str) -> None: """Secret validation handler.""" if not secret: raise ValidationError("Missing a secret to encrypt.") @@ -64,13 +64,13 @@ def validate_secret(secret): "The secret needs to have less than 150 characters.") -def validate_passphrase(passphrase): +def validate_passphrase(passphrase: str) -> None: """Passphrase validation handler.""" if not passphrase: raise ValidationError("Missing a passphrase.") -def validate_days(days): +def validate_days(days: int) -> None: """Expiration validation handler.""" if days == 0: raise ValidationError( @@ -80,7 +80,7 @@ def validate_days(days): "The maximum number of days to keep the secret alive is 7.") -def validate_tries(tries): +def validate_tries(tries: int) -> None: """Maximum tries validation handler.""" if tries < 3: raise ValidationError( @@ -90,7 +90,7 @@ def validate_tries(tries): "The maximum number of tries to decrypt the secret is 10.") -def validate_slug(slug): +def validate_slug(slug: str) -> None: """Link validation handler.""" if not slug: raise ValidationError("Missing a secret link.") diff --git a/shhh/config.py b/shhh/config.py index 6ad9205c..bb85b47b 100644 --- a/shhh/config.py +++ b/shhh/config.py @@ -1,4 +1,5 @@ import os +from typing import Optional from shhh.scheduler import delete_expired_links @@ -27,7 +28,7 @@ class DefaultConfig: # SqlAlchemy SQLALCHEMY_ECHO = True SQLALCHEMY_TRACK_MODIFICATIONS = False - SQLALCHEMY_DATABASE_URI = ( + SQLALCHEMY_DATABASE_URI: Optional[str] = ( f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}" f"@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}") diff --git a/shhh/extensions.py b/shhh/extensions.py index 57a01398..90dcb3cf 100644 --- a/shhh/extensions.py +++ b/shhh/extensions.py @@ -1,4 +1,6 @@ # All extensions are used as singletons and initialized in application factory. +from typing import Any, Union + from flask_apscheduler import APScheduler from flask_assets import Environment from flask_sqlalchemy import SQLAlchemy, Model @@ -6,26 +8,27 @@ class CRUDMixin(Model): """Add convenience methods for CRUD operations with SQLAlchemy.""" + @classmethod - def create(cls, **kwargs): + def create(cls, **kwargs) -> Union[bool, Any]: """Create a new record and save it the database.""" instance = cls(**kwargs) return instance.save() - def update(self, commit=True, **kwargs): + def update(self, commit: bool = True, **kwargs) -> Union[bool, Any]: """Update specific fields of a record.""" for attr, value in kwargs.items(): setattr(self, attr, value) return self.save() if commit else self - def save(self, commit=True): + def save(self, commit: bool = True) -> Union[bool, Any]: """Save the record.""" db.session.add(self) if commit: db.session.commit() return self - def delete(self, commit=True): + def delete(self, commit: bool = True) -> Union[bool, Any]: """Remove the record from the database.""" db.session.delete(self) return commit and db.session.commit() diff --git a/shhh/views.py b/shhh/views.py index bc0ea45b..fd8dce75 100644 --- a/shhh/views.py +++ b/shhh/views.py @@ -1,4 +1,5 @@ import inspect +from typing import Tuple from flask import current_app as app from flask import redirect @@ -15,6 +16,7 @@ def qs_to_args(f): querystring. Check that the query keys are matching the args. """ + def wrapper(*args, **kwargs): if sorted(inspect.getfullargspec(f).args) != sorted( @@ -26,32 +28,32 @@ def wrapper(*args, **kwargs): @app.route("/") -def create(): +def create() -> str: """View to create a secret.""" return rt("create.html") @app.route("/c") @qs_to_args -def created(link, expires_on): +def created(link: str, expires_on: str) -> str: """View to see the link for the created secret.""" return rt("created.html", link=link, expires_on=expires_on) @app.route("/r/") -def read(slug): +def read(slug: str) -> str: """View to read a secret.""" return rt("read.html", slug=slug) @app.errorhandler(404) -def not_found(error): +def not_found(error: str) -> Tuple[str, int]: """404 handler.""" return rt("404.html", error=error), 404 @app.route("/robots.txt") -def robots(): +def robots() -> str: """Robots handler.""" return send_from_directory(app.static_folder, request.path[1:]) diff --git a/test-requirements.txt b/test-requirements.txt index d8a96e6c..73dfee2c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ pylint bandit responses +mypy