Skip to content

Commit

Permalink
mypy: WIP
Browse files Browse the repository at this point in the history
Signed-off-by: William Roberts <william.c.roberts@intel.com>
  • Loading branch information
William Roberts committed Nov 21, 2022
1 parent 6e58567 commit 336aa16
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function run_style() {
}

function run_mypy_check() {
"${PYTHON}" -m mypy --exclude=docs --exclude=scripts "${SRC_ROOT}"
"${PYTHON}" -m mypy --exclude=docs --exclude=scripts --exclude='setup.py' "${SRC_ROOT}"
}

if [ "x${TEST}" != "x" ]; then
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,3 @@ dev =
build
installer
mypy
types-pycparser
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import os
from setuptools import setup
from setuptools.command.build_ext import build_ext
from pkgconfig import pkgconfig # type: ignore
from pycparser import c_parser, preprocess_file # type: ignore
from pycparser.c_ast import ( # type: ignore
from pkgconfig import pkgconfig
from pycparser import c_parser, preprocess_file
from pycparser.c_ast import (
Typedef,
TypeDecl,
IdentifierType,
Expand Down
4 changes: 2 additions & 2 deletions src/tpm2_pytss/TCTI.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-2

from ._libtpm2_pytss import ffi, lib
from ._libtpm2_pytss import ffi, lib # type: ignore[import]

from .internal.utils import _chkrc
from .constants import TSS2_RC, TPM2_RC
Expand Down Expand Up @@ -241,7 +241,7 @@ def cancel(self) -> None:
_chkrc(self._v1.cancel(self._ctx))

@common_checks()
def get_poll_handles(self) -> Tuple[PollData]:
def get_poll_handles(self) -> Tuple[PollData, ...]:
"""Gets the poll handles from the TPM.
Returns:
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/TCTILdr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-2

from ._libtpm2_pytss import lib, ffi
from ._libtpm2_pytss import lib, ffi # type: ignore[import]
from .TCTI import TCTI
from .internal.utils import _chkrc

Expand Down
5 changes: 3 additions & 2 deletions src/tpm2_pytss/TSS2_Exception.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ._libtpm2_pytss import lib, ffi # type: ignore
from ._libtpm2_pytss import lib, ffi # type: ignore[import]
from typing import Union


class TSS2_Exception(RuntimeError):
"""TSS2_Exception represents an error returned by the TSS APIs."""

# prevent cirular dependency and don't use the types directly here.
def __init__(self, rc: Union["TSS2_RC", "TPM2_RC", int]): # type: ignore
def __init__(self, rc: Union["TSS2_RC", "TPM2_RC", int]): # type: ignore[name-defined]
if isinstance(rc, int):
# defer this to avoid circular dep.
from .constants import TSS2_RC
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: BSD-2


from ._libtpm2_pytss import lib # type: ignore
from ._libtpm2_pytss import lib # type: ignore[import]
from .internal.constants import CALLBACK_BASE_NAME, CALLBACK_COUNT, CallbackType


Expand Down
17 changes: 9 additions & 8 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
Along with helpers to go from string values to constants and constant values to string values.
"""
from ._libtpm2_pytss import lib, ffi # type: ignore
from ._libtpm2_pytss import lib, ffi # type: ignore[import]
from tpm2_pytss.internal.utils import _CLASS_INT_ATTRS_from_string, _lib_version_atleast
from typing import Dict, Tuple


class TPM_FRIENDLY_INT(int):
_FIXUP_MAP = {}
_FIXUP_MAP: Dict[str, str] = {}

@classmethod
def parse(cls, value: str) -> int:
Expand Down Expand Up @@ -231,7 +232,7 @@ def _fix_const_type(cls):


class TPMA_FRIENDLY_INTLIST(TPM_FRIENDLY_INT):
_MASKS = tuple()
_MASKS: Tuple[Tuple[int, int, str], ...] = tuple()

@classmethod
def parse(cls, value: str) -> int:
Expand Down Expand Up @@ -392,7 +393,7 @@ class ESYS_TR(TPM_FRIENDLY_INT):
RH_PLATFORM = lib.ESYS_TR_RH_PLATFORM
RH_PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV

def serialize(self, ectx: "ESAPI") -> bytes:
def serialize(self, ectx: "ESAPI") -> bytes: # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_serialize
Args:
Expand All @@ -405,7 +406,7 @@ def serialize(self, ectx: "ESAPI") -> bytes:
return ectx.tr_serialize(self)

@staticmethod
def deserialize(ectx: "ESAPI", buffer: bytes) -> "ESYS_TR":
def deserialize(ectx: "ESAPI", buffer: bytes) -> "ESYS_TR": # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_derialize
Args:
Expand All @@ -417,7 +418,7 @@ def deserialize(ectx: "ESAPI", buffer: bytes) -> "ESYS_TR":

return ectx.tr_deserialize(buffer)

def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME":
def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME": # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_get_name
Args:
Expand All @@ -428,7 +429,7 @@ def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME":
"""
return ectx.tr_get_name(self)

def close(self, ectx: "ESAPI"):
def close(self, ectx: "ESAPI"): # type: ignore[name-defined]
"""Same as see tpm2_pytss.ESAPI.tr_close
Args:
Expand Down Expand Up @@ -1256,7 +1257,7 @@ def parse(cls, value: str) -> "TPMA_LOCALITY":
return cls(value, base=0)
except ValueError:
pass
return super().parse(value)
return TPMA_LOCALITY(super().parse(value))

def __str__(self) -> str:
"""Given a set of localities or an extended locality, return the string representation
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/encoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from binascii import hexlify, unhexlify
from typing import Any, Union, List, Dict, Tuple
from ._libtpm2_pytss import ffi # type: ignore
from ._libtpm2_pytss import ffi # type: ignore[name-defined]
from .internal.crypto import _get_digest_size
from .constants import (
TPM_FRIENDLY_INT,
Expand Down
49 changes: 28 additions & 21 deletions src/tpm2_pytss/internal/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@
from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature
from typing import Tuple, Type
from typing import Tuple, Type, Any, Union
import secrets
import sys

_curvetable = (
# Despite below, it won't allow us to use the right classes for the
# typehint so we just use Any...
# from cryptography.hazmat.primitives.asymmetric import rsa, ec, padding
# ec.SECP192R1
# <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP192R1'>
# type(ec.SECP192R1)
# <class 'abc.ABCMeta'>
# ec.SECP192R1.__bases__
# (<class 'cryptography.hazmat.primitives.asymmetric.ec.EllipticCurve'>,)

_curvetable: Tuple[Tuple[TPM2_ECC, Any], ...] = (
(TPM2_ECC.NIST_P192, ec.SECP192R1),
(TPM2_ECC.NIST_P224, ec.SECP224R1),
(TPM2_ECC.NIST_P256, ec.SECP256R1),
(TPM2_ECC.NIST_P384, ec.SECP384R1),
(TPM2_ECC.NIST_P521, ec.SECP521R1),
)

_digesttable = (
_digesttable: Tuple[Tuple[TPM2_ALG, Any], ...] = (
(TPM2_ALG.SHA1, hashes.SHA1),
(TPM2_ALG.SHA256, hashes.SHA256),
(TPM2_ALG.SHA384, hashes.SHA384),
Expand All @@ -48,14 +58,14 @@
if hasattr(hashes, "SM3"):
_digesttable += ((TPM2_ALG.SM3_256, hashes.SM3),)

_algtable = (
_algtable: Tuple[Tuple[TPM2_ALG, Any], ...] = (
(TPM2_ALG.AES, AES),
(TPM2_ALG.CAMELLIA, Camellia),
(TPM2_ALG.CFB, modes.CFB),
)

try:
from cryptography.hazmat.primitives.ciphers.algorithms import SM4
from cryptography.hazmat.primitives.ciphers.algorithms import SM4 # type: ignore[attr-defined]

_algtable += ((TPM2_ALG.SM4, SM4),)
except ImportError:
Expand Down Expand Up @@ -274,8 +284,7 @@ def _generate_d(p, q, e, n):
return d


def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC"):
key = None
def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC") -> Union[ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey]: # type: ignore[name-defined]
if private.sensitiveType == TPM2_ALG.RSA:

p = int.from_bytes(bytes(private.sensitive.rsa), byteorder="big")
Expand All @@ -286,7 +295,7 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC")
else 65537
)

key = _MyRSAPrivateNumbers(p, n, e, rsa.RSAPublicNumbers(e, n)).private_key(
return _MyRSAPrivateNumbers(p, n, e, rsa.RSAPublicNumbers(e, n)).private_key(
backend=default_backend()
)
elif private.sensitiveType == TPM2_ALG.ECC:
Expand All @@ -301,13 +310,11 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC")
x = int.from_bytes(bytes(public.unique.ecc.x), byteorder="big")
y = int.from_bytes(bytes(public.unique.ecc.y), byteorder="big")

key = ec.EllipticCurvePrivateNumbers(
return ec.EllipticCurvePrivateNumbers(
p, ec.EllipticCurvePublicNumbers(x, y, curve())
).private_key(backend=default_backend())
else:
raise ValueError(f"unsupported key type: {private.sensitiveType}")

return key
raise ValueError(f"unsupported key type: {private.sensitiveType}")


def _public_to_pem(obj, encoding="pem"):
Expand Down Expand Up @@ -535,7 +542,7 @@ def _generate_ecc_seed(
return (seed, secret)


def _generate_seed(public: "types.TPMT_PUBLIC", label: bytes) -> Tuple[bytes, bytes]:
def _generate_seed(public: "types.TPMT_PUBLIC", label: bytes) -> Tuple[bytes, bytes]: # type: ignore[name-defined]
key = public_to_key(public)
if public.type == TPM2_ALG.RSA:
return _generate_rsa_seed(key, public.nameAlg, label)
Expand Down Expand Up @@ -588,8 +595,8 @@ def __ecc_secret_to_seed(


def _secret_to_seed(
private: "types.TPMT_SENSITIVE",
public: "types.TPMT_PUBLIC",
private: "types.TPMT_SENSITIVE", # type: ignore[name-defined]
public: "types.TPMT_PUBLIC", # type: ignore[name-defined]
label: bytes,
outsymseed: bytes,
):
Expand All @@ -605,7 +612,7 @@ def _secret_to_seed(
def _hmac(
halg: hashes.HashAlgorithm, hmackey: bytes, enc_cred: bytes, name: bytes
) -> bytes:
h = HMAC(hmackey, halg(), backend=default_backend())
h = HMAC(hmackey, halg(), backend=default_backend()) # type: ignore[operator]
h.update(enc_cred)
h.update(name)
return h.finalize()
Expand All @@ -618,7 +625,7 @@ def _check_hmac(
name: bytes,
expected: bytes,
):
h = HMAC(hmackey, halg(), backend=default_backend())
h = HMAC(hmackey, halg(), backend=default_backend()) # type: ignore[operator]
h.update(enc_cred)
h.update(name)
h.verify(expected)
Expand All @@ -628,8 +635,8 @@ def _encrypt(
cipher: Type[CipherAlgorithm], mode: Type[modes.Mode], key: bytes, data: bytes
) -> bytes:
iv = len(key) * b"\x00"
ci = cipher(key)
ciph = Cipher(ci, mode(iv), backend=default_backend())
ci = cipher(key) # type: ignore[call-arg]
ciph = Cipher(ci, mode(iv), backend=default_backend()) # type: ignore[call-arg]
encr = ciph.encryptor()
encdata = encr.update(data) + encr.finalize()
return encdata
Expand All @@ -639,8 +646,8 @@ def _decrypt(
cipher: Type[CipherAlgorithm], mode: Type[modes.Mode], key: bytes, data: bytes
) -> bytes:
iv = len(key) * b"\x00"
ci = cipher(key)
ciph = Cipher(ci, mode(iv), backend=default_backend())
ci = cipher(key) # type: ignore[call-arg]
ciph = Cipher(ci, mode(iv), backend=default_backend()) # type: ignore[call-arg]
decr = ciph.decryptor()
plaintextdata = decr.update(data) + decr.finalize()
return plaintextdata
10 changes: 5 additions & 5 deletions src/tpm2_pytss/internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# SPDX-License-Identifier: BSD-2
import logging
import sys
from typing import List
from typing import List, Optional
from packaging.version import Version, InvalidVersion

from .._libtpm2_pytss import ffi, lib # type: ignore
from .._libtpm2_pytss import ffi, lib # type: ignore[import]
from ..TSS2_Exception import TSS2_Exception

try:
# This is generated by install, so just ignore it.
from .versions import _versions # type: ignore
from .versions import _versions # type: ignore[import]
except ImportError as e:
# this is needed so docs can be generated without building
if "sphinx" not in sys.modules:
Expand Down Expand Up @@ -200,7 +200,7 @@ def _check_friendly_int(friendly, varname, clazz):


def is_bug_fixed(
fixed_in=None, backports: List[str] = None, lib: str = "tss2-fapi"
fixed_in=None, backports: Optional[List[str]] = None, lib: str = "tss2-fapi"
) -> bool:
"""Use pkg-config to determine if a bug was fixed in the currently installed tpm2-tss version."""
if fixed_in and _lib_version_atleast(lib, fixed_in):
Expand All @@ -227,7 +227,7 @@ def is_bug_fixed(
def _check_bug_fixed(
details,
fixed_in=None,
backports: List[str] = None,
backports: Optional[List[str]] = None,
lib: str = "tss2-fapi",
error: bool = False,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from .constants import TPM2_ALG, ESYS_TR, TSS2_RC, TPM2_RC
from .TSS2_Exception import TSS2_Exception
from ._libtpm2_pytss import ffi, lib # type: ignore
from ._libtpm2_pytss import ffi, lib # type: ignore[name-defined]
from .ESAPI import ESAPI
from enum import Enum
from typing import Callable, Union
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/tsskey.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: BSD-2

import warnings
from ._libtpm2_pytss import lib # type: ignore
from ._libtpm2_pytss import lib # type: ignore[name-defined]
from .types import *
from .constants import TPM2_ECC, TPM2_CAP, ESYS_TR
from asn1crypto.core import ObjectIdentifier, Sequence, Boolean, OctetString, Integer
Expand Down
Loading

0 comments on commit 336aa16

Please sign in to comment.