Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline decorator #2629

Merged
merged 20 commits into from
Feb 12, 2021
3 changes: 2 additions & 1 deletion dali/python/__init__.py.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,3 +27,4 @@ from . import tfrecord
from . import types
from . import plugin_manager
from . import sysconfig
from .pipeline import Pipeline, pipeline_def
110 changes: 107 additions & 3 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,20 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#pylint: disable=no-member
# pylint: disable=no-member
from collections import deque
from nvidia.dali import backend as b
from nvidia.dali import tensors as Tensors
from nvidia.dali import types
from nvidia.dali.backend import CheckDLPackCapsule
from threading import local as tls
from . import data_node as _data_node
import functools
import inspect
import warnings
import ctypes

pipeline_tls = tls()

from .data_node import DataNode
DataNode.__module__ = __name__ # move to pipeline

DataNode.__module__ = __name__ # move to pipeline


def _show_deprecation_warning(deprecated, in_favor_of):
# show only this warning
Expand All @@ -34,6 +39,7 @@ def _show_deprecation_warning(deprecated, in_favor_of):
warnings.warn("{} is deprecated, please use {} instead".format(deprecated, in_favor_of),
Warning, stacklevel=2)


def _get_default_stream_for_array(array):
if types._is_torch_tensor(array):
import torch
Expand Down Expand Up @@ -990,3 +996,101 @@ def iter_setup(self):
For example, one can use this function to feed the input
data from NumPy arrays."""
pass


def _discriminate_args(func, **func_kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how it plays with political correctness ;P

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned it before but the comment got lost.

  • I'd use "split" or "separate", discriminate might carry a negative connotation.
  • There is a function def _separate_kwargs(kwargs): in ops.py. Isn't this exactly what you need? (haven't read it)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a discriminator in electronics, mathematics, GANs also have a discriminator. I wouldn't like to be oversensitive - should we avoid to use "gap", just because there is a wage-gap?

(sacrasm warning)
After all, that's precisely what this function does - treats function args better than ctor args ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Achievement unlocked: troll ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the existing function I pointed out? Is it a duplicate or not?

"""Split args on those applicable to Pipeline constructor and the decorated function."""
func_argspec = inspect.getfullargspec(func)
ctor_argspec = inspect.getfullargspec(Pipeline.__init__)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if func_argspec.varkw is not None:
raise ValueError("Use of variadic keyword argument `**{}` in graph definition function is not allowed.".format(func_argspec.varkw))

Copy link
Contributor

@klecki klecki Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this is probably the simplest and cleanest solution, that we can easily document - just say that all the variadic kwargs are used for forwarding arguments to the pipeline and the original function is not supposed to use it/have it as we may steal them with changes to Pipeline class.

ctor_args = {}
fn_args = {}

for farg in func_kwargs.items():
is_ctor_arg = farg[0] in ctor_argspec.args or farg[0] in ctor_argspec.kwonlyargs
is_fn_arg = farg[0] in func_argspec.args or farg[0] in func_argspec.kwonlyargs
if is_fn_arg:
fn_args[farg[0]] = farg[1]
if is_ctor_arg:
print(
"Warning: the argument `{}` shadows a Pipeline constructor argument of the same name.".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider

Suggested change
"Warning: the argument `{}` shadows a Pipeline constructor argument of the same name.".format(
f"Warning: the argument `{farg[0]}` shadows a Pipeline constructor argument of the same name."

farg[0]))
elif is_ctor_arg:
ctor_args[farg[0]] = farg[1]
else:
fn_args[farg[0]] = farg[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just raise here? If we disable kwargs that would save us from guessing as it would be an explicit error.


return ctor_args, fn_args


def pipeline_def(fn=None, **pipeline_kwargs):
"""
Decorator that converts a graph definition function into a DALI pipeline factory.

A graph definition function is a function that returns intended pipeline outputs.
You can decorate this function with ``@pipeline_def``::

@pipeline_def
def my_pipe(flip_vertical, flip_horizontal):
''' Creates a DALI pipeline, which returns flipped and original images '''
data, _ = fn.file_reader(file_root=images_dir)
img = fn.image_decoder(data, device="mixed")
flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical)
return flipped, img

The decorated function returns a DALI Pipeline object::

pipe = my_pipe(True, False)
# pipe.build() # the pipeline is not configured properly yet

A pipeline requires additional parameters such as batch size, number of worker threads,
GPU device id and so on (see :meth:`Pipeline.__init__` for a complete list of pipeline parameters).
These parameters can be supplied as additional keyword arguments, passed to the decorated function::

pipe = my_pipe(True, False, batch_size=32, num_threads=1, device_id=0)
pipe.build() # the pipeline is properly configured, we can build it now

The outputs from the original function became the outputs of the Pipeline::

flipped, img = pipe.run()

When some of the pipeline parameters are fixed, they can be specified by name in the decorator::

@pipeline_def(batch_size=42, num_threads=3)
def my_pipe(flip_vertical, flip_horizontal):
...

Any Pipeline constructor parameter passed later when calling the decorated function will
override the decorator-defined params::

@pipeline_def(batch_size=32, num_threads=3)
def my_pipe():
data = fn.external_source(source=my_generator)
return data

pipe = my_pipe(batch_size=128) # batch_size=128 overrides batch_size=32

.. warning::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. warning::
.. warning::
Avoid the use of `**kwargs` in the graph definition function, since it may result in unwanted,
silent hijacking of some arguments of the same name by Pipeline constructor. Code written
this way may cease to work with future versions of DALI when new parameters are added
to the Pipeline constructor.
.. warning::

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just disallow the **kwargs completely for the function that is to be decorated and raise an error. I'd say play safe rather than write lengthy manuals. If we find a better way or this is an issue we can always relax that check (and doing the other way would be hard).


The arguments of the function being decorated can shadow pipeline constructor arguments -
in which case there's no way to alter their values. Be especially mindful about using
``**kwargs``, since code written this way may break with future versions of DALI, when
new parameters are added to the ``Pipeline`` constructor.
"""
def actual_decorator(func):
@functools.wraps(func)
def create_pipeline(*args, **kwargs):
ctor_args, fn_kwargs = _discriminate_args(func, **kwargs)
pipe = Pipeline(**{**pipeline_kwargs, **ctor_args}) # Merge and overwrite dict
with pipe:
pipe_outputs = func(*args, **fn_kwargs)
if isinstance(pipe_outputs, tuple):
po = pipe_outputs
elif pipe_outputs is None:
po = ()
else:
po = (pipe_outputs,)
pipe.set_outputs(*po)
return pipe
return create_pipeline
return actual_decorator(fn) if fn else actual_decorator
Loading