Skip to content

Commit

Permalink
Use device-specific gatesets for transpilation (#141)
Browse files Browse the repository at this point in the history
* Direct support for controlled gates

* More gates

* Use device-specific gateset

* Consider supported modifiers

* linters

* constant names

* bugfix

* Added supported modifiers to mocks

* Address comments

* Expose controlled gateset through method

* Reject negative control
  • Loading branch information
speller26 committed Feb 7, 2024
1 parent c6938fe commit b6f3112
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 46 deletions.
163 changes: 129 additions & 34 deletions qiskit_braket_provider/providers/adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Util function for provider."""
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from collections.abc import Callable, Iterable
from typing import Optional, Union
import warnings

from braket.aws import AwsDevice
Expand Down Expand Up @@ -30,15 +31,51 @@

from qiskit import QuantumCircuit, transpile
from qiskit.circuit import Instruction as QiskitInstruction
from qiskit.circuit import Measure, Parameter
from qiskit.circuit import ControlledGate, Measure, Parameter
import qiskit.circuit.library as qiskit_gates

from qiskit.transpiler import InstructionProperties, Target
from qiskit_braket_provider.exception import QiskitBraketException

BRAKET_TO_QISKIT_NAMES = {
"u": "u",
"phaseshift": "p",
"cnot": "cx",
"x": "x",
"y": "y",
"z": "z",
"t": "t",
"ti": "tdg",
"s": "s",
"si": "sdg",
"v": "sx",
"vi": "sxdg",
"swap": "swap",
"rx": "rx",
"ry": "ry",
"rz": "rz",
"xx": "rxx",
"yy": "ryy",
"zz": "rzz",
"i": "id",
"h": "h",
"cy": "cy",
"cz": "cz",
"ccnot": "ccx",
"cswap": "cswap",
"cphaseshift": "cp",
"ecr": "ecr",
}

_CONTROLLED_GATES_BY_QUBIT_COUNT = {
1: {"ch", "cs", "csdg", "csx", "crx", "cry", "crz", "ccz"},
3: {"c3sx"},
}
_ARBITRARY_CONTROLLED_GATES = {"mcx"}

_EPS = 1e-10 # global variable used to chop very small numbers to zero

GATE_NAME_TO_BRAKET_GATE: Dict[str, Callable] = {
GATE_NAME_TO_BRAKET_GATE: dict[str, Callable] = {
"u1": lambda lam: [braket_gates.PhaseShift(lam)],
"u2": lambda phi, lam: [
braket_gates.PhaseShift(lam),
Expand Down Expand Up @@ -84,12 +121,26 @@
"iswap": lambda: [braket_gates.ISwap()],
}

_QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES: dict[str, Callable] = {
"ch": braket_gates.H,
"cs": braket_gates.S,
"csdg": braket_gates.Si,
"csx": braket_gates.V,
"ccz": braket_gates.CZ,
"c3sx": braket_gates.V,
"mcx": braket_gates.CNot,
"crx": braket_gates.Rx,
"cry": braket_gates.Ry,
"crz": braket_gates.Rz,
}

_TRANSLATABLE_QISKIT_GATE_NAMES = set(GATE_NAME_TO_BRAKET_GATE.keys()).union(
{"measure", "barrier", "reset"}
_TRANSLATABLE_QISKIT_GATE_NAMES = (
set(GATE_NAME_TO_BRAKET_GATE.keys())
.union(set(_QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES))
.union({"measure", "barrier", "reset"})
)

GATE_NAME_TO_QISKIT_GATE: Dict[str, Optional[QiskitInstruction]] = {
GATE_NAME_TO_QISKIT_GATE: dict[str, Optional[QiskitInstruction]] = {
"u": qiskit_gates.UGate(Parameter("theta"), Parameter("phi"), Parameter("lam")),
"u1": qiskit_gates.U1Gate(Parameter("theta")),
"u2": qiskit_gates.U2Gate(Parameter("theta"), Parameter("lam")),
Expand Down Expand Up @@ -124,6 +175,28 @@
}


def get_controlled_gateset(max_qubits: Optional[int] = None) -> set[str]:
"""Returns the Qiskit gates expressible as controlled versions of existing Braket gates
This set can be filtered by the maximum number of control qubits.
Args:
max_qubits (Optional[int]): The maximum number of control qubits that can be used to express
the Qiskit gate as a controlled Braket gate. If `None`, then there is no limit to the
number of control qubits. Default: `None`.
Returns:
set[str]: The names of the controlled gates.
"""
if max_qubits is None:
gateset = set().union(*[g for _, g in _CONTROLLED_GATES_BY_QUBIT_COUNT.items()])
gateset.update(_ARBITRARY_CONTROLLED_GATES)
return gateset
return set().union(
*[g for q, g in _CONTROLLED_GATES_BY_QUBIT_COUNT.items() if q <= max_qubits]
)


def local_simulator_to_target(simulator: LocalSimulator) -> Target:
"""Converts properties of LocalSimulator into Qiskit Target object.
Expand All @@ -146,7 +219,7 @@ def local_simulator_to_target(simulator: LocalSimulator) -> Target:

for instruction in instructions:
instruction_props: Optional[
Dict[Union[Tuple[int], Tuple[int, int]], Optional[InstructionProperties]]
dict[Union[tuple[int], tuple[int, int]], Optional[InstructionProperties]]
] = {}

if instruction.num_qubits == 1:
Expand Down Expand Up @@ -189,7 +262,7 @@ def aws_device_to_target(device: AwsDevice) -> Target:
)
paradigm: GateModelQpuParadigmProperties = properties.paradigm
connectivity = paradigm.connectivity
instructions: List[QiskitInstruction] = []
instructions: list[QiskitInstruction] = []

for operation in action_properties.supportedOperations:
instruction = GATE_NAME_TO_QISKIT_GATE.get(operation.lower(), None)
Expand All @@ -205,8 +278,8 @@ def aws_device_to_target(device: AwsDevice) -> Target:

for instruction in instructions:
instruction_props: Optional[
Dict[
Union[Tuple[int], Tuple[int, int]], Optional[InstructionProperties]
dict[
Union[tuple[int], tuple[int, int]], Optional[InstructionProperties]
]
] = {}
# adding 1 qubit instructions
Expand Down Expand Up @@ -303,8 +376,8 @@ def convert_continuous_qubit_indices(

for instruction in instructions:
simulator_instruction_props: Optional[
Dict[
Union[Tuple[int], Tuple[int, int]],
dict[
Union[tuple[int], tuple[int, int]],
Optional[InstructionProperties],
]
] = {}
Expand All @@ -331,26 +404,26 @@ def convert_continuous_qubit_indices(
return target


def to_braket(circuit: QuantumCircuit) -> Circuit:
def to_braket(circuit: QuantumCircuit, gateset: Iterable[str] = None) -> Circuit:
"""Return a Braket quantum circuit from a Qiskit quantum circuit.
Args:
circuit (QuantumCircuit): Qiskit Quantum Circuit
circuit (QuantumCircuit): Qiskit quantum circuit
gateset (Iterable[str]): The gateset to transpile to
Returns:
Circuit: Braket circuit
"""
gateset = gateset or _TRANSLATABLE_QISKIT_GATE_NAMES
if not isinstance(circuit, QuantumCircuit):
raise TypeError(f"Expected a QuantumCircuit, got {type(circuit)} instead.")

quantum_circuit = Circuit()
braket_circuit = Circuit()
if not (
{gate.name for gate, _, _ in circuit.data}.issubset(
_TRANSLATABLE_QISKIT_GATE_NAMES
)
):
circuit = transpile(
circuit, basis_gates=_TRANSLATABLE_QISKIT_GATE_NAMES, optimization_level=0
)
circuit = transpile(circuit, basis_gates=gateset, optimization_level=0)

# handle qiskit to braket conversion
for circuit_instruction in circuit.data:
Expand All @@ -362,7 +435,7 @@ def to_braket(circuit: QuantumCircuit) -> Circuit:
if gate_name == "measure":
qubit = qubits[0] # qubit count = 1 for measure
qubit_index = circuit.find_bit(qubit).index
quantum_circuit.sample(
braket_circuit.sample(
observable=observables.Z(),
target=[
qubit_index,
Expand All @@ -377,23 +450,44 @@ def to_braket(circuit: QuantumCircuit) -> Circuit:
"reset operation not supported by qiskit to braket adapter"
)
else:
params = operation.params if hasattr(operation, "params") else []

for i, param in enumerate(params):
if isinstance(param, Parameter):
params[i] = FreeParameter(param.name)

for gate in GATE_NAME_TO_BRAKET_GATE[gate_name](*params):
params = _create_free_parameters(operation)
if (
isinstance(operation, ControlledGate)
and operation.ctrl_state != 2**operation.num_ctrl_qubits - 1
):
raise ValueError("Negative control is not supported")
if gate_name in _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES:
gate = _QISKIT_CONTROLLED_GATE_NAMES_TO_BRAKET_GATES[gate_name](*params)
qubit_indices = [circuit.find_bit(qubit).index for qubit in qubits]
gate_qubit_count = gate.qubit_count
target_indices = qubit_indices[-gate_qubit_count:]
instruction = Instruction(
# Getting the index from the bit mapping
operator=gate,
target=[circuit.find_bit(qubit).index for qubit in qubits],
target=target_indices,
control=qubit_indices[:-gate_qubit_count],
)
quantum_circuit += instruction
braket_circuit += instruction
else:
for gate in GATE_NAME_TO_BRAKET_GATE[gate_name](*params):
instruction = Instruction(
operator=gate,
target=[circuit.find_bit(qubit).index for qubit in qubits],
)
braket_circuit += instruction

if circuit.global_phase > _EPS:
quantum_circuit.gphase(circuit.global_phase)
braket_circuit.gphase(circuit.global_phase)

return braket_circuit

return quantum_circuit

def _create_free_parameters(operation):
params = operation.params if hasattr(operation, "params") else []
for i, param in enumerate(params):
if isinstance(param, Parameter):
params[i] = FreeParameter(param.name)
return params


def convert_qiskit_to_braket_circuit(circuit: QuantumCircuit) -> Circuit:
Expand All @@ -414,11 +508,12 @@ def convert_qiskit_to_braket_circuit(circuit: QuantumCircuit) -> Circuit:


def convert_qiskit_to_braket_circuits(
circuits: List[QuantumCircuit],
circuits: list[QuantumCircuit],
) -> Iterable[Circuit]:
"""Converts all Qiskit circuits to Braket circuits.
Args:
circuits (List(QuantumCircuit)): Qiskit Quantum Cricuit
circuits (List(QuantumCircuit)): Qiskit quantum circuit
Returns:
Circuit (Iterable[Circuit]): Braket circuit
Expand All @@ -436,7 +531,7 @@ def convert_qiskit_to_braket_circuits(
def to_qiskit(circuit: Circuit) -> QuantumCircuit:
"""Return a Qiskit quantum circuit from a Braket quantum circuit.
Args:
circuit (Circuit): Braket Quantum Cricuit
circuit (Circuit): Braket quantum circuit
Returns:
QuantumCircuit: Qiskit quantum circuit
Expand Down Expand Up @@ -489,7 +584,7 @@ def _create_gate(
return gate_cls(*gate_params)


def wrap_circuits_in_verbatim_box(circuits: List[Circuit]) -> Iterable[Circuit]:
def wrap_circuits_in_verbatim_box(circuits: list[Circuit]) -> Iterable[Circuit]:
"""Convert each Braket circuit an equivalent one wrapped in verbatim box.
Args:
Expand Down
38 changes: 27 additions & 11 deletions qiskit_braket_provider/providers/braket_backend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
"""AWS Braket backends."""


import datetime
import logging
import enum
from abc import ABC
from typing import Iterable, Union, List
from collections.abc import Iterable
from typing import Union

from braket.aws import AwsDevice, AwsQuantumTaskBatch, AwsQuantumTask
from braket.aws.queue_information import QueueDepthInfo
from braket.circuits import Circuit
from braket.device_schema import DeviceActionType
from braket.devices import LocalSimulator
from braket.ir.openqasm.modifiers import Control
from braket.tasks.local_quantum_task import LocalQuantumTask
from qiskit import QuantumCircuit
from qiskit.providers import BackendV2, QubitProperties, Options, Provider

from .adapter import (
aws_device_to_target,
BRAKET_TO_QISKIT_NAMES,
local_simulator_to_target,
to_braket,
wrap_circuits_in_verbatim_box,
get_controlled_gateset,
)
from .braket_job import AmazonBraketTask
from .. import version
Expand Down Expand Up @@ -89,12 +93,12 @@ def dtm(self) -> float:
)

@property
def meas_map(self) -> List[List[int]]:
def meas_map(self) -> list[list[int]]:
raise NotImplementedError(f"Measurement map is not supported by {self.name}.")

def qubit_properties(
self, qubit: Union[int, List[int]]
) -> Union[QubitProperties, List[QubitProperties]]:
self, qubit: Union[int, list[int]]
) -> Union[QubitProperties, list[QubitProperties]]:
raise NotImplementedError

def drive_channel(self, qubit: int):
Expand All @@ -110,12 +114,24 @@ def control_channel(self, qubits: Iterable[int]):
raise NotImplementedError(f"Control channel is not supported by {self.name}.")

def run(
self, run_input: Union[QuantumCircuit, List[QuantumCircuit]], **options
self, run_input: Union[QuantumCircuit, list[QuantumCircuit]], **options
) -> AmazonBraketTask:
convert_input = (
[run_input] if isinstance(run_input, QuantumCircuit) else list(run_input)
)
circuits: List[Circuit] = [to_braket(input) for input in convert_input]
action = self._aws_device.properties.action[DeviceActionType.OPENQASM]
gateset = {
BRAKET_TO_QISKIT_NAMES[op.lower()]
for op in action.supportedOperations
if op.lower() in BRAKET_TO_QISKIT_NAMES
}
max_control = 0
for modifier in action.supportedModifiers:
if isinstance(modifier, Control):
max_control = modifier.max_qubits
break
gateset.update(get_controlled_gateset(max_control))
circuits: list[Circuit] = [to_braket(circ, gateset) for circ in convert_input]
shots = options["shots"] if "shots" in options else 1024
if shots == 0:
circuits = list(map(lambda x: x.state_vector(), circuits))
Expand Down Expand Up @@ -223,8 +239,8 @@ def _default_options(cls):
return Options()

def qubit_properties(
self, qubit: Union[int, List[int]]
) -> Union[QubitProperties, List[QubitProperties]]:
self, qubit: Union[int, list[int]]
) -> Union[QubitProperties, list[QubitProperties]]:
# TODO: fetch information from device.properties.provider # pylint: disable=fixme
raise NotImplementedError

Expand Down Expand Up @@ -268,7 +284,7 @@ def dtm(self) -> float:
)

@property
def meas_map(self) -> List[List[int]]:
def meas_map(self) -> list[list[int]]:
raise NotImplementedError(f"Measurement map is not supported by {self.name}.")

def drive_channel(self, qubit: int):
Expand Down Expand Up @@ -303,7 +319,7 @@ def run(self, run_input, **options):
batch_task: AwsQuantumTaskBatch = self._device.run_batch(
braket_circuits, **options
)
tasks: List[AwsQuantumTask] = batch_task.tasks
tasks: list[AwsQuantumTask] = batch_task.tasks
task_id = TASK_ID_DIVIDER.join(task.id for task in tasks)

return AmazonBraketTask(
Expand Down
Loading

0 comments on commit b6f3112

Please sign in to comment.