Skip to content

Commit

Permalink
Move measurements after verbatim box (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Aug 5, 2024
1 parent fbb3f00 commit 2b09a3f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
20 changes: 17 additions & 3 deletions qiskit_braket_provider/providers/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from qiskit.circuit import Measure, Parameter, ParameterExpression
from qiskit.circuit.parametervector import ParameterVectorElement
from qiskit.transpiler import Target
from qiskit_ionq import ionq_gates
from qiskit_ionq import add_equivalences, ionq_gates

from qiskit_braket_provider.exception import QiskitBraketException

add_equivalences()

_GPHASE_GATE_NAME = "global_phase"

_BRAKET_TO_QISKIT_NAMES = {
Expand Down Expand Up @@ -436,6 +438,7 @@ def to_braket(
_validate_name_conflicts(circuit.parameters)

# Handle qiskit to braket conversion
measured_qubits = set()
for circuit_instruction in circuit.data:
operation = circuit_instruction.operation
gate_name = operation.name
Expand All @@ -445,7 +448,11 @@ def to_braket(
if gate_name == "measure":
qubit = qubits[0] # qubit count = 1 for measure
qubit_index = circuit.find_bit(qubit).index
braket_circuit.measure(qubit_index)
if qubit_index in measured_qubits:
raise ValueError(
f"Cannot measure previously measured qubit {qubit_index}"
)
measured_qubits.add(qubit_index)
elif gate_name == "barrier":
warnings.warn(
"The Qiskit circuit contains barrier instructions that are ignored."
Expand All @@ -463,6 +470,10 @@ def to_braket(

# Getting the index from the bit mapping
qubit_indices = [circuit.find_bit(qubit).index for qubit in qubits]
if intersection := measured_qubits.intersection(qubit_indices):
raise ValueError(
f"Cannot apply operation {gate_name} to measured qubits {intersection}"
)
params = _create_free_parameters(operation)
if gate_name in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES:
for gate in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES[gate_name](
Expand Down Expand Up @@ -492,10 +503,13 @@ def to_braket(
)

if verbatim:
return Circuit(braket_circuit.result_types).add_verbatim_box(
braket_circuit = Circuit(braket_circuit.result_types).add_verbatim_box(
Circuit(braket_circuit.instructions)
)

for qubit in sorted(measured_qubits):
braket_circuit.measure(qubit)

return braket_circuit


Expand Down
2 changes: 1 addition & 1 deletion qiskit_braket_provider/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Qiskit-Braket provider version."""

__version__ = "0.4.0"
__version__ = "0.4.1"
27 changes: 25 additions & 2 deletions tests/providers/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,28 @@ def test_measure(self):

self.assertEqual(braket_circuit, expected_braket_circuit)

def test_measure_repeated(self):
"""Tests that repeated measurement on a qubit raises a ValueError."""
qiskit_circuit = QuantumCircuit(2, 2)
qiskit_circuit.h(0)
qiskit_circuit.cx(0, 1)
qiskit_circuit.measure(0, 0)
qiskit_circuit.measure([0, 1], [0, 1])

with self.assertRaises(ValueError):
to_braket(qiskit_circuit)

def test_gate_after_measure(self):
"""Tests that adding a gate to a measured qubit raises a ValueError."""
qiskit_circuit = QuantumCircuit(2, 2)
qiskit_circuit.h(0)
qiskit_circuit.cx(0, 1)
qiskit_circuit.measure(0, 0)
qiskit_circuit.h(0)

with self.assertRaises(ValueError):
to_braket(qiskit_circuit)

def test_reset(self):
"""Tests if NotImplementedError is raised for reset operation."""

Expand Down Expand Up @@ -436,13 +458,14 @@ def test_multiple_registers(self):

def test_verbatim(self):
"""Tests that transpilation is skipped for verbatim circuits."""
qiskit_circuit = QuantumCircuit(2)
qiskit_circuit = QuantumCircuit(2, 1)
qiskit_circuit.h(0)
qiskit_circuit.cx(0, 1)
qiskit_circuit.measure(1, 0)

assert to_braket(qiskit_circuit, {"x"}, True) == Circuit().add_verbatim_box(
Circuit().h(0).cnot(0, 1)
)
).measure(1)

def test_parameter_vector(self):
"""Tests ParameterExpression translation."""
Expand Down

0 comments on commit 2b09a3f

Please sign in to comment.