Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update lightning qubit memory management #601

Merged
merged 35 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
635767f
update dev version
AmintorDusko Jan 25, 2024
5c05425
update clean
AmintorDusko Jan 25, 2024
567d21f
expand .gitignore
AmintorDusko Jan 25, 2024
b761f3f
expand lightning qubit class specific python bindings
AmintorDusko Jan 25, 2024
762ee8c
add some statevector manipulation methods to statevector managed simu…
AmintorDusko Jan 25, 2024
99150d8
update lightning qubit python class
AmintorDusko Jan 25, 2024
404d249
add c++ unit tests
AmintorDusko Jan 25, 2024
b760684
add comments to test
AmintorDusko Jan 25, 2024
ab3304a
update lightning_qubit.py
AmintorDusko Jan 25, 2024
6b9de7c
remove obsolete tests
AmintorDusko Jan 26, 2024
dd1e632
rename statevector instance
AmintorDusko Jan 26, 2024
2e497fb
return some important tests
AmintorDusko Jan 26, 2024
d0675d1
implement PR review sugestions
AmintorDusko Jan 26, 2024
1ec3cd2
add review suggestion
AmintorDusko Jan 26, 2024
9c66497
Add semi-colon.
vincentmr Jan 26, 2024
31f4104
Fix Projector obs in L-Qubit and add Proj support in L-Kokkos.
vincentmr Jan 29, 2024
7ee08a2
Implement previous commit fix for None shots only.
vincentmr Jan 29, 2024
8df8a2a
Add tests for Proj expval/var
vincentmr Jan 29, 2024
e19003e
Merge branch 'master' into update_lightning_qubit_memory_management
AmintorDusko Jan 29, 2024
9e39e05
Auto update version
github-actions[bot] Jan 29, 2024
e6412bf
Trigger CI
AmintorDusko Jan 29, 2024
4e7f1a9
remove comment
AmintorDusko Jan 29, 2024
0c8cffb
add projector support to LGPU
AmintorDusko Jan 29, 2024
645bbe0
format
AmintorDusko Jan 29, 2024
40a8a00
revert changes for LGPU
AmintorDusko Jan 29, 2024
8d28fd4
skip tests for Projector observable not supported
AmintorDusko Jan 29, 2024
f115f6f
expand tests for lightning.qubit
AmintorDusko Jan 29, 2024
3d67dfe
update changelog
AmintorDusko Jan 29, 2024
4134f0e
Merge branch 'master' into update_lightning_qubit_memory_management
AmintorDusko Jan 29, 2024
0d7b876
Auto update version
github-actions[bot] Jan 29, 2024
2e7d363
Trigger CI
AmintorDusko Jan 29, 2024
6a2cab9
add some review suggestions
AmintorDusko Jan 29, 2024
e800725
remove identities
AmintorDusko Jan 30, 2024
1dc6b3e
update LKokkos and LQubit _apply_state_vector
AmintorDusko Jan 30, 2024
91d672e
Update tests/test_apply.py
AmintorDusko Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

### Improvements

* Decouple LightningQubit memory ownership from numpy and migrate it to LightningQubit managed state-vector class.
[(#601)](https://github.com/PennyLaneAI/pennylane-lightning/pull/601)

* Expand support for Projector observables on LightningKokkos.
[(#601)](https://github.com/PennyLaneAI/pennylane-lightning/pull/601)
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved

* Split Docker build cron job into two jobs: master and latest. This is mainly for reporting in the `plugin-test-matrix` repo.
[(#600)](https://github.com/PennyLaneAI/pennylane-lightning/pull/600)

Expand Down
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ doc/code/api/
PennyLane_Lightning.egg-info/
PennyLane_Lightning_Kokkos.egg-info/
build/
build_lightning_qubit/
build_lightning_kokkos/
build_lightning_gpu/
build_lightning_*/
Build/
BuildCov/
BuildGBench/
Expand Down
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ help:
clean:
find . -type d -name '__pycache__' -exec rm -r {} \+
rm -rf build Build BuildTests BuildTidy BuildGBench
rm -rf build_lightning_qubit
rm -rf build_lightning_kokkos
rm -rf build_lightning_gpu
rm -rf build_*
rm -rf .coverage coverage_html_report/
rm -rf pennylane_lightning/*_ops*

Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.35.0-dev8"
__version__ = "0.35.0-dev10"
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class StateVectorLQubitManaged final
* @param memory_model Memory model the statevector will use
*/
explicit StateVectorLQubitManaged(
size_t num_qubits, Threading threading = Threading::SingleThread,
std::size_t num_qubits, Threading threading = Threading::SingleThread,
CPUMemoryModel memory_model = bestCPUMemoryModel())
: BaseType{num_qubits, threading, memory_model},
data_{exp2(num_qubits), ComplexT{0.0, 0.0},
Expand Down Expand Up @@ -103,7 +103,7 @@ class StateVectorLQubitManaged final
* @param threading Threading option the statevector to use
* @param memory_model Memory model the statevector will use
*/
StateVectorLQubitManaged(const ComplexT *other_data, size_t other_size,
StateVectorLQubitManaged(const ComplexT *other_data, std::size_t other_size,
Threading threading = Threading::SingleThread,
CPUMemoryModel memory_model = bestCPUMemoryModel())
: BaseType(log2PerfectPower(other_size), threading, memory_model),
Expand Down Expand Up @@ -145,7 +145,7 @@ class StateVectorLQubitManaged final
*
* @param index Index of the target element.
*/
void setBasisState(const size_t index) {
void setBasisState(const std::size_t index) {
std::fill(data_.begin(), data_.end(), 0);
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
data_[index] = {1, 0};
}
Expand All @@ -158,7 +158,7 @@ class StateVectorLQubitManaged final
*/
void setStateVector(const std::vector<std::size_t> &indices,
const std::vector<ComplexT> &values) {
for (size_t n = 0; n < indices.size(); n++) {
for (std::size_t n = 0; n < indices.size(); n++) {
data_[indices[n]] = values[n];
}
}
Expand Down Expand Up @@ -198,7 +198,7 @@ class StateVectorLQubitManaged final
* @param new_data data pointer to new data.
* @param new_size size of underlying data storage.
*/
void updateData(const ComplexT *new_data, size_t new_size) {
void updateData(const ComplexT *new_data, std::size_t new_size) {
PL_ASSERT(data_.size() == new_size);
std::copy(new_data, new_data + new_size, data_.data());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
data_ptr);
}
},
"Synchronize data from the GPU device to host.")
"Copy StateVector data into a Numpy array.")
.def(
"UpdateData",
mlxd marked this conversation as resolved.
Show resolved Hide resolved
[](StateVectorT &device_sv, const np_arr_c &state) {
Expand All @@ -219,7 +219,7 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
device_sv.updateData(data_ptr, length);
}
},
"Synchronize data from the host device to GPU.")
"Copy StateVector data into a Numpy array.")
.def("applyControlledMatrix", &applyControlledMatrix<StateVectorT>,
"Apply controlled operation")
.def("kernel_map", &svKernelMap<StateVectorT>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,9 @@ TEMPLATE_TEST_CASE("StateVectorLQubitManaged::setBasisState",

TestVectorT expected_state(size_t{1U} << num_qubits, 0.0,
getBestAllocator<ComplexT>());
expected_state[3] = {1.0, 0.0};

std::size_t index = GENERATE(0, 1, 2, 3, 4, 5, 6, 7);
expected_state[index] = {1.0, 0.0};
StateVectorLQubitManaged<PrecisionT> sv(init_state);

size_t index = 3;
sv.setBasisState(index);

REQUIRE(sv.getDataVector() == approx(expected_state));
Expand Down
15 changes: 10 additions & 5 deletions pennylane_lightning/lightning_kokkos/lightning_kokkos.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"Hadamard",
"Hermitian",
"Identity",
"Projector",
"SparseHamiltonian",
"Hamiltonian",
"Sum",
Expand Down Expand Up @@ -196,7 +197,7 @@
shots=None,
batch_obs=False,
kokkos_args=None,
): # pylint: disable=unused-argument
): # pylint: disable=unused-argument, too-many-arguments
super().__init__(wires, shots=shots, c_dtype=c_dtype)

if kokkos_args is None:
Expand All @@ -213,10 +214,6 @@
if not LightningKokkos.kokkos_config:
LightningKokkos.kokkos_config = _kokkos_configuration()

# Create the initial state. Internally, we store the
# state as an array of dimension [2]*wires.
self._pre_rotated_state = _kokkos_dtype(c_dtype)(self.num_wires)

@staticmethod
def _asarray(arr, dtype=None):
arr = np.asarray(arr) # arr is not copied
Expand Down Expand Up @@ -466,8 +463,12 @@
Expectation value of the observable
"""
if observable.name in [
"Identity",
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
"Projector",
]:
if self.shots is None:
qs = qml.tape.QuantumScript([], [qml.expval(observable)])
self.apply(self._get_diagonalizing_gates(qs))

Check warning on line 471 in pennylane_lightning/lightning_kokkos/lightning_kokkos.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_kokkos/lightning_kokkos.py#L469-L471

Added lines #L469 - L471 were not covered by tests
return super().expval(observable, shot_range=shot_range, bin_size=bin_size)

if self.shots is not None:
Expand Down Expand Up @@ -526,8 +527,12 @@
Variance of the observable
"""
if observable.name in [
"Identity",
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
"Projector",
]:
if self.shots is None:
qs = qml.tape.QuantumScript([], [qml.var(observable)])
self.apply(self._get_diagonalizing_gates(qs))

Check warning on line 535 in pennylane_lightning/lightning_kokkos/lightning_kokkos.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_kokkos/lightning_kokkos.py#L533-L535

Added lines #L533 - L535 were not covered by tests
return super().var(observable, shot_range=shot_range, bin_size=bin_size)

if self.shots is not None:
Expand Down
26 changes: 15 additions & 11 deletions pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
def _state_dtype(dtype):
if dtype not in [np.complex128, np.complex64]: # pragma: no cover
raise ValueError(f"Data type is not supported for state-vector computation: {dtype}")
return StateVectorC128 if dtype == np.complex128 else StateVectorC64

Check warning on line 81 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L81

Added line #L81 was not covered by tests
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved

allowed_operations = {
"Identity",
Expand Down Expand Up @@ -231,9 +231,7 @@

# Create the initial state. Internally, we store the
# state as an array of dimension [2]*wires.
self._state = _state_dtype(c_dtype)(self.num_wires)
self._pre_rotated_state = _state_dtype(c_dtype)(self.num_wires)

self._qubit_state = _state_dtype(c_dtype)(self.num_wires)

Check warning on line 234 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L234

Added line #L234 was not covered by tests
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
self._batch_obs = batch_obs
self._mcmc = mcmc
if self._mcmc:
Expand Down Expand Up @@ -280,14 +278,14 @@
Args:
index (int): integer representing the computational basis state.
"""
self._state.setBasisState(index)
self._qubit_state.setBasisState(index)

Check warning on line 281 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L281

Added line #L281 was not covered by tests

def reset(self):
"""Reset the device"""
super().reset()

# init the state vector to |00..0>
self._state.resetStateVector()
self._qubit_state.resetStateVector()

Check warning on line 288 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L288

Added line #L288 was not covered by tests

@property
def create_ops_list(self):
Expand All @@ -314,15 +312,15 @@
>>> print(dev.state)
[0.+0.j 1.+0.j]
"""
state = np.zeros(2**self.num_wires, dtype=self.C_DTYPE)
state = self._asarray(state, dtype=self.C_DTYPE)
self._state.getState(state.ravel(order="C"))
self._qubit_state.getState(state.ravel(order="C"))
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
return state

Check warning on line 318 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L315-L318

Added lines #L315 - L318 were not covered by tests

@property
def state_vector(self):
"""Returns a handle to the statevector."""
return self._state
return self._qubit_state

Check warning on line 323 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L323

Added line #L323 was not covered by tests

def _apply_state_vector(self, state, device_wires):
"""Initialize the internal state vector in a specified state.
Expand All @@ -332,11 +330,11 @@
device_wires (Wires): wires that get initialized in the state
"""

if isinstance(state, self._state.__class__):
if isinstance(state, self._qubit_state.__class__):
state_data = np.zeros(state.size, dtype=self.C_DTYPE)
state_data = self._asarray(state_data, dtype=self.C_DTYPE)
self._state.getState(state_data.ravel(order="C"))
self._qubit_state.getState(state_data.ravel(order="C"))
state = state_data

Check warning on line 337 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L333-L337

Added lines #L333 - L337 were not covered by tests

ravelled_indices, state = self._preprocess_state_vector(state, device_wires)

Expand All @@ -346,11 +344,11 @@

if len(device_wires) == self.num_wires and Wires(sorted(device_wires)) == device_wires:
# Initialize the entire device state with the input state
state = self._reshape(state, output_shape).ravel(order="C")
self._state.UpdateData(state)
self._qubit_state.UpdateData(state)

Check warning on line 348 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L347-L348

Added lines #L347 - L348 were not covered by tests
return

self._state.setStateVector(ravelled_indices, state) # this operation on device
self._qubit_state.setStateVector(ravelled_indices, state) # this operation on device

Check warning on line 351 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L351

Added line #L351 was not covered by tests

def _apply_basis_state(self, state, wires):
"""Initialize the state vector in a specified computational basis state.
Expand All @@ -364,7 +362,7 @@
Note: This function does not support broadcasted inputs yet.
"""
num = self._get_basis_state_index(state, wires)
self._create_basis_state(num)

Check warning on line 365 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L365

Added line #L365 was not covered by tests

def _apply_lightning_controlled(self, operation):
"""Apply an arbitrary controlled operation to the state tensor.
Expand All @@ -375,12 +373,12 @@
Returns:
array[complex]: the output state tensor
"""
state = self.state_vector

Check warning on line 376 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L376

Added line #L376 was not covered by tests

basename = "PauliX" if operation.name == "MultiControlledX" else operation.base.name
if basename == "Identity":
return
method = getattr(state, f"{basename}", None)

Check warning on line 381 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L381

Added line #L381 was not covered by tests
control_wires = self.wires.indices(operation.control_wires)
control_values = (
[bool(int(i)) for i in operation.hyperparameters["control_values"]]
Expand All @@ -396,7 +394,7 @@
param = operation.parameters
method(control_wires, control_values, target_wires, inv, param)
else: # apply gate as an n-controlled matrix
method = getattr(state, "applyControlledMatrix")

Check warning on line 397 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L397

Added line #L397 was not covered by tests
target_wires = self.wires.indices(operation.target_wires)
try:
method(
Expand All @@ -421,15 +419,15 @@
Returns:
array[complex]: the output state tensor
"""
state = self.state_vector

Check warning on line 422 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L422

Added line #L422 was not covered by tests

# Skip over identity operations instead of performing
# matrix multiplication with it.
for operation in operations:
name = operation.name
if name == "Identity":

Check warning on line 428 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L427-L428

Added lines #L427 - L428 were not covered by tests
continue
method = getattr(state, name, None)

Check warning on line 430 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L430

Added line #L430 was not covered by tests
wires = self.wires.indices(operation.wires)

if method is not None: # apply specialized gate
Expand All @@ -441,11 +439,11 @@
or name == "ControlledQubitUnitary"
or name == "MultiControlledX"
): # apply n-controlled gate
self._apply_lightning_controlled(operation)

Check warning on line 442 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L442

Added line #L442 was not covered by tests
else: # apply gate as a matrix
# Inverse can be set to False since qml.matrix(operation) is already in
# inverted form
method = getattr(state, "applyMatrix")

Check warning on line 446 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L446

Added line #L446 was not covered by tests
try:
method(qml.matrix(operation), wires, False)
except AttributeError: # pragma: no cover
Expand Down Expand Up @@ -473,7 +471,7 @@
f"Operations have already been applied on a {self.short_name} device."
)

self.apply_lightning(operations)

Check warning on line 474 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L474

Added line #L474 was not covered by tests

# pylint: disable=protected-access
def expval(self, observable, shot_range=None, bin_size=None):
Expand All @@ -494,6 +492,9 @@
"Identity",
"Projector",
]:
if self.shots is None:
qs = qml.tape.QuantumScript([], [qml.expval(observable)])
self.apply(self._get_diagonalizing_gates(qs))

Check warning on line 497 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L495-L497

Added lines #L495 - L497 were not covered by tests
return super().expval(observable, shot_range=shot_range, bin_size=bin_size)

if self.shots is not None:
Expand Down Expand Up @@ -548,6 +549,9 @@
"Identity",
"Projector",
]:
if self.shots is None:
qs = qml.tape.QuantumScript([], [qml.var(observable)])
self.apply(self._get_diagonalizing_gates(qs))

Check warning on line 554 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L552-L554

Added lines #L552 - L554 were not covered by tests
return super().var(observable, shot_range=shot_range, bin_size=bin_size)

if self.shots is not None:
Expand Down Expand Up @@ -624,10 +628,10 @@
# pylint: disable=attribute-defined-outside-init
def sample(self, observable, shot_range=None, bin_size=None, counts=False):
"""Return samples of an observable."""
if observable.name != "PauliZ":
self.apply_lightning(observable.diagonalizing_gates())
self._samples = self.generate_samples()
return super().sample(

Check warning on line 634 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L631-L634

Added lines #L631 - L634 were not covered by tests
observable, shot_range=shot_range, bin_size=bin_size, counts=counts
)

Expand Down Expand Up @@ -694,11 +698,11 @@
"The number of qubits of starting_state must be the same as "
"that of the device."
)
self._apply_state_vector(starting_state, self.wires)
elif not use_device_state:
self.reset()
self.apply(tape.operations)
return self.state_vector

Check warning on line 705 in pennylane_lightning/lightning_qubit/lightning_qubit.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_qubit/lightning_qubit.py#L701-L705

Added lines #L701 - L705 were not covered by tests

def adjoint_jacobian(self, tape, starting_state=None, use_device_state=False):
"""Computes and returns the Jacobian with the adjoint method."""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ def test_provide_starting_state(self, tol, dev):

dM1 = dev.adjoint_jacobian(tape)

if device_name == "lightning.kokkos":
dev._pre_rotated_state = dev.state_vector # necessary for lightning.kokkos
if device_name in ["lightning.kokkos", "lightning.qubit"]:
dev._pre_rotated_state = dev.state_vector

qml.execute([tape], dev, None)
dM2 = dev.adjoint_jacobian(tape, starting_state=dev._pre_rotated_state)
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading
Loading