diff --git a/pycardano/certificate.py b/pycardano/certificate.py index a2a27115..ece33d3c 100644 --- a/pycardano/certificate.py +++ b/pycardano/certificate.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Union +from typing import Optional, Union from pycardano.hash import PoolKeyHash, ScriptHash, VerificationKeyHash from pycardano.serialization import ArrayCBORSerializable @@ -16,7 +16,7 @@ @dataclass(repr=False) class StakeCredential(ArrayCBORSerializable): - _CODE: int = field(init=False, default=None) + _CODE: Optional[int] = field(init=False, default=None) credential: Union[VerificationKeyHash, ScriptHash] diff --git a/pycardano/cip/cip8.py b/pycardano/cip/cip8.py index d4d50a26..1e909bf1 100644 --- a/pycardano/cip/cip8.py +++ b/pycardano/cip/cip8.py @@ -119,13 +119,19 @@ def verify( if attach_cose_key: # The cose key is attached as a dict object which contains the verification key # the headers of the signature are emtpy + assert isinstance( + signed_message, dict + ), "signed_message must be a dict if attach_cose_key is True" key = signed_message.get("key") - signed_message = signed_message.get("signature") + signed_message = signed_message.get("signature") # type: ignore else: key = "" # key will be extracted later from the payload headers # Add back the "D2" header byte and decode + assert isinstance( + signed_message, str + ), "signed_message must be a hex string at this point" decoded_message = CoseMessage.decode(bytes.fromhex("d2" + signed_message)) # generate/extract the cose key @@ -146,6 +152,9 @@ def verify( else: # i,e key is sent separately + assert isinstance( + key, str + ), "key must be a hex string if attach_cose_key is True" cose_key = CoseKey.decode(bytes.fromhex(key)) verification_key = cose_key[OKPKpX] diff --git a/pycardano/coinselection.py b/pycardano/coinselection.py index 77201218..eff4eaaa 100644 --- a/pycardano/coinselection.py +++ b/pycardano/coinselection.py @@ -33,9 +33,9 @@ def select( utxos: List[UTxO], outputs: List[TransactionOutput], context: ChainContext, - max_input_count: int = None, - include_max_fee: bool = True, - respect_min_utxo: bool = True, + max_input_count: Optional[int] = None, + include_max_fee: Optional[bool] = True, + respect_min_utxo: Optional[bool] = True, ) -> Tuple[List[UTxO], Value]: """From an input list of UTxOs, select a subset of UTxOs whose sum (including ADA and multi-assets) is equal to or larger than the sum of a set of outputs. @@ -115,7 +115,11 @@ def select( if change.coin < min_change_amount: additional, _ = self.select( available, - [TransactionOutput(None, min_change_amount - change.coin)], + [ + TransactionOutput( + _FAKE_ADDR, Value(min_change_amount - change.coin) + ) + ], context, max_input_count - len(selected) if max_input_count else None, include_max_fee=False, @@ -230,13 +234,13 @@ def _improve( remaining: List[UTxO], ideal: Value, upper_bound: Value, - max_input_count: int, + max_input_count: Optional[int] = None, ): if not remaining or self._find_diff_by_former(ideal, selected_amount) <= 0: # In case where there is no remaining UTxOs or we already selected more than ideal, # we cannot improve by randomly adding more UTxOs, therefore return immediate. return - if max_input_count and len(selected) > max_input_count: + if max_input_count is not None and len(selected) > max_input_count: raise MaxInputCountExceededException( f"Max input count: {max_input_count} exceeded!" ) @@ -269,9 +273,9 @@ def select( utxos: List[UTxO], outputs: List[TransactionOutput], context: ChainContext, - max_input_count: int = None, - include_max_fee: bool = True, - respect_min_utxo: bool = True, + max_input_count: Optional[int] = None, + include_max_fee: Optional[bool] = True, + respect_min_utxo: Optional[bool] = True, ) -> Tuple[List[UTxO], Value]: # Shallow copy the list remaining = list(utxos) @@ -284,7 +288,7 @@ def select( request_sorted = sorted(assets, key=self._get_single_asset_val, reverse=True) # Phase 1 - random select - selected = [] + selected: List[UTxO] = [] selected_amount = Value() for r in request_sorted: self._random_select_subset(r, remaining, selected, selected_amount) @@ -321,7 +325,11 @@ def select( if change.coin < min_change_amount: additional, _ = self.select( remaining, - [TransactionOutput(None, min_change_amount - change.coin)], + [ + TransactionOutput( + _FAKE_ADDR, Value(min_change_amount - change.coin) + ) + ], context, max_input_count - len(selected) if max_input_count else None, include_max_fee=False, diff --git a/pycardano/key.py b/pycardano/key.py index 66f72e3e..55f9d01f 100644 --- a/pycardano/key.py +++ b/pycardano/key.py @@ -4,7 +4,7 @@ import json import os -from typing import Type +from typing import Optional, Type from nacl.encoding import RawEncoder from nacl.hash import blake2b @@ -41,7 +41,12 @@ class Key(CBORSerializable): KEY_TYPE = "" DESCRIPTION = "" - def __init__(self, payload: bytes, key_type: str = None, description: str = None): + def __init__( + self, + payload: bytes, + key_type: Optional[str] = None, + description: Optional[str] = None, + ): self._payload = payload self._key_type = key_type or self.KEY_TYPE self._description = description or self.KEY_TYPE @@ -83,7 +88,7 @@ def to_json(self) -> str: ) @classmethod - def from_json(cls, data: str, validate_type=False) -> Key: + def from_json(cls: Type[Key], data: str, validate_type=False) -> Key: """Restore a key from a JSON string. Args: @@ -105,8 +110,12 @@ def from_json(cls, data: str, validate_type=False) -> Key: f"Expect key type: {cls.KEY_TYPE}, got {obj['type']} instead." ) + k = cls.from_cbor(obj["cborHex"]) + + assert isinstance(k, cls) + return cls( - cls.from_cbor(obj["cborHex"]).payload, + k.payload, key_type=obj["type"], description=obj["description"], ) @@ -244,19 +253,19 @@ class PaymentExtendedVerificationKey(ExtendedVerificationKey): class PaymentKeyPair: - def __init__( - self, signing_key: PaymentSigningKey, verification_key: PaymentVerificationKey - ): + def __init__(self, signing_key: SigningKey, verification_key: VerificationKey): self.signing_key = signing_key self.verification_key = verification_key @classmethod - def generate(cls) -> PaymentKeyPair: + def generate(cls: Type[PaymentKeyPair]) -> PaymentKeyPair: signing_key = PaymentSigningKey.generate() return cls.from_signing_key(signing_key) @classmethod - def from_signing_key(cls, signing_key: PaymentSigningKey) -> PaymentKeyPair: + def from_signing_key( + cls: Type[PaymentKeyPair], signing_key: SigningKey + ) -> PaymentKeyPair: return cls(signing_key, PaymentVerificationKey.from_signing_key(signing_key)) def __eq__(self, other): @@ -288,17 +297,17 @@ class StakeExtendedVerificationKey(ExtendedVerificationKey): class StakeKeyPair: - def __init__( - self, signing_key: StakeSigningKey, verification_key: StakeVerificationKey - ): + def __init__(self, signing_key: SigningKey, verification_key: VerificationKey): self.signing_key = signing_key self.verification_key = verification_key @classmethod - def generate(cls) -> StakeKeyPair: + def generate(cls: Type[StakeKeyPair]) -> StakeKeyPair: signing_key = StakeSigningKey.generate() return cls.from_signing_key(signing_key) @classmethod - def from_signing_key(cls, signing_key: StakeSigningKey) -> StakeKeyPair: + def from_signing_key( + cls: Type[StakeKeyPair], signing_key: SigningKey + ) -> StakeKeyPair: return cls(signing_key, StakeVerificationKey.from_signing_key(signing_key)) diff --git a/pycardano/metadata.py b/pycardano/metadata.py index 1dc8c85a..a9aaf211 100644 --- a/pycardano/metadata.py +++ b/pycardano/metadata.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, ClassVar, List, Type, Union +from typing import Any, ClassVar, List, Optional, Type, Union from cbor2 import CBORTag from nacl.encoding import RawEncoder @@ -20,7 +20,7 @@ list_hook, ) -__all__ = ["Metadata", "ShellayMarryMetadata", "AlonzoMetadata", "AuxiliaryData"] +__all__ = ["Metadata", "ShelleyMarryMetadata", "AlonzoMetadata", "AuxiliaryData"] class Metadata(DictCBORSerializable): @@ -68,9 +68,9 @@ def __init__(self, *args, **kwargs): @dataclass -class ShellayMarryMetadata(ArrayCBORSerializable): +class ShelleyMarryMetadata(ArrayCBORSerializable): metadata: Metadata - native_scripts: List[NativeScript] = field( + native_scripts: Optional[List[NativeScript]] = field( default=None, metadata={"object_hook": list_hook(NativeScript)} ) @@ -79,12 +79,14 @@ class ShellayMarryMetadata(ArrayCBORSerializable): class AlonzoMetadata(MapCBORSerializable): TAG: ClassVar[int] = 259 - metadata: Metadata = field(default=None, metadata={"optional": True, "key": 0}) - native_scripts: List[NativeScript] = field( + metadata: Optional[Metadata] = field( + default=None, metadata={"optional": True, "key": 0} + ) + native_scripts: Optional[List[NativeScript]] = field( default=None, metadata={"optional": True, "key": 1, "object_hook": list_hook(NativeScript)}, ) - plutus_scripts: List[bytes] = field( + plutus_scripts: Optional[List[bytes]] = field( default=None, metadata={"optional": True, "key": 2} ) @@ -107,23 +109,23 @@ def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata: @dataclass class AuxiliaryData(CBORSerializable): - data: Union[Metadata, ShellayMarryMetadata, AlonzoMetadata] + data: Union[Metadata, ShelleyMarryMetadata, AlonzoMetadata] def to_primitive(self) -> Primitive: return self.data.to_primitive() @classmethod def from_primitive(cls: Type[AuxiliaryData], value: Primitive) -> AuxiliaryData: - for t in [AlonzoMetadata, ShellayMarryMetadata, Metadata]: + for t in [AlonzoMetadata, ShelleyMarryMetadata, Metadata]: # The schema of metadata in different eras are mutually exclusive, so we can try deserializing # them one by one without worrying about mismatch. try: - return AuxiliaryData(t.from_primitive(value)) + return AuxiliaryData(t.from_primitive(value)) # type: ignore except DeserializeException: pass raise DeserializeException(f"Couldn't parse auxiliary data: {value}") def hash(self) -> AuxiliaryDataHash: return AuxiliaryDataHash( - blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder) + blake2b(self.to_cbor("bytes"), AUXILIARY_DATA_HASH_SIZE, encoder=RawEncoder) # type: ignore ) diff --git a/pycardano/plutus.py b/pycardano/plutus.py index e6a422ef..3906344e 100644 --- a/pycardano/plutus.py +++ b/pycardano/plutus.py @@ -21,6 +21,7 @@ CBORSerializable, DictCBORSerializable, IndefiniteList, + Primitive, RawCBOR, default_encoder, limit_primitive_type, @@ -39,6 +40,7 @@ "PlutusV2Script", "RawPlutusData", "Redeemer", + "ScriptType", "datum_hash", "plutus_script_hash", "script_hash", @@ -471,7 +473,7 @@ def __post_init__(self): ) def to_shallow_primitive(self) -> CBORTag: - primitives = super().to_shallow_primitive() + primitives: Primitive = super().to_shallow_primitive() if primitives: primitives = IndefiniteList(primitives) tag = get_tag(self.CONSTR_ID) @@ -544,7 +546,7 @@ def _dfs(obj): return json.dumps(_dfs(self), **kwargs) @classmethod - def from_dict(cls: PlutusData, data: dict) -> PlutusData: + def from_dict(cls: Type[PlutusData], data: dict) -> PlutusData: """Convert a dictionary to PlutusData Args: @@ -606,7 +608,7 @@ def _dfs(obj): return _dfs(data) @classmethod - def from_json(cls: PlutusData, data: str) -> PlutusData: + def from_json(cls: Type[PlutusData], data: str) -> PlutusData: """Restore a json encoded string to a PlutusData. Args: @@ -701,7 +703,7 @@ class Redeemer(ArrayCBORSerializable): data: Any - ex_units: ExecutionUnits = None + ex_units: Optional[ExecutionUnits] = None @classmethod @limit_primitive_type(list) @@ -729,13 +731,23 @@ def plutus_script_hash( return script_hash(script) -def script_hash( - script: Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script] -) -> ScriptHash: +class PlutusV1Script(bytes): + pass + + +class PlutusV2Script(bytes): + pass + + +ScriptType = Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script] +"""Script type. A Union type that contains all valid script types.""" + + +def script_hash(script: ScriptType) -> ScriptHash: """Calculates the hash of a script, which could be either native script or plutus script. Args: - script (Union[bytes, NativeScript, PlutusV1Script, PlutusV2Script]): A script. + script (ScriptType): A script. Returns: ScriptHash: blake2b hash of the script. @@ -752,11 +764,3 @@ def script_hash( ) else: raise TypeError(f"Unexpected script type: {type(script)}") - - -class PlutusV1Script(bytes): - pass - - -class PlutusV2Script(bytes): - pass diff --git a/pycardano/serialization.py b/pycardano/serialization.py index c8b1e733..e4dcc32f 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -10,7 +10,7 @@ from decimal import Decimal from functools import wraps from inspect import isclass -from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints +from typing import Any, Callable, List, Optional, Type, TypeVar, Union, get_type_hints from cbor2 import CBOREncoder, CBORSimpleValue, CBORTag, dumps, loads, undefined from pprintpp import pformat @@ -33,12 +33,13 @@ "DictCBORSerializable", "RawCBOR", "list_hook", + "limit_primitive_type", ] class IndefiniteList(UserList): - def __init__(self, list: [Primitive]): # type: ignore - super().__init__(list) + def __init__(self, li: Primitive): # type: ignore + super().__init__(li) # type: ignore @dataclass @@ -414,7 +415,9 @@ def _restore_dataclass_field( return f.type.from_primitive(v) elif isclass(f.type) and issubclass(f.type, IndefiniteList): return IndefiniteList(v) - elif hasattr(f.type, "__origin__") and f.type.__origin__ is Union: + elif hasattr(f.type, "__origin__") and ( + f.type.__origin__ is Union or f.type.__origin__ is Optional + ): t_args = f.type.__args__ for t in t_args: if isclass(t) and issubclass(t, IndefiniteList): @@ -424,8 +427,11 @@ def _restore_dataclass_field( return t.from_primitive(v) except DeserializeException: pass - elif t in PRIMITIVE_TYPES and isinstance(v, t): - return v + else: + if not isclass(t) and hasattr(t, "__origin__"): + t = t.__origin__ + if t in PRIMITIVE_TYPES and isinstance(v, t): + return v raise DeserializeException( f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}." ) diff --git a/pycardano/transaction.py b/pycardano/transaction.py index 4bb72af5..58843220 100644 --- a/pycardano/transaction.py +++ b/pycardano/transaction.py @@ -284,7 +284,8 @@ def __post_init__(self): def from_primitive(cls: Type[_Script], values: List[Primitive]) -> _Script: if values[0] == 0: return cls(NativeScript.from_primitive(values[1])) - elif values[0] == 1: + assert isinstance(values[1], bytes) + if values[0] == 1: return cls(PlutusV1Script(values[1])) else: return cls(PlutusV2Script(values[1])) @@ -315,8 +316,10 @@ def from_primitive( cls: Type[_DatumOption], values: List[Primitive] ) -> _DatumOption: if values[0] == 0: + assert isinstance(values[1], bytes) return _DatumOption(DatumHash(values[1])) else: + assert isinstance(values[1], CBORTag) v = cbor2.loads(values[1].value) if isinstance(v, CBORTag): return _DatumOption(RawPlutusData.from_primitive(v)) @@ -334,6 +337,7 @@ def to_primitive(self) -> Primitive: @classmethod def from_primitive(cls: Type[_ScriptRef], value: Primitive) -> _ScriptRef: + assert isinstance(value, CBORTag) return cls(_Script.from_primitive(cbor2.loads(value.value))) @@ -343,9 +347,13 @@ class _TransactionOutputPostAlonzo(MapCBORSerializable): amount: Union[int, Value] = field(metadata={"key": 1}) - datum: _DatumOption = field(default=None, metadata={"key": 2, "optional": True}) + datum: Optional[_DatumOption] = field( + default=None, metadata={"key": 2, "optional": True} + ) - script_ref: _ScriptRef = field(default=None, metadata={"key": 3, "optional": True}) + script_ref: Optional[_ScriptRef] = field( + default=None, metadata={"key": 3, "optional": True} + ) @property def script(self) -> Optional[Union[NativeScript, PlutusV1Script, PlutusV2Script]]: @@ -361,7 +369,7 @@ class _TransactionOutputLegacy(ArrayCBORSerializable): amount: Union[int, Value] - datum_hash: DatumHash = field(default=None, metadata={"optional": True}) + datum_hash: Optional[DatumHash] = field(default=None, metadata={"optional": True}) @dataclass(repr=False) @@ -369,7 +377,7 @@ class TransactionOutput(CBORSerializable): address: Address - amount: Union[int, Value] + amount: Union[Value] datum_hash: Optional[DatumHash] = None @@ -496,30 +504,36 @@ class TransactionBody(MapCBORSerializable): fee: int = field(default=0, metadata={"key": 2}) - ttl: int = field(default=None, metadata={"key": 3, "optional": True}) + ttl: Optional[int] = field(default=None, metadata={"key": 3, "optional": True}) - certificates: List[Certificate] = field( + certificates: Optional[List[Certificate]] = field( default=None, metadata={"key": 4, "optional": True} ) - withdraws: Withdrawals = field(default=None, metadata={"key": 5, "optional": True}) + withdraws: Optional[Withdrawals] = field( + default=None, metadata={"key": 5, "optional": True} + ) # TODO: Add proposal update support update: Any = field(default=None, metadata={"key": 6, "optional": True}) - auxiliary_data_hash: AuxiliaryDataHash = field( + auxiliary_data_hash: Optional[AuxiliaryDataHash] = field( default=None, metadata={"key": 7, "optional": True} ) - validity_start: int = field(default=None, metadata={"key": 8, "optional": True}) + validity_start: Optional[int] = field( + default=None, metadata={"key": 8, "optional": True} + ) - mint: MultiAsset = field(default=None, metadata={"key": 9, "optional": True}) + mint: Optional[MultiAsset] = field( + default=None, metadata={"key": 9, "optional": True} + ) - script_data_hash: ScriptDataHash = field( + script_data_hash: Optional[ScriptDataHash] = field( default=None, metadata={"key": 11, "optional": True} ) - collateral: List[TransactionInput] = field( + collateral: Optional[List[TransactionInput]] = field( default=None, metadata={ "key": 13, @@ -528,7 +542,7 @@ class TransactionBody(MapCBORSerializable): }, ) - required_signers: List[VerificationKeyHash] = field( + required_signers: Optional[List[VerificationKeyHash]] = field( default=None, metadata={ "key": 14, @@ -537,15 +551,19 @@ class TransactionBody(MapCBORSerializable): }, ) - network_id: Network = field(default=None, metadata={"key": 15, "optional": True}) + network_id: Optional[Network] = field( + default=None, metadata={"key": 15, "optional": True} + ) - collateral_return: TransactionOutput = field( + collateral_return: Optional[TransactionOutput] = field( default=None, metadata={"key": 16, "optional": True} ) - total_collateral: int = field(default=None, metadata={"key": 17, "optional": True}) + total_collateral: Optional[int] = field( + default=None, metadata={"key": 17, "optional": True} + ) - reference_inputs: List[TransactionInput] = field( + reference_inputs: Optional[List[TransactionInput]] = field( default=None, metadata={ "key": 18, @@ -555,9 +573,7 @@ class TransactionBody(MapCBORSerializable): ) def hash(self) -> bytes: - return blake2b( - self.to_cbor(encoding="bytes"), TRANSACTION_HASH_SIZE, encoder=RawEncoder - ) + return blake2b(self.to_cbor(encoding="bytes"), TRANSACTION_HASH_SIZE, encoder=RawEncoder) # type: ignore @property def id(self) -> TransactionId: @@ -572,7 +588,7 @@ class Transaction(ArrayCBORSerializable): valid: bool = True - auxiliary_data: Union[AuxiliaryData, type(None)] = None + auxiliary_data: Optional[AuxiliaryData] = None @property def id(self) -> TransactionId: diff --git a/pycardano/txbuilder.py b/pycardano/txbuilder.py index c5fb3da1..0ebd7ff9 100644 --- a/pycardano/txbuilder.py +++ b/pycardano/txbuilder.py @@ -40,6 +40,7 @@ PlutusV2Script, Redeemer, RedeemerTag, + ScriptType, datum_hash, script_hash, ) @@ -87,23 +88,23 @@ class TransactionBuilder: execution_step_buffer: float = 0.2 """Additional amount of execution step (in ratio) that will be added on top of estimation""" - ttl: int = field(default=None) + ttl: Optional[int] = field(default=None) - validity_start: int = field(default=None) + validity_start: Optional[int] = field(default=None) - auxiliary_data: AuxiliaryData = field(default=None) + auxiliary_data: Optional[AuxiliaryData] = field(default=None) - native_scripts: List[NativeScript] = field(default=None) + native_scripts: Optional[List[NativeScript]] = field(default=None) - mint: MultiAsset = field(default=None) + mint: Optional[MultiAsset] = field(default=None) - required_signers: List[VerificationKeyHash] = field(default=None) + required_signers: Optional[List[VerificationKeyHash]] = field(default=None) collaterals: List[UTxO] = field(default_factory=lambda: []) - certificates: List[Certificate] = field(default=None) + certificates: Optional[List[Certificate]] = field(default=None) - withdrawals: Withdrawals = field(default=None) + withdrawals: Optional[Withdrawals] = field(default=None) reference_inputs: Set[TransactionInput] = field( init=False, default_factory=lambda: set() @@ -113,7 +114,9 @@ class TransactionBuilder: _excluded_inputs: List[UTxO] = field(init=False, default_factory=lambda: []) - _input_addresses: List[Address] = field(init=False, default_factory=lambda: []) + _input_addresses: List[Union[Address, str]] = field( + init=False, default_factory=lambda: [] + ) _outputs: List[TransactionOutput] = field(init=False, default_factory=lambda: []) @@ -121,19 +124,19 @@ class TransactionBuilder: _datums: Dict[DatumHash, Datum] = field(init=False, default_factory=lambda: {}) - _collateral_return: TransactionOutput = field(init=False, default=None) + _collateral_return: Optional[TransactionOutput] = field(init=False, default=None) - _total_collateral: int = field(init=False, default=None) + _total_collateral: Optional[int] = field(init=False, default=None) _inputs_to_redeemers: Dict[UTxO, Redeemer] = field( init=False, default_factory=lambda: {} ) - _minting_script_to_redeemers: List[Tuple[bytes, Redeemer]] = field( + _minting_script_to_redeemers: List[Tuple[ScriptType, Optional[Redeemer]]] = field( init=False, default_factory=lambda: [] ) - _inputs_to_scripts: Dict[UTxO, bytes] = field( + _inputs_to_scripts: Dict[UTxO, ScriptType] = field( init=False, default_factory=lambda: {} ) @@ -141,7 +144,7 @@ class TransactionBuilder: Union[NativeScript, PlutusV1Script, PlutusV2Script] ] = field(init=False, default_factory=lambda: []) - _should_estimate_execution_units: bool = field(init=False, default=None) + _should_estimate_execution_units: Optional[bool] = field(init=False, default=None) def add_input(self, utxo: UTxO) -> TransactionBuilder: """Add a specific UTxO to transaction's inputs. @@ -208,13 +211,18 @@ def add_script_input( f"Expect the output address of utxo to be script type, " f"but got {utxo.output.address.address_type} instead." ) - if utxo.output.datum_hash and utxo.output.datum_hash != datum_hash(datum): + + if ( + utxo.output.datum_hash + and datum is not None + and utxo.output.datum_hash != datum_hash(datum) + ): raise InvalidArgumentException( f"Datum hash in transaction output is {utxo.output.datum_hash}, " f"but actual datum hash from input datum is {datum_hash(datum)}." ) - if datum: + if datum is not None: self.datums[datum_hash(datum)] = datum if redeemer: @@ -233,6 +241,7 @@ def add_script_input( self._reference_scripts.append(i.output.script) break elif isinstance(script, UTxO): + assert script.output.script is not None self._inputs_to_scripts[utxo] = script.output.script self.reference_inputs.add(script.input) self._reference_scripts.append(script.output.script) @@ -265,6 +274,7 @@ def add_minting_script( self._consolidate_redeemer(redeemer) if isinstance(script, UTxO): + assert script.output.script is not None self._minting_script_to_redeemers.append((script.output.script, redeemer)) self.reference_inputs.add(script.input) self._reference_scripts.append(script.output.script) @@ -303,10 +313,10 @@ def add_output( Returns: TransactionBuilder: Current transaction builder. """ - if datum: + if datum is not None: tx_out.datum_hash = datum_hash(datum) self.outputs.append(tx_out) - if add_datum_to_witness: + if datum is not None and add_datum_to_witness: self.datums[datum_hash(datum)] = datum return self @@ -339,8 +349,9 @@ def fee(self, fee: int): self._fee = fee @property - def all_scripts(self) -> List[bytes]: - scripts = {} + def all_scripts(self) -> List[ScriptType]: + scripts: Dict[ScriptHash, ScriptType] = {} + s: ScriptType if self.native_scripts: for s in self.native_scripts: @@ -355,8 +366,11 @@ def all_scripts(self) -> List[bytes]: return list(scripts.values()) @property - def scripts(self) -> List[bytes]: - scripts = {script_hash(s): s for s in self.all_scripts} + def scripts(self) -> List[ScriptType]: + scripts: Dict[ScriptHash, ScriptType] = { + script_hash(s): s for s in self.all_scripts + } + s: ScriptType for s in self._reference_scripts: if script_hash(s) in scripts: @@ -370,8 +384,8 @@ def datums(self) -> Dict[DatumHash, Datum]: @property def redeemers(self) -> List[Redeemer]: - return list(self._inputs_to_redeemers.values()) + [ - r for _, r in self._minting_script_to_redeemers + return [r for r in self._inputs_to_redeemers.values() if r is not None] + [ + r for _, r in self._minting_script_to_redeemers if r is not None ] @property @@ -438,7 +452,7 @@ def _calc_change( f"Not enough ADA left for change: {change.coin} but needs " f"{min_lovelace_post_alonzo(TransactionOutput(address, change), self.context)}" ) - lovelace_change = change.coin + lovelace_change = Value(change.coin) change_output_arr.append(TransactionOutput(address, lovelace_change)) # If there are multi asset in the change @@ -582,6 +596,7 @@ def _pack_tokens_for_change( ) -> List[MultiAsset]: multi_asset_arr = [] base_coin = Value(coin=change_estimator.coin) + change_address = change_address or Address(FAKE_VKEY.hash()) output = TransactionOutput(change_address, base_coin) # iteratively add tokens to output @@ -801,15 +816,17 @@ def build_witness_set(self) -> TransactionWitnessSet: TransactionWitnessSet: A transaction witness set without verification key witnesses. """ - native_scripts = [] - plutus_v1_scripts = [] - plutus_v2_scripts = [] + native_scripts: List[NativeScript] = [] + plutus_v1_scripts: List[PlutusV1Script] = [] + plutus_v2_scripts: List[PlutusV2Script] = [] for script in self.scripts: if isinstance(script, NativeScript): native_scripts.append(script) - elif isinstance(script, PlutusV1Script) or type(script) is bytes: + elif isinstance(script, PlutusV1Script): plutus_v1_scripts.append(script) + elif type(script) is bytes: + plutus_v1_scripts.append(PlutusV1Script(script)) elif isinstance(script, PlutusV2Script): plutus_v2_scripts.append(script) else: @@ -944,11 +961,15 @@ def build( additional_utxo_pool.append(utxo) additional_amount += utxo.output.amount - for i, selector in enumerate(self.utxo_selectors): + for index, selector in enumerate(self.utxo_selectors): try: selected, _ = selector.select( additional_utxo_pool, - [TransactionOutput(None, unfulfilled_amount)], + [ + TransactionOutput( + Address(FAKE_VKEY.hash()), unfulfilled_amount + ) + ], self.context, include_max_fee=False, respect_min_utxo=not can_merge_change, @@ -960,7 +981,7 @@ def build( break except UTxOSelectionException as e: - if i < len(self.utxo_selectors) - 1: + if index < len(self.utxo_selectors) - 1: logger.info(e) logger.info(f"{selector} failed. Trying next selector.") else: @@ -1007,7 +1028,7 @@ def build( return tx_body - def _set_collateral_return(self, collateral_return_address: Address): + def _set_collateral_return(self, collateral_return_address: Optional[Address]): """Calculate and set the change returned from the collateral inputs. Args: @@ -1090,7 +1111,7 @@ def _add_collateral_input(cur_total, candidate_inputs): def _update_execution_units( self, change_address: Optional[Address] = None, - merge_change: bool = False, + merge_change: Optional[bool] = False, collateral_change_address: Optional[Address] = None, ): if self._should_estimate_execution_units: @@ -1099,7 +1120,10 @@ def _update_execution_units( ) for r in self.redeemers: key = f"{r.tag.name.lower()}:{r.index}" - if key not in estimated_execution_units: + if ( + key not in estimated_execution_units + or estimated_execution_units[key] is None + ): raise TransactionBuilderException( f"Cannot find execution unit for redeemer: {r} " f"in estimated execution units: {estimated_execution_units}" @@ -1115,9 +1139,9 @@ def _update_execution_units( def _estimate_execution_units( self, change_address: Optional[Address] = None, - merge_change: bool = False, + merge_change: Optional[bool] = False, collateral_change_address: Optional[Address] = None, - ): + ) -> Dict[str, ExecutionUnits]: # Create a deep copy of current builder, so we won't mess up current builder's internal states tmp_builder = TransactionBuilder(self.context) for f in fields(self): diff --git a/pycardano/utils.py b/pycardano/utils.py index eeaf3d82..80034791 100644 --- a/pycardano/utils.py +++ b/pycardano/utils.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import cbor2 from nacl.encoding import RawEncoder @@ -28,8 +28,8 @@ def fee( context: ChainContext, length: int, - exec_steps: Optional[int] = 0, - max_mem_unit: Optional[int] = 0, + exec_steps: int = 0, + max_mem_unit: int = 0, ) -> int: """Calculate fee based on the length of a transaction's CBOR bytes and script execution. @@ -122,7 +122,7 @@ def min_lovelace( def min_lovelace_pre_alonzo( - amount: Union[int, Value], context: ChainContext, has_datum: bool = False + amount: Union[int, Value, None], context: ChainContext, has_datum: bool = False ) -> int: """Calculate minimum lovelace a transaction output needs to hold. @@ -137,7 +137,7 @@ def min_lovelace_pre_alonzo( Returns: int: Minimum required lovelace amount for this transaction output. """ - if isinstance(amount, int) or not amount.multi_asset: + if amount is None or isinstance(amount, int) or not amount.multi_asset: return context.protocol_param.min_utxo b_size = bundle_size(amount.multi_asset) @@ -187,7 +187,7 @@ def min_lovelace_post_alonzo(output: TransactionOutput, context: ChainContext) - def script_data_hash( redeemers: List[Redeemer], datums: List[Datum], - cost_models: Optional[CostModels] = None, + cost_models: Optional[Union[CostModels, Dict]] = None, ) -> ScriptDataHash: """Calculate plutus script data hash diff --git a/pycardano/witness.py b/pycardano/witness.py index ffac2cbf..227cdefc 100644 --- a/pycardano/witness.py +++ b/pycardano/witness.py @@ -1,11 +1,11 @@ """Transaction witness.""" from dataclasses import dataclass, field -from typing import Any, List, Union +from typing import Any, List, Optional, Union from pycardano.key import ExtendedVerificationKey, VerificationKey from pycardano.nativescript import NativeScript -from pycardano.plutus import RawPlutusData, Redeemer +from pycardano.plutus import PlutusV1Script, PlutusV2Script, RawPlutusData, Redeemer from pycardano.serialization import ( ArrayCBORSerializable, MapCBORSerializable, @@ -29,7 +29,7 @@ def __post_init__(self): @dataclass(repr=False) class TransactionWitnessSet(MapCBORSerializable): - vkey_witnesses: List[VerificationKeyWitness] = field( + vkey_witnesses: Optional[List[VerificationKeyWitness]] = field( default=None, metadata={ "optional": True, @@ -38,30 +38,30 @@ class TransactionWitnessSet(MapCBORSerializable): }, ) - native_scripts: List[NativeScript] = field( + native_scripts: Optional[List[NativeScript]] = field( default=None, metadata={"optional": True, "key": 1, "object_hook": list_hook(NativeScript)}, ) # TODO: Add bootstrap witness (byron) support - bootstrap_witness: List[Any] = field( + bootstrap_witness: Optional[List[Any]] = field( default=None, metadata={"optional": True, "key": 2} ) - plutus_v1_script: List[bytes] = field( + plutus_v1_script: Optional[List[PlutusV1Script]] = field( default=None, metadata={"optional": True, "key": 3} ) - plutus_v2_script: List[bytes] = field( + plutus_v2_script: Optional[List[PlutusV2Script]] = field( default=None, metadata={"optional": True, "key": 6} ) - plutus_data: List[Any] = field( + plutus_data: Optional[List[Any]] = field( default=None, metadata={"optional": True, "key": 4, "object_hook": list_hook(RawPlutusData)}, ) - redeemer: List[Redeemer] = field( + redeemer: Optional[List[Redeemer]] = field( default=None, metadata={"optional": True, "key": 5, "object_hook": list_hook(Redeemer)}, ) diff --git a/pyproject.toml b/pyproject.toml index cd9c19c4..aac40106 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,18 +67,5 @@ ignore_missing_imports = true disable_error_code = ["str-bytes-safe"] python_version = 3.7 exclude = [ - '^pycardano/cip/cip8.py$', '^pycardano/crypto/bech32.py$', - '^pycardano/certificate.py$', - '^pycardano/coinselection.py$', - '^pycardano/exception.py$', - '^pycardano/hash.py$', - '^pycardano/key.py$', - '^pycardano/logging.py$', - '^pycardano/metadata.py$', - '^pycardano/plutus.py$', - '^pycardano/transaction.py$', - '^pycardano/txbuilder.py$', - '^pycardano/utils.py$', - '^pycardano/witness.py$', ] diff --git a/test/pycardano/test_metadata.py b/test/pycardano/test_metadata.py index d3d2d255..8a7379d3 100644 --- a/test/pycardano/test_metadata.py +++ b/test/pycardano/test_metadata.py @@ -9,7 +9,7 @@ AlonzoMetadata, AuxiliaryData, Metadata, - ShellayMarryMetadata, + ShelleyMarryMetadata, ) from pycardano.nativescript import ( InvalidBefore, @@ -46,7 +46,7 @@ def test_shelley_marry_metadata(): script = generate_script() m = Metadata(M_PRIMITIVE) - shelley_marry_m = ShellayMarryMetadata(m, [script]) + shelley_marry_m = ShelleyMarryMetadata(m, [script]) check_two_way_cbor(shelley_marry_m) @@ -76,7 +76,7 @@ def test_auxiliary_data(): plutus_scripts = [b"fake_script"] m = Metadata(M_PRIMITIVE) - shelley_marry_m = ShellayMarryMetadata(m, [script]) + shelley_marry_m = ShelleyMarryMetadata(m, [script]) alonzo_m = AlonzoMetadata(m, [script], plutus_scripts) check_two_way_cbor(AuxiliaryData(m))