Skip to content

Commit

Permalink
Defer op-arithmetic to default qubit (#349)
Browse files Browse the repository at this point in the history
* defer op-arithmetic to default qubit

* Auto update version

* tests, changelog, black

* differentiation tests

* black

Co-authored-by: Dev version update bot <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
albi3ro and github-actions[bot] committed Sep 9, 2022
1 parent 3b78924 commit 84e31bd
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 32 deletions.
13 changes: 12 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
* Implements caching for Kokkos installation.
[(#316)](https://github.com/PennyLaneAI/pennylane-lightning/pull/316)

* Supports measurements of operator arithmetic classes such as `Sum`, `Prod`,
and `SProd` by deferring handling of them to `DefaultQubit`.
[(#349)](https://github.com/PennyLaneAI/pennylane-lightning/pull/349)

```
@qml.qnode(qml.device('lightning.qubit', wires=2))
def circuit():
obs = qml.s_prod(2.1, qml.PauliZ(0)) + qml.op_sum(qml.PauliX(0), qml.PauliZ(1))
return qml.expval(obs)
```

### Documentation

### Bug fixes
Expand All @@ -32,7 +43,7 @@

This release contains contributions from (in alphabetical order):

Amintor Dusko, Chae-Yeun Park
Amintor Dusko, Christina Lee, Chae-Yeun Park

---

Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.26.0-dev11"
__version__ = "0.26.0-dev12"
22 changes: 9 additions & 13 deletions pennylane_lightning/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ def _remove_snapshot_from_operations(operations):
return operations


def _remove_op_arithmetic_from_observables(observables):
observables = observables.copy()
observables.discard("Sum")
observables.discard("SProd")
observables.discard("Prod")
return observables


class LightningQubit(DefaultQubit):
"""PennyLane Lightning device.
Expand Down Expand Up @@ -111,7 +103,6 @@ class LightningQubit(DefaultQubit):
author = "Xanadu Inc."
_CPP_BINARY_AVAILABLE = True
operations = _remove_snapshot_from_operations(DefaultQubit.operations)
observables = _remove_op_arithmetic_from_observables(DefaultQubit.observables)

def __init__(self, wires, *, c_dtype=np.complex128, shots=None, batch_obs=False):
if c_dtype is np.complex64:
Expand Down Expand Up @@ -617,10 +608,15 @@ def expval(self, observable, shot_range=None, bin_size=None):
Returns:
Expectation value of the observable
"""
if isinstance(observable.name, List) or observable.name in [
"Identity",
"Projector",
]:
if (
(observable.arithmetic_depth > 0)
or isinstance(observable.name, List)
or observable.name
in [
"Identity",
"Projector",
]
):
return super().expval(observable, shot_range=shot_range, bin_size=bin_size)

if self.shots is not None:
Expand Down
17 changes: 0 additions & 17 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,3 @@ def test_create_device_with_dtype(C):
def test_create_device_with_unsupported_dtype():
with pytest.raises(TypeError, match="Unsupported complex Type:"):
dev = qml.device("lightning.qubit", wires=1, c_dtype=np.complex256)


def test_no_op_arithmetic_support():
"""Test that lightning qubit explicitly does not support SProd, Prod, and Sum."""

dev = qml.device("lightning.qubit", wires=2)
for name in {"Prod", "SProd", "Sum"}:
assert name not in dev.operations

obs = qml.prod(qml.PauliX(0), qml.PauliY(1))

@qml.qnode(dev)
def circuit():
return qml.expval(obs)

with pytest.raises(qml.DeviceError, match=r"Observable Prod not supported on device"):
circuit()
83 changes: 83 additions & 0 deletions tests/test_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,89 @@ def test_hadamard_expectation(self, theta, phi, qubit_device_3_wires, tol):
assert np.allclose(res, expected, tol)


@pytest.mark.parametrize("diff_method", ("parameter-shift", "adjoint"))
class TestExpOperatorArithmetic:
"""Test integration of lightning with SProd, Prod, and Sum."""

dev = qml.device("lightning.qubit", wires=2)

def test_sprod(self, diff_method):
"""Test the `SProd` class with lightning qubit."""

@qml.qnode(self.dev, diff_method=diff_method)
def circuit(x):
qml.RX(x, wires=0)
return qml.expval(qml.s_prod(0.5, qml.PauliZ(0)))

x = qml.numpy.array(0.123, requires_grad=True)
res = circuit(x)
assert qml.math.allclose(res, 0.5 * np.cos(x))

g = qml.grad(circuit)(x)
expected_grad = -0.5 * np.sin(x)
assert qml.math.allclose(g, expected_grad)

def test_prod(self, diff_method):
"""Test the `Prod` class with lightning qubit."""

@qml.qnode(self.dev, diff_method=diff_method)
def circuit(x):
qml.RX(x, wires=0)
qml.Hadamard(1)
qml.PauliZ(1)
return qml.expval(qml.prod(qml.PauliZ(0), qml.PauliX(1)))

x = qml.numpy.array(0.123, requires_grad=True)
res = circuit(x)
assert qml.math.allclose(res, -np.cos(x))

g = qml.grad(circuit)(x)
expected_grad = np.sin(x)
assert qml.math.allclose(g, expected_grad)

def test_sum(self, diff_method):
"""Test the `Sum` class with lightning qubit."""

@qml.qnode(self.dev, diff_method=diff_method)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
return qml.expval(qml.op_sum(qml.PauliZ(0), qml.PauliX(1)))

x = qml.numpy.array(-3.21, requires_grad=True)
y = qml.numpy.array(2.34, requires_grad=True)
res = circuit(x, y)
assert qml.math.allclose(res, np.cos(x) + np.sin(y))

g = qml.grad(circuit)(x, y)
expected = (-np.sin(x), np.cos(y))
assert qml.math.allclose(g, expected)

def test_integration(self, diff_method):
"""Test a Combination of `Sum`, `SProd`, and `Prod`."""

obs = qml.op_sum(
qml.s_prod(2.3, qml.PauliZ(0)), -0.5 * qml.prod(qml.PauliY(0), qml.PauliZ(1))
)

@qml.qnode(self.dev, diff_method=diff_method)
def circuit(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=1)
return qml.expval(obs)

x = qml.numpy.array(0.654, requires_grad=True)
y = qml.numpy.array(-0.634, requires_grad=True)

res = circuit(x, y)
expected = 2.3 * np.cos(x) + 0.5 * np.sin(x) * np.cos(y)
assert qml.math.allclose(res, expected)

g = qml.grad(circuit)(x, y)
expected = (-2.3 * np.sin(x) + 0.5 * np.cos(y) * np.cos(x), -0.5 * np.sin(x) * np.sin(y))
assert qml.math.allclose(g, expected)


@pytest.mark.parametrize("theta,phi,varphi", list(zip(THETA, PHI, VARPHI)))
class TestTensorExpval:
"""Test tensor expectation values"""
Expand Down

0 comments on commit 84e31bd

Please sign in to comment.