Skip to content

Commit

Permalink
fix: Include measured in noncontiguous qubit map (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 committed Jun 26, 2024
1 parent c9730d2 commit d663479
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 93 deletions.
1 change: 1 addition & 0 deletions src/braket/default_simulator/openqasm/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def add_measure(self, target: tuple[int], classical_targets: Iterable[int] = Non
if qubit in self.measured_qubits:
raise ValueError(f"Qubit {qubit} is already measured or captured.")
self.measured_qubits.append(qubit)
self.qubit_set.add(qubit)
self.target_classical_indices.append(
classical_targets[index]
if classical_targets
Expand Down
71 changes: 31 additions & 40 deletions src/braket/default_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _create_results_obj(
openqasm_ir: OpenQASMProgram,
simulation: Simulation,
measured_qubits: list[int] = None,
mapped_measured_qubits: list[int] = None,
) -> GateModelTaskResult:
return GateModelTaskResult.construct(
taskMetadata=TaskMetadata(
Expand All @@ -267,10 +268,8 @@ def _create_results_obj(
action=openqasm_ir,
),
resultTypes=results,
measurements=self._formatted_measurements(simulation, measured_qubits),
measuredQubits=(
measured_qubits if measured_qubits else self._get_all_qubits(simulation.qubit_count)
),
measurements=self._formatted_measurements(simulation, mapped_measured_qubits),
measuredQubits=(measured_qubits or list(range(simulation.qubit_count))),
)

@staticmethod
Expand Down Expand Up @@ -348,10 +347,6 @@ def _validate_input_provided(self, circuit: Circuit) -> None:
missing_input = param.free_symbols.pop()
raise NameError(f"Missing input variable '{missing_input}'.")

@staticmethod
def _get_all_qubits(qubit_count: int) -> list[int]:
return list(range(qubit_count))

@staticmethod
def _tensor_product_index_dict(
observable: TensorProduct, func: Callable[[Observable], Any]
Expand Down Expand Up @@ -383,7 +378,7 @@ def _observable_hash(observable: Observable) -> Union[str, dict[int, str]]:
return str(observable.__class__.__name__)

@staticmethod
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> Circuit:
def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) -> dict[int, int]:
"""
Maps the qubits in operations and result types to contiguous qubits.
Expand All @@ -392,24 +387,23 @@ def _map_circuit_to_contiguous_qubits(circuit: Union[Circuit, JaqcdProgram]) ->
result types.
Returns:
Circuit: The circuit with qubits in operations and result types mapped
to contiguous qubits.
dict[int, int]: Map of qubit index to corresponding contiguous index
"""
circuit_qubit_set = BaseLocalSimulator._get_circuit_qubit_set(circuit)
qubit_map = BaseLocalSimulator._contiguous_qubit_mapping(circuit_qubit_set)
BaseLocalSimulator._map_instructions_to_qubits(circuit, qubit_map)
return circuit
BaseLocalSimulator._map_circuit_qubits(circuit, qubit_map)
return qubit_map

@staticmethod
def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set:
def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set[int]:
"""
Returns the set of qubits used in the given circuit.
Args:
circuit (Union[Circuit, JaqcdProgram]): The circuit from which to extract the qubit set.
Returns:
set: The set of qubits used in the circuit.
set[int]: The set of qubits used in the circuit.
"""
if isinstance(circuit, Circuit):
return circuit.qubit_set
Expand All @@ -425,12 +419,13 @@ def _get_circuit_qubit_set(circuit: Union[Circuit, JaqcdProgram]) -> set:
return BaseLocalSimulator._get_qubits_referenced(operations)

@staticmethod
def _map_instructions_to_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map: dict):
def _map_circuit_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map: dict[int, int]):
"""
Maps the qubits in operations and result types to contiguous qubits.
Args:
circuit (Circuit): The circuit containing the operations and result types.
qubit_map (dict[int, int]): The mapping from qubits to their contiguous indices.
Returns:
Circuit: The circuit with qubits in operations and result types mapped
Expand All @@ -441,7 +436,6 @@ def _map_instructions_to_qubits(circuit: Union[Circuit, JaqcdProgram], qubit_map
BaseLocalSimulator._map_circuit_results(circuit, qubit_map)
else:
BaseLocalSimulator._map_jaqcd_instructions(circuit, qubit_map)

return circuit

@staticmethod
Expand Down Expand Up @@ -514,13 +508,13 @@ def _map_instruction_attributes(instruction, qubit_map: dict):
instruction.targets = [qubit_map.get(q, q) for q in instruction.targets]

@staticmethod
def _contiguous_qubit_mapping(qubit_set: list[int]) -> dict[int, int]:
def _contiguous_qubit_mapping(qubit_set: set[int]) -> dict[int, int]:
"""
Maping of qubits to contiguous integers. The qubit mapping may be discontiguous or
contiguous.
Args:
qubit_set (list[int]): List of qubits to be mapped.
qubit_set (set[int]): List of qubits to be mapped.
Returns:
dict[int, int]: Dictionary where keys are qubits and values are contiguous integers.
Expand Down Expand Up @@ -548,22 +542,16 @@ def _formatted_measurements(
]
# Gets the subset of measurements from the full measurements
if measured_qubits is not None and measured_qubits != []:
if any(qubit in range(simulation.qubit_count) for qubit in measured_qubits):
measured_qubits = np.array(measured_qubits)
in_circuit_mask = measured_qubits < simulation.qubit_count
measured_qubits_in_circuit = measured_qubits[in_circuit_mask]
measured_qubits_not_in_circuit = measured_qubits[~in_circuit_mask]

measurements_array = np.array(measurements)
selected_measurements = measurements_array[:, measured_qubits_in_circuit]
measurements = np.pad(
selected_measurements, ((0, 0), (0, len(measured_qubits_not_in_circuit)))
).tolist()

else:
measurements = np.zeros(
(simulation.shots, len(measured_qubits)), dtype=int
).tolist()
measured_qubits = np.array(measured_qubits)
in_circuit_mask = measured_qubits < simulation.qubit_count
measured_qubits_in_circuit = measured_qubits[in_circuit_mask]
measured_qubits_not_in_circuit = measured_qubits[~in_circuit_mask]

measurements_array = np.array(measurements)
selected_measurements = measurements_array[:, measured_qubits_in_circuit]
measurements = np.pad(
selected_measurements, ((0, 0), (0, len(measured_qubits_not_in_circuit)))
).tolist()
return measurements

def run_openqasm(
Expand Down Expand Up @@ -593,8 +581,12 @@ def run_openqasm(
are requested when shots>0.
"""
circuit = self.parse_program(openqasm_ir).circuit
qubit_map = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)
qubit_count = circuit.num_qubits
measured_qubits = circuit.measured_qubits
mapped_measured_qubits = (
[qubit_map[q] for q in measured_qubits] if measured_qubits else None
)

self._validate_ir_results_compatibility(
circuit.results,
Expand All @@ -607,8 +599,6 @@ def run_openqasm(
self._validate_input_provided(circuit)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit.results, qubit_count)

circuit = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit)

results = circuit.results

simulation = self.initialize_simulation(
Expand All @@ -635,7 +625,9 @@ def run_openqasm(
else:
simulation.evolve(circuit.basis_rotation_instructions)

return self._create_results_obj(results, openqasm_ir, simulation, measured_qubits)
return self._create_results_obj(
results, openqasm_ir, simulation, measured_qubits, mapped_measured_qubits
)

def run_jaqcd(
self,
Expand Down Expand Up @@ -674,8 +666,7 @@ def run_jaqcd(
device_action_type=DeviceActionType.JAQCD,
)
BaseLocalSimulator._validate_shots_and_ir_results(shots, circuit_ir.results, qubit_count)

circuit_ir = BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)
BaseLocalSimulator._map_circuit_to_contiguous_qubits(circuit_ir)

operations = [
from_braket_instruction(instruction) for instruction in circuit_ir.instructions
Expand Down
6 changes: 0 additions & 6 deletions test/resources/discontiguous.qasm

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"braketSchemaHeader": {"name": "braket.ir.jaqcd.program", "version": "1"},
"instructions": [
{"target": 2, "type": "x"},
{"target": 2, "type": "h"},
{"control": 2, "target": 9, "type": "cnot"}
],
"results": [],
Expand Down
6 changes: 6 additions & 0 deletions test/resources/noncontiguous_physical.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
OPENQASM 3.0;
bit[2] b;
h $2;
cnot $2, $8;
b[0] = measure $2;
b[1] = measure $8;
7 changes: 7 additions & 0 deletions test/resources/noncontiguous_virtual.qasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
OPENQASM 3.0;
bit[2] b;
qubit[10] q;
h q[2];
cnot q[2], q[8];
b[0] = measure q[2];
b[1] = measure q[8];
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,12 @@ def grcs_8_qubit(ir_type):


@pytest.fixture
def discontiguous_jaqcd():
with open("test/resources/discontiguous_jaqcd.json") as jaqcd_definition:
def noncontiguous_jaqcd():
with open("test/resources/noncontiguous_jaqcd.json") as jaqcd_definition:
data = json.load(jaqcd_definition)
return json.dumps(data)


@pytest.fixture
def discontiguous_qasm():
return OpenQASMProgram(source="test/resources/discontiguous.qasm")


@pytest.fixture
def bell_ir(ir_type):
return (
Expand Down Expand Up @@ -828,35 +823,44 @@ def test_measure_no_gates():

def test_measure_with_qubits_not_used():
qasm = """
bit[4] b;
qubit[4] q;
h q[0];
cnot q[0], q[1];
bit[5] b;
qubit[5] q;
h q[1];
cnot q[1], q[3];
b = measure q;
"""
simulator = DensityMatrixSimulator()
result = simulator.run(OpenQASMProgram(source=qasm), shots=1000)
measurements = np.array(result.measurements, dtype=int)
assert 400 < np.sum(measurements, axis=0)[0] < 600
assert 400 < np.sum(measurements, axis=0)[1] < 600
assert 400 < np.sum(measurements, axis=0)[3] < 600
assert np.sum(measurements, axis=0)[0] == 0
assert np.sum(measurements, axis=0)[2] == 0
assert np.sum(measurements, axis=0)[3] == 0
assert len(measurements[0]) == 4
assert result.measuredQubits == [0, 1, 2, 3]
assert np.sum(measurements, axis=0)[4] == 0
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]


def test_discontiguous_qubits_jaqcd(discontiguous_jaqcd):
prg = JaqcdProgram.parse_raw(discontiguous_jaqcd)
def test_noncontiguous_qubits_jaqcd(noncontiguous_jaqcd):
prg = JaqcdProgram.parse_raw(noncontiguous_jaqcd)
result = DensityMatrixSimulator().run(prg, qubit_count=2, shots=1)

assert result.measuredQubits == [0, 1]
assert result.measurements == [["1", "1"]]
assert result.measurements in ([["0", "0"]], [["1", "1"]])


def test_discontiguous_qubits_openqasm(discontiguous_qasm):
@pytest.mark.parametrize("qasm_file_name", ["noncontiguous_virtual", "noncontiguous_physical"])
def test_noncontiguous_qubits_openqasm(qasm_file_name):
simulator = DensityMatrixSimulator()
result = simulator.run(discontiguous_qasm, shots=1000)
shots = 1000
result = simulator.run(
OpenQASMProgram(source=f"test/resources/{qasm_file_name}.qasm"), shots=shots
)

assert result.measuredQubits == [2, 8]
measurements = np.array(result.measurements, dtype=int)
assert len(measurements[0]) == 5
assert result.measuredQubits == [0, 1, 2, 3, 4]
assert measurements.shape == (shots, 2)
assert all(
(np.allclose(measurement, [0, 0]) or np.allclose(measurement, [1, 1]))
for measurement in measurements
)
Loading

0 comments on commit d663479

Please sign in to comment.