Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename AmazonBraketTask to BraketTask #171

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qiskit_braket_provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BraketAwsBackend,
BraketLocalBackend,
BraketProvider,
BraketTask,
to_braket,
to_qiskit,
)
1 change: 1 addition & 0 deletions qiskit_braket_provider/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .braket_backend import AWSBraketBackend, BraketAwsBackend, BraketLocalBackend
from .braket_job import AmazonBraketTask, AWSBraketJob
from .braket_provider import AWSBraketProvider, BraketProvider
from .braket_task import BraketTask
12 changes: 6 additions & 6 deletions qiskit_braket_provider/providers/braket_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
local_simulator_to_target,
to_braket,
)
from .braket_job import AmazonBraketTask
from .braket_task import BraketTask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,7 +125,7 @@ def control_channel(self, qubits: Iterable[int]):

def run(
self, run_input: Union[QuantumCircuit, list[QuantumCircuit]], **options
) -> AmazonBraketTask:
) -> BraketTask:
convert_input = (
[run_input] if isinstance(run_input, QuantumCircuit) else list(run_input)
)
Expand Down Expand Up @@ -160,7 +160,7 @@ def run(

task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks)

return AmazonBraketTask(
return BraketTask(
task_id=task_id,
tasks=tasks,
backend=self,
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__( # pylint: disable=too-many-arguments
)
self._target = aws_device_to_target(device=device)

def retrieve_job(self, task_id: str) -> AmazonBraketTask:
def retrieve_job(self, task_id: str) -> BraketTask:
"""Return a single job submitted to AWS backend.
Args:
Expand All @@ -231,7 +231,7 @@ def retrieve_job(self, task_id: str) -> AmazonBraketTask:
"""
task_ids = task_id.split(_TASK_ID_DIVIDER)

return AmazonBraketTask(
return BraketTask(
task_id=task_id,
backend=self,
tasks=[AwsQuantumTask(arn=task_id) for task_id in task_ids],
Expand Down Expand Up @@ -336,7 +336,7 @@ def run(self, run_input, **options):
tasks: list[AwsQuantumTask] = batch_task.tasks
task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks)

return AmazonBraketTask(
return BraketTask(
task_id=task_id, tasks=tasks, backend=self, shots=options.get("shots")
)

Expand Down
201 changes: 17 additions & 184 deletions qiskit_braket_provider/providers/braket_job.py
Original file line number Diff line number Diff line change
@@ -1,203 +1,40 @@
"""Amazon Braket task."""
"""Deprecated Amazon Braket Qiskit Job classes"""

from datetime import datetime
from typing import List, Optional, Union
from warnings import warn

from braket.aws import AwsQuantumTask, AwsQuantumTaskBatch
from braket.aws.queue_information import QuantumTaskQueueInfo
from braket.aws import AwsQuantumTask
from braket.tasks.local_quantum_task import LocalQuantumTask
from qiskit.providers import BackendV2, JobStatus, JobV1
from qiskit.quantum_info import Statevector
from qiskit.result import Result
from qiskit.result.models import ExperimentResult, ExperimentResultData
from qiskit.providers import BackendV2

from .braket_task import BraketTask

def retry_if_result_none(result):
"""Retry on result function."""
return result is None


def _get_result_from_aws_tasks(
tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]],
) -> Optional[List[ExperimentResult]]:
"""Returns experiment results of AWS tasks.
Args:
tasks: AWS Quantum tasks
shots: number of shots
Returns:
List of experiment results.
"""
experiment_results: List[ExperimentResult] = []

results = AwsQuantumTaskBatch._retrieve_results(
tasks, AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT
)

# For each task we create an ExperimentResult object with the downloaded results.
for task, result in zip(tasks, results):
if not result:
return None

if result.task_metadata.shots == 0:
braket_statevector = result.values[
result._result_types_indices[
"{'type': <Type.statevector: 'statevector'>}"
]
]
data = ExperimentResultData(
statevector=Statevector(braket_statevector).reverse_qargs().data,
)
else:
counts = {
k[::-1]: v for k, v in dict(result.measurement_counts).items()
} # convert to little-endian

data = ExperimentResultData(
counts=counts,
memory=[
"".join(shot_result[::-1].astype(str))
for shot_result in result.measurements
],
)

experiment_result = ExperimentResult(
shots=result.task_metadata.shots,
success=True,
status=task.state()
if isinstance(task, LocalQuantumTask)
else result.task_metadata.status,
data=data,
)
experiment_results.append(experiment_result)

return experiment_results


class AmazonBraketTask(JobV1):
class AmazonBraketTask(BraketTask):
"""AmazonBraketTask."""

def __init_subclass__(cls, **kwargs):
"""This throws a deprecation warning on subclassing."""
warn(f"{cls.__name__} is deprecated.", DeprecationWarning, stacklevel=2)
super().__init_subclass__(**kwargs)

def __init__(
self,
task_id: str,
backend: BackendV2,
tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]],
**metadata: Optional[dict],
):
"""AmazonBraketTask for local execution of circuits.
Args:
task_id: id of the task
backend: Local simulator
tasks: Executed tasks
**metadata:
"""
super().__init__(backend=backend, job_id=task_id, metadata=metadata)
self._task_id = task_id
self._backend = backend
self._metadata = metadata
self._tasks = tasks
self._date_of_creation = datetime.now()

@property
def shots(self) -> int:
"""Return the number of shots.
Returns:
shots: int with the number of shots.
"""
return (
self.metadata["metadata"]["shots"]
if "shots" in self.metadata["metadata"]
else 0
)

def submit(self):
return

def queue_position(self) -> QuantumTaskQueueInfo:
"""
The queue position details for the quantum job.
Returns:
QuantumTaskQueueInfo: Instance of QuantumTaskQueueInfo class
representing the queue position information for the quantum job.
The queue_position is only returned when quantum job is not in
RUNNING/CANCELLING/TERMINAL states, else queue_position is returned as None.
The normal tasks refers to the quantum jobs not submitted via Hybrid Jobs.
Whereas, the priority tasks refers to the total number of quantum jobs waiting to run
submitted through Amazon Braket Hybrid Jobs. These tasks run before the normal tasks.
If the queue position for normal or priority quantum tasks is greater than 2000,
we display their respective queue position as '>2000'.
Note: We don't provide queue information for the LocalQuantumTasks.
Examples:
job status = QUEUED and queue position is 2050
>>> task.queue_position()
QuantumTaskQueueInfo(queue_type=<QueueType.NORMAL: 'Normal'>,
queue_position='>2000', message=None)
job status = COMPLETED
>>> task.queue_position()
QuantumTaskQueueInfo(queue_type=<QueueType.NORMAL: 'Normal'>,
queue_position=None, message='Task is in COMPLETED status. AmazonBraket does
not show queue position for this status.')
"""
for task in self._tasks:
if isinstance(task, LocalQuantumTask):
raise NotImplementedError(
"We don't provide queue information for the LocalQuantumTask."
)
return AwsQuantumTask(self.task_id()).queue_position()

def task_id(self) -> str:
"""Return a unique id identifying the task."""
return self._task_id

def result(self) -> Result:
experiment_results = _get_result_from_aws_tasks(tasks=self._tasks)
status = self.status(use_cached_value=True)

return Result(
backend_name=self._backend,
backend_version=self._backend.version,
job_id=self._task_id,
qobj_id=0,
success=status not in AwsQuantumTask.NO_RESULT_TERMINAL_STATES,
results=experiment_results,
status=status,
"""This throws a deprecation warning on initialization."""
warn(
f"{self.__class__.__name__} is deprecated.",
DeprecationWarning,
stacklevel=2,
)

def cancel(self):
for task in self._tasks:
task.cancel()

def status(self, use_cached_value: bool = False):
braket_tasks_states = [
task.state()
if isinstance(task, LocalQuantumTask)
else task.state(use_cached_value=use_cached_value)
for task in self._tasks
]

if "FAILED" in braket_tasks_states:
status = JobStatus.ERROR
elif "CANCELLED" in braket_tasks_states:
status = JobStatus.CANCELLED
elif all(state == "COMPLETED" for state in braket_tasks_states):
status = JobStatus.DONE
elif all(state == "RUNNING" for state in braket_tasks_states):
status = JobStatus.RUNNING
else:
status = JobStatus.QUEUED

return status
super().__init__(task_id=task_id, backend=backend, tasks=tasks, **metadata)


class AWSBraketJob(AmazonBraketTask):
class AWSBraketJob(BraketTask):
"""AWSBraketJob."""

def __init_subclass__(cls, **kwargs):
Expand All @@ -220,7 +57,3 @@ def __init__(
)
super().__init__(task_id=job_id, backend=backend, tasks=tasks, **metadata)
self._job_id = job_id
self._backend = backend
self._metadata = metadata
self._tasks = tasks
self._date_of_creation = datetime.now()
Loading
Loading