Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support file-like object in info #1108

Merged
merged 1 commit into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/torchaudio_unittest/soundfile_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import patch
import warnings
import tarfile

import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
Expand Down Expand Up @@ -125,3 +126,65 @@ class MockSoundFileInfo:
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0


@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext, subtype, bits_per_sample):
"""Query audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
path = self.get_temp_path(f'test.{ext}')

data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(path, data, sample_rate, subtype=subtype)

with open(path, 'rb') as fileobj:
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav', 'PCM_16', 16)

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac', 'PCM_16', 16)

def _test_tarobj(self, ext, subtype, bits_per_sample):
"""Query compressed audio via file-like object works"""
duration = 2
sample_rate = 16000
num_channels = 2
num_frames = sample_rate * duration
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

data = torch.randn(num_frames, num_channels).numpy()
soundfile.write(audio_path, data, sample_rate, subtype=subtype)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
info = soundfile_backend.info(fileobj)
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_tarobj_wav(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('wav', 'PCM_16', 16)

@skipIfFormatNotSupported("FLAC")
def test_tarobj_flac(self):
"""Query compressed audio via file-like object works"""
self._test_tarobj('flac', 'PCM_16', 16)
151 changes: 150 additions & 1 deletion test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import io
import itertools
from parameterized import parameterized
import tarfile

from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils

from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
get_asset_path,
get_wav_data,
save_wav,
Expand All @@ -18,6 +23,10 @@
)


if _mod_utils.is_module_available("requests"):
import requests


@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
Expand Down Expand Up @@ -197,3 +206,143 @@ def test_mp3(self):
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_fileobj(self, ext, bits_per_sample):
"""Querying audio via file object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)

with open(path, 'rb') as fileobj:
sinfo = sox_io_backend.info(fileobj, format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample

def _test_bytesio(self, ext, bits_per_sample, duration):
sample_rate = 16000
num_channels = 2
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
duration=duration)

with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
sinfo = sox_io_backend.info(fileobj, format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample

@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works"""
self._test_bytesio(ext, bits_per_sample, duration=3)

@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_bytesio_tiny(self, ext, bits_per_sample):
"""Querying audio via ByteIO object works for small data"""
self._test_bytesio(ext, bits_per_sample, duration=1 / 1600)

@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_tarfile(self, ext, bits_per_sample):
"""Querying compressed audio via file-like object works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)

with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
sinfo = sox_io_backend.info(fileobj, format=format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample


@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', 32),
('mp3', 0),
('flac', 24),
('vorbis', 0),
('amb', 32),
])
def test_requests(self, ext, bits_per_sample):
"""Querying compressed audio via requests works"""
sample_rate = 16000
num_channels = 2
duration = 3
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)

sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=num_channels, duration=duration)

url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
sinfo = sox_io_backend.info(resp.raw, format=format_)

assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
if ext not in ['mp3', 'vorbis']: # these container formats do not have length info
assert sinfo.num_frames == sample_rate * duration
assert sinfo.bits_per_sample == bits_per_sample
10 changes: 6 additions & 4 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.

Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
format (str, optional):
Not used. PySoundFile does not accept format hint.

Expand Down
49 changes: 42 additions & 7 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,26 @@
from .common import AudioMetaData


@torch.jit.unused
def _info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(
filepath, format)
sample_rate, num_channels, num_frames, bits_per_sample = sinfo
return AudioMetaData(
sample_rate, num_frames, num_channels, bits_per_sample)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
)


@_mod_utils.requires_module('torchaudio._torchaudio')
def info(
filepath: str,
Expand All @@ -18,9 +38,21 @@ def info(
"""Get signal information of an audio file.

Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects,
but is annotated as ``str`` for TorchScript compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.

Note:
* When the input type is file-like object, this function cannot
get the correct length (``num_samples``) for certain formats,
such as ``mp3`` and ``vorbis``.
In this case, the value of ``num_samples`` is ``0``.
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
Expand All @@ -29,11 +61,14 @@ def info(
Returns:
AudioMetaData: Metadata of the given audio.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
return _info(filepath, format)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(),
sinfo.get_bits_per_sample())
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample())


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def(
"get_info_fileobj",
&torchaudio::sox_io::get_info_fileobj,
"Get metadata of audio in file object.");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
Expand Down
Loading