Skip to content

Commit

Permalink
Fix static typing
Browse files Browse the repository at this point in the history
Fixed the following files:

pycardano/certificate.py
pycardano/cip/cip8.py
pycardano/coinselection.py
pycardano/key.py
pycardano/metadata.py
pycardano/plutus.py
pycardano/serialization.py
pycardano/transaction.py
pycardano/txbuilder.py
pycardano/utils.py
pycardano/witness.py
  • Loading branch information
cffls committed Dec 24, 2022
1 parent 5e37db1 commit fed5a70
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 151 deletions.
4 changes: 2 additions & 2 deletions pycardano/certificate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Union
from typing import Optional, Union

from pycardano.hash import PoolKeyHash, ScriptHash, VerificationKeyHash
from pycardano.serialization import ArrayCBORSerializable
Expand All @@ -16,7 +16,7 @@
@dataclass(repr=False)
class StakeCredential(ArrayCBORSerializable):

_CODE: int = field(init=False, default=None)
_CODE: Optional[int] = field(init=False, default=None)

credential: Union[VerificationKeyHash, ScriptHash]

Expand Down
11 changes: 10 additions & 1 deletion pycardano/cip/cip8.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,19 @@ def verify(
if attach_cose_key:
# The cose key is attached as a dict object which contains the verification key
# the headers of the signature are emtpy
assert isinstance(
signed_message, dict
), "signed_message must be a dict if attach_cose_key is True"
key = signed_message.get("key")
signed_message = signed_message.get("signature")
signed_message = signed_message.get("signature") # type: ignore

else:
key = "" # key will be extracted later from the payload headers

# Add back the "D2" header byte and decode
assert isinstance(
signed_message, str
), "signed_message must be a hex string at this point"
decoded_message = CoseMessage.decode(bytes.fromhex("d2" + signed_message))

# generate/extract the cose key
Expand All @@ -146,6 +152,9 @@ def verify(

else:
# i,e key is sent separately
assert isinstance(
key, str
), "key must be a hex string if attach_cose_key is True"
cose_key = CoseKey.decode(bytes.fromhex(key))
verification_key = cose_key[OKPKpX]

Expand Down
30 changes: 19 additions & 11 deletions pycardano/coinselection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def select(
utxos: List[UTxO],
outputs: List[TransactionOutput],
context: ChainContext,
max_input_count: int = None,
include_max_fee: bool = True,
respect_min_utxo: bool = True,
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
) -> Tuple[List[UTxO], Value]:
"""From an input list of UTxOs, select a subset of UTxOs whose sum (including ADA and multi-assets)
is equal to or larger than the sum of a set of outputs.
Expand Down Expand Up @@ -115,7 +115,11 @@ def select(
if change.coin < min_change_amount:
additional, _ = self.select(
available,
[TransactionOutput(None, min_change_amount - change.coin)],
[
TransactionOutput(
_FAKE_ADDR, Value(min_change_amount - change.coin)
)
],
context,
max_input_count - len(selected) if max_input_count else None,
include_max_fee=False,
Expand Down Expand Up @@ -230,13 +234,13 @@ def _improve(
remaining: List[UTxO],
ideal: Value,
upper_bound: Value,
max_input_count: int,
max_input_count: Optional[int] = None,
):
if not remaining or self._find_diff_by_former(ideal, selected_amount) <= 0:
# In case where there is no remaining UTxOs or we already selected more than ideal,
# we cannot improve by randomly adding more UTxOs, therefore return immediate.
return
if max_input_count and len(selected) > max_input_count:
if max_input_count is not None and len(selected) > max_input_count:
raise MaxInputCountExceededException(
f"Max input count: {max_input_count} exceeded!"
)
Expand Down Expand Up @@ -269,9 +273,9 @@ def select(
utxos: List[UTxO],
outputs: List[TransactionOutput],
context: ChainContext,
max_input_count: int = None,
include_max_fee: bool = True,
respect_min_utxo: bool = True,
max_input_count: Optional[int] = None,
include_max_fee: Optional[bool] = True,
respect_min_utxo: Optional[bool] = True,
) -> Tuple[List[UTxO], Value]:
# Shallow copy the list
remaining = list(utxos)
Expand All @@ -284,7 +288,7 @@ def select(
request_sorted = sorted(assets, key=self._get_single_asset_val, reverse=True)

# Phase 1 - random select
selected = []
selected: List[UTxO] = []
selected_amount = Value()
for r in request_sorted:
self._random_select_subset(r, remaining, selected, selected_amount)
Expand Down Expand Up @@ -321,7 +325,11 @@ def select(
if change.coin < min_change_amount:
additional, _ = self.select(
remaining,
[TransactionOutput(None, min_change_amount - change.coin)],
[
TransactionOutput(
_FAKE_ADDR, Value(min_change_amount - change.coin)
)
],
context,
max_input_count - len(selected) if max_input_count else None,
include_max_fee=False,
Expand Down
37 changes: 23 additions & 14 deletions pycardano/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import os
from typing import Type
from typing import Optional, Type

from nacl.encoding import RawEncoder
from nacl.hash import blake2b
Expand Down Expand Up @@ -41,7 +41,12 @@ class Key(CBORSerializable):
KEY_TYPE = ""
DESCRIPTION = ""

def __init__(self, payload: bytes, key_type: str = None, description: str = None):
def __init__(
self,
payload: bytes,
key_type: Optional[str] = None,
description: Optional[str] = None,
):
self._payload = payload
self._key_type = key_type or self.KEY_TYPE
self._description = description or self.KEY_TYPE
Expand Down Expand Up @@ -83,7 +88,7 @@ def to_json(self) -> str:
)

@classmethod
def from_json(cls, data: str, validate_type=False) -> Key:
def from_json(cls: Type[Key], data: str, validate_type=False) -> Key:
"""Restore a key from a JSON string.
Args:
Expand All @@ -105,8 +110,12 @@ def from_json(cls, data: str, validate_type=False) -> Key:
f"Expect key type: {cls.KEY_TYPE}, got {obj['type']} instead."
)

k = cls.from_cbor(obj["cborHex"])

assert isinstance(k, cls)

return cls(
cls.from_cbor(obj["cborHex"]).payload,
k.payload,
key_type=obj["type"],
description=obj["description"],
)
Expand Down Expand Up @@ -244,19 +253,19 @@ class PaymentExtendedVerificationKey(ExtendedVerificationKey):


class PaymentKeyPair:
def __init__(
self, signing_key: PaymentSigningKey, verification_key: PaymentVerificationKey
):
def __init__(self, signing_key: SigningKey, verification_key: VerificationKey):
self.signing_key = signing_key
self.verification_key = verification_key

@classmethod
def generate(cls) -> PaymentKeyPair:
def generate(cls: Type[PaymentKeyPair]) -> PaymentKeyPair:
signing_key = PaymentSigningKey.generate()
return cls.from_signing_key(signing_key)

@classmethod
def from_signing_key(cls, signing_key: PaymentSigningKey) -> PaymentKeyPair:
def from_signing_key(
cls: Type[PaymentKeyPair], signing_key: SigningKey
) -> PaymentKeyPair:
return cls(signing_key, PaymentVerificationKey.from_signing_key(signing_key))

def __eq__(self, other):
Expand Down Expand Up @@ -288,17 +297,17 @@ class StakeExtendedVerificationKey(ExtendedVerificationKey):


class StakeKeyPair:
def __init__(
self, signing_key: StakeSigningKey, verification_key: StakeVerificationKey
):
def __init__(self, signing_key: SigningKey, verification_key: VerificationKey):
self.signing_key = signing_key
self.verification_key = verification_key

@classmethod
def generate(cls) -> StakeKeyPair:
def generate(cls: Type[StakeKeyPair]) -> StakeKeyPair:
signing_key = StakeSigningKey.generate()
return cls.from_signing_key(signing_key)

@classmethod
def from_signing_key(cls, signing_key: StakeSigningKey) -> StakeKeyPair:
def from_signing_key(
cls: Type[StakeKeyPair], signing_key: SigningKey
) -> StakeKeyPair:
return cls(signing_key, StakeVerificationKey.from_signing_key(signing_key))
24 changes: 13 additions & 11 deletions pycardano/metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, ClassVar, List, Type, Union
from typing import Any, ClassVar, List, Optional, Type, Union

from cbor2 import CBORTag
from nacl.encoding import RawEncoder
Expand All @@ -20,7 +20,7 @@
list_hook,
)

__all__ = ["Metadata", "ShellayMarryMetadata", "AlonzoMetadata", "AuxiliaryData"]
__all__ = ["Metadata", "ShelleyMarryMetadata", "AlonzoMetadata", "AuxiliaryData"]


class Metadata(DictCBORSerializable):
Expand Down Expand Up @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs):


@dataclass
class ShellayMarryMetadata(ArrayCBORSerializable):
class ShelleyMarryMetadata(ArrayCBORSerializable):
metadata: Metadata
native_scripts: List[NativeScript] = field(
native_scripts: Optional[List[NativeScript]] = field(
default=None, metadata={"object_hook": list_hook(NativeScript)}
)

Expand All @@ -79,12 +79,14 @@ class ShellayMarryMetadata(ArrayCBORSerializable):
class AlonzoMetadata(MapCBORSerializable):
TAG: ClassVar[int] = 259

metadata: Metadata = field(default=None, metadata={"optional": True, "key": 0})
native_scripts: List[NativeScript] = field(
metadata: Optional[Metadata] = field(
default=None, metadata={"optional": True, "key": 0}
)
native_scripts: Optional[List[NativeScript]] = field(
default=None,
metadata={"optional": True, "key": 1, "object_hook": list_hook(NativeScript)},
)
plutus_scripts: List[bytes] = field(
plutus_scripts: Optional[List[bytes]] = field(
default=None, metadata={"optional": True, "key": 2}
)

Expand All @@ -107,23 +109,23 @@ def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:

@dataclass
class AuxiliaryData(CBORSerializable):
data: Union[Metadata, ShellayMarryMetadata, AlonzoMetadata]
data: Union[Metadata, ShelleyMarryMetadata, AlonzoMetadata]

def to_primitive(self) -> Primitive:
return self.data.to_primitive()

@classmethod
def from_primitive(cls: Type[AuxiliaryData], value: Primitive) -> AuxiliaryData:
for t in [AlonzoMetadata, ShellayMarryMetadata, Metadata]:
for t in [AlonzoMetadata, ShelleyMarryMetadata, Metadata]:
# The schema of metadata in different eras are mutually exclusive, so we can try deserializing
# them one by one without worrying about mismatch.
try:
return AuxiliaryData(t.from_primitive(value))
return AuxiliaryData(t.from_primitive(value)) # type: ignore
except DeserializeException:
pass
raise DeserializeException(f"Couldn't parse auxiliary data: {value}")

def hash(self) -> AuxiliaryDataHash:
return AuxiliaryDataHash(
blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder)
blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder) # type: ignore
)
36 changes: 20 additions & 16 deletions pycardano/plutus.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CBORSerializable,
DictCBORSerializable,
IndefiniteList,
Primitive,
RawCBOR,
default_encoder,
limit_primitive_type,
Expand All @@ -39,6 +40,7 @@
"PlutusV2Script",
"RawPlutusData",
"Redeemer",
"ScriptType",
"datum_hash",
"plutus_script_hash",
"script_hash",
Expand Down Expand Up @@ -471,7 +473,7 @@ def __post_init__(self):
)

def to_shallow_primitive(self) -> CBORTag:
primitives = super().to_shallow_primitive()
primitives: Primitive = super().to_shallow_primitive()
if primitives:
primitives = IndefiniteList(primitives)
tag = get_tag(self.CONSTR_ID)
Expand Down Expand Up @@ -544,7 +546,7 @@ def _dfs(obj):
return json.dumps(_dfs(self), **kwargs)

@classmethod
def from_dict(cls: PlutusData, data: dict) -> PlutusData:
def from_dict(cls: Type[PlutusData], data: dict) -> PlutusData:
"""Convert a dictionary to PlutusData
Args:
Expand Down Expand Up @@ -606,7 +608,7 @@ def _dfs(obj):
return _dfs(data)

@classmethod
def from_json(cls: PlutusData, data: str) -> PlutusData:
def from_json(cls: Type[PlutusData], data: str) -> PlutusData:
"""Restore a json encoded string to a PlutusData.
Args:
Expand Down Expand Up @@ -701,7 +703,7 @@ class Redeemer(ArrayCBORSerializable):

data: Any

ex_units: ExecutionUnits = None
ex_units: Optional[ExecutionUnits] = None

@classmethod
@limit_primitive_type(list)
Expand Down Expand Up @@ -729,13 +731,23 @@ def plutus_script_hash(
return script_hash(script)


def script_hash(
script: Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]
) -> ScriptHash:
class PlutusV1Script(bytes):
pass


class PlutusV2Script(bytes):
pass


ScriptType = Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]
"""Script type. A Union type that contains all valid script types."""


def script_hash(script: ScriptType) -> ScriptHash:
"""Calculates the hash of a script, which could be either native script or plutus script.
Args:
script (Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]): A script.
script (ScriptType): A script.
Returns:
ScriptHash: blake2b hash of the script.
Expand All @@ -752,11 +764,3 @@ def script_hash(
)
else:
raise TypeError(f"Unexpected script type: {type(script)}")


class PlutusV1Script(bytes):
pass


class PlutusV2Script(bytes):
pass
Loading

0 comments on commit fed5a70

Please sign in to comment.