Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitaryHack] Mock provider, backend #107

Merged
merged 11 commits into from
Jun 29, 2023
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,QuantumCircuit
ignored-classes=optparse.Values,thread._local,_thread._local,QuantumCircuit,Circuit

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
"braket.ir.jaqcd.program": {
"actionType": "braket.ir.jaqcd.program",
"version": ["1"],
"supportedOperations": ["H"],
"supportedOperations": ["H", "CNOT"],
}
},
"paradigm": {"qubitCount": 30},
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/test_braket_job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Tests for AWS Braket job."""

from unittest import TestCase
from unittest.mock import Mock

import pytest
from braket.aws.aws_quantum_task import AwsQuantumTask
from qiskit.providers import JobStatus

from qiskit_braket_provider.providers import (
Expand Down Expand Up @@ -75,3 +78,40 @@ def test_AWS_result(self):
self.assertEqual(job.result().results[0].status, "COMPLETED")
self.assertEqual(job.result().results[0].shots, 3)
self.assertEqual(job.result().get_memory(), ["10", "10", "01"])


class TestBracketJobStatus:
"""Tests for AWS Braket job status."""

def _get_mock_aws_quantum_task(self, status: str) -> AwsQuantumTask:
"""
Creates a mock AwsQuantumTask with the given status.
Status can be one of "CREATED", "QUEUED", "RUNNING", "COMPLETED",
"FAILED", "CANCELLING", "CANCELLED"
"""
task = Mock(spec=AwsQuantumTask)
task.state.return_value = status
return task

@pytest.mark.parametrize(
"task_states, expected_status",
[
(["COMPLETED", "FAILED"], JobStatus.ERROR),
(["COMPLETED", "CANCELLED"], JobStatus.CANCELLED),
(["COMPLETED", "COMPLETED"], JobStatus.DONE),
(["RUNNING", "RUNNING"], JobStatus.RUNNING),
(["QUEUED", "QUEUED"], JobStatus.QUEUED),
],
)
def test_status(self, task_states, expected_status):
"""Tests job status when multiple task status is present."""
job = AWSBraketJob(
backend=BraketLocalBackend(name="default"),
job_id="MockId",
tasks=[MOCK_LOCAL_QUANTUM_TASK],
shots=100,
)
job._tasks = Mock(spec=AmazonBraketTask)
job._tasks = [self._get_mock_aws_quantum_task(state) for state in task_states]

assert job.status() == expected_status
101 changes: 72 additions & 29 deletions tests/providers/test_braket_provider.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Tests for AWS Braket provider."""
import unittest
from unittest import TestCase
from unittest.mock import Mock, patch
import uuid

from braket.aws import AwsDeviceType
from braket.circuits import Circuit
from braket.aws import AwsSession, AwsQuantumTaskBatch
from braket.aws import AwsDevice, AwsDeviceType
from qiskit import circuit as qiskit_circuit
from qiskit.circuit.random import random_circuit
from qiskit.compiler import transpile

from qiskit_braket_provider.providers import AWSBraketProvider
from qiskit_braket_provider.providers.braket_backend import (
BraketBackend,
Expand All @@ -24,50 +26,91 @@
class TestAWSBraketProvider(TestCase):
"""Tests AWSBraketProvider."""

def test_provider_backends(self):
"""Tests provider."""
mock_session = Mock()
def setUp(self):
self.mock_session = Mock()
simulators = [MOCK_GATE_MODEL_SIMULATOR_SV, MOCK_GATE_MODEL_SIMULATOR_TN]
mock_session.get_device.side_effect = simulators
mock_session.region = SIMULATOR_REGION
mock_session.boto_session.region_name = SIMULATOR_REGION
mock_session.search_devices.return_value = simulators
self.mock_session.get_device.side_effect = simulators
self.mock_session.region = SIMULATOR_REGION
self.mock_session.boto_session.region_name = SIMULATOR_REGION
self.mock_session.search_devices.return_value = simulators

def test_provider_backends(self):
"""Tests provider."""
provider = AWSBraketProvider()
backends = provider.backends(
aws_session=mock_session, types=[AwsDeviceType.SIMULATOR]
aws_session=self.mock_session, types=[AwsDeviceType.SIMULATOR]
)

self.assertTrue(len(backends) > 0)
for backend in backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, BraketBackend)

@unittest.skip("Call to external service")
def test_real_devices(self):
"""Tests real devices."""
provider = AWSBraketProvider()
backends = provider.backends()
self.assertTrue(len(backends) > 0)
for backend in backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)
with patch(
"qiskit_braket_provider.providers.braket_provider.AwsDevice"
) as mock_get_devices:
mock_get_devices.get_devices.return_value = [
AwsDevice(MOCK_GATE_MODEL_SIMULATOR_SV["deviceArn"], self.mock_session),
AwsDevice(MOCK_GATE_MODEL_SIMULATOR_TN["deviceArn"], self.mock_session),
]
provider = AWSBraketProvider()
backends = provider.backends()
self.assertTrue(len(backends) > 0)
for backend in backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)

online_simulators_backends = provider.backends(
statuses=["ONLINE"], types=["SIMULATOR"]
online_simulators_backends = provider.backends(
statuses=["ONLINE"], types=["SIMULATOR"]
)
for backend in online_simulators_backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)

@patch("qiskit_braket_provider.providers.braket_backend.AWSBraketBackend")
@patch("qiskit_braket_provider.providers.braket_backend.AwsDevice.get_devices")
def test_qiskit_circuit_transpilation_run(
self, mock_get_devices, mock_aws_braket_backend
):
"""Tests qiskit circuit transpilation."""
mock_get_devices.return_value = [
AwsDevice(MOCK_GATE_MODEL_SIMULATOR_SV["deviceArn"], self.mock_session)
]
s3_target = AwsSession.S3DestinationFolder("mock_bucket", "mock_key")
q_circuit = qiskit_circuit.QuantumCircuit(2)
q_circuit.h(0)
q_circuit.cx(0, 1)
braket_circuit = Circuit().h(0).cnot(0, 1)

mock_aws_braket_backend = Mock(spec=AWSBraketBackend)
mock_aws_braket_backend._device = Mock(spec=AwsDevice)
task = AwsQuantumTaskBatch(
Mock(),
MOCK_GATE_MODEL_SIMULATOR_SV["deviceArn"],
braket_circuit,
s3_target,
1000,
max_parallel=10,
)
for backend in online_simulators_backends:
with self.subTest(f"{backend.name}"):
self.assertIsInstance(backend, AWSBraketBackend)
task_mock = Mock()
task_mock.id = str(uuid.uuid4())
task_mock.state.return_value = "RUNNING"
task = Mock(spec=AwsQuantumTaskBatch, return_value=task)
task.tasks = [task_mock]

@unittest.skip("Call to external service")
def test_real_device_circuit_execution(self):
"""Tests circuit execution on real device."""
provider = AWSBraketProvider()
state_vector_backend = provider.get_backend("SV1")
circuit = random_circuit(3, 5, seed=42)
state_vector_backend = provider.get_backend(
"SV1", aws_session=self.mock_session
)

transpiled_circuit = transpile(
circuit, backend=state_vector_backend, seed_transpiler=42
q_circuit, backend=state_vector_backend, seed_transpiler=42
)

state_vector_backend._device.run_batch = Mock(
spec=AwsQuantumTaskBatch, return_value=task
)
result = state_vector_backend.run(transpiled_circuit, shots=10)
self.assertTrue(result)
Expand Down
Loading