Skip to content

Commit

Permalink
use wav instead of mp3 for testing functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentqb committed Feb 6, 2020
1 parent 510bb19 commit 2f470ed
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TestFunctional(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()

test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
'steam-train-whistle-daniel_simon.wav')
waveform_train, sr_train = torchaudio.load(test_filepath)

def test_torchscript_spectrogram(self):
Expand Down
10 changes: 8 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)

BACKENDS = torchaudio._backend._audio_backends


def _test_script_module(f, tensor, *args, **kwargs):

Expand Down Expand Up @@ -55,7 +57,7 @@ class Tester(unittest.TestCase):
# file for stereo stft test
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
'steam-train-whistle-daniel_simon.wav')

def scale(self, waveform, factor=float(2**31)):
# scales a waveform by a factor
Expand Down Expand Up @@ -469,8 +471,12 @@ def test_batch_melspectrogram(self):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

@unittest.skipIf(set("sox") not in set(BACKENDS), "sox are not available")
def test_batch_mfcc(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
test_filepath = os.path.join(
test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3'
)
waveform, sample_rate = torchaudio.load(test_filepath)

# Single then transform then batch
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)
Expand Down

0 comments on commit 2f470ed

Please sign in to comment.