Skip to content

Commit

Permalink
Support bytes and file-like obj in load func
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Dec 22, 2020
1 parent be44256 commit b2793c7
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 27 deletions.
50 changes: 50 additions & 0 deletions test/torchaudio_unittest/soundfile_backend/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,53 @@ def test_wav(self, format_):
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)


@skipIfNoModule("soundfile")
class TestFileLikeObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T

with open(path, 'rb') as fileobj:
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)

def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav')

@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac')

def _test_bytes(self, ext):
"""Loading audio via bytes works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T

with open(path, 'rb') as fileobj:
data = fileobj.read()
found, sr = soundfile_backend.load(data)
assert sr == sample_rate
self.assertEqual(expected, found)

def _test_bytes_wav(self):
"""Loading audio via bytes works"""
self._test_bytes('wav')

@skipIfFormatNotSupported("FLAC")
def _test_bytes_flac(self):
"""Loading audio via bytes works"""
self._test_bytes('flac')
65 changes: 65 additions & 0 deletions test/torchaudio_unittest/sox_io_backend/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,68 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
assert sr == 16000


@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileLikeObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Loading audio via file-like object returns the same result as via file path.
We campare the result of file-like object input against file path input because
`load` function is rigrously tested for file path inputs to match libsox's result,
"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)

expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as fileobj:
found, sr = sox_io_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)

@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytes(self, ext, compression):
"""Loading audio via bytes returns the same result as via file path.
We campare the result of file-like object input against file path input because
`load` function is rigrously tested for file path inputs to match libsox's result,
"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')

sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)

expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as fileobj:
data = fileobj.read()
found, sr = sox_io_backend.load(data)
assert sr == sample_rate
self.assertEqual(expected, found)
12 changes: 10 additions & 2 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The new soundfile backend which will become default in 0.8.0 onward"""
import io
from typing import Tuple, Optional
import warnings

Expand Down Expand Up @@ -80,8 +81,13 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
filepath (str, pathlib.Path, or file-like object):
Source of audio data. One of the following types;
* ``str`` or ``pathlib.Path``: file path
* ``bytes``: Audio data in bytes
* ``file-like``: A file-like object with ``read`` method
that returns ``bytes``.
This argument is intentionally annotated as only ``str`` for
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
frame_offset (int):
Expand Down Expand Up @@ -109,6 +115,8 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
if isinstance(filepath, bytes):
filepath = io.BytesIO(filepath)
with soundfile.SoundFile(filepath, "r") as file_:
if file_.format != "WAV" or normalize:
dtype = "float32"
Expand Down
27 changes: 22 additions & 5 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pathlib import Path
from typing import Tuple, Optional

import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)

import torchaudio
from .common import AudioMetaData


Expand Down Expand Up @@ -82,9 +84,14 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
filepath (str, pathlib.Path, file-like object or bytes):
Source of audio data. One of the following types;
* ``str`` or ``pathlib.Path``: file path
* ``bytes``: Audio data in bytes
* ``file-like``: A file-like object with ``read`` method
that returns ``bytes``.
This argument is intentionally annotated as only ``str`` for
TorchScript compiler compatibility.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
Expand Down Expand Up @@ -112,8 +119,18 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
if isinstance(filepath, (str, Path)):
signal = torch.ops.torchaudio.sox_io_load_audio_file(
str(filepath), frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
if isinstance(filepath, bytes):
return torchaudio._torchaudio.load_audio_bytes(
filepath, frame_offset, num_frames, normalize, channels_first, format)
if hasattr(filepath, 'read'):
return torchaudio._torchaudio.load_audio_bytes(
filepath.read(), frame_offset, num_frames, normalize, channels_first, format)
raise RuntimeError('The `filepath` object must be one of str, Path, bytes, file-like object.')
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
Expand Down
19 changes: 19 additions & 0 deletions torchaudio/csrc/sox.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torchaudio/csrc/sox.h>
#include <torchaudio/csrc/sox_io.h>

#include <algorithm>
#include <cstdint>
Expand Down Expand Up @@ -178,6 +179,23 @@ void write_audio_file(
} // namespace audio
} // namespace torch

namespace {

std::tuple<torch::Tensor, int64_t> load_audio_bytes(
py::bytes src,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string>& format) {
auto bytes = static_cast<std::string>(src);
auto result = torchaudio::sox_io::load_audio_bytes(
bytes, frame_offset, num_frames, normalize, channels_first, format);
return std::make_tuple(result->getTensor(), result->getSampleRate());
}

} // namespace

PYBIND11_MODULE(_torchaudio, m) {
py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t")
.def(py::init<>())
Expand Down Expand Up @@ -271,4 +289,5 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def("load_audio_bytes", &load_audio_bytes, "Load audio from byte string.");
}
47 changes: 37 additions & 10 deletions torchaudio/csrc/sox_effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,13 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
out_tensor, chain.getOutputSampleRate(), channels_first);
}

c10::intrusive_ptr<TensorSignal> apply_effects_file(
const std::string path,
namespace {

c10::intrusive_ptr<TensorSignal> apply_effects_impl(
const SoxFormat& sf,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
// Open input file
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
c10::optional<bool>& channels_first) {

validate_input_file(sf);

Expand Down Expand Up @@ -135,5 +130,37 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
tensor, chain.getOutputSampleRate(), channels_first_);
}

} // namespace

c10::intrusive_ptr<TensorSignal> apply_effects_file(
const std::string& path,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
// Open input file
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
return apply_effects_impl(sf, effects, normalize, channels_first);
}

c10::intrusive_ptr<TensorSignal> apply_effects_bytes(
const std::string& bytes,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
SoxFormat sf(sox_open_mem_read(
static_cast<void*>(const_cast<char*>(bytes.c_str())),
bytes.length(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
return apply_effects_impl(sf, effects, normalize, channels_first);
}

} // namespace sox_effects
} // namespace torchaudio
9 changes: 8 additions & 1 deletion torchaudio/csrc/sox_effects.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_tensor(
std::vector<std::vector<std::string>> effects);

c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file(
const std::string path,
const std::string& path,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);

c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_bytes(
const std::string& bytes,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
Expand Down
34 changes: 28 additions & 6 deletions torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ c10::intrusive_ptr<SignalInfo> get_info(
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
}

c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path,
namespace {

std::vector<std::vector<std::string>> get_effects(
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
c10::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
Expand All @@ -79,11 +77,35 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}

} // namespace

c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}

c10::intrusive_ptr<TensorSignal> load_audio_bytes(
const std::string& bytes,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_bytes(
bytes, effects, normalize, channels_first, format);
}

void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
Expand Down
8 changes: 8 additions & 0 deletions torchaudio/csrc/sox_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);

c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_bytes(
const std::string& bytes,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);

void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
Expand Down
3 changes: 0 additions & 3 deletions torchaudio/csrc/sox_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@ void validate_input_file(const SoxFormat& sf) {
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
if (sf->signal.length == 0) {
throw std::runtime_error("Error reading audio file: unkown length.");
}
}

void validate_input_tensor(const torch::Tensor tensor) {
Expand Down

0 comments on commit b2793c7

Please sign in to comment.