diff --git a/storage/google/cloud/storage/client.py b/storage/google/cloud/storage/client.py index 8541cc51469c..11179bd25ce9 100644 --- a/storage/google/cloud/storage/client.py +++ b/storage/google/cloud/storage/client.py @@ -14,6 +14,7 @@ """Client for interacting with the Google Cloud Storage API.""" +from six.moves.urllib.parse import urlsplit from google.auth.credentials import AnonymousCredentials @@ -24,6 +25,7 @@ from google.cloud.storage._http import Connection from google.cloud.storage.batch import Batch from google.cloud.storage.bucket import Bucket +from google.cloud.storage.blob import Blob _marker = object() @@ -341,6 +343,57 @@ def create_bucket(self, bucket_or_name, requester_pays=None, project=None): bucket.create(client=self, project=project) return bucket + def download_blob_to_file(self, blob_or_uri, file_obj, start=None, end=None): + """Download the contents of a blob object or blob URI into a file-like object. + + Args: + blob_or_uri (Union[ \ + :class:`~google.cloud.storage.blob.Blob`, \ + str, \ + ]): + The blob resource to pass or URI to download. + file_obj (file): + A file handle to which to write the blob's data. + start (int): + Optional. The first byte in a range to be downloaded. + end (int): + Optional. The last byte in a range to be downloaded. + + Examples: + Download a blob using using a blob resource. + + >>> from google.cloud import storage + >>> client = storage.Client() + + >>> bucket = client.get_bucket('my-bucket-name') + >>> blob = storage.Blob('path/to/blob', bucket) + + >>> with open('file-to-download-to') as file_obj: + >>> client.download_blob_to_file(blob, file) # API request. + + + Download a blob using a URI. + + >>> from google.cloud import storage + >>> client = storage.Client() + + >>> with open('file-to-download-to') as file_obj: + >>> client.download_blob_to_file( + >>> 'gs://bucket_name/path/to/blob', file) + + + """ + try: + blob_or_uri.download_to_file(file_obj, client=self, start=start, end=end) + except AttributeError: + scheme, netloc, path, query, frag = urlsplit(blob_or_uri) + if scheme != "gs": + raise ValueError("URI scheme must be gs") + bucket = Bucket(self, name=netloc) + blob_or_uri = Blob(path, bucket) + + blob_or_uri.download_to_file(file_obj, client=self, start=start, end=end) + def list_buckets( self, max_results=None, diff --git a/storage/tests/unit/test_client.py b/storage/tests/unit/test_client.py index 83daad9eff38..e874b44e1241 100644 --- a/storage/tests/unit/test_client.py +++ b/storage/tests/unit/test_client.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import json import unittest import mock +import pytest import requests from six.moves import http_client @@ -598,6 +600,45 @@ def test_create_bucket_with_object_success(self): json_sent = http.request.call_args_list[0][1]["data"] self.assertEqual(json_expected, json.loads(json_sent)) + def test_download_blob_to_file_with_blob(self): + project = "PROJECT" + credentials = _make_credentials() + client = self._make_one(project=project, credentials=credentials) + blob = mock.Mock() + file_obj = io.BytesIO() + + client.download_blob_to_file(blob, file_obj) + blob.download_to_file.assert_called_once_with( + file_obj, client=client, start=None, end=None + ) + + def test_download_blob_to_file_with_uri(self): + project = "PROJECT" + credentials = _make_credentials() + client = self._make_one(project=project, credentials=credentials) + blob = mock.Mock() + file_obj = io.BytesIO() + + with mock.patch("google.cloud.storage.client.Blob", return_value=blob): + client.download_blob_to_file("gs://bucket_name/path/to/object", file_obj) + + blob.download_to_file.assert_called_once_with( + file_obj, client=client, start=None, end=None + ) + + def test_download_blob_to_file_with_invalid_uri(self): + project = "PROJECT" + credentials = _make_credentials() + client = self._make_one(project=project, credentials=credentials) + blob = mock.Mock() + file_obj = io.BytesIO() + + with mock.patch("google.cloud.storage.client.Blob", return_value=blob): + with pytest.raises(ValueError, match="URI scheme must be gs"): + client.download_blob_to_file( + "http://bucket_name/path/to/object", file_obj + ) + def test_list_buckets_wo_project(self): CREDENTIALS = _make_credentials() client = self._make_one(project=None, credentials=CREDENTIALS)