Skip to content

Commit

Permalink
fix command tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Quexington committed Sep 26, 2024
1 parent d6c80c9 commit db75520
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 21 deletions.
29 changes: 23 additions & 6 deletions chia/_tests/cmds/test_cmd_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from chia._tests.environments.wallet import STANDARD_TX_ENDPOINT_ARGS, WalletTestFramework
from chia._tests.wallet.conftest import * # noqa
from chia.cmds.cmd_classes import (
_DECORATOR_APPLIED,
ChiaCommand,
Context,
NeedsCoinSelectionConfig,
Expand All @@ -21,12 +22,14 @@
TransactionEndpointWithTimelocks,
chia_command,
option,
transaction_endpoint_runner,
)
from chia.cmds.cmds_util import coin_selection_args, tx_config_args, tx_out_cmd
from chia.cmds.param_types import CliAmount
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
from chia.wallet.conditions import ConditionValidTimes
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.tx_config import CoinSelectionConfig, TXConfig


Expand All @@ -50,6 +53,9 @@ def new_run(self: Any) -> None:
# cmd is appropriately not recognized as a dataclass but I'm not sure how to hint that something is a dataclass
dict_compare_with_ignore_context(asdict(cmd), asdict(self)) # type: ignore[call-overload]

# We hack this in because more robust solutions are harder and probably not worth it
setattr(new_run, _DECORATOR_APPLIED, True)

setattr(mock_type, "run", new_run)
chia_command(_cmd, "_", "")(mock_type)

Expand Down Expand Up @@ -494,24 +500,35 @@ def run(self) -> None:
example_tx_config_cmd.run() # trigger inner assert


def test_transaction_endpoint_mixin() -> None:
@pytest.mark.anyio
async def test_transaction_endpoint_mixin() -> None:
@click.group()
def cmd() -> None:
pass # pragma: no cover

with pytest.raises(TypeError, match="transaction_endpoint_runner"):

@chia_command(cmd, "bad_cmd", "blah")
class BadCMD(TransactionEndpoint):

def run(self) -> None:
pass

BadCMD(**STANDARD_TX_ENDPOINT_ARGS)

@chia_command(cmd, "cs_cmd", "blah")
class TxCMD(TransactionEndpoint):

def run(self) -> None:
@transaction_endpoint_runner
async def run(self) -> List[TransactionRecord]:
assert self.load_condition_valid_times() == ConditionValidTimes(
min_time=uint64(10),
max_time=uint64(20),
)
return []

# Check that our default object lines up with the default options
check_click_parsing(
TxCMD(**STANDARD_TX_ENDPOINT_ARGS),
)
check_click_parsing(TxCMD(**STANDARD_TX_ENDPOINT_ARGS))

example_tx_cmd = TxCMD(
**{
Expand All @@ -535,7 +552,7 @@ def run(self) -> None:
"20",
)

example_tx_cmd.run() # trigger inner assert
await example_tx_cmd.run() # trigger inner assert


# While we sit in between two paradigms, this test is in place to ensure they remain in sync.
Expand Down
2 changes: 1 addition & 1 deletion chia/_tests/cmds/wallet/test_tx_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_cmd(**kwargs: Any) -> List[TransactionRecord]:

runner: CliRunner = CliRunner()
with runner.isolated_filesystem():
runner.invoke(test_cmd, ["--transaction-file", "./temp.transaction"])
runner.invoke(test_cmd, ["--transaction-file-out", "./temp.transaction"])
with open("./temp.transaction", "rb") as file:
assert TransactionBundle.from_bytes(file.read()) == TransactionBundle([STD_TX, STD_TX])
with open("./temp.push") as file2:
Expand Down
4 changes: 2 additions & 2 deletions chia/_tests/cmds/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,10 @@ async def cat_spend(
]
with CliRunner().isolated_filesystem():
run_cli_command_and_assert(
capsys, root_dir, command_args + [FINGERPRINT_ARG] + ["--transaction-file=temp"], assert_list
capsys, root_dir, command_args + [FINGERPRINT_ARG] + ["--transaction-file-out=temp"], assert_list
)
run_cli_command_and_assert(
capsys, root_dir, command_args + [CAT_FINGERPRINT_ARG] + ["--transaction-file=temp2"], cat_assert_list
capsys, root_dir, command_args + [CAT_FINGERPRINT_ARG] + ["--transaction-file-out=temp2"], cat_assert_list
)

with open("temp", "rb") as file:
Expand Down
23 changes: 23 additions & 0 deletions chia/cmds/cmd_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
List,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
get_args,
get_origin,
Expand Down Expand Up @@ -420,6 +422,9 @@ def load_tx_config(self, mojo_per_unit: int, config: Dict[str, Any], fingerprint
).to_tx_config(mojo_per_unit, config, fingerprint)


_DECORATOR_APPLIED = "_DECORATOR_APPLIED"


@dataclass(frozen=True)
class TransactionEndpoint:
rpc_info: NeedsWalletRPC
Expand Down Expand Up @@ -454,6 +459,10 @@ class TransactionEndpoint:
hidden=True,
)

def __post_init__(self) -> None:
if not hasattr(self.run, "_DECORATOR_APPLIED"): # type: ignore[attr-defined]
raise TypeError("TransactionEndpoints must utilize @transaction_endpoint_runner on their `run` method")

def load_condition_valid_times(self) -> ConditionValidTimes:
return ConditionValidTimes(
min_time=uint64.construct_optional(self.valid_at),
Expand All @@ -477,3 +486,17 @@ class TransactionEndpointWithTimelocks(TransactionEndpoint):
required=False,
default=None,
)


_T_TransactionEndpoint = TypeVar("_T_TransactionEndpoint", bound=TransactionEndpoint)


def transaction_endpoint_runner(
func: Callable[[_T_TransactionEndpoint], Coroutine[Any, Any, List[TransactionRecord]]]
) -> Callable[[_T_TransactionEndpoint], Coroutine[Any, Any, None]]:
async def wrapped_func(self: _T_TransactionEndpoint) -> None:
txs = await func(self)
self.transaction_writer.handle_transaction_output(txs)

setattr(wrapped_func, _DECORATOR_APPLIED, True)
return wrapped_func
8 changes: 4 additions & 4 deletions chia/cmds/cmds_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,11 @@ def tx_out_cmd(

def _tx_out_cmd(func: Callable[..., List[TransactionRecord]]) -> Callable[..., None]:
@timelock_args(enable=enable_timelock_args)
def original_cmd(transaction_file: Optional[str] = None, **kwargs: Any) -> None:
def original_cmd(transaction_file_out: Optional[str] = None, **kwargs: Any) -> None:
txs: List[TransactionRecord] = func(**kwargs)
if transaction_file is not None:
print(f"Writing transactions to file {transaction_file}:")
with open(Path(transaction_file), "wb") as file:
if transaction_file_out is not None:
print(f"Writing transactions to file {transaction_file_out}:")
with open(Path(transaction_file_out), "wb") as file:
file.write(bytes(TransactionBundle(txs)))

return click.option(
Expand Down
27 changes: 19 additions & 8 deletions chia/cmds/coins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

import click

from chia.cmds.cmd_classes import NeedsCoinSelectionConfig, NeedsWalletRPC, TransactionEndpoint, chia_command, option
from chia.cmds.cmd_classes import (
NeedsCoinSelectionConfig,
NeedsWalletRPC,
TransactionEndpoint,
chia_command,
option,
transaction_endpoint_runner,
)
from chia.cmds.cmds_util import cli_confirm
from chia.cmds.param_types import AmountParamType, Bytes32ParamType, CliAmount
from chia.cmds.wallet_funcs import get_mojo_per_unit, get_wallet_type, print_balance
Expand Down Expand Up @@ -164,17 +171,18 @@ class CombineCMD(TransactionEndpoint):
help="Sort coins from largest to smallest or smallest to largest.",
)

async def run(self) -> None:
@transaction_endpoint_runner
async def run(self) -> List[TransactionRecord]:
async with self.rpc_info.wallet_rpc() as wallet_rpc:
try:
wallet_type = await get_wallet_type(wallet_id=self.id, wallet_client=wallet_rpc.client)
mojo_per_unit = get_mojo_per_unit(wallet_type)
except LookupError:
print(f"Wallet id: {self.id} not found.")
return
return []
if not await wallet_rpc.client.get_synced():
print("Wallet not synced. Please wait.")
return
return []

tx_config = self.tx_config_loader.load_tx_config(mojo_per_unit, wallet_rpc.config, wallet_rpc.fingerprint)

Expand Down Expand Up @@ -212,6 +220,8 @@ async def run(self) -> None:
f"-f {wallet_rpc.fingerprint} -tx 0x{tx.name}"
)

return resp.transactions


@chia_command(
coins_cmd,
Expand Down Expand Up @@ -244,17 +254,18 @@ class SplitCMD(TransactionEndpoint):
help="The coin id of the coin we are splitting.",
)

async def run(self) -> None:
@transaction_endpoint_runner
async def run(self) -> List[TransactionRecord]:
async with self.rpc_info.wallet_rpc() as wallet_rpc:
try:
wallet_type = await get_wallet_type(wallet_id=self.id, wallet_client=wallet_rpc.client)
mojo_per_unit = get_mojo_per_unit(wallet_type)
except LookupError:
print(f"Wallet id: {self.id} not found.")
return
return []
if not await wallet_rpc.client.get_synced():
print("Wallet not synced. Please wait.")
return
return []

final_amount_per_coin = self.amount_per_coin.convert_amount(mojo_per_unit)

Expand Down Expand Up @@ -297,4 +308,4 @@ async def run(self) -> None:
"mojos or disable it by setting it to 0."
)

self.transaction_writer.handle_transaction_output(transactions)
return transactions

0 comments on commit db75520

Please sign in to comment.