Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add warp transducer #1099

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/torchscript_bc_test/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ build_master() {
conda install -y -q pytorch "cpuonly" -c pytorch-nightly
printf "* Installing torchaudio\n"
cd "${_root_dir}" || exit 1
git submodule update --init --recursive
BUILD_SOX=1 python setup.py clean install
}
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ conda install -y -c "pytorch-${UPLOAD_CHANNEL}" pytorch ${cudatoolkit}

# 2. Install torchaudio
printf "* Installing torchaudio\n"
git submodule update --init --recursive
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
BUILD_SOX=1 python setup.py install

# 3. Install Test tools
Expand Down
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "third_party/warp_transducer/submodule"]
path = third_party/warp_transducer/submodule
url = https://github.com/HawkAaron/warp-transducer
branch = master
86 changes: 69 additions & 17 deletions build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import platform
import subprocess
import torch

from pathlib import Path

from torch.utils.cpp_extension import (
Expand All @@ -17,23 +19,26 @@
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
_CSRC_DIR = _ROOT_DIR / 'torchaudio' / 'csrc'
_TP_BASE_DIR = _ROOT_DIR / 'third_party'
_TP_TRANSDUCER_BASE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer'
_TP_TRANSDUCER_MODULE_DIR = _ROOT_DIR / 'third_party' / 'warp_transducer' / 'submodule'
_TP_INSTALL_DIR = _TP_BASE_DIR / 'install'


def _get_build_sox():
val = os.environ.get('BUILD_SOX', '0')
def _get_build_option(var):
val = os.environ.get(var, '0')
trues = ['1', 'true', 'TRUE', 'on', 'ON', 'yes', 'YES']
falses = ['0', 'false', 'FALSE', 'off', 'OFF', 'no', 'NO']
if val in trues:
return True
if val not in falses:
print(
f'WARNING: Unexpected environment variable value `BUILD_SOX={val}`. '
f'WARNING: Unexpected environment variable value `{var}={val}`. '
f'Expected one of {trues + falses}')
return False


_BUILD_SOX = _get_build_sox()
_BUILD_SOX = _get_build_option("BUILD_SOX")
_BUILD_CUDA_WARP_TRANSDUCER = _get_build_option("BUILD_CUDA_WT")


def _get_eca(debug):
Expand Down Expand Up @@ -101,11 +106,12 @@ def _get_libraries():
return [] if _BUILD_SOX else ['sox']


def _build_third_party():
build_dir = str(_TP_BASE_DIR / 'build')
def _build_third_party(base_dir, target=['..'], options=[]):
print(f"Building third party library in {base_dir}...")
build_dir = str(base_dir / 'build')
os.makedirs(build_dir, exist_ok=True)
subprocess.run(
args=['cmake', '..'],
args=['cmake'] + target + options,
cwd=build_dir,
check=True,
)
Expand All @@ -116,27 +122,73 @@ def _build_third_party():
)


def _get_ext(debug):
return CppExtension(
_EXT_NAME,
_get_srcs(),
libraries=_get_libraries(),
include_dirs=_get_include_dirs(),
extra_compile_args=_get_eca(debug),
extra_objects=_get_extra_objects(),
extra_link_args=_get_ela(debug),
)


def _get_ext_transducer(debug):
extra_compile_args = [
'-fPIC',
'-std=c++14',
]

if _BUILD_CUDA_WARP_TRANSDUCER and torch.cuda.is_available():
print("Building GPU extensions for warp_transudcer.")
if "CUDA_HOME" not in os.environ:
raise RuntimeError("Please specify the environment variable CUDA_HOME.")
extra_compile_args += ['-DWARPRNNT_ENABLE_GPU']
else:
print("Not building GPU extensions for warp_transudcer.")

librairies = ['warprnnt']
build_path = _TP_TRANSDUCER_MODULE_DIR / 'build'
include_path = _TP_TRANSDUCER_MODULE_DIR / 'include'
source_path = _TP_TRANSDUCER_BASE_DIR / 'binding.cpp'

return CppExtension(
name='_warp_transducer',
sources=[os.path.realpath(source_path)],
libraries=librairies,
include_dirs=[os.path.realpath(include_path)],
library_dirs=[os.path.realpath(build_path)],
extra_compile_args=extra_compile_args,
extra_objects=[str(build_path / f'lib{l}.a') for l in librairies],
extra_link_args=['-Wl,-rpath,' + os.path.realpath(build_path)],
)


_EXT_NAME = 'torchaudio._torchaudio'


def get_ext_modules(debug=False):
if platform.system() == 'Windows':
return None
return [
CppExtension(
_EXT_NAME,
_get_srcs(),
libraries=_get_libraries(),
include_dirs=_get_include_dirs(),
extra_compile_args=_get_eca(debug),
extra_objects=_get_extra_objects(),
extra_link_args=_get_ela(debug),
),
_get_ext(debug),
_get_ext_transducer(debug),
]


class BuildExtension(TorchBuildExtension):
def build_extension(self, ext):
if ext.name == _EXT_NAME and _BUILD_SOX:
_build_third_party()
_build_third_party(_TP_BASE_DIR)
if ext.name == "_warp_transducer":
# TODO Support OMP on MacOS
_build_third_party(
_TP_TRANSDUCER_MODULE_DIR,
target=[str(_TP_TRANSDUCER_BASE_DIR)],
options=[
"-DWITH_OMP=OFF",
"-DWITH_GPU=ON" if _BUILD_CUDA_WARP_TRANSDUCER else "-DWITH_GPU=OFF",
]
)
super().build_extension(ext)
1 change: 1 addition & 0 deletions packaging/build_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ export NO_CUDA_PACKAGE=1
setup_env 0.8.0
export SOURCE_ROOT_DIR="$PWD"
setup_conda_pytorch_constraint
git submodule update --init --recursive
conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchaudio
1 change: 1 addition & 0 deletions packaging/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ setup_wheel_python
pip_install numpy future
setup_pip_pytorch_version
python setup.py clean
git submodule update --init --recursive
if [[ "$OSTYPE" == "msys" ]]; then
python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')"
python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def run(self):
build_dirs = [
ROOT_DIR / 'build',
ROOT_DIR / 'third_party' / 'build',
ROOT_DIR / 'third_party' / 'warp_transducer' / 'submodule' / 'build',
]
for path in build_dirs:
if path.exists():
Expand Down Expand Up @@ -83,7 +84,8 @@ def run(self):
packages=find_packages(exclude=["build*", "test*", "torchaudio.csrc*", "third_party*", "build_tools*"]),
ext_modules=setup_helpers.get_ext_modules(),
cmdclass={
'build_ext': setup_helpers.BuildExtension.with_options(no_python_abi_suffix=True)
'build_ext': setup_helpers.BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

},
install_requires=[pytorch_package_dep],
zip_safe=False,
Expand Down
2 changes: 2 additions & 0 deletions test/torchaudio_unittest/common_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
skipIfNoModule,
skipIfNoExtension,
skipIfNoSoxBackend,
skipIfNoTransducer,
skipIfNoCudaTransducer,
)
from .wav_utils import (
get_wav_data,
Expand Down
18 changes: 18 additions & 0 deletions test/torchaudio_unittest/common_utils/case_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,21 @@ def skipIfNoExtension(test_item):
if 'TORCHAUDIO_TEST_FAIL_IF_NO_EXTENSION' in os.environ:
raise RuntimeError('torchaudio C++ extension is not available.')
return unittest.skip('torchaudio C++ extension is not available')(test_item)


skipIfNoTransducer = unittest.skipIf(
not is_module_available('_warp_transducer'),
'"_warp_transducer" is not available',
)

try:
torch.ops.warprnnt_pytorch_warp_rnnt.gpu_rnnt
_CUDA_TRANSDUCER = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make this switch dependent on whether CUDA is available at all.

except RuntimeError:
_CUDA_TRANSDUCER = False


skipIfNoCudaTransducer = unittest.skipIf(
not _CUDA_TRANSDUCER,
'"_warp_transducer" not built with GPU support',
)
52 changes: 52 additions & 0 deletions test/torchaudio_unittest/transducer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch

from torchaudio_unittest import common_utils
from torchaudio.prototype.transducer import RNNTLoss


class TransducerTester:
def test_basic_backward(self):
rnnt_loss = RNNTLoss()

acts = torch.FloatTensor(
[
[
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1],
],
[
[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1],
],
]
]
)
labels = torch.IntTensor([[1, 2]])
act_length = torch.IntTensor([2])
label_length = torch.IntTensor([2])

acts = acts.to(self.device)
labels = labels.to(self.device)
act_length = act_length.to(self.device)
label_length = label_length.to(self.device)

acts = torch.autograd.Variable(acts, requires_grad=True)
labels = torch.autograd.Variable(labels)
act_length = torch.autograd.Variable(act_length)
label_length = torch.autograd.Variable(label_length)

loss = rnnt_loss(acts, labels, act_length, label_length)
loss.backward()


@common_utils.skipIfNoTransducer
class CPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cpu"


@common_utils.skipIfNoCudaTransducer
class GPUTransducerTester(TransducerTester, common_utils.PytorchTestCase):
device = "cuda"
Loading