From b2793c7dbb2d7880054e4cfac96811121478462d Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 18 Dec 2020 23:41:51 +0000 Subject: [PATCH] Support bytes and file-like obj in load func --- .../soundfile_backend/load_test.py | 50 ++++++++++++++ .../sox_io_backend/load_test.py | 65 +++++++++++++++++++ torchaudio/backend/_soundfile_backend.py | 12 +++- torchaudio/backend/sox_io_backend.py | 27 ++++++-- torchaudio/csrc/sox.cpp | 19 ++++++ torchaudio/csrc/sox_effects.cpp | 47 +++++++++++--- torchaudio/csrc/sox_effects.h | 9 ++- torchaudio/csrc/sox_io.cpp | 34 ++++++++-- torchaudio/csrc/sox_io.h | 8 +++ torchaudio/csrc/sox_utils.cpp | 3 - 10 files changed, 247 insertions(+), 27 deletions(-) diff --git a/test/torchaudio_unittest/soundfile_backend/load_test.py b/test/torchaudio_unittest/soundfile_backend/load_test.py index 4277ac03e98..da342f5f831 100644 --- a/test/torchaudio_unittest/soundfile_backend/load_test.py +++ b/test/torchaudio_unittest/soundfile_backend/load_test.py @@ -299,3 +299,53 @@ def test_wav(self, format_): @skipIfFormatNotSupported("FLAC") def test_flac(self, format_): self._test_format(format_) + + +@skipIfNoModule("soundfile") +class TestFileLikeObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext): + """Loading audio via file-like object works""" + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(path, data, sample_rate) + expected = soundfile.read(path, dtype='float32')[0].T + + with open(path, 'rb') as fileobj: + found, sr = soundfile_backend.load(fileobj) + assert sr == sample_rate + self.assertEqual(expected, found) + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj('flac') + + def _test_bytes(self, ext): + """Loading audio via bytes works""" + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + data = get_wav_data('float32', num_channels=2).numpy().T + soundfile.write(path, data, sample_rate) + expected = soundfile.read(path, dtype='float32')[0].T + + with open(path, 'rb') as fileobj: + data = fileobj.read() + found, sr = soundfile_backend.load(data) + assert sr == sample_rate + self.assertEqual(expected, found) + + def _test_bytes_wav(self): + """Loading audio via bytes works""" + self._test_bytes('wav') + + @skipIfFormatNotSupported("FLAC") + def _test_bytes_flac(self): + """Loading audio via bytes works""" + self._test_bytes('flac') diff --git a/test/torchaudio_unittest/sox_io_backend/load_test.py b/test/torchaudio_unittest/sox_io_backend/load_test.py index 933ab861950..fac9675c710 100644 --- a/test/torchaudio_unittest/sox_io_backend/load_test.py +++ b/test/torchaudio_unittest/sox_io_backend/load_test.py @@ -369,3 +369,68 @@ def test_mp3(self): path = get_asset_path("mp3_without_ext") _, sr = sox_io_backend.load(path, format="mp3") assert sr == 16000 + + +@skipIfNoExtension +@skipIfNoExec('sox') +class TestFileLikeObject(TempDirMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_fileobj(self, ext, compression): + """Loading audio via file-like object returns the same result as via file path. + + We campare the result of file-like object input against file path input because + `load` function is rigrously tested for file path inputs to match libsox's result, + """ + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + + expected, _ = sox_io_backend.load(path) + with open(path, 'rb') as fileobj: + found, sr = sox_io_backend.load(fileobj) + assert sr == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand([ + ('wav', None), + ('mp3', 128), + ('mp3', 320), + ('flac', 0), + ('flac', 5), + ('flac', 8), + ('vorbis', -1), + ('vorbis', 10), + ('amb', None), + ]) + def test_bytes(self, ext, compression): + """Loading audio via bytes returns the same result as via file path. + + We campare the result of file-like object input against file path input because + `load` function is rigrously tested for file path inputs to match libsox's result, + """ + sample_rate = 16000 + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + compression=compression) + + expected, _ = sox_io_backend.load(path) + with open(path, 'rb') as fileobj: + data = fileobj.read() + found, sr = sox_io_backend.load(data) + assert sr == sample_rate + self.assertEqual(expected, found) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index aa0a2e854bb..736bf9888a0 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -1,4 +1,5 @@ """The new soundfile backend which will become default in 0.8.0 onward""" +import io from typing import Tuple, Optional import warnings @@ -80,8 +81,13 @@ def load( ``[-1.0, 1.0]``. Args: - filepath (str or pathlib.Path): Path to audio file. - This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str`` + filepath (str, pathlib.Path, or file-like object): + Source of audio data. One of the following types; + * ``str`` or ``pathlib.Path``: file path + * ``bytes``: Audio data in bytes + * ``file-like``: A file-like object with ``read`` method + that returns ``bytes``. + This argument is intentionally annotated as only ``str`` for for the consistency with "sox_io" backend, which has a restriction on type annotation for TorchScript compiler compatiblity. frame_offset (int): @@ -109,6 +115,8 @@ def load( integer type, else ``float32`` type. If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``. """ + if isinstance(filepath, bytes): + filepath = io.BytesIO(filepath) with soundfile.SoundFile(filepath, "r") as file_: if file_.format != "WAV" or normalize: dtype = "float32" diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index fbff7efba3d..fb77fe3d319 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Tuple, Optional import torch @@ -5,6 +6,7 @@ module_utils as _mod_utils, ) +import torchaudio from .common import AudioMetaData @@ -82,9 +84,14 @@ def load( ``[-1.0, 1.0]``. Args: - filepath (str or pathlib.Path): - Path to audio file. This function also handles ``pathlib.Path`` objects, but is - annotated as ``str`` for TorchScript compiler compatibility. + filepath (str, pathlib.Path, file-like object or bytes): + Source of audio data. One of the following types; + * ``str`` or ``pathlib.Path``: file path + * ``bytes``: Audio data in bytes + * ``file-like``: A file-like object with ``read`` method + that returns ``bytes``. + This argument is intentionally annotated as only ``str`` for + TorchScript compiler compatibility. frame_offset (int): Number of frames to skip before start reading data. num_frames (int): @@ -112,8 +119,18 @@ def load( integer type, else ``float32`` type. If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``. """ - # Cast to str in case type is `pathlib.Path` - filepath = str(filepath) + if not torch.jit.is_scripting(): + if isinstance(filepath, (str, Path)): + signal = torch.ops.torchaudio.sox_io_load_audio_file( + str(filepath), frame_offset, num_frames, normalize, channels_first, format) + return signal.get_tensor(), signal.get_sample_rate() + if isinstance(filepath, bytes): + return torchaudio._torchaudio.load_audio_bytes( + filepath, frame_offset, num_frames, normalize, channels_first, format) + if hasattr(filepath, 'read'): + return torchaudio._torchaudio.load_audio_bytes( + filepath.read(), frame_offset, num_frames, normalize, channels_first, format) + raise RuntimeError('The `filepath` object must be one of str, Path, bytes, file-like object.') signal = torch.ops.torchaudio.sox_io_load_audio_file( filepath, frame_offset, num_frames, normalize, channels_first, format) return signal.get_tensor(), signal.get_sample_rate() diff --git a/torchaudio/csrc/sox.cpp b/torchaudio/csrc/sox.cpp index 3be7d013148..ce4ae305e03 100644 --- a/torchaudio/csrc/sox.cpp +++ b/torchaudio/csrc/sox.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -178,6 +179,23 @@ void write_audio_file( } // namespace audio } // namespace torch +namespace { + +std::tuple load_audio_bytes( + py::bytes src, + c10::optional frame_offset, + c10::optional num_frames, + c10::optional normalize, + c10::optional channels_first, + c10::optional& format) { + auto bytes = static_cast(src); + auto result = torchaudio::sox_io::load_audio_bytes( + bytes, frame_offset, num_frames, normalize, channels_first, format); + return std::make_tuple(result->getTensor(), result->getSampleRate()); +} + +} // namespace + PYBIND11_MODULE(_torchaudio, m) { py::class_(m, "sox_signalinfo_t") .def(py::init<>()) @@ -271,4 +289,5 @@ PYBIND11_MODULE(_torchaudio, m) { "get_info", &torch::audio::get_info, "Gets information about an audio file"); + m.def("load_audio_bytes", &load_audio_bytes, "Load audio from byte string."); } diff --git a/torchaudio/csrc/sox_effects.cpp b/torchaudio/csrc/sox_effects.cpp index 4f12212c305..9961511a475 100644 --- a/torchaudio/csrc/sox_effects.cpp +++ b/torchaudio/csrc/sox_effects.cpp @@ -88,18 +88,13 @@ c10::intrusive_ptr apply_effects_tensor( out_tensor, chain.getOutputSampleRate(), channels_first); } -c10::intrusive_ptr apply_effects_file( - const std::string path, +namespace { + +c10::intrusive_ptr apply_effects_impl( + const SoxFormat& sf, std::vector> effects, c10::optional& normalize, - c10::optional& channels_first, - c10::optional& format) { - // Open input file - SoxFormat sf(sox_open_read( - path.c_str(), - /*signal=*/nullptr, - /*encoding=*/nullptr, - /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + c10::optional& channels_first) { validate_input_file(sf); @@ -135,5 +130,37 @@ c10::intrusive_ptr apply_effects_file( tensor, chain.getOutputSampleRate(), channels_first_); } +} // namespace + +c10::intrusive_ptr apply_effects_file( + const std::string& path, + std::vector> effects, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + // Open input file + SoxFormat sf(sox_open_read( + path.c_str(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + return apply_effects_impl(sf, effects, normalize, channels_first); +} + +c10::intrusive_ptr apply_effects_bytes( + const std::string& bytes, + std::vector> effects, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + SoxFormat sf(sox_open_mem_read( + static_cast(const_cast(bytes.c_str())), + bytes.length(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + return apply_effects_impl(sf, effects, normalize, channels_first); +} + } // namespace sox_effects } // namespace torchaudio diff --git a/torchaudio/csrc/sox_effects.h b/torchaudio/csrc/sox_effects.h index c99d3dc0301..67a8bd3f4ff 100644 --- a/torchaudio/csrc/sox_effects.h +++ b/torchaudio/csrc/sox_effects.h @@ -16,7 +16,14 @@ c10::intrusive_ptr apply_effects_tensor( std::vector> effects); c10::intrusive_ptr apply_effects_file( - const std::string path, + const std::string& path, + std::vector> effects, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format); + +c10::intrusive_ptr apply_effects_bytes( + const std::string& bytes, std::vector> effects, c10::optional& normalize, c10::optional& channels_first, diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index bd92e4b8808..7bf3c29caa6 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -49,13 +49,11 @@ c10::intrusive_ptr get_info( static_cast(sf->signal.length / sf->signal.channels)); } -c10::intrusive_ptr load_audio_file( - const std::string& path, +namespace { + +std::vector> get_effects( c10::optional& frame_offset, - c10::optional& num_frames, - c10::optional& normalize, - c10::optional& channels_first, - c10::optional& format) { + c10::optional& num_frames) { const auto offset = frame_offset.value_or(0); if (offset < 0) { throw std::runtime_error( @@ -79,11 +77,35 @@ c10::intrusive_ptr load_audio_file( os_offset << offset << "s"; effects.emplace_back(std::vector{"trim", os_offset.str()}); } + return effects; +} + +} // namespace +c10::intrusive_ptr load_audio_file( + const std::string& path, + c10::optional& frame_offset, + c10::optional& num_frames, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + auto effects = get_effects(frame_offset, num_frames); return torchaudio::sox_effects::apply_effects_file( path, effects, normalize, channels_first, format); } +c10::intrusive_ptr load_audio_bytes( + const std::string& bytes, + c10::optional& frame_offset, + c10::optional& num_frames, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format) { + auto effects = get_effects(frame_offset, num_frames); + return torchaudio::sox_effects::apply_effects_bytes( + bytes, effects, normalize, channels_first, format); +} + void save_audio_file( const std::string& file_name, const c10::intrusive_ptr& signal, diff --git a/torchaudio/csrc/sox_io.h b/torchaudio/csrc/sox_io.h index cb18bb1cc05..334ae05a8e3 100644 --- a/torchaudio/csrc/sox_io.h +++ b/torchaudio/csrc/sox_io.h @@ -33,6 +33,14 @@ c10::intrusive_ptr load_audio_file( c10::optional& channels_first, c10::optional& format); +c10::intrusive_ptr load_audio_bytes( + const std::string& bytes, + c10::optional& frame_offset, + c10::optional& num_frames, + c10::optional& normalize, + c10::optional& channels_first, + c10::optional& format); + void save_audio_file( const std::string& file_name, const c10::intrusive_ptr& signal, diff --git a/torchaudio/csrc/sox_utils.cpp b/torchaudio/csrc/sox_utils.cpp index 656cc63348b..ed2c52a3bbf 100644 --- a/torchaudio/csrc/sox_utils.cpp +++ b/torchaudio/csrc/sox_utils.cpp @@ -99,9 +99,6 @@ void validate_input_file(const SoxFormat& sf) { if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { throw std::runtime_error("Error loading audio file: unknown encoding."); } - if (sf->signal.length == 0) { - throw std::runtime_error("Error reading audio file: unkown length."); - } } void validate_input_tensor(const torch::Tensor tensor) {