From 3e1b33cec69ff33b4405599f8be053fcbdc890b1 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 18 Dec 2020 23:41:51 +0000 Subject: [PATCH] Support file-like object in info --- .../soundfile_backend/info_test.py | 62 ++++++++ .../sox_io_backend/info_test.py | 134 +++++++++++++++++- torchaudio/backend/_soundfile_backend.py | 10 +- torchaudio/backend/sox_io_backend.py | 27 +++- torchaudio/csrc/pybind.cpp | 4 + torchaudio/csrc/sox/io.cpp | 32 ++++- torchaudio/csrc/sox/io.h | 6 +- torchaudio/csrc/sox/register.cpp | 2 +- 8 files changed, 264 insertions(+), 13 deletions(-) diff --git a/test/torchaudio_unittest/soundfile_backend/info_test.py b/test/torchaudio_unittest/soundfile_backend/info_test.py index 71acb20689c..703d8bb59d9 100644 --- a/test/torchaudio_unittest/soundfile_backend/info_test.py +++ b/test/torchaudio_unittest/soundfile_backend/info_test.py @@ -1,3 +1,5 @@ +import tarfile + import torch from torchaudio.backend import _soundfile_backend as soundfile_backend from torchaudio._internal import module_utils as _mod_utils @@ -93,3 +95,63 @@ def test_sphere(self, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + + +@skipIfNoModule("soundfile") +class TestFileObject(TempDirMixin, PytorchTestCase): + def _test_fileobj(self, ext): + """Query audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + path = self.get_temp_path(f'test.{ext}') + + data = torch.randn(num_frames, num_channels).numpy() + soundfile.write(path, data, sample_rate) + + with open(path, 'rb') as fileobj: + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + + def test_fileobj_wav(self): + """Loading audio via file-like object works""" + self._test_fileobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_fileobj_flac(self): + """Loading audio via file-like object works""" + self._test_fileobj('flac') + + def _test_tarobj(self, ext): + """Query compressed audio via file-like object works""" + duration = 2 + sample_rate = 16000 + num_channels = 2 + num_frames = sample_rate * duration + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + data = torch.randn(num_frames, num_channels).numpy() + soundfile.write(audio_path, data, sample_rate) + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + info = soundfile_backend.info(fileobj) + assert info.sample_rate == sample_rate + assert info.num_frames == num_frames + assert info.num_channels == num_channels + + def test_tarobj_wav(self): + """Query compressed audio via file-like object works""" + self._test_tarobj('wav') + + @skipIfFormatNotSupported("FLAC") + def test_tarobj_flac(self): + """Query compressed audio via file-like object works""" + self._test_tarobj('flac') diff --git a/test/torchaudio_unittest/sox_io_backend/info_test.py b/test/torchaudio_unittest/sox_io_backend/info_test.py index 49fc7973545..8a209aa8128 100644 --- a/test/torchaudio_unittest/sox_io_backend/info_test.py +++ b/test/torchaudio_unittest/sox_io_backend/info_test.py @@ -1,13 +1,18 @@ +import io import itertools -from parameterized import parameterized +import tarfile +from parameterized import parameterized from torchaudio.backend import sox_io_backend +from torchaudio._internal import module_utils as _mod_utils from torchaudio_unittest.common_utils import ( TempDirMixin, + HttpServerMixin, PytorchTestCase, skipIfNoExec, skipIfNoExtension, + skipIfNoModule, get_asset_path, get_wav_data, save_wav, @@ -18,6 +23,10 @@ ) +if _mod_utils.is_module_available("requests"): + import requests + + @skipIfNoExec('sox') @skipIfNoExtension class TestInfo(TempDirMixin, PytorchTestCase): @@ -184,3 +193,126 @@ def test_mp3(self): path = get_asset_path("mp3_without_ext") sinfo = sox_io_backend.info(path, format="mp3") assert sinfo.sample_rate == 16000 + + +@skipIfNoExtension +@skipIfNoExec('sox') +class TestFileObject(TempDirMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', ), + ('mp3', ), + ('flac', ), + ('vorbis', ), + ('amb', ), + ]) + def test_fileobj(self, ext): + """Querying audio via file object works""" + sample_rate = 16000 + num_channels = 2 + duration = 3 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + duration=duration) + + with open(path, 'rb') as fileobj: + sinfo = sox_io_backend.info(fileobj, format_) + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + if ext not in ['mp3', 'vorbis']: # these container formats do not have length info + assert sinfo.num_frames == sample_rate * duration + + @parameterized.expand([ + ('wav', ), + ('mp3', ), + ('flac', ), + ('vorbis', ), + ('amb', ), + ]) + def test_bytesio(self, ext): + """Querying audio via ByteIO object works""" + sample_rate = 16000 + num_channels = 2 + duration = 3 + format_ = ext if ext in ['mp3'] else None + path = self.get_temp_path(f'test.{ext}') + + sox_utils.gen_audio_file( + path, sample_rate, num_channels=2, + duration=duration) + + with open(path, 'rb') as file_: + fileobj = io.BytesIO(file_.read()) + sinfo = sox_io_backend.info(fileobj, format_) + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + if ext not in ['mp3', 'vorbis']: # these container formats do not have length info + assert sinfo.num_frames == sample_rate * duration + + @parameterized.expand([ + ('wav', ), + ('mp3', ), + ('flac', ), + ('vorbis', ), + ('amb', ), + ]) + def test_tarfile(self, ext): + """Querying compressed audio via file-like object works""" + sample_rate = 16000 + num_channels = 2 + duration = 3 + format_ = ext if ext in ['mp3'] else None + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + archive_path = self.get_temp_path('archive.tar.gz') + + sox_utils.gen_audio_file( + audio_path, sample_rate, num_channels=num_channels, duration=duration) + + with tarfile.TarFile(archive_path, 'w') as tarobj: + tarobj.add(audio_path, arcname=audio_file) + with tarfile.TarFile(archive_path, 'r') as tarobj: + fileobj = tarobj.extractfile(audio_file) + sinfo = sox_io_backend.info(fileobj, format=format_) + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + if ext not in ['mp3', 'vorbis']: # these container formats do not have length info + assert sinfo.num_frames == sample_rate * duration + + +@skipIfNoExtension +@skipIfNoExec('sox') +@skipIfNoModule("requests") +class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): + @parameterized.expand([ + ('wav', ), + ('mp3', ), + ('flac', ), + ('vorbis', ), + ('amb', ), + ]) + def test_requests(self, ext): + """Querying compressed audio via requests works""" + sample_rate = 16000 + num_channels = 2 + duration = 3 + format_ = ext if ext in ['mp3'] else None + audio_file = f'test.{ext}' + audio_path = self.get_temp_path(audio_file) + + sox_utils.gen_audio_file( + audio_path, sample_rate, num_channels=num_channels, duration=duration) + + url = self.get_url(audio_file) + with requests.get(url, stream=True) as resp: + sinfo = sox_io_backend.info(resp.raw, format=format_) + + assert sinfo.sample_rate == sample_rate + assert sinfo.num_channels == num_channels + if ext not in ['mp3', 'vorbis']: # these container formats do not have length info + assert sinfo.num_frames == sample_rate * duration diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 3366780bdb2..704b777e88d 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -16,10 +16,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData: """Get signal information of an audio file. Args: - filepath (str or pathlib.Path): Path to audio file. - This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str`` - for the consistency with "sox_io" backend, which has a restriction on type annotation - for TorchScript compiler compatiblity. + filepath (path-like object or file-like object): + Source of audio data. + Note: + * This argument is intentionally annotated as ``str`` only, + for the consistency with "sox_io" backend, which has a restriction + on type annotation due to TorchScript compiler compatiblity. format (str, optional): Not used. PySoundFile does not accept format hint. diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 18296ef7e27..c67fc11941c 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -18,9 +18,21 @@ def info( """Get signal information of an audio file. Args: - filepath (str or pathlib.Path): - Path to audio file. This function also handles ``pathlib.Path`` objects, - but is annotated as ``str`` for TorchScript compatibility. + filepath (path-like object or file-like object): + Source of audio data. When the function is not compiled by TorchScript, + (e.g. ``torch.jit.script``), the following types are accepted; + * ``path-like``: file path + * ``file-like``: Object with ``read(size: int) -> bytes`` method, + which returns byte string of at most ``size`` length. + When the function is compiled by TorchScript, only ``str`` type is allowed. + + Note: + * When the input type is file-like object, this function cannot + get the correct length (``num_samples``) for certain formats, + such as ``mp3`` and ``vorbis``. + In this case, the value of ``num_samples`` is ``0``. + * This argument is intentionally annotated as ``str`` only due to + TorchScript compiler compatibility. format (str, optional): Override the format detection with the given format. Providing the argument might help when libsox can not infer the format @@ -29,8 +41,13 @@ def info( Returns: AudioMetaData: Metadata of the given audio. """ - # Cast to str in case type is `pathlib.Path` - filepath = str(filepath) + if not torch.jit.is_scripting(): + if hasattr(filepath, 'read'): + sample_rate, num_channels, num_frames = torchaudio._torchaudio.get_info_fileobj( + filepath, format) + return AudioMetaData(sample_rate, num_frames, num_channels) + sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format) + return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels()) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels()) diff --git a/torchaudio/csrc/pybind.cpp b/torchaudio/csrc/pybind.cpp index eb8c30b96ac..162716e22d8 100644 --- a/torchaudio/csrc/pybind.cpp +++ b/torchaudio/csrc/pybind.cpp @@ -96,6 +96,10 @@ PYBIND11_MODULE(_torchaudio, m) { "get_info", &torch::audio::get_info, "Gets information about an audio file"); + m.def( + "get_info_fileobj", + &torchaudio::sox_io::get_info_fileobj, + "Get metadata of audio in file object."); m.def( "load_audio_fileobj", &torchaudio::sox_io::load_audio_fileobj, diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index c6531eb77eb..aa50eabcde7 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -30,7 +30,7 @@ int64_t SignalInfo::getNumFrames() const { return num_frames; } -c10::intrusive_ptr get_info( +c10::intrusive_ptr get_info_file( const std::string& path, c10::optional& format) { SoxFormat sf(sox_open_read( @@ -140,6 +140,36 @@ void save_audio_file( #ifdef TORCH_API_INCLUDE_EXTENSION_H +std::tuple get_info_fileobj( + py::object fileobj, + c10::optional& format) { + // 4096 is fixed minimum buffer size + // https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L46 + const size_t buf_size = 4096; + std::string buffer(buf_size, 'x'); + auto* buf = const_cast(buffer.data()); + + // Fetch the header, and copy it to the buffer. + auto header = static_cast(static_cast(fileobj.attr("read")(4096))); + memcpy(static_cast(buf), + static_cast(const_cast(header.data())), header.length()); + + SoxFormat sf(sox_open_mem_read( + buf, + buf_size, + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); + + // In case of streamed data, length can be 0 + validate_input_file(sf, /*check_length=*/false); + + return std::make_tuple( + static_cast(sf->signal.rate), + static_cast(sf->signal.channels), + static_cast(sf->signal.length / sf->signal.channels)); +} + std::tuple load_audio_fileobj( py::object fileobj, c10::optional& frame_offset, diff --git a/torchaudio/csrc/sox/io.h b/torchaudio/csrc/sox/io.h index ac7191527f8..1af7d94369b 100644 --- a/torchaudio/csrc/sox/io.h +++ b/torchaudio/csrc/sox/io.h @@ -25,7 +25,7 @@ struct SignalInfo : torch::CustomClassHolder { int64_t getNumFrames() const; }; -c10::intrusive_ptr get_info( +c10::intrusive_ptr get_info_file( const std::string& path, c10::optional& format); @@ -47,6 +47,10 @@ void save_audio_file( #ifdef TORCH_API_INCLUDE_EXTENSION_H +std::tuple get_info_fileobj( + py::object fileobj, + c10::optional& format); + std::tuple load_audio_fileobj( py::object fileobj, c10::optional& frame_offset, diff --git a/torchaudio/csrc/sox/register.cpp b/torchaudio/csrc/sox/register.cpp index 7c65bebe2df..d82974f5928 100644 --- a/torchaudio/csrc/sox/register.cpp +++ b/torchaudio/csrc/sox/register.cpp @@ -44,7 +44,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { .def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels) .def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames); - m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info); + m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file); m.def( "torchaudio::sox_io_load_audio_file(" "str path,"