Skip to content

Commit

Permalink
test: Compatibility with LegacyDeviceFacade (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Sep 4, 2024
1 parent 17dae39 commit 8030b47
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/braket/pennylane_plugin/ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.aws import AwsDevice, AwsQuantumTask, AwsSession
from braket.devices import Device, LocalSimulator
from braket.tasks.local_quantum_task import LocalQuantumTask
from pennylane import QubitDevice
from pennylane._version import __version__
from pennylane.measurements import MeasurementProcess, SampleMeasurement
Expand Down Expand Up @@ -562,11 +563,10 @@ def _ahs_program_from_evolution(

return ahs_program

def _run_task(self, ahs_program: AnalogHamiltonianSimulation) -> AwsQuantumTask:
def _run_task(self, ahs_program: AnalogHamiltonianSimulation) -> LocalQuantumTask:
"""Run and return a task executing the AnalogHamiltonianSimulation program on the
device"""
task = self._device.run(ahs_program, shots=self.shots, steps=100)
return task
return self._device.run(ahs_program, shots=self.shots, steps=100)

def _validate_pulses(self, pulses: list[HardwarePulse]): # noqa: C901
"""Validate that all pulses are defined as expected by the device. This validation includes:
Expand Down
18 changes: 15 additions & 3 deletions test/unit_tests/test_ahs_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,20 @@ def test_generate_samples(self):
task run"""
ahs_program = dummy_ahs_program()
dev = qml.device("braket.local.ahs", wires=3)

# checked in _validate_operations in the full pipeline
# since these are created manually for the unit test elsewhere in the file,
# we confirm the values used for the test are valid here
assert len(ahs_program.register.coordinate_list(0)) == len(dev.wires)

task = dev._run_task(ahs_program)

dev._task = task
# PennyLane 0.38+ wraps the device in a `LegacyDeviceFacade`
# TODO: Remove else branch once minimum PennyLane is >=0.38
if hasattr(dev, "target_device"):
dev.target_device._task = task
else:
dev._task = task

samples = dev.generate_samples()

assert len(samples) == 1000
Expand All @@ -557,7 +562,7 @@ def test_expval_handles_nan(self):

dev = qml.device("braket.local.ahs", wires=4, shots=4)

dev._samples = np.array(
samples = np.array(
[
[0, 1, 1, np.NaN],
[1, 1, 0, 0],
Expand All @@ -566,6 +571,13 @@ def test_expval_handles_nan(self):
]
)

# PennyLane 0.38+ wraps the device in a `LegacyDeviceFacade`
# TODO: Remove else branch once minimum PennyLane is >=0.38
if hasattr(dev, "target_device"):
dev.target_device._samples = samples
else:
dev._samples = samples

res = dev.expval(qml.PauliZ(3))

assert res != np.NaN
Expand Down

0 comments on commit 8030b47

Please sign in to comment.