Skip to content

Commit

Permalink
Dispatcher tutorial (pytorch#1072)
Browse files Browse the repository at this point in the history
* Dispatcher tutorial

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* typofix

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* morefix

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Jul 22, 2020
1 parent 0075e38 commit b6ffdb9
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 0 deletions.
281 changes: 281 additions & 0 deletions advanced_source/dispatcher.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
Dispatcher in C++
=================

The dispatcher is an internal component of PyTorch which is responsible for
figuring out what code should actually get run when you call a function like
``torch::add``. This can be nontrivial, because PyTorch operations need
to handle a lot of cross-cutting concerns that are "layered" on top of one
of another. Here is a sampling of some of the things it handles:

* Switching between the CPU and CUDA implementations of an operator, depending
on the devices of the input tensors.
* Switching between the autograd and backend implementations of an operator,
depending on whether or not autograd handling is necessary.
* Applying autocasting when necessary for automatic mixed precision.
* Applying batching rules when an operator is run under a ``vmap`` call.
* Tracing execution of operations, if you are tracing a model for export.

If in your `custom operator code <torch_script_custom_ops>`_ you find yourself
manually writing if statements to handle these cases, the dispatcher APIs can
help organize your code. (Conversely, if your custom operator is very simple
and is only for CPU inference, you probably don't need to use the dispatcher,
just use the basic API.)

In this tutorial, we will describe how to structure a custom operator
registration to use the dispatcher to organize various components. We'll
assume that you are familiar with how to
`register an operator <torch_script_custom_ops>`_ and how to write
a `custom autograd function <cpp_autograd>`_.

Defining schema and backend implementations
-------------------------------------------

The general principle behind the dispatcher is that it divides the
implementation of an operator into multiple kernels, each of which implements
functionality for a specific *dispatch key*; for example, CPU, CUDA or Autograd.
The dispatcher determines what the highest priority dispatch key is at the time
you call an operator (this is done by looking at both the tensor arguments as
well as some thread local state), and transfers control to the kernel for that
dispatch key. The end effect is that when you call an operator, we first
execute the Autograd kernel, and then we redispatch to the CPU or CUDA kernel
depending on the device types of the passed in tensors.

Let's take a look at the various parts involved in making this
happen. First, we must define the schema for the operator in question.
Unlike simple pybind11-style operator registration, we don't actually
provide an implementation of our operator at this point; we just
provide a schema string specifying the type signature of the operator
that all of our other kernels will abide by:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN TORCH_LIBRARY
:end-before: END TORCH_LIBRARY

Next, we need to actually provide some implementations of this operator.
For concreteness, here is a really simple implementation of addition on CPU:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN myadd_cpu
:end-before: END myadd_cpu

We'd like to register this function as an implementation of ``myops::myadd``.
However, the simple way of registering it (``def("myadd", myadd_cpu)``) would
register the kernel to run in all cases, even if the tensor is not a CPU
tensor! (Internally, we refer to these as "catch-all" kernels, since they
catch all cases.) To ensure that ``myadd_cpu`` is only run for
CPU tensors, we can use the ``TORCH_LIBRARY_IMPL`` macro:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN TORCH_LIBRARY_IMPL CPU
:end-before: END TORCH_LIBRARY_IMPL CPU

The ``TORCH_LIBRARY_IMPL`` lets us register implementations for operators on
a specific dispatch key (in this case, CPU). Each call to ``impl``
associates a CPU kernel with the corresponding operator (which we previously
defined in the ``TORCH_LIBRARY`` block). If we also have a CUDA implementation ``myadd_cuda``,
we can register it in a separate ``TORCH_LIBRARY_IMPL`` block:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN TORCH_LIBRARY_IMPL CUDA
:end-before: END TORCH_LIBRARY_IMPL CUDA

These registrations can be split across files or even across library boundaries; so
for example, you could have these two ``TORCH_LIBRARY_IMPL`` blocks compiled
into a separate ``myops_cpu`` and ``myops_cuda`` dynamic libraries. Generally,
speaking, the structure of your registrations will look like this:

1. A single ``TORCH_LIBRARY`` that lists every custom operator in your namespace
in a centralized place.
2. A ``TORCH_LIBRARY_IMPL`` per dispatch key that registers implementations for
that key (e.g., CPU or CUDA). If you like, you can further subdivide
``TORCH_LIBRARY_IMPL`` blocks into a block per operator. This is convenient
if you have a separate file per operator implementation, but don't want to
expose the operators in a header; you can just put the registration in the
cpp file that defines your operator.

.. note::

Did you know that you can also write ``TORCH_LIBRARY_IMPL`` blocks for existing
core operators in PyTorch? This is how XLA support for PyTorch is
implemented: the ``torch_xla`` library contains a ``TORCH_LIBRARY_IMPL``
that provides implementations for all basic operators on the XLA dispatch
key.

Adding autograd support
-----------------------

At this point, we have an operator with both CPU and CUDA implementations. How
can we add autograd support to it? As you might guess, we will register an
autograd kernel (similar to what's described in the `custom autograd function <cpp_autograd>`_ tutorial)!
However, there is a twist: unlike the CPU and CUDA kernels, the autograd kernel
needs to *redispatch*: it needs to call back into the dispatcher to get to
the final CPU and CUDA implementations.

Thus, before we write the autograd kernel, let's write a *dispatching function*
which calls into the dispatcher to find the right kernel for your operator.
This function constitutes the public C++ API for your operators--in fact, all of
the tensor functions in PyTorch's C++ API all call the dispatcher in the same
way under the hood. Here's what the dispatching function looks like:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN myadd
:end-before: END myadd

Let's break it down:

* In the first line, we look up a typed operator handle from the dispatcher
corresponding to the operator that we are going to dispatch to.
``findSchemaOrThrow`` takes two arguments: the (namespace qualified) name
of the operator, and the overload name of the operator (typically just
the empty string). ``typed`` casts the dynamically typed handle into
a statically typed handle (doing a runtime test to make sure you've given
the correct C++ type), so that we can do a normal C++ call on it. We
pass it ``decltype(myadd)`` since the type of the dispatching function is
the same as the type of the underlying kernels registered to the dispatcher.

For performance, this computation is done in a static variable, so that
we only need to do the (slow) lookup once. If you typoed the name of the
operator you want to call, this lookup will error the first time you call this
function.

* In the second line, we simply ``call`` the operator handle with all of the
arguments passed into the dispatching function. This will actually invoke
the dispatcher and in the end control will be transferred to whatever kernel
is appropriate for this call.

With the dispatch function in hand, we can now write the autograd kernel:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN myadd_autograd
:end-before: END myadd_autograd

The autograd function is written as normal using ``torch::autograd::Function``,
except that instead of directly writing the implementation in ``forward()``,
we:

1. Turn off autograd handling with the ``at::AutoNonVariableTypeMode`` RAII
guard, and then
2. Call the dispatch function ``myadd`` to call back into the dispatcher.

Without (1), your calls will infinite loop (and stack overflow), because
``myadd`` will send you back to this function (as the highest priority dispatch
key would still be autograd.) With (1),
autograd is excluded from the set of dispatch keys under consideration, and
we will go to the next handlers, which will either be CPU and CUDA.

We can now register this function in the same way we registered the CPU/CUDA
functions:

.. literalinclude:: ../advanced_source/dispatcher/op.cpp
:language: cpp
:start-after: BEGIN TORCH_LIBRARY_IMPL Autograd
:end-before: END TORCH_LIBRARY_IMPL Autograd

Going beyond autograd
---------------------

In some sense, the dispatcher isn't doing all that much: all it does is
implement a glorified if-statement, along the lines of this:

.. code-block:: cpp
class MyAddFunction : ... {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
if (self.device().type() == DeviceType::CPU) {
return add_cpu(self, other);
} else if (self.device().type() == DeviceType::CUDA) {
return add_cuda(self, other);
} else {
TORCH_CHECK(0, "Unsupported device ", self.device().type());
}
}
...
}
So why use the dispatcher? There are a few reasons:

1. It is decentralized. You can assemble all of the pieces of an operator
(CPU, CUDA, Autograd) without having to write a single, centralized
if statement that refers to all of them. Importantly, third parties can
register extra implementations for other aspects without having to patch the
original definition of an operator.

2. It supports more dispatch keys than CPU, CUDA and Autograd. You can
see a full list of dispatch keys that are currently implemented
in PyTorch in ``c10/core/DispatchKey.h``. These dispatch keys
implement a variety of optional functionality for operators, and if you
decide you want your custom operator to support this functionality,
all you have to register a kernel for the appropriate key.

3. The dispatcher implements support for boxed fallback functions, which
are functions that can be implemented once and apply to all operators
in the system. Boxed fallbacks can be used to provide default behavior
for a dispatch key; if you use the dispatcher to implement your operator,
you also opt into the fallbacks for all of these operations.

Here are some particular dispatch keys which you may need to define an operator
for.

Autocast
^^^^^^^^

The Autocast dispatch key implements support for
`automatic mixed precision <https://developer.nvidia.com/automatic-mixed-precision>`_
(AMP). An autocast kernel typically modifies the operation of an operator by casting the
input arguments to some precision before carrying out the operation. For some
operations, it is numerically safe to cast to lower precision, which is how AMP
can achieve speed ups and reduced memory usage without sacrificing much
accuracy. A nontrivial autocast kernel looks something like this:

.. code-block:: cpp
Tensor mymatmul_autocast(const Tensor& self, const Tensor& other) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return mymatmul(autocast::_cast(at::kHalf, self), autocast::_cast(at::kHalf, other));
}
Notice that, like our autograd kernels, we exclude the ``Autocast`` key from
dispatch before redispatching. By default, if no autocast kernel is provided,
we simply fallthrough directly to the regular operator implementation (no
autocasting occurs.) (We didn't use ``myadd`` for this example, since pointwise
addition doesn't do autocasting and should just fall through).

When should an autocast kernel be registered? Unfortunately, there aren't
cut-and-dry rules for when you should cast to a lower precision. You can
get a sense for what operators have autocasting behavior by looking at
the `AMP documentation
<https://pytorch.org/docs/master/amp.html#op-specific-behavior>`_. Some other
general rules:

* Operations that do reductions should be carried out in float32,
* Any operation with multiple float tensor inputs has to standardize them
to a common precision, and
* Any operation that does a convolution or gemm under the hood should
probably be float16

Batched
^^^^^^^

Batched tensors allow you to write your code in a per-example manner, and then
have them be automatically batched when run under a ``vmap`` invocation. The
API for writing batching rules is currently under development, but once it is
stabilized, you can add support for ``vmap`` for your operators by registering
a kernel at the Batched dispatch key.

Tracer
^^^^^^

The Tracer dispatch key implements support for recording invocations of operators
into a trace when you run ``torch.jit.trace``. We intend to provide a
boxed fallback that will implement tracing for arbitrary operations,
see `issue #41478 <https://github.com/pytorch/pytorch/issues/41478>`_ to track
progress.
8 changes: 8 additions & 0 deletions advanced_source/dispatcher/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(dispatcher)

find_package(Torch REQUIRED)

add_library(dispatcher SHARED op.cpp)
target_compile_features(dispatcher PRIVATE cxx_std_14)
target_link_libraries(dispatcher "${TORCH_LIBRARIES}")
105 changes: 105 additions & 0 deletions advanced_source/dispatcher/op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <torch/torch.h>
#include <torch/script.h>

#include <ATen/NamedTensorUtils.h>

using torch::Tensor;
using torch::DeviceType;
using torch::autograd::tensor_list;
using torch::autograd::AutogradContext;

// BEGIN myadd
Tensor myadd(const Tensor& self, const Tensor& other) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("myops::myadd", "")
.typed<decltype(myadd)>();
return op.call(self, other);
}
// END myadd

// BEGIN TORCH_LIBRARY
TORCH_LIBRARY(myops, m) {
m.def("myadd(Tensor self, Tensor other) -> Tensor");
}
// END TORCH_LIBRARY

// BEGIN myadd_cpu
Tensor myadd_cpu(const Tensor& self_, const Tensor& other_) {
TORCH_CHECK(self_.sizes() == other_.sizes());
TORCH_INTERNAL_ASSERT(self_.device().type() == DeviceType::CPU);
TORCH_INTERNAL_ASSERT(other_.device().type() == DeviceType::CPU);
Tensor self = self_.contiguous();
Tensor other = other_.contiguous();
Tensor result = torch::empty(self.sizes(), self.options());
const float* self_ptr = self.data_ptr<float>();
const float* other_ptr = other.data_ptr<float>();
float* result_ptr = result.data_ptr<float>();
for (int64_t i = 0; i < result.numel(); i++) {
result_ptr[i] = self_ptr[i] + other_ptr[i];
}
return result;
}
// END myadd_cpu

// BEGIN TORCH_LIBRARY_IMPL CPU
TORCH_LIBRARY_IMPL(myops, CPU, m) {
m.impl("myadd", myadd_cpu);
}
// END TORCH_LIBRARY_IMPL CPU

Tensor myadd_cuda(const Tensor& self, const Tensor& other) {
// Insert your CUDA implementation here
TORCH_CHECK(0, "CUDA not yet implemented");
}

// BEGIN TORCH_LIBRARY_IMPL CUDA
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("myadd", myadd_cuda);
}
// END TORCH_LIBRARY_IMPL CUDA

// BEGIN myadd_autograd
class MyAddFunction : public torch::autograd::Function<MyAddFunction> {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
at::AutoNonVariableTypeMode g;
return myadd(self, other);
}

static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) {
auto grad_output = grad_outputs[0];
return {grad_output, grad_output};
}
};

Tensor myadd_autograd(const Tensor& self, const Tensor& other) {
return MyAddFunction::apply(self, other)[0];
}
// END myadd_autograd

// BEGIN TORCH_LIBRARY_IMPL Autograd
TORCH_LIBRARY_IMPL(myops, Autograd, m) {
m.impl("myadd", myadd_autograd);
}
// END TORCH_LIBRARY_IMPL Autograd

#if 0
// BEGIN TORCH_LIBRARY_IMPL Named
Tensor myadd_named(const Tensor& self, const Tensor& other) {
// TODO: shouldn't need to do size check here
TORCH_CHECK(self.sizes() == other.sizes());
auto maybe_outnames = at::unify_from_right(self.names(), other.names());
auto result = ([&]() {
at::NoNamesGuard guard;
return myadd(self, other);
})();
at::namedinference::propagate_names_if_nonempty(result, maybe_outnames);
return result;
}

TORCH_LIBRARY_IMPL(myops, Named, m) {
m.impl("myadd", myadd_named);
}
// END TORCH_LIBRARY_IMPL Named
#endif
Loading

0 comments on commit b6ffdb9

Please sign in to comment.