diff --git a/docs/Account.md b/docs/Account.md index b041e4019..7a3bf4de3 100644 --- a/docs/Account.md +++ b/docs/Account.md @@ -13,6 +13,7 @@ A more detailed writeup on the topic can be found on [Perama's blogpost](https:/ * [Keys, signatures and signers](#keys-signatures-and-signers) * [Signer](#signer) * [MockSigner utility](#mocksigner-utility) + * [MockEthSigner utility](#mockethsigner-utility) * [Account entrypoint](#account-entrypoint) * [Call and AccountCallArray format](#call-and-accountcallarray-format) * [Call](#call) @@ -24,6 +25,12 @@ A more detailed writeup on the topic can be found on [Perama's blogpost](https:/ * [`set_public_key`](#set_public_key) * [`is_valid_signature`](#is_valid_signature) * [`__execute__`](#__execute__) + * [`is_valid_eth_signature`](#is_valid_eth_signature) + * [`eth_execute`](#eth_execute) + * [`_unsafe_execute`](#_unsafe_execute) +* [Presets](#presets) + * [Account](#account) + * [Eth Account](#eth-account) * [Account differentiation with ERC165](#account-differentiation-with-erc165) * [Extending the Account contract](#extending-the-account-contract) * [L1 escape hatch mechanism](#l1-escape-hatch-mechanism) @@ -159,6 +166,13 @@ If utilizing multicall, send multiple transactions with the `send_transactions` ) ``` +### MockEthSigner utility + +The `MockEthSigner` class in [utils.py](../tests/utils.py) is used to perform transactions on a given Account with a secp256k1 curve key pair, crafting the transaction and managing nonces. It differs from the `MockSigner` implementation by: + +* not using the public key but its derived address instead (the last 20 bytes of the keccak256 hash of the public key and adding `0x` to the beginning) +* signing the message with a secp256k1 curve address + ## Account entrypoint `__execute__` acts as a single entrypoint for all user interaction with any contract, including managing the account contract itself. That's why if you want to change the public key controlling the Account, you would send a transaction targeting the very Account contract: @@ -361,11 +375,10 @@ is_valid: felt This is the only external entrypoint to interact with the Account contract. It: -1. Takes the input and builds a `Call` for each iterated message. See [Multicall transactions](#multicall-transactions) for more information -2. Validates the transaction signature matches the message (including the nonce) -3. Increments the nonce -4. Calls the target contract with the intended function selector and calldata parameters -5. Forwards the contract call response data as return value +1. Validates the transaction signature matches the message (including the nonce) +2. Increments the nonce +3. Calls the target contract with the intended function selector and calldata parameters +4. Forwards the contract call response data as return value Parameters: @@ -386,6 +399,69 @@ response_len: felt response: felt* ``` +### `is_valid_eth_signature` + +Returns `TRUE` if a given signature in the secp256k1 curve is valid, otherwise it reverts. In the future it will return `FALSE` if a given signature is invalid (for more info please check [this issue](https://github.com/OpenZeppelin/cairo-contracts/issues/327)). + +Parameters: + +```cairo +signature_len: felt +signature: felt* +``` + +Returns: + +```cairo +is_valid: felt +``` + +> returns `TRUE` if a given signature is valid. Otherwise, reverts. In the future it will return `FALSE` if a given signature is invalid (for more info please check [this issue](https://github.com/OpenZeppelin/cairo-contracts/issues/327)). + +### `eth_execute` + +This follows the same idea as the vanilla version of `execute` with the sole difference that signature verification is on the secp256k1 curve. + +Parameters: + +```cairo +call_array_len: felt +call_array: AccountCallArray* +calldata_len: felt +calldata: felt* +nonce: felt +``` + +> Note that the current signature scheme expects a 7-element array like `[sig_v, uint256_sig_r_low, uint256_sig_r_high, uint256_sig_s_low, uint256_sig_s_high, uint256_hash_low, uint256_hash_high]` given that the parameters of the verification are bigger than a felt. + +Returns: + +```cairo +response_len: felt +response: felt* +``` + +### `_unsafe_execute` + +It's an internal method that performs the following tasks: + +1. Increments the nonce. +2. Takes the input and builds a `Call` for each iterated message. See [Multicall transactions](#multicall-transactions) for more information. +3. Calls the target contract with the intended function selector and calldata parameters +4. Forwards the contract call response data as return value + +## Presets + +The following contract presets are ready to deploy and can be used as-is for quick prototyping and testing. Each preset differs on the signature type being used by the Account. + +### Account + +The [`Account`](../src/openzeppelin/account/Account.cairo) preset uses StarkNet keys to validate transactions. + +### Eth Account + +The [`EthAccount`](../src/openzeppelin/account/EthAccount.cairo) preset supports Ethereum addresses, validating transactions with secp256k1 keys. + ## Account differentiation with ERC165 Certain contracts like ERC721 require a means to differentiate between account contracts and non-account contracts. For a contract to declare itself as an account, it should implement [ERC165](https://eips.ethereum.org/EIPS/eip-165) as proposed in [#100](https://github.com/OpenZeppelin/cairo-contracts/discussions/100). To be in compliance with ERC165 specifications, the idea is to calculate the XOR of `IAccount`'s EVM selectors (not StarkNet selectors). The resulting magic value of `IAccount` is 0x50b70dcb. @@ -394,13 +470,18 @@ Our ERC165 integration on StarkNet is inspired by OpenZeppelin's Solidity implem ## Extending the Account contract -Account contracts can be extended by following the [extensibility pattern](../docs/Extensibility.md#the-pattern). The basic idea behind integrating the pattern is to import the requisite methods from the Account library and incorporate the extended logic thereafter. +Account contracts can be extended by following the [extensibility pattern](../docs/Extensibility.md#the-pattern). + +To implement custom account contracts, a pair of `validate` and `execute` functions should be exposed. This is why the Account library comes with different flavors of such pairs, like the vanilla `is_valid_signature` and `execute`, or the Ethereum flavored `is_valid_eth_signature` and `eth_execute` pair. + +Account contract developers are encouraged to implement the [standard Account interface](https://github.com/OpenZeppelin/cairo-contracts/discussions/41) and incorporate the custom logic thereafter. + +To implement alternative `execute` functions, make sure to check their corresponding `validate` function before calling the `_unsafe_execute` building block, as each of the current presets is doing. Do not expose `_unsafe_execute` directly. -Currently, there's only a single library/preset Account scheme, but we're looking for feedback and new presets to emerge. Some new validation schemes to look out for in the future: +Some other validation schemes to look out for in the future: * multisig * guardian logic like in [Argent's account](https://github.com/argentlabs/argent-contracts-starknet/blob/de5654555309fa76160ba3d7393d32d2b12e7349/contracts/ArgentAccount.cairo) -* [Ethereum signatures](https://github.com/OpenZeppelin/cairo-contracts/issues/161) ## L1 escape hatch mechanism diff --git a/src/openzeppelin/account/Account.cairo b/src/openzeppelin/account/Account.cairo index b2a19d139..56f85d818 100644 --- a/src/openzeppelin/account/Account.cairo +++ b/src/openzeppelin/account/Account.cairo @@ -3,7 +3,7 @@ %lang starknet -from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin, BitwiseBuiltin from openzeppelin.account.library import Account, AccountCallArray @@ -95,7 +95,8 @@ func __execute__{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, - ecdsa_ptr: SignatureBuiltin* + ecdsa_ptr: SignatureBuiltin*, + bitwise_ptr: BitwiseBuiltin* }( call_array_len: felt, call_array: AccountCallArray*, diff --git a/src/openzeppelin/account/EthAccount.cairo b/src/openzeppelin/account/EthAccount.cairo new file mode 100644 index 000000000..87ec6f87c --- /dev/null +++ b/src/openzeppelin/account/EthAccount.cairo @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: MIT +# OpenZeppelin Contracts for Cairo v0.1.0 (account/EthAccount.cairo) + +%lang starknet +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin, BitwiseBuiltin +from openzeppelin.account.library import Account, AccountCallArray + +from openzeppelin.introspection.ERC165 import ERC165 + +# +# Constructor +# + +@constructor +func constructor{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr + }(eth_address: felt): + Account.initializer(eth_address) + return () +end + +# +# Getters +# + +@view +func get_eth_address{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr + }() -> (res: felt): + let (res) = Account.get_public_key() + return (res=res) +end + +@view +func get_nonce{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr + }() -> (res: felt): + let (res) = Account.get_nonce() + return (res=res) +end + +@view +func supportsInterface{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + range_check_ptr + } (interfaceId: felt) -> (success: felt): + let (success) = ERC165.supports_interface(interfaceId) + return (success) +end + +# +# Setters +# + +@external +func set_eth_address{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr + }(new_eth_address: felt): + Account.set_public_key(new_eth_address) + return () +end + +# +# Business logic +# + +@view +func is_valid_signature{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr, + ecdsa_ptr: SignatureBuiltin*, + bitwise_ptr: BitwiseBuiltin* + }( + hash: felt, + signature_len: felt, + signature: felt* + ) -> (is_valid: felt): + let (is_valid) = Account.is_valid_eth_signature(hash, signature_len, signature) + return (is_valid=is_valid) +end + +@external +func __execute__{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr, + bitwise_ptr: BitwiseBuiltin* + }( + call_array_len: felt, + call_array: AccountCallArray*, + calldata_len: felt, + calldata: felt*, + nonce: felt + ) -> (response_len: felt, response: felt*): + let (response_len, response) = Account.eth_execute( + call_array_len, + call_array, + calldata_len, + calldata, + nonce + ) + return (response_len=response_len, response=response) +end diff --git a/src/openzeppelin/account/library.cairo b/src/openzeppelin/account/library.cairo index c004f4a6b..3ae289a51 100644 --- a/src/openzeppelin/account/library.cairo +++ b/src/openzeppelin/account/library.cairo @@ -3,12 +3,14 @@ from starkware.cairo.common.registers import get_fp_and_pc from starkware.starknet.common.syscalls import get_contract_address from starkware.cairo.common.signature import verify_ecdsa_signature -from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin +from starkware.cairo.common.cairo_builtins import HashBuiltin, SignatureBuiltin, BitwiseBuiltin from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.uint256 import Uint256 from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.math import split_felt from starkware.cairo.common.bool import TRUE from starkware.starknet.common.syscalls import call_contract, get_caller_address, get_tx_info - +from starkware.cairo.common.cairo_secp.signature import verify_eth_signature_uint256 from openzeppelin.introspection.ERC165 import ERC165 from openzeppelin.utils.constants import IACCOUNT_ID @@ -141,12 +143,49 @@ namespace Account: return (is_valid=TRUE) end + func is_valid_eth_signature{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + bitwise_ptr: BitwiseBuiltin*, + range_check_ptr + }( + hash: felt, + signature_len: felt, + signature: felt* + ) -> (is_valid: felt): + alloc_locals + let (_public_key) = get_public_key() + let (__fp__, _) = get_fp_and_pc() + + # This interface expects a signature pointer and length to make + # no assumption about signature validation schemes. + # But this implementation does, and it expects a the sig_v, sig_r, + # sig_s, and hash elements. + let sig_v : felt = signature[0] + let sig_r : Uint256 = Uint256(low=signature[1], high=signature[2]) + let sig_s : Uint256 = Uint256(low=signature[3], high=signature[4]) + let (high, low) = split_felt(hash) + let msg_hash : Uint256 = Uint256(low=low, high=high) + + let (local keccak_ptr : felt*) = alloc() + + with keccak_ptr: + verify_eth_signature_uint256( + msg_hash=msg_hash, + r=sig_r, + s=sig_s, + v=sig_v, + eth_address=_public_key) + end + + return (is_valid=TRUE) + end func execute{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr, - ecdsa_ptr: SignatureBuiltin* + bitwise_ptr: BitwiseBuiltin* }( call_array_len: felt, call_array: AccountCallArray*, @@ -156,34 +195,81 @@ namespace Account: ) -> (response_len: felt, response: felt*): alloc_locals + let (__fp__, _) = get_fp_and_pc() + let (tx_info) = get_tx_info() + let (local ecdsa_ptr : SignatureBuiltin*) = alloc() + with ecdsa_ptr: + # validate transaction + with_attr error_message("Account: invalid signature"): + let (is_valid) = is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature) + assert is_valid = TRUE + end + end + + return _unsafe_execute(call_array_len, call_array, calldata_len, calldata, nonce) + end + + func eth_execute{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr, + bitwise_ptr: BitwiseBuiltin* + }( + call_array_len: felt, + call_array: AccountCallArray*, + calldata_len: felt, + calldata: felt*, + nonce: felt + ) -> (response_len: felt, response: felt*): + alloc_locals + + let (__fp__, _) = get_fp_and_pc() + let (tx_info) = get_tx_info() + + # validate transaction + with_attr error_message("Account: invalid secp256k1 signature"): + let (is_valid) = is_valid_eth_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature) + assert is_valid = TRUE + end + + return _unsafe_execute(call_array_len, call_array, calldata_len, calldata, nonce) + end + + func _unsafe_execute{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr, + bitwise_ptr: BitwiseBuiltin* + }( + call_array_len: felt, + call_array: AccountCallArray*, + calldata_len: felt, + calldata: felt*, + nonce: felt + ) -> (response_len: felt, response: felt*): + alloc_locals + let (caller) = get_caller_address() with_attr error_message("Account: no reentrant call"): assert caller = 0 end - let (__fp__, _) = get_fp_and_pc() - let (tx_info) = get_tx_info() - let (_current_nonce) = Account_current_nonce.read() - # validate nonce + + let (_current_nonce) = Account_current_nonce.read() + with_attr error_message("Account: nonce is invalid"): assert _current_nonce = nonce end + # bump nonce + Account_current_nonce.write(_current_nonce + 1) + # TMP: Convert `AccountCallArray` to 'Call'. let (calls : Call*) = alloc() _from_call_array_to_call(call_array_len, call_array, calldata, calls) let calls_len = call_array_len - # validate transaction - let (is_valid) = is_valid_signature(tx_info.transaction_hash, tx_info.signature_len, tx_info.signature) - with_attr error_message("Account: invalid signature"): - assert is_valid = TRUE - end - - # bump nonce - Account_current_nonce.write(_current_nonce + 1) - # execute call let (response : felt*) = alloc() let (response_len) = _execute_list(calls_len, calls, response) @@ -240,4 +326,5 @@ namespace Account: _from_call_array_to_call(call_array_len - 1, call_array + AccountCallArray.SIZE, calldata, calls + Call.SIZE) return () end + end diff --git a/tests/access/test_Ownable.py b/tests/access/test_Ownable.py index f23bd6501..a23539c8c 100644 --- a/tests/access/test_Ownable.py +++ b/tests/access/test_Ownable.py @@ -1,7 +1,7 @@ import pytest +from signers import MockSigner from starkware.starknet.testing.starknet import Starknet from utils import ( - MockSigner, ZERO_ADDRESS, assert_event_emitted, get_contract_class, diff --git a/tests/account/test_Account.py b/tests/account/test_Account.py index 1b5705486..2ed100a2f 100644 --- a/tests/account/test_Account.py +++ b/tests/account/test_Account.py @@ -1,6 +1,7 @@ import pytest from starkware.starknet.testing.starknet import Starknet -from utils import MockSigner, assert_revert, get_contract_class, cached_contract, TRUE +from signers import MockSigner +from utils import assert_revert, get_contract_class, cached_contract, TRUE signer = MockSigner(123456789987654321) diff --git a/tests/account/test_AddressRegistry.py b/tests/account/test_AddressRegistry.py index 2e8c6c600..bb9e2461b 100644 --- a/tests/account/test_AddressRegistry.py +++ b/tests/account/test_AddressRegistry.py @@ -1,6 +1,7 @@ import pytest from starkware.starknet.testing.starknet import Starknet -from utils import MockSigner, get_contract_class, cached_contract +from signers import MockSigner +from utils import get_contract_class, cached_contract signer = MockSigner(123456789987654321) diff --git a/tests/account/test_EthAccount.py b/tests/account/test_EthAccount.py new file mode 100644 index 000000000..30d0ecb14 --- /dev/null +++ b/tests/account/test_EthAccount.py @@ -0,0 +1,203 @@ +import pytest +from starkware.starknet.testing.starknet import Starknet +from starkware.starkware_utils.error_handling import StarkException +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from utils import assert_revert, get_contract_class, cached_contract, TRUE, FALSE +from signers import MockEthSigner + +private_key = b'\x01' * 32 +signer = MockEthSigner(b'\x01' * 32) +other = MockEthSigner(b'\x02' * 32) + +IACCOUNT_ID = 0xf10dbd44 + + +@pytest.fixture(scope='module') +def contract_defs(): + account_cls = get_contract_class('openzeppelin/account/EthAccount.cairo') + init_cls = get_contract_class("tests/mocks/Initializable.cairo") + attacker_cls = get_contract_class("tests/mocks/account_reentrancy.cairo") + + return account_cls, init_cls, attacker_cls + + +@pytest.fixture(scope='module') +async def account_init(contract_defs): + account_cls, init_cls, attacker_cls = contract_defs + starknet = await Starknet.empty() + + account1 = await starknet.deploy( + contract_class=account_cls, + constructor_calldata=[signer.eth_address] + ) + account2 = await starknet.deploy( + contract_class=account_cls, + constructor_calldata=[signer.eth_address] + ) + initializable1 = await starknet.deploy( + contract_class=init_cls, + constructor_calldata=[], + ) + initializable2 = await starknet.deploy( + contract_class=init_cls, + constructor_calldata=[], + ) + attacker = await starknet.deploy( + contract_class=attacker_cls, + constructor_calldata=[], + ) + + return starknet.state, account1, account2, initializable1, initializable2, attacker + + +@pytest.fixture +def account_factory(contract_defs, account_init): + account_cls, init_cls, attacker_cls = contract_defs + state, account1, account2, initializable1, initializable2, attacker = account_init + _state = state.copy() + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + initializable1 = cached_contract(_state, init_cls, initializable1) + initializable2 = cached_contract(_state, init_cls, initializable2) + attacker = cached_contract(_state, attacker_cls, attacker) + + return account1, account2, initializable1, initializable2, attacker + + +@pytest.mark.asyncio +async def test_constructor(account_factory): + account, *_ = account_factory + + execution_info = await account.get_eth_address().call() + assert execution_info.result == (signer.eth_address,) + + execution_info = await account.supportsInterface(IACCOUNT_ID).call() + assert execution_info.result == (TRUE,) + + +@pytest.mark.asyncio +async def test_execute(account_factory): + account, _, initializable, *_ = account_factory + + execution_info = await initializable.initialized().call() + assert execution_info.result == (FALSE,) + + _, hash, signature = await signer.send_transactions(account, [(initializable.contract_address, 'initialize', [])]) + + validity_info, *_ = await signer.send_transactions(account, [(account.contract_address, 'is_valid_signature', [hash, len(signature), *signature])]) + assert validity_info.result.response[0] == TRUE + + execution_info = await initializable.initialized().call() + assert execution_info.result == (TRUE,) + + # should revert if signature is not correct + await assert_revert( + signer.send_transactions(account, [(account.contract_address, 'is_valid_signature', [hash-1, len(signature), *signature])]), + reverted_with="Invalid signature" + ) + + +@pytest.mark.asyncio +async def test_multicall(account_factory): + account, _, initializable_1, initializable_2, _ = account_factory + + execution_info = await initializable_1.initialized().call() + assert execution_info.result == (FALSE,) + execution_info = await initializable_2.initialized().call() + assert execution_info.result == (FALSE,) + + await signer.send_transactions( + account, + [ + (initializable_1.contract_address, 'initialize', []), + (initializable_2.contract_address, 'initialize', []) + ] + ) + + execution_info = await initializable_1.initialized().call() + assert execution_info.result == (TRUE,) + execution_info = await initializable_2.initialized().call() + assert execution_info.result == (TRUE,) + + +@pytest.mark.asyncio +async def test_return_value(account_factory): + account, _, initializable, *_ = account_factory + + # initialize, set `initialized = 1` + await signer.send_transactions(account, [(initializable.contract_address, 'initialize', [])]) + + read_info, *_ = await signer.send_transactions(account, [(initializable.contract_address, 'initialized', [])]) + call_info = await initializable.initialized().call() + (call_result, ) = call_info.result + assert read_info.result.response == [call_result] # 1 + + +@ pytest.mark.asyncio +async def test_nonce(account_factory): + account, _, initializable, *_ = account_factory + + # bump nonce + _, hash, signature = await signer.send_transactions(account, [(initializable.contract_address, 'initialized', [])]) + + execution_info = await account.get_nonce().call() + current_nonce = execution_info.result.res + + # lower nonce + await assert_revert( + signer.send_transactions(account, [(initializable.contract_address, 'initialize', [])], current_nonce - 1), + reverted_with="Account: nonce is invalid" + ) + + # higher nonce + await assert_revert( + signer.send_transactions(account, [(initializable.contract_address, 'initialize', [])], current_nonce + 1), + reverted_with="Account: nonce is invalid" + ) + + # right nonce + await signer.send_transactions(account, [(initializable.contract_address, 'initialize', [])], current_nonce) + + execution_info = await initializable.initialized().call() + assert execution_info.result == (TRUE,) + + +@pytest.mark.asyncio +async def test_eth_address_setter(account_factory): + account, *_ = account_factory + + execution_info = await account.get_eth_address().call() + assert execution_info.result == (signer.eth_address,) + + # set new pubkey + await signer.send_transactions(account, [(account.contract_address, 'set_eth_address', [other.eth_address])]) + + execution_info = await account.get_eth_address().call() + assert execution_info.result == (other.eth_address,) + + +@pytest.mark.asyncio +async def test_eth_address_setter_different_account(account_factory): + account, bad_account, *_ = account_factory + + # set new pubkey + await assert_revert( + signer.send_transactions( + bad_account, + [(account.contract_address, 'set_eth_address', [other.eth_address])] + ), + reverted_with="Account: caller is not this account" + ) + + +@pytest.mark.asyncio +async def test_account_takeover_with_reentrant_call(account_factory): + account, _, _, _, attacker = account_factory + + await assert_revert( + signer.send_transaction(account, attacker.contract_address, 'account_takeover', []), + reverted_with="Account: no reentrant call" + ) + + execution_info = await account.get_eth_address().call() + assert execution_info.result == (signer.eth_address,) diff --git a/tests/security/test_pausable.py b/tests/security/test_pausable.py index 9a3c80808..4deb6365a 100644 --- a/tests/security/test_pausable.py +++ b/tests/security/test_pausable.py @@ -1,10 +1,12 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( TRUE, FALSE, assert_revert, assert_event_emitted, - get_contract_class, cached_contract, MockSigner + get_contract_class, cached_contract ) + signer = MockSigner(12345678987654321) @pytest.fixture diff --git a/tests/signers.py b/tests/signers.py new file mode 100644 index 000000000..14f7a81b8 --- /dev/null +++ b/tests/signers.py @@ -0,0 +1,95 @@ +from nile.signer import Signer, from_call_to_call_array, get_transaction_hash +from utils import to_uint +import eth_keys + +class MockSigner(): + """ + Utility for sending signed transactions to an Account on Starknet. + + Parameters + ---------- + + private_key : int + + Examples + --------- + Constructing a MockSigner object + + >>> signer = MockSigner(1234) + + Sending a transaction + + >>> await signer.send_transaction( + account, contract_address, 'contract_method', [arg_1] + ) + + Sending multiple transactions + + >>> await signer.send_transaction( + account, [ + (contract_address, 'contract_method', [arg_1]), + (contract_address, 'another_method', [arg_1, arg_2]) + ] + ) + + """ + def __init__(self, private_key): + self.signer = Signer(private_key) + self.public_key = self.signer.public_key + + async def send_transaction(self, account, to, selector_name, calldata, nonce=None, max_fee=0): + return await self.send_transactions(account, [(to, selector_name, calldata)], nonce, max_fee) + + async def send_transactions(self, account, calls, nonce=None, max_fee=0): + if nonce is None: + execution_info = await account.get_nonce().call() + nonce, = execution_info.result + + build_calls = [] + for call in calls: + build_call = list(call) + build_call[0] = hex(build_call[0]) + build_calls.append(build_call) + + (call_array, calldata, sig_r, sig_s) = self.signer.sign_transaction(hex(account.contract_address), build_calls, nonce, max_fee) + return await account.__execute__(call_array, calldata, nonce).invoke(signature=[sig_r, sig_s]) + +class MockEthSigner(): + """ + Utility for sending signed transactions to an Account on Starknet, like MockSigner, but using a secp256k1 signature. + Parameters + ---------- + private_key : int + + """ + def __init__(self, private_key): + self.signer = eth_keys.keys.PrivateKey(private_key) + self.eth_address = int(self.signer.public_key.to_checksum_address(),0) + + async def send_transaction(self, account, to, selector_name, calldata, nonce=None, max_fee=0): + return await self.send_transactions(account, [(to, selector_name, calldata)], nonce, max_fee) + + async def send_transactions(self, account, calls, nonce=None, max_fee=0): + if nonce is None: + execution_info = await account.get_nonce().call() + nonce, = execution_info.result + + build_calls = [] + for call in calls: + build_call = list(call) + build_call[0] = hex(build_call[0]) + build_calls.append(build_call) + + (call_array, calldata) = from_call_to_call_array(build_calls) + message_hash = get_transaction_hash( + account.contract_address, call_array, calldata, nonce, max_fee + ) + + signature = self.signer.sign_msg_hash((message_hash).to_bytes(32, byteorder="big")) + sig_r = to_uint(signature.r) + sig_s = to_uint(signature.s) + + # the hash and signature are returned for other tests to use + return await account.__execute__(call_array, calldata, nonce).invoke( + signature=[signature.v, *sig_r, *sig_s] + ), message_hash, [signature.v, *sig_r, *sig_s] diff --git a/tests/token/erc20/test_ERC20.py b/tests/token/erc20/test_ERC20.py index 5e85f0f7b..ca79d46a8 100644 --- a/tests/token/erc20/test_ERC20.py +++ b/tests/token/erc20/test_ERC20.py @@ -1,11 +1,13 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, to_uint, add_uint, sub_uint, str_to_felt, MAX_UINT256, + to_uint, add_uint, sub_uint, str_to_felt, MAX_UINT256, ZERO_ADDRESS, INVALID_UINT256, TRUE, get_contract_class, cached_contract, assert_revert, assert_event_emitted, contract_path ) + signer = MockSigner(123456789987654321) # testing vars diff --git a/tests/token/erc20/test_ERC20_Burnable_mock.py b/tests/token/erc20/test_ERC20_Burnable_mock.py index caff5f936..1a67d7e28 100644 --- a/tests/token/erc20/test_ERC20_Burnable_mock.py +++ b/tests/token/erc20/test_ERC20_Burnable_mock.py @@ -1,10 +1,12 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, to_uint, add_uint, sub_uint, str_to_felt, ZERO_ADDRESS, INVALID_UINT256, - get_contract_class, cached_contract, assert_revert, assert_event_emitted, + to_uint, add_uint, sub_uint, str_to_felt, ZERO_ADDRESS, INVALID_UINT256, + get_contract_class, cached_contract, assert_revert, assert_event_emitted ) + signer = MockSigner(123456789987654321) # testing vars diff --git a/tests/token/erc20/test_ERC20_Mintable.py b/tests/token/erc20/test_ERC20_Mintable.py index 94689750b..ec213bcd6 100644 --- a/tests/token/erc20/test_ERC20_Mintable.py +++ b/tests/token/erc20/test_ERC20_Mintable.py @@ -1,11 +1,13 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, to_uint, add_uint, sub_uint, str_to_felt, + to_uint, add_uint, sub_uint, str_to_felt, MAX_UINT256, ZERO_ADDRESS, INVALID_UINT256, get_contract_class, cached_contract, assert_revert, assert_event_emitted ) + signer = MockSigner(123456789987654321) # testing vars diff --git a/tests/token/erc20/test_ERC20_Pausable.py b/tests/token/erc20/test_ERC20_Pausable.py index a890e9ae9..9295b28e4 100644 --- a/tests/token/erc20/test_ERC20_Pausable.py +++ b/tests/token/erc20/test_ERC20_Pausable.py @@ -1,10 +1,12 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, TRUE, FALSE, to_uint, str_to_felt, assert_revert, + TRUE, FALSE, to_uint, str_to_felt, assert_revert, get_contract_class, cached_contract ) + signer = MockSigner(123456789987654321) # testing vars diff --git a/tests/token/erc20/test_ERC20_Upgradeable.py b/tests/token/erc20/test_ERC20_Upgradeable.py index 65aed3386..053b3ce36 100644 --- a/tests/token/erc20/test_ERC20_Upgradeable.py +++ b/tests/token/erc20/test_ERC20_Upgradeable.py @@ -1,11 +1,13 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, to_uint, sub_uint, str_to_felt, assert_revert, + to_uint, sub_uint, str_to_felt, assert_revert, get_contract_class, cached_contract ) + signer = MockSigner(123456789987654321) USER = 999 diff --git a/tests/token/erc721/test_ERC721_Mintable_Burnable.py b/tests/token/erc721/test_ERC721_Mintable_Burnable.py index 72a265afb..2bc21c62f 100644 --- a/tests/token/erc721/test_ERC721_Mintable_Burnable.py +++ b/tests/token/erc721/test_ERC721_Mintable_Burnable.py @@ -1,7 +1,8 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, str_to_felt, ZERO_ADDRESS, TRUE, FALSE, assert_revert, INVALID_UINT256, + str_to_felt, ZERO_ADDRESS, TRUE, FALSE, assert_revert, INVALID_UINT256, assert_event_emitted, get_contract_class, cached_contract, to_uint, sub_uint, add_uint ) diff --git a/tests/token/erc721/test_ERC721_Mintable_Pausable.py b/tests/token/erc721/test_ERC721_Mintable_Pausable.py index 5249774f9..d6f9ef698 100644 --- a/tests/token/erc721/test_ERC721_Mintable_Pausable.py +++ b/tests/token/erc721/test_ERC721_Mintable_Pausable.py @@ -1,7 +1,8 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, str_to_felt, TRUE, FALSE, get_contract_class, cached_contract, + str_to_felt, TRUE, FALSE, get_contract_class, cached_contract, assert_revert, to_uint ) diff --git a/tests/token/erc721/test_ERC721_SafeMintable_mock.py b/tests/token/erc721/test_ERC721_SafeMintable_mock.py index 9e757a43e..7fc8463e6 100644 --- a/tests/token/erc721/test_ERC721_SafeMintable_mock.py +++ b/tests/token/erc721/test_ERC721_SafeMintable_mock.py @@ -1,7 +1,8 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, str_to_felt, ZERO_ADDRESS, INVALID_UINT256, assert_revert, + str_to_felt, ZERO_ADDRESS, INVALID_UINT256, assert_revert, assert_event_emitted, get_contract_class, cached_contract, to_uint ) diff --git a/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py b/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py index e6d01ccc3..11ef5c08f 100644 --- a/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py +++ b/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py @@ -1,7 +1,8 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, str_to_felt, MAX_UINT256, get_contract_class, cached_contract, + str_to_felt, MAX_UINT256, get_contract_class, cached_contract, TRUE, assert_revert, to_uint, sub_uint, add_uint ) diff --git a/tests/upgrades/test_Proxy.py b/tests/upgrades/test_Proxy.py index 4a9a94538..11ff100ae 100644 --- a/tests/upgrades/test_Proxy.py +++ b/tests/upgrades/test_Proxy.py @@ -1,7 +1,7 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, assert_revert, get_contract_class, cached_contract, diff --git a/tests/upgrades/test_upgrades.py b/tests/upgrades/test_upgrades.py index 8d4af3573..f9e2fc71a 100644 --- a/tests/upgrades/test_upgrades.py +++ b/tests/upgrades/test_upgrades.py @@ -1,7 +1,7 @@ import pytest from starkware.starknet.testing.starknet import Starknet +from signers import MockSigner from utils import ( - MockSigner, assert_revert, assert_revert_entry_point, assert_event_emitted, diff --git a/tests/utils.py b/tests/utils.py index a97a8548a..7a86e87cb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,6 @@ from starkware.starkware_utils.error_handling import StarkException from starkware.starknet.testing.starknet import StarknetContract from starkware.starknet.business_logic.execution.objects import Event -from nile.signer import Signer MAX_UINT256 = (2**128 - 1, 2**128 - 1) @@ -130,56 +129,3 @@ def cached_contract(state, _class, deployed): deploy_execution_info=deployed.deploy_execution_info ) return contract - - -class MockSigner(): - """ - Utility for sending signed transactions to an Account on Starknet. - - Parameters - ---------- - - private_key : int - - Examples - --------- - Constructing a MockSigner object - - >>> signer = MockSigner(1234) - - Sending a transaction - - >>> await signer.send_transaction( - account, contract_address, 'contract_method', [arg_1] - ) - - Sending multiple transactions - - >>> await signer.send_transaction( - account, [ - (contract_address, 'contract_method', [arg_1]), - (contract_address, 'another_method', [arg_1, arg_2]) - ] - ) - - """ - def __init__(self, private_key): - self.signer = Signer(private_key) - self.public_key = self.signer.public_key - - async def send_transaction(self, account, to, selector_name, calldata, nonce=None, max_fee=0): - return await self.send_transactions(account, [(to, selector_name, calldata)], nonce, max_fee) - - async def send_transactions(self, account, calls, nonce=None, max_fee=0): - if nonce is None: - execution_info = await account.get_nonce().call() - nonce, = execution_info.result - - build_calls = [] - for call in calls: - build_call = list(call) - build_call[0] = hex(build_call[0]) - build_calls.append(build_call) - - (call_array, calldata, sig_r, sig_s) = self.signer.sign_transaction(hex(account.contract_address), build_calls, nonce, max_fee) - return await account.__execute__(call_array, calldata, nonce).invoke(signature=[sig_r, sig_s])