Skip to content

Commit

Permalink
grouping backend functions into a separate file.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentqb committed Mar 26, 2020
1 parent a8a6895 commit 470e655
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 44 deletions.
44 changes: 44 additions & 0 deletions test/_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import unittest
from contextlib import contextmanager

import common_utils
import torchaudio

BACKENDS = torchaudio._backend._audio_backends


@contextmanager
def AudioBackendScope(new_backend):
previous_backend = torchaudio.get_audio_backend()
try:
# unittest.skipIf(not new_backend, "No backend supporting this test.")
# unittest.skipIf(new_backend not in BACKENDS, new_backend + " is not available.")
torchaudio.set_audio_backend(new_backend)
yield
finally:
torchaudio.set_audio_backend(previous_backend)


def get_backends_with_mp3(backends):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(
test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3"
)

backends_mp3 = []

for b in backends:
torchaudio.load(test_filepath)
try:
with AudioBackendScope(b):
waveform, sample_rate = torchaudio.load(test_filepath)
backends_mp3.append(b)
except:
pass

return backends_mp3


BACKENDS_MP3 = get_backends_with_mp3(BACKENDS)
FIRST_BACKEND_MP3 = BACKENDS_MP3[0] if BACKENDS_MP3 else None
19 changes: 1 addition & 18 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,7 @@
import torchaudio
import math
import os


BACKENDS = torchaudio._backend._audio_backends
BACKENDS_MP3 = ["sox"] if "sox" in BACKENDS else []


class AudioBackendScope:
def __init__(self, backend):
self.new_backend = backend
self.previous_backend = torchaudio.get_audio_backend()

def __enter__(self):
torchaudio.set_audio_backend(self.new_backend)
return self.new_backend

def __exit__(self, type, value, traceback):
backend = self.previous_backend
torchaudio.set_audio_backend(backend)
from _test import AudioBackendScope, BACKENDS, BACKENDS_MP3


class Test_LoadSave(unittest.TestCase):
Expand Down
17 changes: 15 additions & 2 deletions test/test_compliance_kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,24 @@
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest
from contextlib import contextmanager


BACKENDS = torchaudio._backend._audio_backends


@contextmanager
def AudioBackendScope(new_backend):
previous_backend = torchaudio.get_audio_backend()
try:
unittest.skipIf(not new_backend, "No backend supporting this test.")
unittest.skipIf(new_backend not in BACKENDS, new_backend + " is not available.")
torchaudio.set_audio_backend(new_backend)
yield
finally:
torchaudio.set_audio_backend(previous_backend)


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
# just a copy of ExtractWindow from feature-window.cc in python
def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
Expand Down Expand Up @@ -163,7 +176,7 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
self.assertTrue(output.shape, kaldi_output.shape)
self.assertTrue(torch.allclose(output, kaldi_output, atol=atol, rtol=rtol))

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope("sox")
def test_spectrogram(self):
def get_output_fn(sound, args):
output = kaldi.spectrogram(
Expand Down Expand Up @@ -214,7 +227,7 @@ def get_output_fn(sound, args):

self._compliance_test_helper(self.test_filepath, 'fbank', 97, 22, get_output_fn, atol=1e-3, rtol=1e-1)

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope("sox")
def test_mfcc(self):
def get_output_fn(sound, args):
output = kaldi.mfcc(
Expand Down
14 changes: 13 additions & 1 deletion test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
BACKENDS = torchaudio._backend._audio_backends


@contextmanager
def AudioBackendScope(new_backend):
previous_backend = torchaudio.get_audio_backend()
try:
unittest.skipIf(not new_backend, "No backend supporting this test.")
unittest.skipIf(new_backend not in BACKENDS, new_backend + " is not available.")
torchaudio.set_audio_backend(new_backend)
yield
finally:
torchaudio.set_audio_backend(previous_backend)


class TORCHAUDIODS(Dataset):

test_dirpath, test_dir = common_utils.create_temp_assets_dir()
Expand All @@ -37,7 +49,7 @@ def __len__(self):
return len(self.data)


@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
class Test_DataLoader(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down
9 changes: 4 additions & 5 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
import unittest
import common_utils

from _test import AudioBackendScope, BACKENDS
from torchaudio.common_utils import IMPORT_LIBROSA

if IMPORT_LIBROSA:
import numpy as np
import librosa

BACKENDS = torchaudio._backend._audio_backends


def _test_torchscript_functional_shape(py_method, *args, **kwargs):
jit_method = torch.jit.script(py_method)
Expand Down Expand Up @@ -436,7 +435,7 @@ def test_create_fb(self):
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope("sox")
def test_gain(self):
waveform_gain = F.gain(self.waveform_train, 3)
self.assertTrue(waveform_gain.abs().max().item(), 1.)
Expand All @@ -448,7 +447,7 @@ def test_gain(self):

self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04))

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope("sox")
def test_dither(self):
waveform_dithered = F.dither(self.waveform_train)
waveform_dithered_noiseshaped = F.dither(self.waveform_train, noise_shaping=True)
Expand All @@ -466,7 +465,7 @@ def test_dither(self):

self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02))

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope("sox")
def test_vctk_transform_pipeline(self):
test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav")
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)
Expand Down
18 changes: 7 additions & 11 deletions test/test_functional_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import unittest
import common_utils
import time


BACKENDS = torchaudio._backend._audio_backends
from _test import AudioBackendScope, BACKENDS, BACKENDS_MP3, FIRST_BACKEND_MP3


def _test_torchscript_functional(py_method, *args, **kwargs):
Expand Down Expand Up @@ -93,15 +91,13 @@ def _test_lfilter(self, waveform, device):
assert output_waveform.size(1) == waveform.size(1)
_test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope(FIRST_BACKEND_MP3)
def test_lfilter(self):

filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, _ = torchaudio.load(filepath, normalization=True)

self._test_lfilter(waveform, torch.device("cpu"))

@unittest.skipIf("sox" not in BACKENDS, "sox is not available")
@AudioBackendScope(FIRST_BACKEND_MP3)
def test_lfilter_gpu(self):
if torch.cuda.is_available():
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
Expand All @@ -113,7 +109,7 @@ def test_lfilter_gpu(self):
print("skipping GPU test for lfilter because device not available")
pass

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_lowpass(self):

"""
Expand All @@ -134,7 +130,7 @@ def test_lowpass(self):
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_highpass(self):
"""
Test biquad highpass filter, compare to SoX implementation
Expand Down Expand Up @@ -334,7 +330,7 @@ def test_riaa(self):
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_equalizer(self):
"""
Test biquad peaking equalizer filter, compare to SoX implementation
Expand All @@ -356,7 +352,7 @@ def test_equalizer(self):
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
_test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)

@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_perf_biquad_filtering(self):

fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
Expand Down
5 changes: 2 additions & 3 deletions test/test_sox_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import math
import os

from _test import AudioBackendScope

BACKENDS = torchaudio._backend._audio_backends


@unittest.skipIf("sox" not in BACKENDS, "sox not available")
# @AudioBackendScope("sox")
class Test_SoxEffectsChain(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets",
Expand Down
7 changes: 3 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
import unittest
import common_utils
from _test import AudioBackendScope, BACKENDS, BACKENDS_MP3, FIRST_BACKEND_MP3

if IMPORT_LIBROSA:
import librosa
Expand All @@ -19,8 +20,6 @@
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)

BACKENDS = torchaudio._backend._audio_backends


def _test_script_module(f, tensor, *args, **kwargs):

Expand Down Expand Up @@ -533,10 +532,10 @@ def test_batch_melspectrogram(self):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

@unittest.skipIf("sox" not in BACKENDS, "sox are not available")
@AudioBackendScope(FIRST_BACKEND_MP3)
def test_batch_mfcc(self):
test_filepath = os.path.join(
test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3'
self.test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3'
)
waveform, sample_rate = torchaudio.load(test_filepath)

Expand Down

0 comments on commit 470e655

Please sign in to comment.