Skip to content

Commit

Permalink
feat: Support SProd and CompositeOp for expval (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Sep 6, 2024
1 parent b154dce commit df838cb
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 133 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
packages=find_namespace_packages(where="src", exclude=("test",)),
package_dir={"": "src"},
install_requires=[
"amazon-braket-sdk>=1.47.0",
"amazon-braket-sdk>=1.87.0",
"autoray>=0.6.11",
"pennylane>=0.34.0",
],
Expand Down
14 changes: 7 additions & 7 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def operations(self) -> frozenset[str]:
@property
def observables(self) -> frozenset[str]:
base_observables = frozenset(super().observables)
# Amazon Braket only supports coefficients and multiple terms when shots==0
# Amazon Braket only supports scalar multiplication and addition when shots==0
if not self.shots:
return base_observables.union({"Hamiltonian", "LinearCombination"})
return base_observables
Expand Down Expand Up @@ -254,9 +254,8 @@ def _pl_to_braket_circuit(
braket_circuit = self._apply_gradient_result_type(circuit, braket_circuit)
elif not isinstance(circuit.measurements[0], MeasurementTransform):
for measurement in circuit.measurements:
dev_wires = self.map_wires(measurement.wires).tolist()
translated = translate_result_type(
measurement, dev_wires, self._braket_result_types
measurement.map_wires(self.wire_map), None, self._braket_result_types
)
if isinstance(translated, tuple):
for result_type in translated:
Expand All @@ -281,7 +280,7 @@ def _apply_gradient_result_type(self, circuit, braket_circuit):
f"Braket can only compute gradients for circuits with a single expectation"
f" observable, not a {pl_measurements.return_type} observable."
)
if isinstance(pl_observable, (Hamiltonian, qml.Hamiltonian, Sum)):
if isinstance(pl_observable, (Hamiltonian, Sum)):
targets = [self.map_wires(op.wires) for op in pl_observable.terms()[1]]
else:
targets = self.map_wires(pl_observable.wires).tolist()
Expand Down Expand Up @@ -544,9 +543,10 @@ def _run_task(self, circuit, inputs=None):
def _run_snapshots(self, snapshot_circuits, n_qubits, mapped_wires):
raise NotImplementedError("Need to implement snapshots runner")

def _get_statistic(self, braket_result, observable):
dev_wires = self.map_wires(observable.wires).tolist()
return translate_result(braket_result, observable, dev_wires, self._braket_result_types)
def _get_statistic(self, braket_result, mp):
return translate_result(
braket_result, mp.map_wires(self.wire_map), None, self._braket_result_types
)

@staticmethod
def _get_trainable_parameters(tape: QuantumTape) -> dict[int, numbers.Number]:
Expand Down
101 changes: 50 additions & 51 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from pennylane import numpy as np
from pennylane.measurements import MeasurementProcess, ObservableReturnTypes
from pennylane.operation import Observable, Operation
from pennylane.ops import Adjoint, Hamiltonian
from pennylane.pulse import ParametrizedEvolution

from braket.pennylane_plugin.ops import (
Expand Down Expand Up @@ -434,7 +433,7 @@ def _(ms: AAMS, parameters, device=None):


@_translate_operation.register
def _(adjoint: Adjoint, parameters, device=None):
def _(adjoint: qml.ops.Adjoint, parameters, device=None):
if isinstance(adjoint.base, qml.ISWAP):
# gates.ISwap.adjoint() returns a different value
return gates.PSwap(3 * np.pi / 2)
Expand Down Expand Up @@ -523,23 +522,25 @@ def get_adjoint_gradient_result_type(
):
if "AdjointGradient" not in supported_result_types:
raise NotImplementedError("Unsupported return type: AdjointGradient")
braket_observable = _translate_observable(observable)

braket_observable = _translate_observable(_flatten_observable(observable))
braket_observable = (
braket_observable.item() if hasattr(braket_observable, "item") else braket_observable
)
return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters)


def translate_result_type( # noqa: C901
measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str]
measurement: MeasurementProcess,
targets: Optional[list[int]],
supported_result_types: frozenset[str],
) -> Union[ResultType, tuple[ResultType, ...]]:
"""Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``.
Args:
measurement (MeasurementProcess): The PennyLane ``MeasurementProcess`` to translate
targets (list[int]): The target wires of the observable using a consecutive integer wire
ordering
targets (Optional[list[int]]): The target wires of the observable using a consecutive
integer wire ordering
supported_result_types (frozenset[str]): Braket result types supported by the Braket device
Returns:
Expand All @@ -548,6 +549,7 @@ def translate_result_type( # noqa: C901
then this will return a result type for each term.
"""
return_type = measurement.return_type
targets = targets or measurement.wires.tolist()
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
Expand All @@ -560,93 +562,88 @@ def translate_result_type( # noqa: C901
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in observable.ops
)
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

if observable is None:
if return_type is ObservableReturnTypes.Counts:
return tuple(Sample(observables.Z(), target) for target in targets or measurement.wires)
return tuple(Sample(observables.Z(target)) for target in targets or measurement.wires)
raise NotImplementedError(f"Unsupported return type: {return_type}")

observable = _flatten_observable(observable)

if isinstance(observable, qml.ops.LinearCombination):
if return_type is ObservableReturnTypes.Expectation:
return tuple(Expectation(_translate_observable(op)) for op in observable.terms()[1])
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
return Expectation(braket_observable, targets)
return Expectation(braket_observable)
elif return_type is ObservableReturnTypes.Variance:
return Variance(braket_observable, targets)
return Variance(braket_observable)
elif return_type in (ObservableReturnTypes.Sample, ObservableReturnTypes.Counts):
return Sample(braket_observable, targets)
return Sample(braket_observable)
else:
raise NotImplementedError(f"Unsupported return type: {return_type}")


def _flatten_observable(observable):
if isinstance(observable, (qml.ops.Hamiltonian, qml.ops.CompositeOp, qml.ops.SProd)):
simplified = qml.ops.LinearCombination(*observable.terms()).simplify()
coeffs, _ = simplified.terms()
if len(coeffs) > 1 or coeffs[0] != 1:
return simplified
return observable


@singledispatch
def _translate_observable(observable):
raise qml.DeviceError(f"Unsupported observable: {type(observable)}")


@_translate_observable.register(Hamiltonian)
@_translate_observable.register(qml.Hamiltonian)
def _(H: Union[Hamiltonian, qml.Hamiltonian]):
# terms is structured like [C, O] where C is a tuple of all the coefficients, and O is
# a tuple of all the corresponding observable terms (X, Y, Z, H, etc or a tensor product
# of them)
coefficents, pl_observables = H.terms()
braket_observables = list(map(lambda obs: _translate_observable(obs), pl_observables))
braket_hamiltonian = sum(
(coef * obs for coef, obs in zip(coefficents[1:], braket_observables[1:])),
coefficents[0] * braket_observables[0],
)
return braket_hamiltonian


@_translate_observable.register
def _(_: qml.PauliX):
return observables.X()
def _(obs: qml.PauliX):
return observables.X(obs.wires[0])


@_translate_observable.register
def _(_: qml.PauliY):
return observables.Y()
def _(obs: qml.PauliY):
return observables.Y(obs.wires[0])


@_translate_observable.register
def _(_: qml.PauliZ):
return observables.Z()
def _(obs: qml.PauliZ):
return observables.Z(obs.wires[0])


@_translate_observable.register
def _(_: qml.Hadamard):
return observables.H()
def _(obs: qml.Hadamard):
return observables.H(obs.wires[0])


@_translate_observable.register
def _(_: qml.Identity):
return observables.I()
def _(obs: qml.Identity):
return observables.I(obs.wires[0])


@_translate_observable.register
def _(h: qml.Hermitian):
return observables.Hermitian(qml.matrix(h))
def _(obs: qml.Hermitian):
return observables.Hermitian(qml.matrix(obs), targets=obs.wires)


_zero = np.array([[1, 0], [0, 0]])
_one = np.array([[0, 0], [0, 1]])


@_translate_observable.register
def _(p: qml.Projector):
state, wires = p.parameters[0], p.wires
def _(obs: qml.Projector):
state = obs.parameters[0]
wires = obs.wires
if len(state) == len(wires): # state is a basis state
products = [_one if b else _zero for b in state]
hermitians = [observables.Hermitian(p) for p in products]
hermitians = [observables.Hermitian(p, targets=[w]) for p, w in zip(products, wires)]
return observables.TensorProduct(hermitians)

# state is a state vector
return observables.Hermitian(p.matrix())
return observables.Hermitian(obs.matrix(), targets=wires)


@_translate_observable.register
Expand All @@ -672,7 +669,7 @@ def _(t: qml.ops.Sum):
def translate_result(
braket_result: GateModelQuantumTaskResult,
measurement: MeasurementProcess,
targets: list[int],
targets: Optional[list[int]],
supported_result_types: frozenset[str],
) -> Any:
"""Translates a Braket result into the corresponding PennyLane return type value.
Expand All @@ -681,7 +678,7 @@ def translate_result(
braket_result (GateModelQuantumTaskResult): The Braket result to translate.
measurement (MeasurementProcess): The PennyLane measurement process associated with the
result.
targets (list[int]): The qubits in the result.
targets (Optional[list[int]]): The qubits in the result.
supported_result_types (frozenset[str]): The result types supported by the device.
Returns:
Expand All @@ -706,6 +703,7 @@ def translate_result(
for i in sorted(key_indices)
]

targets = targets or measurement.wires.tolist()
if measurement.return_type is ObservableReturnTypes.Counts and observable is None:
if targets:
new_dict = {}
Expand All @@ -719,7 +717,8 @@ def translate_result(
return dict(braket_result.measurement_counts)

translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
observable = _flatten_observable(observable)
if isinstance(observable, qml.ops.LinearCombination):
coeffs, _ = observable.terms()
return sum(
coeff * braket_result.get_value_by_result_type(result_type)
Expand Down
Loading

0 comments on commit df838cb

Please sign in to comment.