From f4fa0ca3aaa2f2c47ca0fd4235d578cc006c0982 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 30 Mar 2020 16:20:50 -0400 Subject: [PATCH] merging into common_utils. --- test/_test.py | 39 --------------------------------------- test/common_utils.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 39 deletions(-) delete mode 100644 test/_test.py diff --git a/test/_test.py b/test/_test.py deleted file mode 100644 index 9ccb41f341d..00000000000 --- a/test/_test.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -from contextlib import contextmanager - -import common_utils -import torchaudio - -BACKENDS = torchaudio._backend._audio_backends - - -@contextmanager -def AudioBackendScope(new_backend): - previous_backend = torchaudio.get_audio_backend() - try: - torchaudio.set_audio_backend(new_backend) - yield - finally: - torchaudio.set_audio_backend(previous_backend) - - -def get_backends_with_mp3(backends): - test_dirpath, _ = common_utils.create_temp_assets_dir() - test_filepath = os.path.join( - test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3" - ) - - backends_mp3 = [] - - for backend in backends: - try: - with AudioBackendScope(backend): - torchaudio.load(test_filepath) - backends_mp3.append(backend) - except RuntimeError: - pass - - return backends_mp3 - - -BACKENDS_MP3 = get_backends_with_mp3(BACKENDS) diff --git a/test/common_utils.py b/test/common_utils.py index a79f413d2bc..36715e38483 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -1,10 +1,13 @@ from __future__ import absolute_import, division, print_function, unicode_literals import os from shutil import copytree +from contextlib import contextmanager import backports.tempfile as tempfile import torch +import torchaudio TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +BACKENDS = torchaudio._backend._audio_backends def create_temp_assets_dir(): @@ -48,3 +51,35 @@ def random_int_tensor(seed, size, low=0, high=2 ** 32, a=22695477, c=1, m=2 ** 3 """ Same as random_float_tensor but integers between [low, high) """ return torch.floor(random_float_tensor(seed, size, a, c, m) * (high - low)) + low + + +@contextmanager +def AudioBackendScope(new_backend): + previous_backend = torchaudio.get_audio_backend() + try: + torchaudio.set_audio_backend(new_backend) + yield + finally: + torchaudio.set_audio_backend(previous_backend) + + +def filter_backends_with_mp3(backends): + test_dirpath, _ = create_temp_assets_dir() + test_filepath = os.path.join( + test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3" + ) + + backends_mp3 = [] + + for backend in backends: + try: + with AudioBackendScope(backend): + torchaudio.load(test_filepath) + backends_mp3.append(backend) + except RuntimeError: + pass + + return backends_mp3 + + +BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)