Skip to content

Commit

Permalink
fix: Translate Sum for adjoint gradient (#252)
Browse files Browse the repository at this point in the history
The plugin only translates `Hamiltonian`s for the adjoint gradient, but Hamiltonians are often actually `Sum` objects. This change adds `Sum` observable translation.

Co-authored-by: Coull <accoull@amazon.com>
  • Loading branch information
speller26 and Coull committed May 8, 2024
1 parent df4d830 commit dbba03a
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
Variance,
)
from pennylane.operation import Operation
from pennylane.ops import Hamiltonian
from pennylane.ops import Hamiltonian, Sum
from pennylane.tape import QuantumTape

from braket.pennylane_plugin.translation import (
Expand Down Expand Up @@ -162,7 +162,7 @@ def operations(self) -> frozenset[str]:

@property
def observables(self) -> frozenset[str]:
base_observables = frozenset(super().observables - {"SProd", "Sum"})
base_observables = frozenset(super().observables)
# Amazon Braket only supports coefficients and multiple terms when shots==0
if not self.shots:
return base_observables.union({"Hamiltonian", "LinearCombination"})
Expand Down Expand Up @@ -226,8 +226,8 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, (Hamiltonian, qml.Hamiltonian)):
targets = [self.map_wires(op.wires) for op in pl_observable.ops]
if isinstance(pl_observable, (Hamiltonian, qml.Hamiltonian, Sum)):
targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]]
else:
targets = self.map_wires(pl_observable.wires).tolist()

Expand Down
4 changes: 4 additions & 0 deletions src/braket/pennylane_plugin/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
CPhaseShift00,
CPhaseShift01,
CPhaseShift10,
PRx,
PSWAP,
GPi,
GPi2,
MS,
AAMS,
)
Operations
Expand All @@ -42,10 +44,12 @@
CPhaseShift00
CPhaseShift01
CPhaseShift10
PRx
PSWAP
GPi
GPi2
MS
AAMS
Code details
~~~~~~~~~~~~
Expand Down
19 changes: 19 additions & 0 deletions test/unit_tests/test_ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,25 @@ def test_check_validity_valid_circuit(self, H, params):

dev.check_validity(ops, obs)

@pytest.mark.parametrize("H, params", HAMILTONIANS_AND_PARAMS)
def test_check_validity_valid_circuit_no_op_math(self, H, params):
"""Tests that check_validity() doesn't raise any errors when the operations and
observables are valid."""
qml.operation.disable_new_opmath()
ops = [ParametrizedEvolution(H, params, [0, 1.5])]
obs = [
qml.PauliZ(0),
qml.expval(qml.PauliZ(0)),
qml.var(qml.Identity(0)),
qml.sample(qml.PauliZ(0)),
qml.prod(qml.PauliZ(0), qml.Identity(1)),
qml.Hamiltonian([2, 3], [qml.PauliZ(0), qml.PauliZ(1)]),
qml.counts(),
]
dev = qml.device("braket.local.ahs", wires=3)

dev.check_validity(ops, obs)

@pytest.mark.parametrize("H, params", HAMILTONIANS_AND_PARAMS)
def test_check_validity_raises_error_for_state_based_measurement(self, H, params):
"""Tests that requesting a measurement other than a sample-based
Expand Down
76 changes: 76 additions & 0 deletions test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,82 @@ def test_execute_with_gradient(
assert (results[1] == expected_pl_result[0][1]).all()


@patch.object(AwsDevice, "run")
@pytest.mark.parametrize(
"pl_circ, expected_braket_circ, wires, expected_inputs, result_types, expected_pl_result",
[
(
CIRCUIT_2,
Circuit()
.h(0)
.cnot(0, 1)
.rx(0, FreeParameter("p_0"))
.ry(0, FreeParameter("p_1"))
.adjoint_gradient(
observable=(2 * Observable.X() @ Observable.Y()),
target=[0, 1],
parameters=["p_0", "p_1"],
),
2,
{"p_0": 0.432, "p_1": 0.543},
[
{
"type": {
"observable": "2.0 * x() @ y()",
"targets": [[0, 1]],
"parameters": ["p_0", "p_1"],
"type": "adjoint_gradient",
},
"value": {
"gradient": {"p_0": -0.01894799, "p_1": 0.9316158},
"expectation": 0.0,
},
},
],
[
(
np.tensor([0.0], requires_grad=True),
np.tensor([-0.01894799, 0.9316158], requires_grad=True),
)
],
),
],
)
def test_execute_with_gradient_no_op_math(
mock_run,
pl_circ,
expected_braket_circ,
wires,
expected_inputs,
result_types,
expected_pl_result,
):
qml.operation.disable_new_opmath()

task = Mock()
type(task).id = PropertyMock(return_value="task_arn")
task.state.return_value = "COMPLETED"
task.result.return_value = get_test_result_object(rts=result_types)
mock_run.return_value = task
dev = _aws_device(wires=wires, foo="bar", shots=0, device_type=AwsDeviceType.SIMULATOR)

results = dev.execute(pl_circ, compute_gradient=True)

assert dev.task == task

mock_run.assert_called_with(
expected_braket_circ,
s3_destination_folder=("foo", "bar"),
shots=0,
poll_timeout_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_TIMEOUT,
poll_interval_seconds=AwsQuantumTask.DEFAULT_RESULTS_POLL_INTERVAL,
foo="bar",
inputs=expected_inputs,
)
assert (results[0] == expected_pl_result[0][0]).all()
assert (results[1] == expected_pl_result[0][1]).all()


@patch.object(AwsDevice, "run")
def test_execute_tracker(mock_run):
"""Asserts tracker stores information during execute when active"""
Expand Down

0 comments on commit dbba03a

Please sign in to comment.