From dbba03a358f926527de884b1d4566eb1d2e27f5a Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Wed, 8 May 2024 14:31:29 -0700 Subject: [PATCH] fix: Translate `Sum` for adjoint gradient (#252) The plugin only translates `Hamiltonian`s for the adjoint gradient, but Hamiltonians are often actually `Sum` objects. This change adds `Sum` observable translation. Co-authored-by: Coull --- src/braket/pennylane_plugin/braket_device.py | 8 +-- src/braket/pennylane_plugin/ops.py | 4 ++ test/unit_tests/test_ahs_device.py | 19 +++++ test/unit_tests/test_braket_device.py | 76 ++++++++++++++++++++ 4 files changed, 103 insertions(+), 4 deletions(-) diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index 854b4d42..c6d6c87f 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -65,7 +65,7 @@ Variance, ) from pennylane.operation import Operation -from pennylane.ops import Hamiltonian +from pennylane.ops import Hamiltonian, Sum from pennylane.tape import QuantumTape from braket.pennylane_plugin.translation import ( @@ -162,7 +162,7 @@ def operations(self) -> frozenset[str]: @property def observables(self) -> frozenset[str]: - base_observables = frozenset(super().observables - {"SProd", "Sum"}) + base_observables = frozenset(super().observables) # Amazon Braket only supports coefficients and multiple terms when shots==0 if not self.shots: return base_observables.union({"Hamiltonian", "LinearCombination"}) @@ -226,8 +226,8 @@ def _apply_gradient_result_type(self, circuit, braket_circuit): f"Braket can only compute gradients for circuits with a single expectation" f" observable, not a {pl_measurements.return_type} observable." ) - if isinstance(pl_observable, (Hamiltonian, qml.Hamiltonian)): - targets = [self.map_wires(op.wires) for op in pl_observable.ops] + if isinstance(pl_observable, (Hamiltonian, qml.Hamiltonian, Sum)): + targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]] else: targets = self.map_wires(pl_observable.wires).tolist() diff --git a/src/braket/pennylane_plugin/ops.py b/src/braket/pennylane_plugin/ops.py index c9d1a3c5..3ea3b7a8 100644 --- a/src/braket/pennylane_plugin/ops.py +++ b/src/braket/pennylane_plugin/ops.py @@ -29,10 +29,12 @@ CPhaseShift00, CPhaseShift01, CPhaseShift10, + PRx, PSWAP, GPi, GPi2, MS, + AAMS, ) Operations @@ -42,10 +44,12 @@ CPhaseShift00 CPhaseShift01 CPhaseShift10 + PRx PSWAP GPi GPi2 MS + AAMS Code details ~~~~~~~~~~~~ diff --git a/test/unit_tests/test_ahs_device.py b/test/unit_tests/test_ahs_device.py index bb57a3f3..3ff27025 100644 --- a/test/unit_tests/test_ahs_device.py +++ b/test/unit_tests/test_ahs_device.py @@ -439,6 +439,25 @@ def test_check_validity_valid_circuit(self, H, params): dev.check_validity(ops, obs) + @pytest.mark.parametrize("H, params", HAMILTONIANS_AND_PARAMS) + def test_check_validity_valid_circuit_no_op_math(self, H, params): + """Tests that check_validity() doesn't raise any errors when the operations and + observables are valid.""" + qml.operation.disable_new_opmath() + ops = [ParametrizedEvolution(H, params, [0, 1.5])] + obs = [ + qml.PauliZ(0), + qml.expval(qml.PauliZ(0)), + qml.var(qml.Identity(0)), + qml.sample(qml.PauliZ(0)), + qml.prod(qml.PauliZ(0), qml.Identity(1)), + qml.Hamiltonian([2, 3], [qml.PauliZ(0), qml.PauliZ(1)]), + qml.counts(), + ] + dev = qml.device("braket.local.ahs", wires=3) + + dev.check_validity(ops, obs) + @pytest.mark.parametrize("H, params", HAMILTONIANS_AND_PARAMS) def test_check_validity_raises_error_for_state_based_measurement(self, H, params): """Tests that requesting a measurement other than a sample-based diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index ea21ceea..7d818570 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -507,6 +507,82 @@ def test_execute_with_gradient( assert (results[1] == expected_pl_result[0][1]).all() +@patch.object(AwsDevice, "run") +@pytest.mark.parametrize( + "pl_circ, expected_braket_circ, wires, expected_inputs, result_types, expected_pl_result", + [ + ( + CIRCUIT_2, + Circuit() + .h(0) + .cnot(0, 1) + .rx(0, FreeParameter("p_0")) + .ry(0, FreeParameter("p_1")) + .adjoint_gradient( + observable=(2 * Observable.X() @ Observable.Y()), + target=[0, 1], + parameters=["p_0", "p_1"], + ), + 2, + {"p_0": 0.432, "p_1": 0.543}, + [ + { + "type": { + "observable": "2.0 * x() @ y()", + "targets": [[0, 1]], + "parameters": ["p_0", "p_1"], + "type": "adjoint_gradient", + }, + "value": { + "gradient": {"p_0": -0.01894799, "p_1": 0.9316158}, + "expectation": 0.0, + }, + }, + ], + [ + ( + np.tensor([0.0], requires_grad=True), + np.tensor([-0.01894799, 0.9316158], requires_grad=True), + ) + ], + ), + ], +) +def test_execute_with_gradient_no_op_math( + mock_run, + pl_circ, + expected_braket_circ, + wires, + expected_inputs, + result_types, + expected_pl_result, +): + qml.operation.disable_new_opmath() + + task = Mock() + type(task).id = PropertyMock(return_value="task_arn") + task.state.return_value = "COMPLETED" + task.result.return_value = get_test_result_object(rts=result_types) + mock_run.return_value = task + dev = _aws_device(wires=wires, foo="bar", shots=0, device_type=AwsDeviceType.SIMULATOR) + + results = dev.execute(pl_circ, compute_gradient=True) + + assert dev.task == task + + mock_run.assert_called_with( + expected_braket_circ, + s3_destination_folder=("foo", "bar"), + shots=0, + poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT, + poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL, + foo="bar", + inputs=expected_inputs, + ) + assert (results[0] == expected_pl_result[0][0]).all() + assert (results[1] == expected_pl_result[0][1]).all() + + @patch.object(AwsDevice, "run") def test_execute_tracker(mock_run): """Asserts tracker stores information during execute when active"""