Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy checking #472

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ function run_style() {
"${PYTHON}" -m black --diff --check "${SRC_ROOT}"
}

function run_mypy_check() {
"${PYTHON}" -m mypy --exclude=docs --exclude=scripts --exclude='setup.py' "${SRC_ROOT}"
}

if [ "x${TEST}" != "x" ]; then
run_test
elif [ "x${WHITESPACE}" != "x" ]; then
Expand All @@ -111,4 +115,6 @@ elif [ "x${STYLE}" != "x" ]; then
run_style
elif [ "x${PUBLISH_PKG}" != "x" ]; then
run_publish_pkg
elif [ "x${MYPY}" != "x" ]; then
run_mypy_check
fi
21 changes: 21 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,24 @@ jobs:
env:
STYLE: 1
run: ./.ci/run.sh

mypy-check:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v2

- name: Install dependencies
env:
TPM2_TSS_VERSION: master
TPM2_TSS_FAPI: true
TPM2_TOOLS_VERSION: master
run: ./.ci/install-deps.sh

- name: Install tpm2-pytss
run: pip install -e .[dev]

- name: MyPy Check
env:
MYPY: 1
run: ./.ci/run.sh
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ install_requires =
packaging
pyyaml

[options.package_data]
tpm2_pytss = py.typed
tpm2_pytss.internal = py.typed

[options.extras_require]
dev =
Expand All @@ -56,3 +59,4 @@ dev =
myst-parser
build
installer
mypy
4 changes: 2 additions & 2 deletions src/tpm2_pytss/TCTI.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-2

from ._libtpm2_pytss import ffi, lib
from ._libtpm2_pytss import ffi, lib # type: ignore[import]

from .internal.utils import _chkrc
from .constants import TSS2_RC, TPM2_RC
Expand Down Expand Up @@ -241,7 +241,7 @@ def cancel(self) -> None:
_chkrc(self._v1.cancel(self._ctx))

@common_checks()
def get_poll_handles(self) -> Tuple[PollData]:
def get_poll_handles(self) -> Tuple[PollData, ...]:
"""Gets the poll handles from the TPM.

Returns:
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/TCTILdr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-2

from ._libtpm2_pytss import lib, ffi
from ._libtpm2_pytss import lib, ffi # type: ignore[import]
from .TCTI import TCTI
from .internal.utils import _chkrc

Expand Down
4 changes: 2 additions & 2 deletions src/tpm2_pytss/TSS2_Exception.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from ._libtpm2_pytss import lib, ffi
from ._libtpm2_pytss import lib, ffi # type: ignore[import]
from typing import Union


class TSS2_Exception(RuntimeError):
"""TSS2_Exception represents an error returned by the TSS APIs."""

# prevent cirular dependency and don't use the types directly here.
def __init__(self, rc: Union["TSS2_RC", "TPM2_RC", int]):
def __init__(self, rc: Union["TSS2_RC", "TPM2_RC", int]): # type: ignore[name-defined]
if isinstance(rc, int):
# defer this to avoid circular dep.
from .constants import TSS2_RC
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: BSD-2


from ._libtpm2_pytss import lib
from ._libtpm2_pytss import lib # type: ignore[import]
from .internal.constants import CALLBACK_BASE_NAME, CALLBACK_COUNT, CallbackType


Expand Down
17 changes: 9 additions & 8 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

Along with helpers to go from string values to constants and constant values to string values.
"""
from ._libtpm2_pytss import lib, ffi
from ._libtpm2_pytss import lib, ffi # type: ignore[import]
from tpm2_pytss.internal.utils import _CLASS_INT_ATTRS_from_string, _lib_version_atleast
from typing import Dict, Tuple


class TPM_FRIENDLY_INT(int):
_FIXUP_MAP = {}
_FIXUP_MAP: Dict[str, str] = {}

@classmethod
def parse(cls, value: str) -> int:
Expand Down Expand Up @@ -231,7 +232,7 @@ def _fix_const_type(cls):


class TPMA_FRIENDLY_INTLIST(TPM_FRIENDLY_INT):
_MASKS = tuple()
_MASKS: Tuple[Tuple[int, int, str], ...] = tuple()

@classmethod
def parse(cls, value: str) -> int:
Expand Down Expand Up @@ -392,7 +393,7 @@ class ESYS_TR(TPM_FRIENDLY_INT):
RH_PLATFORM = lib.ESYS_TR_RH_PLATFORM
RH_PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV

def serialize(self, ectx: "ESAPI") -> bytes:
def serialize(self, ectx: "ESAPI") -> bytes: # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_serialize

Args:
Expand All @@ -405,7 +406,7 @@ def serialize(self, ectx: "ESAPI") -> bytes:
return ectx.tr_serialize(self)

@staticmethod
def deserialize(ectx: "ESAPI", buffer: bytes) -> "ESYS_TR":
def deserialize(ectx: "ESAPI", buffer: bytes) -> "ESYS_TR": # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_derialize

Args:
Expand All @@ -417,7 +418,7 @@ def deserialize(ectx: "ESAPI", buffer: bytes) -> "ESYS_TR":

return ectx.tr_deserialize(buffer)

def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME":
def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME": # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_get_name

Args:
Expand All @@ -428,7 +429,7 @@ def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME":
"""
return ectx.tr_get_name(self)

def close(self, ectx: "ESAPI"):
def close(self, ectx: "ESAPI"): # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_close

Args:
Expand Down Expand Up @@ -1256,7 +1257,7 @@ def parse(cls, value: str) -> "TPMA_LOCALITY":
return cls(value, base=0)
except ValueError:
pass
return super().parse(value)
return TPMA_LOCALITY(super().parse(value))

def __str__(self) -> str:
"""Given a set of localities or an extended locality, return the string representation
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/encoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from binascii import hexlify, unhexlify
from typing import Any, Union, List, Dict, Tuple
from ._libtpm2_pytss import ffi
from ._libtpm2_pytss import ffi # type: ignore[name-defined]
from .internal.crypto import _get_digest_size
from .constants import (
TPM_FRIENDLY_INT,
Expand Down
49 changes: 28 additions & 21 deletions src/tpm2_pytss/internal/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@
from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature
from typing import Tuple, Type
from typing import Tuple, Type, Any, Union
import secrets
import sys

_curvetable = (
# Despite below, it won't allow us to use the right classes for the
# typehint so we just use Any...
# from cryptography.hazmat.primitives.asymmetric import rsa, ec, padding
# ec.SECP192R1
# <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP192R1'>
# type(ec.SECP192R1)
# <class 'abc.ABCMeta'>
# ec.SECP192R1.__bases__
# (<class 'cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve'>,)

_curvetable: Tuple[Tuple[TPM2_ECC, Any], ...] = (
(TPM2_ECC.NIST_P192, ec.SECP192R1),
(TPM2_ECC.NIST_P224, ec.SECP224R1),
(TPM2_ECC.NIST_P256, ec.SECP256R1),
(TPM2_ECC.NIST_P384, ec.SECP384R1),
(TPM2_ECC.NIST_P521, ec.SECP521R1),
)

_digesttable = (
_digesttable: Tuple[Tuple[TPM2_ALG, Any], ...] = (
(TPM2_ALG.SHA1, hashes.SHA1),
(TPM2_ALG.SHA256, hashes.SHA256),
(TPM2_ALG.SHA384, hashes.SHA384),
Expand All @@ -48,14 +58,14 @@
if hasattr(hashes, "SM3"):
_digesttable += ((TPM2_ALG.SM3_256, hashes.SM3),)

_algtable = (
_algtable: Tuple[Tuple[TPM2_ALG, Any], ...] = (
(TPM2_ALG.AES, AES),
(TPM2_ALG.CAMELLIA, Camellia),
(TPM2_ALG.CFB, modes.CFB),
)

try:
from cryptography.hazmat.primitives.ciphers.algorithms import SM4
from cryptography.hazmat.primitives.ciphers.algorithms import SM4 # type: ignore[attr-defined]

_algtable += ((TPM2_ALG.SM4, SM4),)
except ImportError:
Expand Down Expand Up @@ -274,8 +284,7 @@ def _generate_d(p, q, e, n):
return d


def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC"):
key = None
def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC") -> Union[ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey]: # type: ignore[name-defined]
if private.sensitiveType == TPM2_ALG.RSA:

p = int.from_bytes(bytes(private.sensitive.rsa), byteorder="big")
Expand All @@ -286,7 +295,7 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC")
else 65537
)

key = _MyRSAPrivateNumbers(p, n, e, rsa.RSAPublicNumbers(e, n)).private_key(
return _MyRSAPrivateNumbers(p, n, e, rsa.RSAPublicNumbers(e, n)).private_key(
backend=default_backend()
)
elif private.sensitiveType == TPM2_ALG.ECC:
Expand All @@ -301,13 +310,11 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC")
x = int.from_bytes(bytes(public.unique.ecc.x), byteorder="big")
y = int.from_bytes(bytes(public.unique.ecc.y), byteorder="big")

key = ec.EllipticCurvePrivateNumbers(
return ec.EllipticCurvePrivateNumbers(
p, ec.EllipticCurvePublicNumbers(x, y, curve())
).private_key(backend=default_backend())
else:
raise ValueError(f"unsupported key type: {private.sensitiveType}")

return key
raise ValueError(f"unsupported key type: {private.sensitiveType}")


def _public_to_pem(obj, encoding="pem"):
Expand Down Expand Up @@ -535,7 +542,7 @@ def _generate_ecc_seed(
return (seed, secret)


def _generate_seed(public: "types.TPMT_PUBLIC", label: bytes) -> Tuple[bytes, bytes]:
def _generate_seed(public: "types.TPMT_PUBLIC", label: bytes) -> Tuple[bytes, bytes]: # type: ignore[name-defined]
key = public_to_key(public)
if public.type == TPM2_ALG.RSA:
return _generate_rsa_seed(key, public.nameAlg, label)
Expand Down Expand Up @@ -588,8 +595,8 @@ def __ecc_secret_to_seed(


def _secret_to_seed(
private: "types.TPMT_SENSITIVE",
public: "types.TPMT_PUBLIC",
private: "types.TPMT_SENSITIVE", # type: ignore[name-defined]
public: "types.TPMT_PUBLIC", # type: ignore[name-defined]
label: bytes,
outsymseed: bytes,
):
Expand All @@ -605,7 +612,7 @@ def _secret_to_seed(
def _hmac(
halg: hashes.HashAlgorithm, hmackey: bytes, enc_cred: bytes, name: bytes
) -> bytes:
h = HMAC(hmackey, halg(), backend=default_backend())
h = HMAC(hmackey, halg(), backend=default_backend()) # type: ignore[operator]
h.update(enc_cred)
h.update(name)
return h.finalize()
Expand All @@ -618,7 +625,7 @@ def _check_hmac(
name: bytes,
expected: bytes,
):
h = HMAC(hmackey, halg(), backend=default_backend())
h = HMAC(hmackey, halg(), backend=default_backend()) # type: ignore[operator]
h.update(enc_cred)
h.update(name)
h.verify(expected)
Expand All @@ -628,8 +635,8 @@ def _encrypt(
cipher: Type[CipherAlgorithm], mode: Type[modes.Mode], key: bytes, data: bytes
) -> bytes:
iv = len(key) * b"\x00"
ci = cipher(key)
ciph = Cipher(ci, mode(iv), backend=default_backend())
ci = cipher(key) # type: ignore[call-arg]
ciph = Cipher(ci, mode(iv), backend=default_backend()) # type: ignore[call-arg]
encr = ciph.encryptor()
encdata = encr.update(data) + encr.finalize()
return encdata
Expand All @@ -639,8 +646,8 @@ def _decrypt(
cipher: Type[CipherAlgorithm], mode: Type[modes.Mode], key: bytes, data: bytes
) -> bytes:
iv = len(key) * b"\x00"
ci = cipher(key)
ciph = Cipher(ci, mode(iv), backend=default_backend())
ci = cipher(key) # type: ignore[call-arg]
ciph = Cipher(ci, mode(iv), backend=default_backend()) # type: ignore[call-arg]
decr = ciph.decryptor()
plaintextdata = decr.update(data) + decr.finalize()
return plaintextdata
Empty file.
11 changes: 6 additions & 5 deletions src/tpm2_pytss/internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# SPDX-License-Identifier: BSD-2
import logging
import sys
from typing import List
from typing import List, Optional
from packaging.version import Version, InvalidVersion

from .._libtpm2_pytss import ffi, lib
from .._libtpm2_pytss import ffi, lib # type: ignore[import]
from ..TSS2_Exception import TSS2_Exception

try:
from .versions import _versions
# This is generated by install, so just ignore it.
from .versions import _versions # type: ignore[import]
except ImportError as e:
# this is needed so docs can be generated without building
if "sphinx" not in sys.modules:
Expand Down Expand Up @@ -199,7 +200,7 @@ def _check_friendly_int(friendly, varname, clazz):


def is_bug_fixed(
fixed_in=None, backports: List[str] = None, lib: str = "tss2-fapi"
fixed_in=None, backports: Optional[List[str]] = None, lib: str = "tss2-fapi"
) -> bool:
"""Use pkg-config to determine if a bug was fixed in the currently installed tpm2-tss version."""
if fixed_in and _lib_version_atleast(lib, fixed_in):
Expand All @@ -226,7 +227,7 @@ def is_bug_fixed(
def _check_bug_fixed(
details,
fixed_in=None,
backports: List[str] = None,
backports: Optional[List[str]] = None,
lib: str = "tss2-fapi",
error: bool = False,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from .constants import TPM2_ALG, ESYS_TR, TSS2_RC, TPM2_RC
from .TSS2_Exception import TSS2_Exception
from ._libtpm2_pytss import ffi, lib
from ._libtpm2_pytss import ffi, lib # type: ignore[name-defined]
from .ESAPI import ESAPI
from enum import Enum
from typing import Callable, Union
Expand Down
Empty file added src/tpm2_pytss/py.typed
Empty file.
2 changes: 1 addition & 1 deletion src/tpm2_pytss/tsskey.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: BSD-2

import warnings
from ._libtpm2_pytss import lib
from ._libtpm2_pytss import lib # type: ignore[name-defined]
from .types import *
from .constants import TPM2_ECC, TPM2_CAP, ESYS_TR
from asn1crypto.core import ObjectIdentifier, Sequence, Boolean, OctetString, Integer
Expand Down
Loading