From 8868fecb5f9c335b669183c8bfeed5f487964771 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Tue, 9 Apr 2024 11:31:22 -0700 Subject: [PATCH] Rename AmazonBraketTask to BraketTask (#171) This makes all the provider class names consistent --- qiskit_braket_provider/__init__.py | 1 + qiskit_braket_provider/providers/__init__.py | 1 + .../providers/braket_backend.py | 12 +- .../providers/braket_job.py | 201 ++---------------- .../providers/braket_task.py | 196 +++++++++++++++++ ...test_braket_job.py => test_braket_task.py} | 59 +++-- 6 files changed, 267 insertions(+), 203 deletions(-) create mode 100644 qiskit_braket_provider/providers/braket_task.py rename tests/providers/{test_braket_job.py => test_braket_task.py} (78%) diff --git a/qiskit_braket_provider/__init__.py b/qiskit_braket_provider/__init__.py index 74bbf39..5222784 100644 --- a/qiskit_braket_provider/__init__.py +++ b/qiskit_braket_provider/__init__.py @@ -8,6 +8,7 @@ BraketAwsBackend, BraketLocalBackend, BraketProvider, + BraketTask, to_braket, to_qiskit, ) diff --git a/qiskit_braket_provider/providers/__init__.py b/qiskit_braket_provider/providers/__init__.py index 464af93..0097fb1 100644 --- a/qiskit_braket_provider/providers/__init__.py +++ b/qiskit_braket_provider/providers/__init__.py @@ -24,3 +24,4 @@ from .braket_backend import AWSBraketBackend, BraketAwsBackend, BraketLocalBackend from .braket_job import AmazonBraketTask, AWSBraketJob from .braket_provider import AWSBraketProvider, BraketProvider +from .braket_task import BraketTask diff --git a/qiskit_braket_provider/providers/braket_backend.py b/qiskit_braket_provider/providers/braket_backend.py index a4d8bc4..ab35a62 100644 --- a/qiskit_braket_provider/providers/braket_backend.py +++ b/qiskit_braket_provider/providers/braket_backend.py @@ -25,7 +25,7 @@ local_simulator_to_target, to_braket, ) -from .braket_job import AmazonBraketTask +from .braket_task import BraketTask logger = logging.getLogger(__name__) @@ -125,7 +125,7 @@ def control_channel(self, qubits: Iterable[int]): def run( self, run_input: Union[QuantumCircuit, list[QuantumCircuit]], **options - ) -> AmazonBraketTask: + ) -> BraketTask: convert_input = ( [run_input] if isinstance(run_input, QuantumCircuit) else list(run_input) ) @@ -160,7 +160,7 @@ def run( task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks) - return AmazonBraketTask( + return BraketTask( task_id=task_id, tasks=tasks, backend=self, @@ -220,7 +220,7 @@ def __init__( # pylint: disable=too-many-arguments ) self._target = aws_device_to_target(device=device) - def retrieve_job(self, task_id: str) -> AmazonBraketTask: + def retrieve_job(self, task_id: str) -> BraketTask: """Return a single job submitted to AWS backend. Args: @@ -231,7 +231,7 @@ def retrieve_job(self, task_id: str) -> AmazonBraketTask: """ task_ids = task_id.split(_TASK_ID_DIVIDER) - return AmazonBraketTask( + return BraketTask( task_id=task_id, backend=self, tasks=[AwsQuantumTask(arn=task_id) for task_id in task_ids], @@ -336,7 +336,7 @@ def run(self, run_input, **options): tasks: list[AwsQuantumTask] = batch_task.tasks task_id = _TASK_ID_DIVIDER.join(task.id for task in tasks) - return AmazonBraketTask( + return BraketTask( task_id=task_id, tasks=tasks, backend=self, shots=options.get("shots") ) diff --git a/qiskit_braket_provider/providers/braket_job.py b/qiskit_braket_provider/providers/braket_job.py index 1615c03..acf1280 100644 --- a/qiskit_braket_provider/providers/braket_job.py +++ b/qiskit_braket_provider/providers/braket_job.py @@ -1,84 +1,23 @@ -"""Amazon Braket task.""" +"""Deprecated Amazon Braket Qiskit Job classes""" -from datetime import datetime from typing import List, Optional, Union from warnings import warn -from braket.aws import AwsQuantumTask, AwsQuantumTaskBatch -from braket.aws.queue_information import QuantumTaskQueueInfo +from braket.aws import AwsQuantumTask from braket.tasks.local_quantum_task import LocalQuantumTask -from qiskit.providers import BackendV2, JobStatus, JobV1 -from qiskit.quantum_info import Statevector -from qiskit.result import Result -from qiskit.result.models import ExperimentResult, ExperimentResultData +from qiskit.providers import BackendV2 +from .braket_task import BraketTask -def retry_if_result_none(result): - """Retry on result function.""" - return result is None - -def _get_result_from_aws_tasks( - tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]], -) -> Optional[List[ExperimentResult]]: - """Returns experiment results of AWS tasks. - - Args: - tasks: AWS Quantum tasks - shots: number of shots - - Returns: - List of experiment results. - """ - experiment_results: List[ExperimentResult] = [] - - results = AwsQuantumTaskBatch._retrieve_results( - tasks, AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT - ) - - # For each task we create an ExperimentResult object with the downloaded results. - for task, result in zip(tasks, results): - if not result: - return None - - if result.task_metadata.shots == 0: - braket_statevector = result.values[ - result._result_types_indices[ - "{'type': }" - ] - ] - data = ExperimentResultData( - statevector=Statevector(braket_statevector).reverse_qargs().data, - ) - else: - counts = { - k[::-1]: v for k, v in dict(result.measurement_counts).items() - } # convert to little-endian - - data = ExperimentResultData( - counts=counts, - memory=[ - "".join(shot_result[::-1].astype(str)) - for shot_result in result.measurements - ], - ) - - experiment_result = ExperimentResult( - shots=result.task_metadata.shots, - success=True, - status=task.state() - if isinstance(task, LocalQuantumTask) - else result.task_metadata.status, - data=data, - ) - experiment_results.append(experiment_result) - - return experiment_results - - -class AmazonBraketTask(JobV1): +class AmazonBraketTask(BraketTask): """AmazonBraketTask.""" + def __init_subclass__(cls, **kwargs): + """This throws a deprecation warning on subclassing.""" + warn(f"{cls.__name__} is deprecated.", DeprecationWarning, stacklevel=2) + super().__init_subclass__(**kwargs) + def __init__( self, task_id: str, @@ -86,118 +25,16 @@ def __init__( tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]], **metadata: Optional[dict], ): - """AmazonBraketTask for local execution of circuits. - - Args: - task_id: id of the task - backend: Local simulator - tasks: Executed tasks - **metadata: - """ - super().__init__(backend=backend, job_id=task_id, metadata=metadata) - self._task_id = task_id - self._backend = backend - self._metadata = metadata - self._tasks = tasks - self._date_of_creation = datetime.now() - - @property - def shots(self) -> int: - """Return the number of shots. - - Returns: - shots: int with the number of shots. - """ - return ( - self.metadata["metadata"]["shots"] - if "shots" in self.metadata["metadata"] - else 0 - ) - - def submit(self): - return - - def queue_position(self) -> QuantumTaskQueueInfo: - """ - The queue position details for the quantum job. - - Returns: - QuantumTaskQueueInfo: Instance of QuantumTaskQueueInfo class - representing the queue position information for the quantum job. - The queue_position is only returned when quantum job is not in - RUNNING/CANCELLING/TERMINAL states, else queue_position is returned as None. - The normal tasks refers to the quantum jobs not submitted via Hybrid Jobs. - Whereas, the priority tasks refers to the total number of quantum jobs waiting to run - submitted through Amazon Braket Hybrid Jobs. These tasks run before the normal tasks. - If the queue position for normal or priority quantum tasks is greater than 2000, - we display their respective queue position as '>2000'. - - Note: We don't provide queue information for the LocalQuantumTasks. - - Examples: - job status = QUEUED and queue position is 2050 - >>> task.queue_position() - QuantumTaskQueueInfo(queue_type=, - queue_position='>2000', message=None) - - job status = COMPLETED - >>> task.queue_position() - QuantumTaskQueueInfo(queue_type=, - queue_position=None, message='Task is in COMPLETED status. AmazonBraket does - not show queue position for this status.') - """ - for task in self._tasks: - if isinstance(task, LocalQuantumTask): - raise NotImplementedError( - "We don't provide queue information for the LocalQuantumTask." - ) - return AwsQuantumTask(self.task_id()).queue_position() - - def task_id(self) -> str: - """Return a unique id identifying the task.""" - return self._task_id - - def result(self) -> Result: - experiment_results = _get_result_from_aws_tasks(tasks=self._tasks) - status = self.status(use_cached_value=True) - - return Result( - backend_name=self._backend, - backend_version=self._backend.version, - job_id=self._task_id, - qobj_id=0, - success=status not in AwsQuantumTask.NO_RESULT_TERMINAL_STATES, - results=experiment_results, - status=status, + """This throws a deprecation warning on initialization.""" + warn( + f"{self.__class__.__name__} is deprecated.", + DeprecationWarning, + stacklevel=2, ) - - def cancel(self): - for task in self._tasks: - task.cancel() - - def status(self, use_cached_value: bool = False): - braket_tasks_states = [ - task.state() - if isinstance(task, LocalQuantumTask) - else task.state(use_cached_value=use_cached_value) - for task in self._tasks - ] - - if "FAILED" in braket_tasks_states: - status = JobStatus.ERROR - elif "CANCELLED" in braket_tasks_states: - status = JobStatus.CANCELLED - elif all(state == "COMPLETED" for state in braket_tasks_states): - status = JobStatus.DONE - elif all(state == "RUNNING" for state in braket_tasks_states): - status = JobStatus.RUNNING - else: - status = JobStatus.QUEUED - - return status + super().__init__(task_id=task_id, backend=backend, tasks=tasks, **metadata) -class AWSBraketJob(AmazonBraketTask): +class AWSBraketJob(BraketTask): """AWSBraketJob.""" def __init_subclass__(cls, **kwargs): @@ -220,7 +57,3 @@ def __init__( ) super().__init__(task_id=job_id, backend=backend, tasks=tasks, **metadata) self._job_id = job_id - self._backend = backend - self._metadata = metadata - self._tasks = tasks - self._date_of_creation = datetime.now() diff --git a/qiskit_braket_provider/providers/braket_task.py b/qiskit_braket_provider/providers/braket_task.py new file mode 100644 index 0000000..4e9dd95 --- /dev/null +++ b/qiskit_braket_provider/providers/braket_task.py @@ -0,0 +1,196 @@ +"""Amazon Braket task.""" + +from datetime import datetime +from typing import List, Optional, Union + +from braket.aws import AwsQuantumTask, AwsQuantumTaskBatch +from braket.aws.queue_information import QuantumTaskQueueInfo +from braket.tasks.local_quantum_task import LocalQuantumTask +from qiskit.providers import BackendV2, JobStatus, JobV1 +from qiskit.quantum_info import Statevector +from qiskit.result import Result +from qiskit.result.models import ExperimentResult, ExperimentResultData + + +def retry_if_result_none(result): + """Retry on result function.""" + return result is None + + +def _get_result_from_aws_tasks( + tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]], +) -> Optional[List[ExperimentResult]]: + """Returns experiment results of AWS tasks. + + Args: + tasks: AWS Quantum tasks + shots: number of shots + + Returns: + List of experiment results. + """ + experiment_results: List[ExperimentResult] = [] + + results = AwsQuantumTaskBatch._retrieve_results( + tasks, AwsQuantumTaskBatch.MAX_CONNECTIONS_DEFAULT + ) + + # For each task we create an ExperimentResult object with the downloaded results. + for task, result in zip(tasks, results): + if not result: + return None + + if result.task_metadata.shots == 0: + braket_statevector = result.values[ + result._result_types_indices[ + "{'type': }" + ] + ] + data = ExperimentResultData( + statevector=Statevector(braket_statevector).reverse_qargs().data, + ) + else: + counts = { + k[::-1]: v for k, v in dict(result.measurement_counts).items() + } # convert to little-endian + + data = ExperimentResultData( + counts=counts, + memory=[ + "".join(shot_result[::-1].astype(str)) + for shot_result in result.measurements + ], + ) + + experiment_result = ExperimentResult( + shots=result.task_metadata.shots, + success=True, + status=task.state() + if isinstance(task, LocalQuantumTask) + else result.task_metadata.status, + data=data, + ) + experiment_results.append(experiment_result) + + return experiment_results + + +class BraketTask(JobV1): + """BraketTask.""" + + def __init__( + self, + task_id: str, + backend: BackendV2, + tasks: Union[List[LocalQuantumTask], List[AwsQuantumTask]], + **metadata: Optional[dict], + ): + """BraketTask for execution of circuits on Amazon Braket or locally. + + Args: + task_id: Semicolon-separated IDs of the underlying tasks + backend: BraketBackend that ran the circuit + tasks: Executed tasks + **metadata: Additional metadata + """ + super().__init__(backend=backend, job_id=task_id, metadata=metadata) + self._task_id = task_id + self._backend = backend + self._metadata = metadata + self._tasks = tasks + self._date_of_creation = datetime.now() + + @property + def shots(self) -> int: + """Return the number of shots. + + Returns: + shots: int with the number of shots. + """ + return ( + self.metadata["metadata"]["shots"] + if "shots" in self.metadata["metadata"] + else 0 + ) + + def submit(self): + return + + def queue_position(self) -> QuantumTaskQueueInfo: + """ + The queue position details for the quantum job. + + Returns: + QuantumTaskQueueInfo: Instance of QuantumTaskQueueInfo class + representing the queue position information for the quantum job. + The queue_position is only returned when quantum job is not in + RUNNING/CANCELLING/TERMINAL states, else queue_position is returned as None. + The normal tasks refers to the quantum jobs not submitted via Hybrid Jobs. + Whereas, the priority tasks refers to the total number of quantum jobs waiting to run + submitted through Amazon Braket Hybrid Jobs. These tasks run before the normal tasks. + If the queue position for normal or priority quantum tasks is greater than 2000, + we display their respective queue position as '>2000'. + + Note: We don't provide queue information for the LocalQuantumTasks. + + Examples: + job status = QUEUED and queue position is 2050 + >>> task.queue_position() + QuantumTaskQueueInfo(queue_type=, + queue_position='>2000', message=None) + + job status = COMPLETED + >>> task.queue_position() + QuantumTaskQueueInfo(queue_type=, + queue_position=None, message='Task is in COMPLETED status. AmazonBraket does + not show queue position for this status.') + """ + for task in self._tasks: + if isinstance(task, LocalQuantumTask): + raise NotImplementedError( + "We don't provide queue information for the LocalQuantumTask." + ) + return AwsQuantumTask(self.task_id()).queue_position() + + def task_id(self) -> str: + """Return a unique id identifying the task.""" + return self._task_id + + def result(self) -> Result: + experiment_results = _get_result_from_aws_tasks(tasks=self._tasks) + status = self.status(use_cached_value=True) + + return Result( + backend_name=self._backend, + backend_version=self._backend.version, + job_id=self._task_id, + qobj_id=0, + success=status not in AwsQuantumTask.NO_RESULT_TERMINAL_STATES, + results=experiment_results, + status=status, + ) + + def cancel(self): + for task in self._tasks: + task.cancel() + + def status(self, use_cached_value: bool = False): + braket_tasks_states = [ + task.state() + if isinstance(task, LocalQuantumTask) + else task.state(use_cached_value=use_cached_value) + for task in self._tasks + ] + + if "FAILED" in braket_tasks_states: + status = JobStatus.ERROR + elif "CANCELLED" in braket_tasks_states: + status = JobStatus.CANCELLED + elif all(state == "COMPLETED" for state in braket_tasks_states): + status = JobStatus.DONE + elif all(state == "RUNNING" for state in braket_tasks_states): + status = JobStatus.RUNNING + else: + status = JobStatus.QUEUED + + return status diff --git a/tests/providers/test_braket_job.py b/tests/providers/test_braket_task.py similarity index 78% rename from tests/providers/test_braket_job.py rename to tests/providers/test_braket_task.py index e01be63..ae87d4e 100644 --- a/tests/providers/test_braket_job.py +++ b/tests/providers/test_braket_task.py @@ -11,10 +11,55 @@ AmazonBraketTask, AWSBraketJob, BraketLocalBackend, + BraketTask, ) from tests.providers.mocks import MOCK_LOCAL_QUANTUM_TASK +class TestBraketTask(TestCase): + """Tests BraketTask.""" + + def _get_task(self): + return BraketTask( + backend=BraketLocalBackend(name="default"), + task_id="AwesomeId", + tasks=[MOCK_LOCAL_QUANTUM_TASK], + shots=10, + ) + + def test_task(self): + """Tests task.""" + task = self._get_task() + + self.assertTrue(isinstance(task, BraketTask)) + self.assertEqual(task.shots, 10) + + self.assertEqual(task.status(), JobStatus.DONE) + + def test_result(self): + """Tests result.""" + task = self._get_task() + + self.assertEqual(task.result().job_id, "AwesomeId") + self.assertEqual(task.result().results[0].data.counts, {"01": 1, "10": 2}) + self.assertEqual(task.result().results[0].data.memory, ["10", "10", "01"]) + self.assertEqual(task.result().results[0].status, "COMPLETED") + self.assertEqual(task.result().results[0].shots, 3) + self.assertEqual(task.result().get_memory(), ["10", "10", "01"]) + + def test_queue_position_for_local_quantum_task(self): + """Tests job status when multiple task status is present.""" + task = BraketTask( + backend=BraketLocalBackend(name="default"), + task_id="MockId", + tasks=[MOCK_LOCAL_QUANTUM_TASK], + shots=100, + ) + message = "We don't provide queue information for the LocalQuantumTask." + with pytest.raises(NotImplementedError, match=message): + task.queue_position() + + class TestAmazonBraketTask(TestCase): """Tests AmazonBraketTask.""" @@ -79,18 +124,6 @@ def test_AWS_result(self): self.assertEqual(job.result().results[0].shots, 3) self.assertEqual(job.result().get_memory(), ["10", "10", "01"]) - def test_queue_position_for_local_quantum_task(self): - """Tests job status when multiple task status is present.""" - job = AWSBraketJob( - backend=BraketLocalBackend(name="default"), - job_id="MockId", - tasks=[MOCK_LOCAL_QUANTUM_TASK], - shots=100, - ) - message = "We don't provide queue information for the LocalQuantumTask." - with pytest.raises(NotImplementedError, match=message): - job.queue_position() - class TestBraketJobStatus: """Tests for Amazon Braket job status.""" @@ -123,7 +156,7 @@ def test_status(self, task_states, expected_status): tasks=[MOCK_LOCAL_QUANTUM_TASK], shots=100, ) - job._tasks = Mock(spec=AmazonBraketTask) + job._tasks = Mock(spec=BraketTask) job._tasks = [self._get_mock_aws_quantum_task(state) for state in task_states] assert job.status() == expected_status