Skip to content

Commit

Permalink
Support file-like object in info
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 15, 2021
1 parent f1d8d1e commit 3e1b33c
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 13 deletions.
62 changes: 62 additions & 0 deletions test/torchaudio_unittest/soundfile_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tarfile

import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
Expand Down Expand Up @@ -93,3 +95,63 @@ def test_sphere(self, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')

data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate)

with open(path, 'rb') as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels

def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac')

def _test_tarobj(self, ext):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels

def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac')
134 changes: 133 additions & 1 deletion test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import io
import itertools
from parameterized import parameterized
import tarfile

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils

from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
get_asset_path,
get_wav_data,
save_wav,
Expand All @@ -18,6 +23,10 @@
)


if _mod_utils.is_module_available("requests"):
import requests


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
Expand Down Expand Up @@ -184,3 +193,126 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', ),
('mp3', ),
('flac', ),
('vorbis', ),
('amb', ),
])
def test_fileobj(self, ext):
"""Querying audio via file object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)

with open(path, 'rb') as fileobj:
sinfo = sox_io_backend.info(fileobj, format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration

@parameterized.expand([
('wav', ),
('mp3', ),
('flac', ),
('vorbis', ),
('amb', ),
])
def test_bytesio(self, ext):
"""Querying audio via ByteIO object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)

with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
sinfo = sox_io_backend.info(fileobj, format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration

@parameterized.expand([
('wav', ),
('mp3', ),
('flac', ),
('vorbis', ),
('amb', ),
])
def test_tarfile(self, ext):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
sinfo = sox_io_backend.info(fileobj, format=format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration


@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', ),
('mp3', ),
('flac', ),
('vorbis', ),
('amb', ),
])
def test_requests(self, ext):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
sinfo = sox_io_backend.info(resp.raw, format=format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
10 changes: 6 additions & 4 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
format (str, optional):
Not used. PySoundFile does not accept format hint.
Expand Down
27 changes: 22 additions & 5 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,21 @@ def info(
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects,
but is annotated as ``str`` for TorchScript compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
Expand All @@ -29,8 +41,13 @@ def info(
Returns:
AudioMetaData: Metadata of the given audio.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
sample_rate, num_channels, num_frames = torchaudio._torchaudio.get_info_fileobj(
filepath, format)
return AudioMetaData(sample_rate, num_frames, num_channels)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())

Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def(
"get_info_fileobj",
&torchaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
Expand Down
32 changes: 31 additions & 1 deletion torchaudio/csrc/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int64_t SignalInfo::getNumFrames() const {
return num_frames;
}

c10::intrusive_ptr<SignalInfo> get_info(
c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path,
c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
Expand Down Expand Up @@ -140,6 +140,36 @@ void save_audio_file(

#ifdef TORCH_API_INCLUDE_EXTENSION_H

std::tuple<int64_t, int64_t, int64_t> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format) {
// 4096 is fixed minimum buffer size
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L46
const size_t buf_size = 4096;
std::string buffer(buf_size, 'x');
auto* buf = const_cast<char*>(buffer.data());

// Fetch the header, and copy it to the buffer.
auto header = static_cast<std::string>(static_cast<py::bytes>(fileobj.attr("read")(4096)));
memcpy(static_cast<void*>(buf),
static_cast<void*>(const_cast<char*>(header.data())), header.length());

SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));

// In case of streamed data, length can be 0
validate_input_file(sf, /*check_length=*/false);

return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
}

std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t>& frame_offset,
Expand Down
6 changes: 5 additions & 1 deletion torchaudio/csrc/sox/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t getNumFrames() const;
};

c10::intrusive_ptr<SignalInfo> get_info(
c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path,
c10::optional<std::string>& format);

Expand All @@ -47,6 +47,10 @@ void save_audio_file(

#ifdef TORCH_API_INCLUDE_EXTENSION_H

std::tuple<int64_t, int64_t, int64_t> get_info_fileobj(
py::object fileobj,
c10::optional<std::string>& format);

std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t>& frame_offset,
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/sox/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames);

m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file("
"str path,"
Expand Down

0 comments on commit 3e1b33c

Please sign in to comment.