diff --git a/README.md b/README.md index f3b303e11..36e395fdc 100644 --- a/README.md +++ b/README.md @@ -226,20 +226,20 @@ docker run cairo-tests This repo utilizes the [pytest-xdist](https://pytest-xdist.readthedocs.io/en/latest/) plugin which runs tests in parallel. This feature increases testing speed; however, conflicts with a shared state can occur since tests do not run in order. To overcome this, independent cached versions of contracts being tested should be provisioned to each test case. Here's a simple fixture example: ```python -from utils import get_contract_def, cached_contract +from utils import get_contract_class, cached_contract @pytest.fixture(scope='module') def foo_factory(): - # get contract definition - foo_def = get_contract_def('path/to/foo.cairo') + # get contract class + foo_cls = get_contract_class('path/to/foo.cairo') # deploy contract starknet = await Starknet.empty() - foo = await starknet.deploy(contract_def=foo_def) + foo = await starknet.deploy(contract_class=foo_cls) # copy the state and cache contract state = starknet.state.copy() - cached_foo = cached_contract(state, foo_def, foo) + cached_foo = cached_contract(state, foo_cls, foo) return cached_foo ``` diff --git a/docs/Proxies.md b/docs/Proxies.md index cd537e1d5..a0d9ba770 100644 --- a/docs/Proxies.md +++ b/docs/Proxies.md @@ -16,6 +16,7 @@ * [Events](#events) * [Using proxies](#using-proxies) * [Contract upgrades](#contract-upgrades) + * [Declaring contracts](#declaring-contracts) * [Handling method calls](#handling-method-calls) * [Presets](#presets) @@ -23,24 +24,23 @@ The general workflow is: -1. deploy implementation contract -2. deploy proxy contract with the implementation contract's address set in the proxy's constructor calldata -3. initialize the implementation contract by sending a call to the proxy contract. This will redirect the call to the implementation contract and behave like the implementation contract's constructor +1. declare an implementation [contract class](https://starknet.io/docs/hello_starknet/intro.html#declare-the-contract-on-the-starknet-testnet) +2. deploy proxy contract with the implementation contract's class hash set in the proxy's constructor calldata +3. initialize the implementation contract by sending a call to the proxy contract. This will redirect the call to the implementation contract class and behave like the implementation contract's constructor In Python, this would look as follows: ```python - # deploy implementation - IMPLEMENTATION = await starknet.deploy( + # declare implementation contract + IMPLEMENTATION = await starknet.declare( "path/to/implementation.cairo", - constructor_calldata=[] ) # deploy proxy PROXY = await starknet.deploy( "path/to/proxy.cairo", constructor_calldata=[ - IMPLEMENTATION.contract_address, # set implementation address + IMPLEMENTATION.class_hash, # set implementation contract class hash ] ) @@ -74,7 +74,7 @@ The StarkNet compiler, meanwhile, already creates pseudo-random storage addresse A proxy contract is a contract that delegates function calls to another contract. This type of pattern decouples state and logic. Proxy contracts store the state and redirect function calls to an implementation contract that handles the logic. This allows for different patterns such as upgrades, where implementation contracts can change but the proxy contract (and thus the state) does not; as well as deploying multiple proxy instances pointing to the same implementation. This can be useful to deploy many contracts with identical logic but unique initialization data. -In the case of contract upgrades, it is achieved by simply changing the proxy's reference to the implementation contract. This allows developers to add features, update logic, and fix bugs without touching the state or the contract address to interact with the application. +In the case of contract upgrades, it is achieved by simply changing the proxy's reference to the class hash of the declared implementation. This allows developers to add features, update logic, and fix bugs without touching the state or the contract address to interact with the application. ### Proxy contract @@ -84,7 +84,7 @@ The [Proxy contract](../src/openzeppelin/upgrades/Proxy.cairo) includes two core 2. The `__l1_default__` method is also a fallback method; however, it redirects the function call and associated calldata to a layer one contract. In order to invoke `__l1_default__`, the original function call must include the library function `send_message_to_l1`. See Cairo's [Interacting with L1 contracts](https://www.cairo-lang.org/docs/hello_starknet/l1l2.html) for more information. -Since this proxy is designed to work both as an [UUPS-flavored upgrade proxy](https://eips.ethereum.org/EIPS/eip-1822) as well as a non-upgradeable proxy, it does not know how to handle its own state. Therefore it requires the implementation contract to be deployed beforehand, so its address can be passed to the Proxy on construction time. +Since this proxy is designed to work both as an [UUPS-flavored upgrade proxy](https://eips.ethereum.org/EIPS/eip-1822) as well as a non-upgradeable proxy, it does not know how to handle its own state. Therefore it requires the implementation contract class to be declared beforehand, so its class hash can be passed to the Proxy on construction time. When interacting with the contract, function calls should be sent by the user to the proxy. The proxy's fallback function redirects the function call to the implementation contract to execute. @@ -104,7 +104,8 @@ If the implementation is upgradeable, it should: The implementation contract should NOT: -* deploy with a traditional constructor (decorated with `@constructor`). Instead, use an initializer method that invokes the Proxy `constructor`. +* be deployed like a regular contract. Instead, the implementation contract should be declared (which creates a `DeclaredClass` containing its hash and abi) +* set its initial state with a traditional constructor (decorated with `@constructor`). Instead, use an initializer method that invokes the Proxy `constructor`. > Note that the Proxy `constructor` includes a check the ensures the initializer can only be called once; however, `_set_implementation` does not include this check. It's up to the developers to protect their implementation contract's upgradeability with access controls such as [`assert_only_admin`](#assert_only_admin). @@ -117,26 +118,26 @@ For a full implementation contract example, please see: ### Methods ```cairo -func constructor(proxy_admin: felt): +func initializer(proxy_admin: felt): end -func _set_implementation(new_implementation: felt): +func assert_only_admin(): end -func _set_admin(new_admin: felt): +func get_implementation_hash() -> (implementation: felt): end -func get_implementation() -> (implementation: felt): +func get_admin() -> (admin: felt): end -func get_admin() -> (admin: felt): +func _set_admin(new_admin: felt): end -func assert_only_admin(): +func _set_implementation_hash(new_implementation: felt): end ``` -#### `constructor` +#### `initializer` Initializes the proxy contract with an initial implementation. @@ -150,37 +151,21 @@ Returns: None. -#### `_set_implementation` +#### `assert_only_admin` -Sets the implementation contract. This method is included in the proxy contract's constructor and is furthermore used to upgrade contracts. +Reverts if called by any account other than the admin. Parameters: -```cairo -new_implementation: felt -``` - -Returns: - None. -#### `_set_admin` - -Sets the admin of the proxy contract. - -Parameters: - -```cairo -new_admin: felt -``` - Returns: None. #### `get_implementation` -Returns the current implementation address. +Returns the current implementation hash. Parameters: @@ -206,14 +191,30 @@ Returns: admin: felt ``` -#### `assert_only_admin` +#### `_set_admin` -Throws if called by any account other than the admin. +Sets `new_admin` as the admin of the proxy contract. Parameters: +```cairo +new_admin: felt +``` + +Returns: + None. +#### `_set_implementation_hash` + +Sets `new_implementation` as the implementation's contract class. This method is included in the proxy contract's constructor and can be used to upgrade contracts. + +Parameters: + +```cairo +new_implementation: felt +``` + Returns: None. @@ -223,11 +224,14 @@ None. ```cairo func Upgraded(implementation: felt): end + +func AdminChanged(previousAdmin: felt, newAdmin: felt): +end ``` #### `Upgraded` -Emitted when a proxy contract sets a new implementation address. +Emitted when a proxy contract sets a new implementation class hash. Parameters: @@ -235,6 +239,17 @@ Parameters: implementation: felt ``` +#### `AdminChanged` + +Emitted when the `admin` changes from `previousAdmin` to `newAdmin`. + +Parameters: + +```cairo +previousAdmin: felt +newAdmin: felt +``` + ## Using proxies ### Contract upgrades @@ -242,30 +257,28 @@ implementation: felt To upgrade a contract, the implementation contract should include an `upgrade` method that, when called, changes the reference to a new deployed contract like this: ```python - # deploy first implementation - IMPLEMENTATION = await starknet.deploy( + # declare first implementation + IMPLEMENTATION = await starknet.declare( "path/to/implementation.cairo", - constructor_calldata=[] ) # deploy proxy PROXY = await starknet.deploy( "path/to/proxy.cairo", constructor_calldata=[ - IMPLEMENTATION.contract_address, # set implementation address + IMPLEMENTATION.class_hash, # set implementation hash ] ) - # deploy implementation v2 - IMPLEMENTATION_V2 = await starknet.deploy( + # declare implementation v2 + IMPLEMENTATION_V2 = await starknet.declare( "path/to/implementation_v2.cairo", - constructor_calldata=[] ) - # call upgrade with the new implementation contract address + # call upgrade with the new implementation contract class hash await signer.send_transaction( account, PROXY.contract_address, 'upgrade', [ - IMPLEMENTATION_V2.contract_address + IMPLEMENTATION_V2.class_hash ] ) ``` @@ -275,6 +288,10 @@ For a full deployment and upgrade implementation, please see: * [Upgrades V1](../tests/mocks/upgrades_v1_mock.cairo) * [Upgrades V2](../tests/mocks/upgrades_v2_mock.cairo) +### Declaring contracts + +StarkNet contracts come in two forms: contract classes and contract instances. Contract classes represent the uninstantiated, stateless code; whereas, contract instances are instantiated and include the state. Since the Proxy contract references the implementation contract by its class hash, declaring an implementation contract proves sufficient (as opposed to a full deployment). For more information on declaring classes, see [StarkNet's documentation](https://starknet.io/docs/hello_starknet/intro.html#declare-contract). + ### Handling method calls As with most StarkNet contracts, interacting with a proxy contract requires an [account abstraction](../docs/Account.md#quickstart). One notable difference with proxy contracts versus other contract implementations is that calling `@view` methods also requires an account abstraction. As of now, direct calls to default entrypoints are only supported by StarkNet's `syscalls` from other contracts i.e. account contracts. The differences in getter methods written in Python, for example, are as follows: diff --git a/docs/Utilities.md b/docs/Utilities.md index d45b7afb1..09dc7ee4e 100644 --- a/docs/Utilities.md +++ b/docs/Utilities.md @@ -18,11 +18,12 @@ The following documentation provides context, reasoning, and examples for method * [`sub_uint`](#sub_uint) * [Assertions](#assertions) * [`assert_revert`](#assert_revert) + * [`assert_revert_entry_point`](#assert_revert_entry_point) * [`assert_events_emitted`](#assert_event_emitted) * [Memoization](#memoization) - * [`get_contract_def`](#get_contract_def) + * [`get_contract_class`](#get_contract_class) * [`cached_contract`](#cached_contract) -* [Signer](#signer) +* [MockSigner](#mocksigner) ## Constants @@ -156,6 +157,19 @@ await assert_revert(signer.send_transaction( ) ``` +### `assert_revert_entry_point` + +An extension of `assert_revert` that asserts an entry point error occurs with the given `invalid_selector` parameter. This assertion is especially useful in checking proxy/implementation contracts. To use `assert_revert_entry_point`: + +```python +await assert_revert_entry_point( + signer.send_transaction( + account, contract.contract_address, 'nonexistent_selector', [] + ), + invalid_selector='nonexistent_selector' +) +``` + ### `assert_event_emitted` A helper method that checks a transaction receipt for the contract emitting the event (`from_address`), the emitted event itself (`name`), and the arguments emitted (`data`). To use `assert_event_emitted`: @@ -185,12 +199,12 @@ assert_event_emitted( Memoizing functions allow for quicker and computationally cheaper calculations which is immensely beneficial while testing smart contracts. -### `get_contract_def` +### `get_contract_class` -A helper method that returns the contract definition from the given path. To capture the contract definition, simply add the contracat path as an argument like this: +A helper method that returns the contract class from the given path. To capture the contract class, simply add the contract path as an argument like this: ```python -contract_definition = get_contract_def('path/to/contract.cairo') +contract_class = get_contract_class('path/to/contract.cairo') ``` ### `cached_contract` @@ -198,31 +212,31 @@ contract_definition = get_contract_def('path/to/contract.cairo') A helper method that returns the cached state of a given contract. It's recommended to first deploy all the relevant contracts before caching the state. The requisite contracts in the testing module should each be instantiated with `cached_contract` in a fixture after the state has been copied. The memoization pattern with `cached_contract` should look something like this: ```python -# get contract definitions +# get contract classes @pytest.fixture(scope='module') -def contract_defs(): - foo_def = get_contract_def('path/to/foo.cairo') - return foo_def +def contract_classes(): + foo_cls = get_contract_class('path/to/foo.cairo') + return foo_cls # deploy contracts @pytest.fixture(scope='module') -async def foo_init(contract_defs): - foo_def = contract_defs +async def foo_init(contract_classes): + foo_cls = contract_classes starknet = await Starknet.empty() foo = await starknet.deploy( - contract_def=foo_def, + contract_class=foo_cls, constructor_calldata=[] ) return starknet.state, foo # return state and all deployed contracts # memoization @pytest.fixture(scope='module') -def foo_factory(contract_defs, foo_init): - foo_def = contract_defs # contract definitions - state, foo = foo_init # state and deployed contracts - _state = state.copy() # copy the state - cached_foo = cached_contract(_state, foo_def, foo) # cache contracts - return cached_foo # return cached contracts +def foo_factory(contract_classes, foo_init): + foo_cls = contract_classes # contract classes + state, foo = foo_init # state and deployed contracts + _state = state.copy() # copy the state + cached_foo = cached_contract(_state, foo_cls, foo) # cache contracts + return cached_foo # return cached contracts ``` ## MockSigner diff --git a/src/openzeppelin/token/erc20/ERC20_Upgradeable.cairo b/src/openzeppelin/token/erc20/ERC20_Upgradeable.cairo index 2ec4bc77d..481c32f89 100644 --- a/src/openzeppelin/token/erc20/ERC20_Upgradeable.cairo +++ b/src/openzeppelin/token/erc20/ERC20_Upgradeable.cairo @@ -42,7 +42,7 @@ func upgrade{ range_check_ptr }(new_implementation: felt): Proxy.assert_only_admin() - Proxy._set_implementation(new_implementation) + Proxy._set_implementation_hash(new_implementation) return () end diff --git a/src/openzeppelin/upgrades/Proxy.cairo b/src/openzeppelin/upgrades/Proxy.cairo index bfd672e93..062fe95aa 100644 --- a/src/openzeppelin/upgrades/Proxy.cairo +++ b/src/openzeppelin/upgrades/Proxy.cairo @@ -1,10 +1,14 @@ # SPDX-License-Identifier: MIT -# OpenZeppelin Contracts for Cairo v0.1.0 (upgrades/Proxy.cairo) +# OpenZeppelin Contracts for Cairo v0.2.0 (upgrades/Proxy.cairo) %lang starknet +#%builtins pedersen range_check bitwise from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.starknet.common.syscalls import delegate_l1_handler, delegate_call +from starkware.starknet.common.syscalls import ( + library_call, + library_call_l1_handler +) from openzeppelin.upgrades.library import Proxy # @@ -16,8 +20,8 @@ func constructor{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr - }(implementation_address: felt): - Proxy._set_implementation(implementation_address) + }(implementation_hash: felt): + Proxy._set_implementation_hash(implementation_hash) return () end @@ -40,18 +44,18 @@ func __default__{ retdata_size: felt, retdata: felt* ): - let (address) = Proxy.get_implementation() + let (class_hash) = Proxy.get_implementation_hash() - let (retdata_size: felt, retdata: felt*) = delegate_call( - contract_address=address, + let (retdata_size: felt, retdata: felt*) = library_call( + class_hash=class_hash, function_selector=selector, calldata_size=calldata_size, - calldata=calldata + calldata=calldata, ) - return (retdata_size=retdata_size, retdata=retdata) end + @l1_handler @raw_input func __l1_default__{ @@ -63,14 +67,13 @@ func __l1_default__{ calldata_size: felt, calldata: felt* ): - let (address) = Proxy.get_implementation() + let (class_hash) = Proxy.get_implementation_hash() - delegate_l1_handler( - contract_address=address, + library_call_l1_handler( + class_hash=class_hash, function_selector=selector, calldata_size=calldata_size, - calldata=calldata + calldata=calldata, ) - return () end diff --git a/src/openzeppelin/upgrades/library.cairo b/src/openzeppelin/upgrades/library.cairo index d0d8e567d..d27e03680 100644 --- a/src/openzeppelin/upgrades/library.cairo +++ b/src/openzeppelin/upgrades/library.cairo @@ -1,11 +1,12 @@ # SPDX-License-Identifier: MIT -# OpenZeppelin Contracts for Cairo v0.1.0 (upgrades/library.cairo) +# OpenZeppelin Contracts for Cairo v0.2.0 (upgrades/library.cairo) %lang starknet from starkware.starknet.common.syscalls import get_caller_address from starkware.cairo.common.cairo_builtins import HashBuiltin from starkware.cairo.common.bool import TRUE, FALSE +from starkware.cairo.common.math import assert_not_zero # # Events @@ -15,12 +16,16 @@ from starkware.cairo.common.bool import TRUE, FALSE func Upgraded(implementation: felt): end +@event +func AdminChanged(previousAdmin: felt, newAdmin: felt): +end + # # Storage variables # @storage_var -func Proxy_implementation_address() -> (implementation_address: felt): +func Proxy_implementation_hash() -> (class_hash: felt): end @storage_var @@ -31,12 +36,7 @@ end func Proxy_initialized() -> (initialized: felt): end -# -# Initializer -# - namespace Proxy: - # # Initializer # @@ -52,34 +52,24 @@ namespace Proxy: end Proxy_initialized.write(TRUE) - Proxy_admin.write(proxy_admin) + _set_admin(proxy_admin) return () end # - # Upgrades - # - - func _set_implementation{ - syscall_ptr: felt*, - pedersen_ptr: HashBuiltin*, - range_check_ptr - }(new_implementation: felt): - Proxy_implementation_address.write(new_implementation) - Upgraded.emit(new_implementation) - return () - end - - # - # Setters + # Guards # - func _set_admin{ + func assert_only_admin{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr - }(new_admin: felt): - Proxy_admin.write(new_admin) + }(): + let (caller) = get_caller_address() + let (admin) = Proxy_admin.read() + with_attr error_message("Proxy: caller is not admin"): + assert admin = caller + end return () end @@ -87,12 +77,12 @@ namespace Proxy: # Getters # - func get_implementation{ + func get_implementation_hash{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr }() -> (implementation: felt): - let (implementation) = Proxy_implementation_address.read() + let (implementation) = Proxy_implementation_hash.read() return (implementation) end @@ -106,19 +96,33 @@ namespace Proxy: end # - # Guards + # Unprotected # - func assert_only_admin{ + func _set_admin{ syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr - }(): - let (caller) = get_caller_address() - let (admin) = Proxy_admin.read() - with_attr error_message("Proxy: caller is not admin"): - assert admin = caller + }(new_admin: felt): + let (previous_admin) = get_admin() + Proxy_admin.write(new_admin) + AdminChanged.emit(previous_admin, new_admin) + return () + end + + # + # Upgrade + # + + func _set_implementation_hash{ + syscall_ptr: felt*, + pedersen_ptr: HashBuiltin*, + range_check_ptr + }(new_implementation: felt): + with_attr error_message("Proxy: implementation hash cannot be zero"): + Proxy_implementation_hash.write(new_implementation) end + Upgraded.emit(new_implementation) return () end diff --git a/tests/access/test_Ownable.py b/tests/access/test_Ownable.py index bf713d6ae..f23bd6501 100644 --- a/tests/access/test_Ownable.py +++ b/tests/access/test_Ownable.py @@ -4,7 +4,7 @@ MockSigner, ZERO_ADDRESS, assert_event_emitted, - get_contract_def, + get_contract_class, cached_contract ) @@ -13,35 +13,35 @@ @pytest.fixture(scope='module') -def contract_defs(): +def contract_classes(): return ( - get_contract_def('openzeppelin/account/Account.cairo'), - get_contract_def('tests/mocks/Ownable.cairo') + get_contract_class('openzeppelin/account/Account.cairo'), + get_contract_class('tests/mocks/Ownable.cairo') ) @pytest.fixture(scope='module') -async def ownable_init(contract_defs): - account_def, ownable_def = contract_defs +async def ownable_init(contract_classes): + account_cls, ownable_cls = contract_classes starknet = await Starknet.empty() owner = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) ownable = await starknet.deploy( - contract_def=ownable_def, + contract_class=ownable_cls, constructor_calldata=[owner.contract_address] ) return starknet.state, ownable, owner @pytest.fixture -def ownable_factory(contract_defs, ownable_init): - account_def, ownable_def = contract_defs +def ownable_factory(contract_classes, ownable_init): + account_cls, ownable_cls = contract_classes state, ownable, owner = ownable_init _state = state.copy() - owner = cached_contract(_state, account_def, owner) - ownable = cached_contract(_state, ownable_def, ownable) + owner = cached_contract(_state, account_cls, owner) + ownable = cached_contract(_state, ownable_cls, ownable) return ownable, owner diff --git a/tests/account/test_Account.py b/tests/account/test_Account.py index c322562e5..1b5705486 100644 --- a/tests/account/test_Account.py +++ b/tests/account/test_Account.py @@ -1,8 +1,6 @@ import pytest from starkware.starknet.testing.starknet import Starknet -from starkware.starkware_utils.error_handling import StarkException -from starkware.starknet.definitions.error_codes import StarknetErrorCode -from utils import MockSigner, assert_revert, get_contract_def, cached_contract, TRUE +from utils import MockSigner, assert_revert, get_contract_class, cached_contract, TRUE signer = MockSigner(123456789987654321) @@ -12,37 +10,37 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - init_def = get_contract_def("tests/mocks/Initializable.cairo") - attacker_def = get_contract_def("tests/mocks/account_reentrancy.cairo") +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + init_cls = get_contract_class("tests/mocks/Initializable.cairo") + attacker_cls = get_contract_class("tests/mocks/account_reentrancy.cairo") - return account_def, init_def, attacker_def + return account_cls, init_cls, attacker_cls @pytest.fixture(scope='module') -async def account_init(contract_defs): - account_def, init_def, attacker_def = contract_defs +async def account_init(contract_classes): + account_cls, init_cls, attacker_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) initializable1 = await starknet.deploy( - contract_def=init_def, + contract_class=init_cls, constructor_calldata=[], ) initializable2 = await starknet.deploy( - contract_def=init_def, + contract_class=init_cls, constructor_calldata=[], ) attacker = await starknet.deploy( - contract_def=attacker_def, + contract_class=attacker_cls, constructor_calldata=[], ) @@ -50,15 +48,15 @@ async def account_init(contract_defs): @pytest.fixture -def account_factory(contract_defs, account_init): - account_def, init_def, attacker_def = contract_defs +def account_factory(contract_classes, account_init): + account_cls, init_cls, attacker_cls = contract_classes state, account1, account2, initializable1, initializable2, attacker = account_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - initializable1 = cached_contract(_state, init_def, initializable1) - initializable2 = cached_contract(_state, init_def, initializable2) - attacker = cached_contract(_state, attacker_def, attacker) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + initializable1 = cached_contract(_state, init_cls, initializable1) + initializable2 = cached_contract(_state, init_cls, initializable2) + attacker = cached_contract(_state, attacker_cls, attacker) return account1, account2, initializable1, initializable2, attacker diff --git a/tests/account/test_AddressRegistry.py b/tests/account/test_AddressRegistry.py index a9e1794d8..2e8c6c600 100644 --- a/tests/account/test_AddressRegistry.py +++ b/tests/account/test_AddressRegistry.py @@ -1,6 +1,6 @@ import pytest from starkware.starknet.testing.starknet import Starknet -from utils import MockSigner, contract_path +from utils import MockSigner, get_contract_class, cached_contract signer = MockSigner(123456789987654321) @@ -9,39 +9,54 @@ @pytest.fixture(scope='module') -async def account_factory(): +async def registry_factory(): + # contract classes + registry_cls = get_contract_class("openzeppelin/account/AddressRegistry.cairo") + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + + # deployments starknet = await Starknet.empty() - registry = await starknet.deploy( - contract_path("openzeppelin/account/AddressRegistry.cairo") - ) account = await starknet.deploy( - contract_path("openzeppelin/account/Account.cairo"), + contract_class=account_cls, constructor_calldata=[signer.public_key] ) + registry = await starknet.deploy( + contract_class=registry_cls, + constructor_calldata=[] + ) + + # cache contracts + state = starknet.state.copy() + account = cached_contract(state, account_cls, account) + registry = cached_contract(state, registry_cls, registry) - return starknet, account, registry + return account, registry @pytest.mark.asyncio -async def test_set_address(account_factory): - _, account, registry = account_factory +async def test_set_address(registry_factory): + account, registry = registry_factory - await signer.send_transaction(account, registry.contract_address, 'set_L1_address', [L1_ADDRESS]) + await signer.send_transaction( + account, registry.contract_address, 'set_L1_address', [L1_ADDRESS] + ) execution_info = await registry.get_L1_address(account.contract_address).call() assert execution_info.result == (L1_ADDRESS,) @pytest.mark.asyncio -async def test_update_address(account_factory): - _, account, registry = account_factory - - await signer.send_transaction(account, registry.contract_address, 'set_L1_address', [L1_ADDRESS]) +async def test_update_address(registry_factory): + account, registry = registry_factory + await signer.send_transaction( + account, registry.contract_address, 'set_L1_address', [L1_ADDRESS] + ) execution_info = await registry.get_L1_address(account.contract_address).call() assert execution_info.result == (L1_ADDRESS,) # set new address - await signer.send_transaction(account, registry.contract_address, 'set_L1_address', [ANOTHER_ADDRESS]) - + await signer.send_transaction( + account, registry.contract_address, 'set_L1_address', [ANOTHER_ADDRESS] + ) execution_info = await registry.get_L1_address(account.contract_address).call() assert execution_info.result == (ANOTHER_ADDRESS,) diff --git a/tests/introspection/test_ERC165.py b/tests/introspection/test_ERC165.py index 28a2db140..b7c712a38 100644 --- a/tests/introspection/test_ERC165.py +++ b/tests/introspection/test_ERC165.py @@ -1,6 +1,12 @@ import pytest from starkware.starknet.testing.starknet import Starknet -from utils import assert_revert, contract_path +from utils import ( + assert_revert, + get_contract_class, + cached_contract, + TRUE, + FALSE +) # interface ids @@ -11,48 +17,54 @@ @pytest.fixture(scope='module') async def erc165_factory(): + # class + erc165_cls = get_contract_class("tests/mocks/ERC165.cairo") + + # deployment starknet = await Starknet.empty() - contract = await starknet.deploy( - contract_path("tests/mocks/ERC165.cairo") - ) - return contract + erc165 = await starknet.deploy(contract_class=erc165_cls) + + # cache + state = starknet.state.copy() + erc165 = cached_contract(state, erc165_cls, erc165) + return erc165 @pytest.mark.asyncio async def test_165_interface(erc165_factory): - contract = erc165_factory + erc165 = erc165_factory - execution_info = await contract.supportsInterface(ERC165_ID).call() - assert execution_info.result == (1,) + execution_info = await erc165.supportsInterface(ERC165_ID).call() + assert execution_info.result == (TRUE,) @pytest.mark.asyncio async def test_invalid_id(erc165_factory): - contract = erc165_factory + erc165 = erc165_factory - execution_info = await contract.supportsInterface(INVALID_ID).call() - assert execution_info.result == (0,) + execution_info = await erc165.supportsInterface(INVALID_ID).call() + assert execution_info.result == (FALSE,) @pytest.mark.asyncio async def test_register_interface(erc165_factory): - contract = erc165_factory + erc165 = erc165_factory - execution_info = await contract.supportsInterface(OTHER_ID).call() - assert execution_info.result == (0,) + execution_info = await erc165.supportsInterface(OTHER_ID).call() + assert execution_info.result == (FALSE,) # register interface - await contract.registerInterface(OTHER_ID).invoke() + await erc165.registerInterface(OTHER_ID).invoke() - execution_info = await contract.supportsInterface(OTHER_ID).call() - assert execution_info.result == (1,) + execution_info = await erc165.supportsInterface(OTHER_ID).call() + assert execution_info.result == (TRUE,) @pytest.mark.asyncio async def test_register_invalid_interface(erc165_factory): - contract = erc165_factory + erc165 = erc165_factory await assert_revert( - contract.registerInterface(INVALID_ID).invoke(), + erc165.registerInterface(INVALID_ID).invoke(), reverted_with="ERC165: invalid interface id" ) diff --git a/tests/mocks/proxiable_implementation.cairo b/tests/mocks/proxiable_implementation.cairo index 2f3e9b933..69b52119e 100644 --- a/tests/mocks/proxiable_implementation.cairo +++ b/tests/mocks/proxiable_implementation.cairo @@ -1,10 +1,8 @@ # SPDX-License-Identifier: MIT %lang starknet -%builtins pedersen range_check from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.uint256 import Uint256 from openzeppelin.upgrades.library import Proxy @@ -35,7 +33,7 @@ end # @view -func get_value{ +func getValue{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -45,17 +43,7 @@ func get_value{ end @view -func get_implementation{ - syscall_ptr : felt*, - pedersen_ptr : HashBuiltin*, - range_check_ptr - }() -> (address: felt): - let (address) = Proxy.get_implementation() - return (address) -end - -@view -func get_admin{ +func getAdmin{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -69,7 +57,7 @@ end # @external -func set_value{ +func setValue{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -77,3 +65,14 @@ func set_value{ value.write(val) return () end + +@external +func setAdmin{ + syscall_ptr : felt*, + pedersen_ptr : HashBuiltin*, + range_check_ptr + }(address: felt): + Proxy.assert_only_admin() + Proxy._set_admin(address) + return () +end diff --git a/tests/mocks/upgrades_v1_mock.cairo b/tests/mocks/upgrades_v1_mock.cairo index 94b480ab8..15f245407 100644 --- a/tests/mocks/upgrades_v1_mock.cairo +++ b/tests/mocks/upgrades_v1_mock.cairo @@ -1,10 +1,8 @@ # SPDX-License-Identifier: MIT %lang starknet -%builtins pedersen range_check from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.uint256 import Uint256 from openzeppelin.upgrades.library import Proxy @@ -41,7 +39,7 @@ func upgrade{ range_check_ptr }(new_implementation: felt): Proxy.assert_only_admin() - Proxy._set_implementation(new_implementation) + Proxy._set_implementation_hash(new_implementation) return () end @@ -50,7 +48,7 @@ end # @view -func get_value_1{ +func getValue1{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -64,7 +62,7 @@ end # @external -func set_value_1{ +func setValue1{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr diff --git a/tests/mocks/upgrades_v2_mock.cairo b/tests/mocks/upgrades_v2_mock.cairo index a44687d64..ac585dd78 100644 --- a/tests/mocks/upgrades_v2_mock.cairo +++ b/tests/mocks/upgrades_v2_mock.cairo @@ -1,10 +1,8 @@ # SPDX-License-Identifier: MIT %lang starknet -%builtins pedersen range_check from starkware.cairo.common.cairo_builtins import HashBuiltin -from starkware.cairo.common.uint256 import Uint256 from openzeppelin.upgrades.library import Proxy @@ -45,7 +43,7 @@ func upgrade{ range_check_ptr }(new_implementation: felt): Proxy.assert_only_admin() - Proxy._set_implementation(new_implementation) + Proxy._set_implementation_hash(new_implementation) return () end @@ -54,7 +52,7 @@ end # @view -func get_value_1{ +func getValue1{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -64,7 +62,7 @@ func get_value_1{ end @view -func get_value_2{ +func getValue2{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -74,17 +72,17 @@ func get_value_2{ end @view -func get_implementation{ +func getImplementationHash{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr }() -> (address: felt): - let (address) = Proxy.get_implementation() + let (address) = Proxy.get_implementation_hash() return (address) end @view -func get_admin{ +func getAdmin{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -98,7 +96,7 @@ end # @external -func set_value_1{ +func setValue1{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -108,7 +106,7 @@ func set_value_1{ end @external -func set_value_2{ +func setValue2{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr @@ -118,7 +116,7 @@ func set_value_2{ end @external -func set_admin{ +func setAdmin{ syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr diff --git a/tests/security/test_pausable.py b/tests/security/test_pausable.py index caf53336f..9a3c80808 100644 --- a/tests/security/test_pausable.py +++ b/tests/security/test_pausable.py @@ -2,29 +2,30 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( TRUE, FALSE, assert_revert, assert_event_emitted, - get_contract_def, cached_contract, MockSigner + get_contract_class, cached_contract, MockSigner ) signer = MockSigner(12345678987654321) @pytest.fixture async def pausable_factory(): - pausable_def = get_contract_def("tests/mocks/Pausable.cairo") - account_def = get_contract_def("openzeppelin/account/Account.cairo") + # class + pausable_cls = get_contract_class("tests/mocks/Pausable.cairo") + account_cls = get_contract_class("openzeppelin/account/Account.cairo") starknet = await Starknet.empty() pausable = await starknet.deploy( - contract_def=pausable_def, + contract_class=pausable_cls, constructor_calldata=[] ) account = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) state = starknet.state.copy() - pausable = cached_contract(state, pausable_def, pausable) - account = cached_contract(state, account_def, account) + pausable = cached_contract(state, pausable_cls, pausable) + account = cached_contract(state, account_cls, account) return pausable, account diff --git a/tests/security/test_reentrancy.py b/tests/security/test_reentrancy.py index 4b1edcfdc..292c523ec 100644 --- a/tests/security/test_reentrancy.py +++ b/tests/security/test_reentrancy.py @@ -1,5 +1,4 @@ import pytest -import asyncio from starkware.starknet.testing.starknet import Starknet from utils import ( assert_revert @@ -10,13 +9,16 @@ @pytest.fixture(scope='module') async def reentrancy_mock(): starknet = await Starknet.empty() - contract = await starknet.deploy("tests/mocks/reentrancy_mock.cairo", constructor_calldata=[INITIAL_COUNTER]) + contract = await starknet.deploy( + "tests/mocks/reentrancy_mock.cairo", + constructor_calldata=[INITIAL_COUNTER] + ) return contract, starknet @pytest.mark.asyncio async def test_reentrancy_guard_deploy(reentrancy_mock): - contract, starknet = reentrancy_mock + contract, _ = reentrancy_mock response = await contract.current_count().call() assert response.result == (INITIAL_COUNTER,) @@ -33,7 +35,7 @@ async def test_reentrancy_guard_remote_callback(reentrancy_mock): @pytest.mark.asyncio async def test_reentrancy_guard_local_recursion(reentrancy_mock): - contract, starknet = reentrancy_mock + contract, _ = reentrancy_mock # should not allow local recursion await assert_revert( contract.count_local_recursive(10).invoke(), @@ -47,8 +49,8 @@ async def test_reentrancy_guard_local_recursion(reentrancy_mock): @pytest.mark.asyncio async def test_reentrancy_guard(reentrancy_mock): - contract, starknet = reentrancy_mock - #should allow non reentrant call + contract, _ = reentrancy_mock + # should allow non reentrant call await contract.callback().invoke() response = await contract.current_count().call() diff --git a/tests/token/erc20/test_ERC20.py b/tests/token/erc20/test_ERC20.py index 46b8af251..5e85f0f7b 100644 --- a/tests/token/erc20/test_ERC20.py +++ b/tests/token/erc20/test_ERC20.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, to_uint, add_uint, sub_uint, str_to_felt, MAX_UINT256, - ZERO_ADDRESS, INVALID_UINT256, TRUE, get_contract_def, cached_contract, + ZERO_ADDRESS, INVALID_UINT256, TRUE, get_contract_class, cached_contract, assert_revert, assert_event_emitted, contract_path ) @@ -20,28 +20,28 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc20_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc20_cls = get_contract_class( 'openzeppelin/token/erc20/ERC20.cairo') - return account_def, erc20_def + return account_cls, erc20_cls @pytest.fixture(scope='module') -async def erc20_init(contract_defs): - account_def, erc20_def = contract_defs +async def erc20_init(contract_classes): + account_cls, erc20_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc20 = await starknet.deploy( - contract_def=erc20_def, + contract_class=erc20_cls, constructor_calldata=[ NAME, SYMBOL, @@ -59,13 +59,13 @@ async def erc20_init(contract_defs): @pytest.fixture -def erc20_factory(contract_defs, erc20_init): - account_def, erc20_def = contract_defs +def erc20_factory(contract_classes, erc20_init): + account_cls, erc20_cls = contract_classes state, account1, account2, erc20 = erc20_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - erc20 = cached_contract(_state, erc20_def, erc20) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + erc20 = cached_contract(_state, erc20_cls, erc20) return erc20, account1, account2 diff --git a/tests/token/erc20/test_ERC20_Burnable_mock.py b/tests/token/erc20/test_ERC20_Burnable_mock.py index c0498d4c0..caff5f936 100644 --- a/tests/token/erc20/test_ERC20_Burnable_mock.py +++ b/tests/token/erc20/test_ERC20_Burnable_mock.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, to_uint, add_uint, sub_uint, str_to_felt, ZERO_ADDRESS, INVALID_UINT256, - get_contract_def, cached_contract, assert_revert, assert_event_emitted, + get_contract_class, cached_contract, assert_revert, assert_event_emitted, ) signer = MockSigner(123456789987654321) @@ -17,24 +17,24 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc20_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc20_cls = get_contract_class( 'tests/mocks/ERC20_Burnable_mock.cairo') - return account_def, erc20_def + return account_cls, erc20_cls @pytest.fixture(scope='module') -async def erc20_init(contract_defs): - account_def, erc20_def = contract_defs +async def erc20_init(contract_classes): + account_cls, erc20_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc20 = await starknet.deploy( - contract_def=erc20_def, + contract_class=erc20_cls, constructor_calldata=[ NAME, SYMBOL, @@ -51,12 +51,12 @@ async def erc20_init(contract_defs): @pytest.fixture -def erc20_factory(contract_defs, erc20_init): - account_def, erc20_def = contract_defs +def erc20_factory(contract_classes, erc20_init): + account_cls, erc20_cls = contract_classes state, account1, erc20 = erc20_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - erc20 = cached_contract(_state, erc20_def, erc20) + account1 = cached_contract(_state, account_cls, account1) + erc20 = cached_contract(_state, erc20_cls, erc20) return erc20, account1 diff --git a/tests/token/erc20/test_ERC20_Mintable.py b/tests/token/erc20/test_ERC20_Mintable.py index dea36dc53..94689750b 100644 --- a/tests/token/erc20/test_ERC20_Mintable.py +++ b/tests/token/erc20/test_ERC20_Mintable.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, to_uint, add_uint, sub_uint, str_to_felt, - MAX_UINT256, ZERO_ADDRESS, INVALID_UINT256, get_contract_def, + MAX_UINT256, ZERO_ADDRESS, INVALID_UINT256, get_contract_class, cached_contract, assert_revert, assert_event_emitted ) @@ -19,24 +19,24 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc20_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc20_cls = get_contract_class( 'openzeppelin/token/erc20/ERC20_Mintable.cairo') - return account_def, erc20_def + return account_cls, erc20_cls @pytest.fixture(scope='module') -async def erc20_init(contract_defs): - account_def, erc20_def = contract_defs +async def erc20_init(contract_classes): + account_cls, erc20_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc20 = await starknet.deploy( - contract_def=erc20_def, + contract_class=erc20_cls, constructor_calldata=[ NAME, SYMBOL, @@ -54,12 +54,12 @@ async def erc20_init(contract_defs): @pytest.fixture -def token_factory(contract_defs, erc20_init): - account_def, erc20_def = contract_defs +def token_factory(contract_classes, erc20_init): + account_cls, erc20_cls = contract_classes state, account1, erc20 = erc20_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - erc20 = cached_contract(_state, erc20_def, erc20) + account1 = cached_contract(_state, account_cls, account1) + erc20 = cached_contract(_state, erc20_cls, erc20) return erc20, account1 diff --git a/tests/token/erc20/test_ERC20_Pausable.py b/tests/token/erc20/test_ERC20_Pausable.py index 4816b8970..a890e9ae9 100644 --- a/tests/token/erc20/test_ERC20_Pausable.py +++ b/tests/token/erc20/test_ERC20_Pausable.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, TRUE, FALSE, to_uint, str_to_felt, assert_revert, - get_contract_def, cached_contract + get_contract_class, cached_contract ) signer = MockSigner(123456789987654321) @@ -16,28 +16,28 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc20_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc20_cls = get_contract_class( 'openzeppelin/token/erc20/ERC20_Pausable.cairo') - return account_def, erc20_def + return account_cls, erc20_cls @pytest.fixture(scope='module') -async def erc20_init(contract_defs): - account_def, erc20_def = contract_defs +async def erc20_init(contract_classes): + account_cls, erc20_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc20 = await starknet.deploy( - contract_def=erc20_def, + contract_class=erc20_cls, constructor_calldata=[ NAME, SYMBOL, @@ -56,13 +56,13 @@ async def erc20_init(contract_defs): @pytest.fixture -def token_factory(contract_defs, erc20_init): - account_def, erc20_def = contract_defs +def token_factory(contract_classes, erc20_init): + account_cls, erc20_cls = contract_classes state, account1, account2, erc20 = erc20_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - erc20 = cached_contract(_state, erc20_def, erc20) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + erc20 = cached_contract(_state, erc20_cls, erc20) return erc20, account1, account2 diff --git a/tests/token/erc20/test_ERC20_Upgradeable.py b/tests/token/erc20/test_ERC20_Upgradeable.py index eb49d913f..65aed3386 100644 --- a/tests/token/erc20/test_ERC20_Upgradeable.py +++ b/tests/token/erc20/test_ERC20_Upgradeable.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, to_uint, sub_uint, str_to_felt, assert_revert, - get_contract_def, cached_contract + get_contract_class, cached_contract ) @@ -17,38 +17,36 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - token_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + token_cls = get_contract_class( 'openzeppelin/token/erc20/ERC20_Upgradeable.cairo') - proxy_def = get_contract_def('openzeppelin/upgrades/Proxy.cairo') + proxy_cls = get_contract_class('openzeppelin/upgrades/Proxy.cairo') - return account_def, token_def, proxy_def + return account_cls, token_cls, proxy_cls @pytest.fixture(scope='module') -async def token_init(contract_defs): - account_def, token_def, proxy_def = contract_defs +async def token_init(contract_classes): + account_cls, token_cls, proxy_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) - token_v1 = await starknet.deploy( - contract_def=token_def, - constructor_calldata=[] + token_v1 = await starknet.declare( + contract_class=token_cls, ) - token_v2 = await starknet.deploy( - contract_def=token_def, - constructor_calldata=[] + token_v2 = await starknet.declare( + contract_class=token_cls, ) proxy = await starknet.deploy( - contract_def=proxy_def, - constructor_calldata=[token_v1.contract_address] + contract_class=proxy_cls, + constructor_calldata=[token_v1.class_hash] ) return ( starknet.state, @@ -61,134 +59,109 @@ async def token_init(contract_defs): @pytest.fixture -def token_factory(contract_defs, token_init): - account_def, token_def, proxy_def = contract_defs +def token_factory(contract_classes, token_init): + account_cls, _, proxy_cls = contract_classes state, account1, account2, token_v1, token_v2, proxy = token_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - token_v1 = cached_contract(_state, token_def, token_v1) - token_v2 = cached_contract(_state, token_def, token_v2) - proxy = cached_contract(_state, proxy_def, proxy) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + proxy = cached_contract(_state, proxy_cls, proxy) - return account1, account2, token_v1, token_v2, proxy + return account1, account2, proxy, token_v1, token_v2 @pytest.fixture async def after_initializer(token_factory): - admin, other, token_v1, token_v2, proxy = token_factory + admin, other, proxy, token_v1, token_v2 = token_factory # initialize await signer.send_transaction( admin, proxy.contract_address, 'initializer', [ - NAME, - SYMBOL, - DECIMALS, - *INIT_SUPPLY, - admin.contract_address, - admin.contract_address + NAME, # name + SYMBOL, # symbol + DECIMALS, # decimals + *INIT_SUPPLY, # initial supply + admin.contract_address, # recipient + admin.contract_address # admin ] ) - return admin, other, token_v1, token_v2, proxy + return admin, other, proxy, token_v1, token_v2 @pytest.mark.asyncio async def test_constructor(token_factory): - admin, _, _, _, proxy = token_factory + admin, _, proxy, *_ = token_factory await signer.send_transaction( admin, proxy.contract_address, 'initializer', [ - NAME, - SYMBOL, - DECIMALS, - *INIT_SUPPLY, - admin.contract_address, - admin.contract_address + NAME, # name + SYMBOL, # symbol + DECIMALS, # decimals + *INIT_SUPPLY, # initial supply + admin.contract_address, # recipient + admin.contract_address # admin ]) - # check name - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'name', []) - assert execution_info.result.response == [NAME] - - # check symbol - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'symbol', [] - ) - assert execution_info.result.response == [SYMBOL] - - # check decimals - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'decimals', [] + execution_info = await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'name', []), + (proxy.contract_address, 'symbol', []), + (proxy.contract_address, 'decimals', []), + (proxy.contract_address, 'totalSupply', []) + ] ) - assert execution_info.result.response == [DECIMALS] - # check total supply - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'totalSupply', [] - ) - assert execution_info.result.response == [*INIT_SUPPLY] + # check values + expected = [NAME, SYMBOL, DECIMALS, *INIT_SUPPLY] + assert execution_info.result.response == expected @pytest.mark.asyncio async def test_upgrade(after_initializer): - admin, _, _, token_v2, proxy = after_initializer + admin, _, proxy, _, token_v2 = after_initializer # transfer await signer.send_transaction( - admin, proxy.contract_address, 'transfer', [ - USER, - *AMOUNT - ] + admin, proxy.contract_address, 'transfer', [USER, *AMOUNT] ) # upgrade await signer.send_transaction( - admin, proxy.contract_address, 'upgrade', [ - token_v2.contract_address - ] + admin, proxy.contract_address, 'upgrade', [token_v2.class_hash] ) - # check admin balance - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'balanceOf', [ - admin.contract_address + # fetch values + execution_info = await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'balanceOf', [admin.contract_address]), + (proxy.contract_address, 'balanceOf', [USER]), + (proxy.contract_address, 'totalSupply', []) ] ) - assert execution_info.result.response == [*sub_uint(INIT_SUPPLY, AMOUNT)] - # check USER balance - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'balanceOf', [ - USER - ] - ) - assert execution_info.result.response == [*AMOUNT] + expected = [ + *sub_uint(INIT_SUPPLY, AMOUNT), # balanceOf admin + *AMOUNT, # balanceOf USER + *INIT_SUPPLY # totalSupply + ] - # check total supply - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'totalSupply', [] - ) - assert execution_info.result.response == [*INIT_SUPPLY] + assert execution_info.result.response == expected @pytest.mark.asyncio async def test_upgrade_from_nonadmin(after_initializer): - admin, non_admin, _, token_v2, proxy = after_initializer + admin, non_admin, proxy, _, token_v2 = after_initializer # should revert - await assert_revert( - signer.send_transaction( - non_admin, proxy.contract_address, 'upgrade', [ - token_v2.contract_address - ] - ) + await assert_revert(signer.send_transaction( + non_admin, proxy.contract_address, 'upgrade', [token_v2.class_hash]), + reverted_with="Proxy: caller is not admin" ) # should upgrade from admin await signer.send_transaction( - admin, proxy.contract_address, 'upgrade', [ - token_v2.contract_address - ] + admin, proxy.contract_address, 'upgrade', [token_v2.class_hash] ) diff --git a/tests/token/erc721/test_ERC721_Mintable_Burnable.py b/tests/token/erc721/test_ERC721_Mintable_Burnable.py index 9c4a6ba5c..72a265afb 100644 --- a/tests/token/erc721/test_ERC721_Mintable_Burnable.py +++ b/tests/token/erc721/test_ERC721_Mintable_Burnable.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, str_to_felt, ZERO_ADDRESS, TRUE, FALSE, assert_revert, INVALID_UINT256, - assert_event_emitted, get_contract_def, cached_contract, to_uint, sub_uint, add_uint + assert_event_emitted, get_contract_class, cached_contract, to_uint, sub_uint, add_uint ) @@ -30,32 +30,32 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc721_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc721_cls = get_contract_class( 'openzeppelin/token/erc721/ERC721_Mintable_Burnable.cairo') - erc721_holder_def = get_contract_def( + erc721_holder_cls = get_contract_class( 'openzeppelin/token/erc721/utils/ERC721_Holder.cairo') - unsupported_def = get_contract_def( + unsupported_cls = get_contract_class( 'tests/mocks/Initializable.cairo') - return account_def, erc721_def, erc721_holder_def, unsupported_def + return account_cls, erc721_cls, erc721_holder_cls, unsupported_cls @pytest.fixture(scope='module') -async def erc721_init(contract_defs): - account_def, erc721_def, erc721_holder_def, unsupported_def = contract_defs +async def erc721_init(contract_classes): + account_cls, erc721_cls, erc721_holder_cls, unsupported_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc721 = await starknet.deploy( - contract_def=erc721_def, + contract_class=erc721_cls, constructor_calldata=[ str_to_felt("Non Fungible Token"), # name str_to_felt("NFT"), # ticker @@ -63,11 +63,11 @@ async def erc721_init(contract_defs): ] ) erc721_holder = await starknet.deploy( - contract_def=erc721_holder_def, + contract_class=erc721_holder_cls, constructor_calldata=[] ) unsupported = await starknet.deploy( - contract_def=unsupported_def, + contract_class=unsupported_cls, constructor_calldata=[] ) return ( @@ -81,15 +81,15 @@ async def erc721_init(contract_defs): @pytest.fixture -def erc721_factory(contract_defs, erc721_init): - account_def, erc721_def, erc721_holder_def, unsupported_def = contract_defs +def erc721_factory(contract_classes, erc721_init): + account_cls, erc721_cls, erc721_holder_cls, unsupported_cls = contract_classes state, account1, account2, erc721, erc721_holder, unsupported = erc721_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - erc721 = cached_contract(_state, erc721_def, erc721) - erc721_holder = cached_contract(_state, erc721_holder_def, erc721_holder) - unsupported = cached_contract(_state, unsupported_def, unsupported) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + erc721 = cached_contract(_state, erc721_cls, erc721) + erc721_holder = cached_contract(_state, erc721_holder_cls, erc721_holder) + unsupported = cached_contract(_state, unsupported_cls, unsupported) return erc721, account1, account2, erc721_holder, unsupported diff --git a/tests/token/erc721/test_ERC721_Mintable_Pausable.py b/tests/token/erc721/test_ERC721_Mintable_Pausable.py index 9f614d781..5249774f9 100644 --- a/tests/token/erc721/test_ERC721_Mintable_Pausable.py +++ b/tests/token/erc721/test_ERC721_Mintable_Pausable.py @@ -1,7 +1,7 @@ import pytest from starkware.starknet.testing.starknet import Starknet from utils import ( - MockSigner, str_to_felt, TRUE, FALSE, get_contract_def, cached_contract, + MockSigner, str_to_felt, TRUE, FALSE, get_contract_class, cached_contract, assert_revert, to_uint ) @@ -16,30 +16,30 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc721_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc721_cls = get_contract_class( 'openzeppelin/token/erc721/ERC721_Mintable_Pausable.cairo') - erc721_holder_def = get_contract_def( + erc721_holder_cls = get_contract_class( 'openzeppelin/token/erc721/utils/ERC721_Holder.cairo') - return account_def, erc721_def, erc721_holder_def + return account_cls, erc721_cls, erc721_holder_cls @pytest.fixture(scope='module') -async def erc721_init(contract_defs): - account_def, erc721_def, erc721_holder_def = contract_defs +async def erc721_init(contract_classes): + account_cls, erc721_cls, erc721_holder_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc721 = await starknet.deploy( - contract_def=erc721_def, + contract_class=erc721_cls, constructor_calldata=[ str_to_felt("Non Fungible Token"), # name str_to_felt("NFT"), # ticker @@ -47,7 +47,7 @@ async def erc721_init(contract_defs): ] ) erc721_holder = await starknet.deploy( - contract_def=erc721_holder_def, + contract_class=erc721_holder_cls, constructor_calldata=[] ) return ( @@ -60,14 +60,14 @@ async def erc721_init(contract_defs): @pytest.fixture -def erc721_factory(contract_defs, erc721_init): - account_def, erc721_def, erc721_holder_def = contract_defs +def erc721_factory(contract_classes, erc721_init): + account_cls, erc721_cls, erc721_holder_cls = contract_classes state, account1, account2, erc721, erc721_holder = erc721_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - erc721 = cached_contract(_state, erc721_def, erc721) - erc721_holder = cached_contract(_state, erc721_holder_def, erc721_holder) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + erc721 = cached_contract(_state, erc721_cls, erc721) + erc721_holder = cached_contract(_state, erc721_holder_cls, erc721_holder) return erc721, account1, account2, erc721_holder diff --git a/tests/token/erc721/test_ERC721_SafeMintable_mock.py b/tests/token/erc721/test_ERC721_SafeMintable_mock.py index 553108cb2..9e757a43e 100644 --- a/tests/token/erc721/test_ERC721_SafeMintable_mock.py +++ b/tests/token/erc721/test_ERC721_SafeMintable_mock.py @@ -2,7 +2,7 @@ from starkware.starknet.testing.starknet import Starknet from utils import ( MockSigner, str_to_felt, ZERO_ADDRESS, INVALID_UINT256, assert_revert, - assert_event_emitted, get_contract_def, cached_contract, to_uint + assert_event_emitted, get_contract_class, cached_contract, to_uint ) @@ -15,31 +15,31 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc721_def = get_contract_def('tests/mocks/ERC721_SafeMintable_mock.cairo') - erc721_holder_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc721_cls = get_contract_class('tests/mocks/ERC721_SafeMintable_mock.cairo') + erc721_holder_cls = get_contract_class( 'openzeppelin/token/erc721/utils/ERC721_Holder.cairo') - unsupported_def = get_contract_def( + unsupported_cls = get_contract_class( 'tests/mocks/Initializable.cairo') - return account_def, erc721_def, erc721_holder_def, unsupported_def + return account_cls, erc721_cls, erc721_holder_cls, unsupported_cls @pytest.fixture(scope='module') -async def erc721_init(contract_defs): - account_def, erc721_def, erc721_holder_def, unsupported_def = contract_defs +async def erc721_init(contract_classes): + account_cls, erc721_cls, erc721_holder_cls, unsupported_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc721 = await starknet.deploy( - contract_def=erc721_def, + contract_class=erc721_cls, constructor_calldata=[ str_to_felt("Non Fungible Token"), # name str_to_felt("NFT"), # ticker @@ -47,11 +47,11 @@ async def erc721_init(contract_defs): ] ) erc721_holder = await starknet.deploy( - contract_def=erc721_holder_def, + contract_class=erc721_holder_cls, constructor_calldata=[] ) unsupported = await starknet.deploy( - contract_def=unsupported_def, + contract_class=unsupported_cls, constructor_calldata=[] ) return ( @@ -65,15 +65,15 @@ async def erc721_init(contract_defs): @pytest.fixture -def erc721_factory(contract_defs, erc721_init): - account_def, erc721_def, erc721_holder_def, unsupported_def = contract_defs +def erc721_factory(contract_classes, erc721_init): + account_cls, erc721_cls, erc721_holder_cls, unsupported_cls = contract_classes state, account1, account2, erc721, erc721_holder, unsupported = erc721_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - erc721 = cached_contract(_state, erc721_def, erc721) - erc721_holder = cached_contract(_state, erc721_holder_def, erc721_holder) - unsupported = cached_contract(_state, unsupported_def, unsupported) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + erc721 = cached_contract(_state, erc721_cls, erc721) + erc721_holder = cached_contract(_state, erc721_holder_cls, erc721_holder) + unsupported = cached_contract(_state, unsupported_cls, unsupported) return erc721, account1, account2, erc721_holder, unsupported diff --git a/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py b/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py index c837103e8..e6d01ccc3 100644 --- a/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py +++ b/tests/token/erc721_enumerable/test_ERC721_Enumerable_Mintable_Burnable.py @@ -1,7 +1,7 @@ import pytest from starkware.starknet.testing.starknet import Starknet from utils import ( - MockSigner, str_to_felt, MAX_UINT256, get_contract_def, cached_contract, + MockSigner, str_to_felt, MAX_UINT256, get_contract_class, cached_contract, TRUE, assert_revert, to_uint, sub_uint, add_uint ) @@ -23,28 +23,28 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - erc721_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + erc721_cls = get_contract_class( 'openzeppelin/token/erc721_enumerable/ERC721_Enumerable_Mintable_Burnable.cairo') - return account_def, erc721_def + return account_cls, erc721_cls @pytest.fixture(scope='module') -async def erc721_init(contract_defs): - account_def, erc721_def = contract_defs +async def erc721_init(contract_classes): + account_cls, erc721_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) erc721 = await starknet.deploy( - contract_def=erc721_def, + contract_class=erc721_cls, constructor_calldata=[ str_to_felt("Non Fungible Token"), # name str_to_felt("NFT"), # ticker @@ -60,13 +60,13 @@ async def erc721_init(contract_defs): @pytest.fixture -def erc721_factory(contract_defs, erc721_init): - account_def, erc721_def = contract_defs +def erc721_factory(contract_classes, erc721_init): + account_cls, erc721_cls = contract_classes state, account1, account2, erc721 = erc721_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - erc721 = cached_contract(_state, erc721_def, erc721) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + erc721 = cached_contract(_state, erc721_cls, erc721) return erc721, account1, account2 diff --git a/tests/upgrades/test_Proxy.py b/tests/upgrades/test_Proxy.py index 0c8bf4211..4a9a94538 100644 --- a/tests/upgrades/test_Proxy.py +++ b/tests/upgrades/test_Proxy.py @@ -1,10 +1,14 @@ import pytest from starkware.starknet.testing.starknet import Starknet from utils import ( - MockSigner, assert_revert, get_contract_def, cached_contract + MockSigner, + assert_revert, + get_contract_class, + cached_contract, + assert_event_emitted, + assert_revert_entry_point ) - # random value VALUE = 123 @@ -12,98 +16,163 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - implementation_def = get_contract_def( +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + implementation_cls = get_contract_class( 'tests/mocks/proxiable_implementation.cairo' ) - proxy_def = get_contract_def('openzeppelin/upgrades/Proxy.cairo') + proxy_cls = get_contract_class('openzeppelin/upgrades/Proxy.cairo') - return account_def, implementation_def, proxy_def + return account_cls, implementation_cls, proxy_cls @pytest.fixture(scope='module') -async def proxy_init(contract_defs): - account_def, implementation_def, proxy_def = contract_defs +async def proxy_init(contract_classes): + account_cls, implementation_cls, proxy_cls = contract_classes starknet = await Starknet.empty() - account = await starknet.deploy( - contract_def=account_def, + account1 = await starknet.deploy( + contract_class=account_cls, + constructor_calldata=[signer.public_key] + ) + account2 = await starknet.deploy( + contract_class=account_cls, constructor_calldata=[signer.public_key] ) - implementation = await starknet.deploy( - contract_def=implementation_def, - constructor_calldata=[] + implementation_decl = await starknet.declare( + contract_class=implementation_cls ) proxy = await starknet.deploy( - contract_def=proxy_def, - constructor_calldata=[implementation.contract_address] + contract_class=proxy_cls, + constructor_calldata=[implementation_decl.class_hash] ) return ( starknet.state, - account, - implementation, + account1, + account2, proxy ) @pytest.fixture -def proxy_factory(contract_defs, proxy_init): - account_def, implementation_def, proxy_def = contract_defs - state, account, implementation, proxy = proxy_init +def proxy_factory(contract_classes, proxy_init): + account_cls, _, proxy_cls = contract_classes + state, account1, account2, proxy = proxy_init _state = state.copy() - account = cached_contract(_state, account_def, account) - implementation = cached_contract( - _state, - implementation_def, - implementation + admin = cached_contract(_state, account_cls, account1) + other = cached_contract(_state, account_cls, account2) + proxy = cached_contract(_state, proxy_cls, proxy) + + return admin, other, proxy + + +@pytest.fixture +async def after_initialized(proxy_factory): + admin, other, proxy = proxy_factory + + # initialize proxy + await signer.send_transaction( + admin, proxy.contract_address, 'initializer', [admin.contract_address] ) - proxy = cached_contract(_state, proxy_def, proxy) - return account, implementation, proxy + return admin, other, proxy +# +# initializer +# @pytest.mark.asyncio -async def test_constructor_sets_correct_implementation(proxy_factory): - account, implementation, proxy = proxy_factory +async def test_initializer(proxy_factory): + admin, _, proxy = proxy_factory + + await signer.send_transaction( + admin, proxy.contract_address, 'initializer', [admin.contract_address] + ) + # check admin is set execution_info = await signer.send_transaction( - account, proxy.contract_address, 'get_implementation', [] + admin, proxy.contract_address, 'getAdmin', [] ) - assert execution_info.result.response == [implementation.contract_address] + assert execution_info.result.response == [admin.contract_address] @pytest.mark.asyncio -async def test_initializer(proxy_factory): - account, _, proxy = proxy_factory +async def test_initializer_after_initialized(after_initialized): + admin, _, proxy = after_initialized - await signer.send_transaction( - account, proxy.contract_address, 'initializer', [ - account.contract_address] + await assert_revert(signer.send_transaction( + admin, proxy.contract_address, 'initializer', [admin.contract_address]), + reverted_with="Proxy: contract already initialized" + ) + +# +# set_admin +# + +@pytest.mark.asyncio +async def test_set_admin(after_initialized): + admin, _, proxy = after_initialized + + # set admin + tx_exec_info = await signer.send_transaction( + admin, proxy.contract_address, 'setAdmin', [VALUE] + ) + + # check event + assert_event_emitted( + tx_exec_info, + from_address=proxy.contract_address, + name='AdminChanged', + data=[ + admin.contract_address, # old admin + VALUE # new admin + ] + ) + + # check new admin + execution_info = await signer.send_transaction( + admin, proxy.contract_address, 'getAdmin', [] + ) + assert execution_info.result.response == [VALUE] + + +@pytest.mark.asyncio +async def test_set_admin_from_unauthorized(after_initialized): + _, non_admin, proxy = after_initialized + + # set admin + await assert_revert(signer.send_transaction( + non_admin, proxy.contract_address, 'setAdmin', [VALUE]), + reverted_with="Proxy: caller is not admin" ) +# +# fallback function +# @pytest.mark.asyncio async def test_default_fallback(proxy_factory): - account, _, proxy = proxy_factory + admin, _, proxy = proxy_factory # set value through proxy await signer.send_transaction( - account, proxy.contract_address, 'set_value', [VALUE] + admin, proxy.contract_address, 'setValue', [VALUE] ) # get value through proxy execution_info = execution_info = await signer.send_transaction( - account, proxy.contract_address, 'get_value', [] + admin, proxy.contract_address, 'getValue', [] ) assert execution_info.result.response == [VALUE] @pytest.mark.asyncio async def test_fallback_when_selector_does_not_exist(proxy_factory): - account, _, proxy = proxy_factory + admin, _, proxy = proxy_factory - await assert_revert( + # should fail with entry point error + await assert_revert_entry_point( signer.send_transaction( - account, proxy.contract_address, 'bad_selector', [] - ) + admin, proxy.contract_address, 'invalid_selector', [] + ), + invalid_selector='invalid_selector' ) diff --git a/tests/upgrades/test_upgrades.py b/tests/upgrades/test_upgrades.py index b5c94f21f..8d4af3573 100644 --- a/tests/upgrades/test_upgrades.py +++ b/tests/upgrades/test_upgrades.py @@ -1,7 +1,12 @@ import pytest from starkware.starknet.testing.starknet import Starknet from utils import ( - MockSigner, assert_revert, assert_event_emitted, get_contract_def, cached_contract + MockSigner, + assert_revert, + assert_revert_entry_point, + assert_event_emitted, + get_contract_class, + cached_contract ) @@ -13,94 +18,79 @@ @pytest.fixture(scope='module') -def contract_defs(): - account_def = get_contract_def('openzeppelin/account/Account.cairo') - v1_def = get_contract_def('tests/mocks/upgrades_v1_mock.cairo') - v2_def = get_contract_def('tests/mocks/upgrades_v2_mock.cairo') - proxy_def = get_contract_def('openzeppelin/upgrades/Proxy.cairo') +def contract_classes(): + account_cls = get_contract_class('openzeppelin/account/Account.cairo') + v1_cls = get_contract_class('tests/mocks/upgrades_v1_mock.cairo') + v2_cls = get_contract_class('tests/mocks/upgrades_v2_mock.cairo') + proxy_cls = get_contract_class('openzeppelin/upgrades/Proxy.cairo') - return account_def, v1_def, v2_def, proxy_def + return account_cls, v1_cls, v2_cls, proxy_cls @pytest.fixture(scope='module') -async def proxy_init(contract_defs): - account_def, dummy_v1_def, dummy_v2_def, proxy_def = contract_defs +async def proxy_init(contract_classes): + account_cls, v1_cls, v2_cls, proxy_cls = contract_classes starknet = await Starknet.empty() account1 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) account2 = await starknet.deploy( - contract_def=account_def, + contract_class=account_cls, constructor_calldata=[signer.public_key] ) - v1 = await starknet.deploy( - contract_def=dummy_v1_def, - constructor_calldata=[] + v1_decl = await starknet.declare( + contract_class=v1_cls, ) - v2 = await starknet.deploy( - contract_def=dummy_v2_def, - constructor_calldata=[] + v2_decl = await starknet.declare( + contract_class=v2_cls, ) proxy = await starknet.deploy( - contract_def=proxy_def, - constructor_calldata=[v1.contract_address] + contract_class=proxy_cls, + constructor_calldata=[v1_decl.class_hash] ) return ( starknet.state, account1, account2, - v1, - v2, + v1_decl, + v2_decl, proxy ) @pytest.fixture -def proxy_factory(contract_defs, proxy_init): - account_def, dummy_v1_def, dummy_v2_def, proxy_def = contract_defs - state, account1, account2, v1, v2, proxy = proxy_init +def proxy_factory(contract_classes, proxy_init): + account_cls, _, _, proxy_cls = contract_classes + state, account1, account2, v1_decl, v2_decl, proxy = proxy_init _state = state.copy() - account1 = cached_contract(_state, account_def, account1) - account2 = cached_contract(_state, account_def, account2) - v1 = cached_contract(_state, dummy_v1_def, v1) - v2 = cached_contract(_state, dummy_v2_def, v2) - proxy = cached_contract(_state, proxy_def, proxy) + account1 = cached_contract(_state, account_cls, account1) + account2 = cached_contract(_state, account_cls, account2) + proxy = cached_contract(_state, proxy_cls, proxy) - return account1, account2, v1, v2, proxy + return account1, account2, proxy, v1_decl, v2_decl @pytest.fixture async def after_upgrade(proxy_factory): - admin, other, v1, v2, proxy = proxy_factory - - # initialize - await signer.send_transaction( - admin, proxy.contract_address, 'initializer', [ - admin.contract_address + admin, other, proxy, v1_decl, v2_decl = proxy_factory + + # initialize, set value, and upgrade to v2 + await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'initializer', [admin.contract_address]), + (proxy.contract_address, 'setValue1', [VALUE_1]), + (proxy.contract_address, 'upgrade', [v2_decl.class_hash]) ] ) - # set value - await signer.send_transaction( - admin, proxy.contract_address, 'set_value_1', [ - VALUE_1 - ] - ) - - # upgrade - await signer.send_transaction( - admin, proxy.contract_address, 'upgrade', [ - v2.contract_address - ] - ) - - return admin, other, v1, v2, proxy + return admin, other, proxy, v1_decl, v2_decl @pytest.mark.asyncio async def test_initializer(proxy_factory): - admin, _, _, _, proxy = proxy_factory + admin, _, proxy, *_ = proxy_factory await signer.send_transaction( admin, proxy.contract_address, 'initializer', [ @@ -111,7 +101,7 @@ async def test_initializer(proxy_factory): @pytest.mark.asyncio async def test_initializer_already_initialized(proxy_factory): - admin, _, _, _, proxy = proxy_factory + admin, _, proxy, *_ = proxy_factory await signer.send_transaction( admin, proxy.contract_address, 'initializer', [ @@ -131,45 +121,40 @@ async def test_initializer_already_initialized(proxy_factory): @pytest.mark.asyncio async def test_upgrade(proxy_factory): - admin, _, _, v2, proxy = proxy_factory - - # initialize implementation - await signer.send_transaction( - admin, proxy.contract_address, 'initializer', [ - admin.contract_address - ] - ) - - # set value - await signer.send_transaction( - admin, proxy.contract_address, 'set_value_1', [ - VALUE_1 + admin, _, proxy, _, v2_decl = proxy_factory + + # initialize and set value + await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'initializer', [admin.contract_address]), + (proxy.contract_address, 'setValue1', [VALUE_1]), ] ) # check value execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'get_value_1', [] + admin, proxy.contract_address, 'getValue1', [] ) - assert execution_info.result.response == [VALUE_1, ] + assert execution_info.result.response == [VALUE_1] # upgrade await signer.send_transaction( admin, proxy.contract_address, 'upgrade', [ - v2.contract_address + v2_decl.class_hash ] ) # check value execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'get_value_1', [] + admin, proxy.contract_address, 'getValue1', [] ) - assert execution_info.result.response == [VALUE_1, ] + assert execution_info.result.response == [VALUE_1] @pytest.mark.asyncio async def test_upgrade_event(proxy_factory): - admin, _, _, v2, proxy = proxy_factory + admin, _, proxy, _, v2_decl = proxy_factory # initialize implementation await signer.send_transaction( @@ -181,7 +166,7 @@ async def test_upgrade_event(proxy_factory): # upgrade tx_exec_info = await signer.send_transaction( admin, proxy.contract_address, 'upgrade', [ - v2.contract_address + v2_decl.class_hash ] ) @@ -191,14 +176,14 @@ async def test_upgrade_event(proxy_factory): from_address=proxy.contract_address, name='Upgraded', data=[ - v2.contract_address + v2_decl.class_hash # new class hash ] ) @pytest.mark.asyncio async def test_upgrade_from_non_admin(proxy_factory): - admin, non_admin, _, v2, proxy = proxy_factory + admin, non_admin, proxy, _, v2_decl = proxy_factory # initialize implementation await signer.send_transaction( @@ -211,64 +196,137 @@ async def test_upgrade_from_non_admin(proxy_factory): await assert_revert( signer.send_transaction( non_admin, proxy.contract_address, 'upgrade', [ - v2.contract_address + v2_decl.class_hash ] ), reverted_with="Proxy: caller is not admin" ) -# Using `after_upgrade` fixture henceforth @pytest.mark.asyncio async def test_implementation_v2(after_upgrade): - admin, _, _, v2, proxy = after_upgrade - - # check implementation address - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'get_implementation', [] + admin, _, proxy, _, v2_decl = after_upgrade + + execution_info = await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'getImplementationHash', []), + (proxy.contract_address, 'getAdmin', []), + (proxy.contract_address, 'getValue1', []) + ] ) - assert execution_info.result.response == [v2.contract_address] - # check admin - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'get_admin', [] - ) - assert execution_info.result.response == [admin.contract_address] + expected = [ + v2_decl.class_hash, # getImplementationHash + admin.contract_address, # getAdmin + VALUE_1 # getValue1 + ] - # check value - execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'get_value_1', [] - ) - assert execution_info.result.response == [VALUE_1, ] + assert execution_info.result.response == expected +# +# v2 functions +# @pytest.mark.asyncio async def test_set_admin(after_upgrade): - admin, new_admin, _, _, proxy = after_upgrade + admin, new_admin, proxy, *_ = after_upgrade # change admin await signer.send_transaction( - admin, proxy.contract_address, 'set_admin', [ + admin, proxy.contract_address, 'setAdmin', [ new_admin.contract_address ] ) # check admin execution_info = await signer.send_transaction( - admin, proxy.contract_address, 'get_admin', [] + admin, proxy.contract_address, 'getAdmin', [] ) assert execution_info.result.response == [new_admin.contract_address] @pytest.mark.asyncio async def test_set_admin_from_non_admin(after_upgrade): - _, non_admin, _, _, proxy = after_upgrade + _, non_admin, proxy, *_ = after_upgrade - # change admin should revert - await assert_revert( + # should revert + await assert_revert(signer.send_transaction( + non_admin, proxy.contract_address, 'setAdmin', [non_admin.contract_address]), + reverted_with="Proxy: caller is not admin" + ) + + +@pytest.mark.asyncio +async def test_v2_functions_pre_and_post_upgrade(proxy_factory): + admin, new_admin, proxy, _, v2_decl = proxy_factory + + # initialize + await signer.send_transaction( + admin, proxy.contract_address, 'initializer', [ + admin.contract_address + ] + ) + + # check getValue2 doesn't exist + await assert_revert_entry_point( signer.send_transaction( - non_admin, proxy.contract_address, 'set_admin', [ - non_admin.contract_address - ] - ) + admin, proxy.contract_address, 'getValue2', [] + ), + invalid_selector='getValue2' + ) + + # check setValue2 doesn't exist in v1 + await assert_revert_entry_point( + signer.send_transaction( + admin, proxy.contract_address, 'setValue2', [VALUE_2] + ), + invalid_selector='setValue2' + ) + + # check getAdmin doesn't exist in v1 + await assert_revert_entry_point( + signer.send_transaction( + admin, proxy.contract_address, 'getAdmin', [] + ), + invalid_selector='getAdmin' + ) + + # check setAdmin doesn't exist in v1 + await assert_revert_entry_point( + signer.send_transaction( + admin, proxy.contract_address, 'setAdmin', [new_admin.contract_address] + ), + invalid_selector='setAdmin' + ) + + # upgrade + await signer.send_transaction( + admin, proxy.contract_address, 'upgrade', [ + v2_decl.class_hash + ] + ) + + # set value 2 and admin + await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'setValue2', [VALUE_2]), + (proxy.contract_address, 'setAdmin', [new_admin.contract_address]) + ] ) + + # check value 2 and admin + execution_info = await signer.send_transactions( + admin, + [ + (proxy.contract_address, 'getValue2', []), + (proxy.contract_address, 'getAdmin', []) + ] + ) + + expected = [ + VALUE_2, # getValue2 + new_admin.contract_address # getAdmin + ] + assert execution_info.result.response == expected diff --git a/tests/utils.py b/tests/utils.py index fe0b882cf..a97a8548a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -96,6 +96,13 @@ async def assert_revert(fun, reverted_with=None): assert reverted_with in error['message'] +async def assert_revert_entry_point(fun, invalid_selector): + selector_hex = hex(get_selector_from_name(invalid_selector)) + entry_point_msg = f"Entry point {selector_hex} not found in contract" + + await assert_revert(fun, entry_point_msg) + + def assert_event_emitted(tx_exec_info, from_address, name, data): assert Event( from_address=from_address, @@ -104,21 +111,21 @@ def assert_event_emitted(tx_exec_info, from_address, name, data): ) in tx_exec_info.raw_events -def get_contract_def(path): - """Returns the contract definition from the contract path""" +def get_contract_class(path): + """Return the contract class from the contract path""" path = contract_path(path) - contract_def = compile_starknet_files( + contract_class = compile_starknet_files( files=[path], debug_info=True ) - return contract_def + return contract_class -def cached_contract(state, definition, deployed): - """Returns the cached contract""" +def cached_contract(state, _class, deployed): + """Return the cached contract""" contract = StarknetContract( state=state, - abi=definition.abi, + abi=_class.abi, contract_address=deployed.contract_address, deploy_execution_info=deployed.deploy_execution_info ) diff --git a/tox.ini b/tox.ini index 4a8f53241..fe1f04898 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,7 @@ passenv = HOME PYTHONPATH deps = - cairo-lang==0.8.2.1 + cairo-lang==0.9.0 cairo-nile==0.6.1 pytest-xdist # See https://github.com/starkware-libs/cairo-lang/issues/52