Skip to content

Commit

Permalink
[Contrib] Added default non-verbose to download_testdata(), pass to d…
Browse files Browse the repository at this point in the history
…ownload() (#8533)

* [Contrib] Added default non-verbose to download_testdata(), pass to download().

Minor cleanup as well, while in the file

- Using tempfile.TemporaryDirectory instead of explicit cleanup.

- Pass through verbose/retries arguments if replacing a corrupted
  copy.

* [Contrib] Switched download.py from print statements to logging

* [Contrib] Added shutil.copy2 fallback after downloading file.

Initial implementation using tempfile.TemporaryDirectory assumed that
the tempdir and output location were on the same drive, and could be
renamed.  This update falls back to copying from the temporary
directory, in case the tempdir is on a different drive.

Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
  • Loading branch information
Lunderberg and Lunderberg committed Jul 29, 2021
1 parent fcbd2b6 commit 6b2cbfe
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 53 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def download_package(tophub_location, package_name):

download_url = "{0}/{1}".format(tophub_location, package_name)
logger.info("Download pre-tuned parameters package from %s", download_url)
download(download_url, Path(rootpath, package_name), True, verbose=0)
download(download_url, Path(rootpath, package_name), overwrite=True)


# global cache for load_reference_log
Expand Down
119 changes: 67 additions & 52 deletions python/tvm/contrib/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""Helper utility for downloading"""

import logging
import os
from pathlib import Path
from os import environ
import sys
import time
import uuid
import shutil
import tempfile
import time

LOG = logging.getLogger("download")


def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=3):
def download(url, path, overwrite=False, size_compare=False, retries=3):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Expand All @@ -33,19 +36,18 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
Download url.
path : str
Local file path to save downloaded file
Local file path to save downloaded file.
overwrite : bool, optional
Whether to overwrite existing file
Whether to overwrite existing file, defaults to False.
size_compare : bool, optional
Whether to do size compare to check downloaded file.
verbose: int, optional
Verbose level
Whether to do size compare to check downloaded file, defaults
to False
retries: int, optional
Number of time to retry download, default at 3.
Number of time to retry download, defaults to 3.
"""
# pylint: disable=import-outside-toplevel
import urllib.request as urllib2
Expand All @@ -62,66 +64,73 @@ def download(url, path, overwrite=False, size_compare=False, verbose=1, retries=
res_get = urllib2.urlopen(url)
url_file_size = int(res_get.headers["Content-Length"])
if url_file_size != file_size:
print("exist file got corrupted, downloading %s file freshly..." % path)
download(url, path, True, False)
LOG.warning("Existing file %s has incorrect size, downloading fresh copy", path)
download(url, path, overwrite=True, size_compare=False, retries=retries)
return
print("File {} exists, skip.".format(path))

LOG.info("File %s exists, skipping.", path)
return

if verbose >= 1:
print("Downloading from url {} to {}".format(url, path))
LOG.info("Downloading from url %s to %s", url, path)

# Stateful start time
start_time = time.time()
dirpath = path.parent
dirpath.mkdir(parents=True, exist_ok=True)
random_uuid = str(uuid.uuid4())
tempfile = Path(dirpath, random_uuid)

def _download_progress(count, block_size, total_size):
# pylint: disable=unused-argument
"""Show the download progress."""
if count == 0:
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
progress_bytes = int(count * block_size)
progress_megabytes = progress_bytes / (1024.0 * 1024)
speed_kbps = int(progress_bytes / (1024 * duration))
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write(
"\r...%d%%, %.2f MB, %d KB/s, %d seconds passed"
% (percent, progress_size / (1024.0 * 1024), speed, duration)

# Temporarily suppress newlines on the output stream.
prev_terminator = logging.StreamHandler.terminator
logging.StreamHandler.terminator = ""
LOG.debug(
"\r...%d%%, %.2f MB, %d KB/s, %d seconds passed",
percent,
progress_megabytes,
speed_kbps,
duration,
)
sys.stdout.flush()

while retries >= 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if sys.version_info >= (3,):
urllib2.urlretrieve(url, tempfile, reporthook=_download_progress)
print("")
else:
f = urllib2.urlopen(url)
data = f.read()
with open(tempfile, "wb") as code:
code.write(data)
shutil.move(tempfile, path)
break
except Exception as err:
retries -= 1
if retries == 0:
if tempfile.exists():
tempfile.unlink()
raise err
print(
"download failed due to {}, retrying, {} attempt{} left".format(
repr(err), retries, "s" if retries > 1 else ""
logging.StreamHandler.terminator = prev_terminator

with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir)
download_loc = tempdir.joinpath(path.name)

for i_retry in range(retries):
# pylint: disable=broad-except
try:

urllib2.urlretrieve(url, download_loc, reporthook=_download_progress)
LOG.debug("")
try:
download_loc.rename(path)
except OSError:
# Prefer a move, but if the tempdir and final
# location are in different drives, fall back to a
# copy.
shutil.copy2(download_loc, path)
return

except Exception as err:
if i_retry == retries - 1:
raise err

LOG.warning(
"%s\nDownload attempt %d/%d failed, retrying.", repr(err), i_retry, retries
)
)


if "TEST_DATA_ROOT_PATH" in environ:
TEST_DATA_ROOT_PATH = Path(environ.get("TEST_DATA_ROOT_PATH"))
if "TEST_DATA_ROOT_PATH" in os.environ:
TEST_DATA_ROOT_PATH = Path(os.environ.get("TEST_DATA_ROOT_PATH"))
else:
TEST_DATA_ROOT_PATH = Path(Path("~").expanduser(), ".tvm_test_data")
TEST_DATA_ROOT_PATH.mkdir(parents=True, exist_ok=True)
Expand All @@ -141,10 +150,16 @@ def download_testdata(url, relpath, module=None, overwrite=False):
module : Union[str, list, tuple], optional
Subdirectory paths under test data folder.
overwrite : bool, defaults to False
If True, will download a fresh copy of the file regardless of
the cache. If False, will only download the file if a cached
version is missing.
Returns
-------
abspath : str
Absolute file path of downloaded file
"""
global TEST_DATA_ROOT_PATH
if module is None:
Expand Down

0 comments on commit 6b2cbfe

Please sign in to comment.