Skip to content

Commit

Permalink
Add mypy typing
Browse files Browse the repository at this point in the history
  • Loading branch information
smallwat3r committed Jun 5, 2020
1 parent e9a75c6 commit 57aee61
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 47 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ help:
@echo " Run tests"
@echo "make lint"
@echo " Run pylint"
@echo "make mypy"
@echo " Run mypy"
@echo "make secure"
@echo " Run bandit"

Expand All @@ -31,5 +33,8 @@ tests: env test-env
lint: env test-env
@pylint --rcfile=.pylintrc shhh

mypy: env test-env
@mypy --ignore-missing-imports shhh

secure: env test-env
@bandit -r shhh
18 changes: 5 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,13 @@ access:

</details>

## Run the tests
## Run the checks

You can run the tests using:
```sh
make tests
```

Run Pylint report using:
```sh
make lint
```

Run Bandit report using:
```sh
make secure
make tests # run tests
make lint # run pylint report
make secure # run bandit report
make mypy # run mypy report
```

## Credits
Expand Down
10 changes: 8 additions & 2 deletions shhh/api/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=no-self-use,too-many-arguments
import functools
from typing import Dict, Tuple

from marshmallow import Schema, fields, validates_schema
from flask import Blueprint
Expand Down Expand Up @@ -38,7 +39,12 @@ def haveibeenpwned_checker(self, data, **kwargs):
class Create(Resource):
"""/api/c Create secret API."""
@json(CreateParams())
def post(self, passphrase, secret, days=3, tries=5, haveibeenpwned=False):
def post(self,
passphrase: str,
secret: str,
days: int = 3,
tries: int = 5,
haveibeenpwned: bool = False) -> Tuple[Dict, int]:
"""Post request handler."""
response, code = create_secret(passphrase, secret, days, tries,
haveibeenpwned)
Expand All @@ -56,7 +62,7 @@ class ReadParams(Schema):
class Read(Resource):
"""/api/r Read secret API."""
@query(ReadParams())
def get(self, slug, passphrase):
def get(self, slug: str, passphrase: str) -> Tuple[Dict, int]:
"""Get request handler."""
response, code = read_secret(slug, passphrase)
return {"response": response}, code
Expand Down
27 changes: 14 additions & 13 deletions shhh/api/services.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
import html
import secrets

from base64 import urlsafe_b64decode, urlsafe_b64encode
from datetime import datetime, timedelta, timezone

from flask import current_app as app
from flask import request
from typing import Dict, Tuple

from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from flask import current_app as app
from flask import request

from shhh.models import Entries
from shhh.api.validators import Status
from shhh.models import Entries


class Secret:
"""Secrets encryption / decryption management."""

__slots__ = ("secret", "passphrase")

def __init__(self, secret, passphrase):
def __init__(self, secret: bytes, passphrase: str):
self.secret = secret
self.passphrase = passphrase

def __derive_key(self, salt, iterations):
def __derive_key(self, salt: bytes, iterations: int) -> bytes:
"""Derive a secret key from a given passphrase and salt."""
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(),
length=32,
Expand All @@ -34,15 +33,15 @@ def __derive_key(self, salt, iterations):
backend=default_backend())
return urlsafe_b64encode(kdf.derive(self.passphrase.encode()))

def encrypt(self, iterations=100_000):
def encrypt(self, iterations: int = 100_000) -> bytes:
"""Encrypt secret."""
salt = secrets.token_bytes(16)
key = self.__derive_key(salt, iterations)
return urlsafe_b64encode(
b"%b%b%b" % (salt, iterations.to_bytes(4, "big"),
urlsafe_b64decode(Fernet(key).encrypt(self.secret))))

def decrypt(self):
def decrypt(self) -> str:
"""Decrypt secret."""
decoded = urlsafe_b64decode(self.secret)
salt, iteration, message = (decoded[:16], decoded[16:20],
Expand All @@ -52,7 +51,7 @@ def decrypt(self):
return Fernet(key).decrypt(message).decode("utf-8")


def _generate_unique_slug():
def _generate_unique_slug() -> str:
"""Generates a unique slug link.
This function will loop recursively on itself to make sure the slug
Expand All @@ -65,7 +64,7 @@ def _generate_unique_slug():
return _generate_unique_slug()


def read_secret(slug, passphrase):
def read_secret(slug: str, passphrase: str) -> Tuple[Dict, int]:
"""Read a secret.
Args:
Expand Down Expand Up @@ -106,7 +105,8 @@ def read_secret(slug, passphrase):
return dict(status=Status.SUCCESS.value, msg=html.escape(msg)), 200


def create_secret(passphrase, secret, expire, tries, haveibeenpwned):
def create_secret(passphrase: str, secret: str, expire: int, tries: int,
haveibeenpwned: bool) -> Tuple[Dict, int]:
"""Create a secret.
Args:
Expand Down Expand Up @@ -137,4 +137,5 @@ def create_secret(passphrase, secret, expire, tries, haveibeenpwned):
details="Secret successfully created.",
slug=slug,
link=f"{request.url_root}r/{slug}",
expires_on=f"{expiration_date.strftime('%Y-%m-%d at %H:%M')} {timez}"), 201
expires_on=f"{expiration_date.strftime('%Y-%m-%d at %H:%M')} {timez}"
), 201
3 changes: 2 additions & 1 deletion shhh/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import hashlib
from typing import Union

import requests


def pwned_password(passphrase):
def pwned_password(passphrase: str) -> Union[int, bool]:
"""Check passphrase with Troy's Hunt haveibeenpwned API.
Query the API to check if the passphrase has already been pwned in the
Expand Down
16 changes: 8 additions & 8 deletions shhh/api/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def handle_parsing_error(err, req, schema, *, error_status_code,
abort(422, response=dict(details=err.messages, status=Status.ERROR.value))


def validate_strength(passphrase):
def validate_strength(passphrase: str) -> None:
"""Passphrase strength validation handler.
Minimum 8 characters containing at least one number and one uppercase.
Expand All @@ -41,21 +41,21 @@ def validate_strength(passphrase):
"1 number and 1 uppercase.")


def validate_haveibeenpwned(passphrase):
def validate_haveibeenpwned(passphrase: str) -> None:
"""Validate passphrase against haveibeenpwned API."""
try:
times_pwned = utils.pwned_password(passphrase)
except Exception as err: # pylint: disable=broad-except
app.logger.error(err)
times_pwned = None # don't break if service isn't reachable.
times_pwned = False # don't break if service isn't reachable.

if times_pwned:
raise ValidationError(
f"This password has been pwned {times_pwned} times "
"(haveibeenpwned.com), please chose another one.")


def validate_secret(secret):
def validate_secret(secret: str) -> None:
"""Secret validation handler."""
if not secret:
raise ValidationError("Missing a secret to encrypt.")
Expand All @@ -64,13 +64,13 @@ def validate_secret(secret):
"The secret needs to have less than 150 characters.")


def validate_passphrase(passphrase):
def validate_passphrase(passphrase: str) -> None:
"""Passphrase validation handler."""
if not passphrase:
raise ValidationError("Missing a passphrase.")


def validate_days(days):
def validate_days(days: int) -> None:
"""Expiration validation handler."""
if days == 0:
raise ValidationError(
Expand All @@ -80,7 +80,7 @@ def validate_days(days):
"The maximum number of days to keep the secret alive is 7.")


def validate_tries(tries):
def validate_tries(tries: int) -> None:
"""Maximum tries validation handler."""
if tries < 3:
raise ValidationError(
Expand All @@ -90,7 +90,7 @@ def validate_tries(tries):
"The maximum number of tries to decrypt the secret is 10.")


def validate_slug(slug):
def validate_slug(slug: str) -> None:
"""Link validation handler."""
if not slug:
raise ValidationError("Missing a secret link.")
3 changes: 2 additions & 1 deletion shhh/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

from shhh.scheduler import delete_expired_links

Expand Down Expand Up @@ -27,7 +28,7 @@ class DefaultConfig:
# SqlAlchemy
SQLALCHEMY_ECHO = True
SQLALCHEMY_TRACK_MODIFICATIONS = False
SQLALCHEMY_DATABASE_URI = (
SQLALCHEMY_DATABASE_URI: Optional[str] = (
f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}"
f"@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}")

Expand Down
11 changes: 7 additions & 4 deletions shhh/extensions.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
# All extensions are used as singletons and initialized in application factory.
from typing import Any, Union

from flask_apscheduler import APScheduler
from flask_assets import Environment
from flask_sqlalchemy import SQLAlchemy, Model


class CRUDMixin(Model):
"""Add convenience methods for CRUD operations with SQLAlchemy."""

@classmethod
def create(cls, **kwargs):
def create(cls, **kwargs) -> Union[bool, Any]:
"""Create a new record and save it the database."""
instance = cls(**kwargs)
return instance.save()

def update(self, commit=True, **kwargs):
def update(self, commit: bool = True, **kwargs) -> Union[bool, Any]:
"""Update specific fields of a record."""
for attr, value in kwargs.items():
setattr(self, attr, value)
return self.save() if commit else self

def save(self, commit=True):
def save(self, commit: bool = True) -> Union[bool, Any]:
"""Save the record."""
db.session.add(self)
if commit:
db.session.commit()
return self

def delete(self, commit=True):
def delete(self, commit: bool = True) -> Union[bool, Any]:
"""Remove the record from the database."""
db.session.delete(self)
return commit and db.session.commit()
Expand Down
12 changes: 7 additions & 5 deletions shhh/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from typing import Tuple

from flask import current_app as app
from flask import redirect
Expand All @@ -15,6 +16,7 @@ def qs_to_args(f):
querystring. Check that the query keys are matching the args.
"""

def wrapper(*args, **kwargs):

if sorted(inspect.getfullargspec(f).args) != sorted(
Expand All @@ -26,32 +28,32 @@ def wrapper(*args, **kwargs):


@app.route("/")
def create():
def create() -> str:
"""View to create a secret."""
return rt("create.html")


@app.route("/c")
@qs_to_args
def created(link, expires_on):
def created(link: str, expires_on: str) -> str:
"""View to see the link for the created secret."""
return rt("created.html", link=link, expires_on=expires_on)


@app.route("/r/<slug>")
def read(slug):
def read(slug: str) -> str:
"""View to read a secret."""
return rt("read.html", slug=slug)


@app.errorhandler(404)
def not_found(error):
def not_found(error: str) -> Tuple[str, int]:
"""404 handler."""
return rt("404.html", error=error), 404


@app.route("/robots.txt")
def robots():
def robots() -> str:
"""Robots handler."""
return send_from_directory(app.static_folder, request.path[1:])

Expand Down
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pylint
bandit
responses
mypy

0 comments on commit 57aee61

Please sign in to comment.