-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pipeline decorator #2629
Pipeline decorator #2629
Changes from 16 commits
b1a4eb3
2843e1d
a1d7e57
484d370
7e0bd2d
3dfd20a
32374e7
22d499f
978f69f
f7a747d
9b47f82
a151663
c102afe
293658f
0a53386
3fe6694
9b687e1
12c01d8
a1ad829
932f1de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,4 +1,4 @@ | ||||||||||||||||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||
# | ||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||
|
@@ -12,20 +12,25 @@ | |||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||
# limitations under the License. | ||||||||||||||||||
|
||||||||||||||||||
#pylint: disable=no-member | ||||||||||||||||||
# pylint: disable=no-member | ||||||||||||||||||
from collections import deque | ||||||||||||||||||
from nvidia.dali import backend as b | ||||||||||||||||||
from nvidia.dali import tensors as Tensors | ||||||||||||||||||
from nvidia.dali import types | ||||||||||||||||||
from nvidia.dali.backend import CheckDLPackCapsule | ||||||||||||||||||
from threading import local as tls | ||||||||||||||||||
from . import data_node as _data_node | ||||||||||||||||||
import functools | ||||||||||||||||||
import inspect | ||||||||||||||||||
import warnings | ||||||||||||||||||
import ctypes | ||||||||||||||||||
|
||||||||||||||||||
pipeline_tls = tls() | ||||||||||||||||||
|
||||||||||||||||||
from .data_node import DataNode | ||||||||||||||||||
DataNode.__module__ = __name__ # move to pipeline | ||||||||||||||||||
|
||||||||||||||||||
DataNode.__module__ = __name__ # move to pipeline | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _show_deprecation_warning(deprecated, in_favor_of): | ||||||||||||||||||
# show only this warning | ||||||||||||||||||
|
@@ -34,6 +39,7 @@ def _show_deprecation_warning(deprecated, in_favor_of): | |||||||||||||||||
warnings.warn("{} is deprecated, please use {} instead".format(deprecated, in_favor_of), | ||||||||||||||||||
Warning, stacklevel=2) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _get_default_stream_for_array(array): | ||||||||||||||||||
if types._is_torch_tensor(array): | ||||||||||||||||||
import torch | ||||||||||||||||||
|
@@ -990,3 +996,101 @@ def iter_setup(self): | |||||||||||||||||
For example, one can use this function to feed the input | ||||||||||||||||||
data from NumPy arrays.""" | ||||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _discriminate_args(func, **func_kwargs): | ||||||||||||||||||
"""Split args on those applicable to Pipeline constructor and the decorated function.""" | ||||||||||||||||||
func_argspec = inspect.getfullargspec(func) | ||||||||||||||||||
ctor_argspec = inspect.getfullargspec(Pipeline.__init__) | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that this is probably the simplest and cleanest solution, that we can easily document - just say that all the variadic kwargs are used for forwarding arguments to the pipeline and the original function is not supposed to use it/have it as we may steal them with changes to Pipeline class. |
||||||||||||||||||
ctor_args = {} | ||||||||||||||||||
fn_args = {} | ||||||||||||||||||
|
||||||||||||||||||
for farg in func_kwargs.items(): | ||||||||||||||||||
is_ctor_arg = farg[0] in ctor_argspec.args or farg[0] in ctor_argspec.kwonlyargs | ||||||||||||||||||
is_fn_arg = farg[0] in func_argspec.args or farg[0] in func_argspec.kwonlyargs | ||||||||||||||||||
if is_fn_arg: | ||||||||||||||||||
fn_args[farg[0]] = farg[1] | ||||||||||||||||||
if is_ctor_arg: | ||||||||||||||||||
print( | ||||||||||||||||||
"Warning: the argument `{}` shadows a Pipeline constructor argument of the same name.".format( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider
Suggested change
|
||||||||||||||||||
farg[0])) | ||||||||||||||||||
elif is_ctor_arg: | ||||||||||||||||||
ctor_args[farg[0]] = farg[1] | ||||||||||||||||||
else: | ||||||||||||||||||
fn_args[farg[0]] = farg[1] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just raise here? If we disable kwargs that would save us from guessing as it would be an explicit error. |
||||||||||||||||||
|
||||||||||||||||||
return ctor_args, fn_args | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def pipeline_def(fn=None, **pipeline_kwargs): | ||||||||||||||||||
""" | ||||||||||||||||||
Decorator that converts a graph definition function into a DALI pipeline factory. | ||||||||||||||||||
|
||||||||||||||||||
A graph definition function is a function that returns intended pipeline outputs. | ||||||||||||||||||
You can decorate this function with ``@pipeline_def``:: | ||||||||||||||||||
|
||||||||||||||||||
@pipeline_def | ||||||||||||||||||
def my_pipe(flip_vertical, flip_horizontal): | ||||||||||||||||||
''' Creates a DALI pipeline, which returns flipped and original images ''' | ||||||||||||||||||
data, _ = fn.file_reader(file_root=images_dir) | ||||||||||||||||||
img = fn.image_decoder(data, device="mixed") | ||||||||||||||||||
flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical) | ||||||||||||||||||
return flipped, img | ||||||||||||||||||
|
||||||||||||||||||
The decorated function returns a DALI Pipeline object:: | ||||||||||||||||||
|
||||||||||||||||||
pipe = my_pipe(True, False) | ||||||||||||||||||
# pipe.build() # the pipeline is not configured properly yet | ||||||||||||||||||
|
||||||||||||||||||
A pipeline requires additional parameters such as batch size, number of worker threads, | ||||||||||||||||||
GPU device id and so on (see :meth:`Pipeline.__init__` for a complete list of pipeline parameters). | ||||||||||||||||||
These parameters can be supplied as additional keyword arguments, passed to the decorated function:: | ||||||||||||||||||
|
||||||||||||||||||
pipe = my_pipe(True, False, batch_size=32, num_threads=1, device_id=0) | ||||||||||||||||||
pipe.build() # the pipeline is properly configured, we can build it now | ||||||||||||||||||
|
||||||||||||||||||
The outputs from the original function became the outputs of the Pipeline:: | ||||||||||||||||||
|
||||||||||||||||||
flipped, img = pipe.run() | ||||||||||||||||||
|
||||||||||||||||||
When some of the pipeline parameters are fixed, they can be specified by name in the decorator:: | ||||||||||||||||||
|
||||||||||||||||||
@pipeline_def(batch_size=42, num_threads=3) | ||||||||||||||||||
def my_pipe(flip_vertical, flip_horizontal): | ||||||||||||||||||
... | ||||||||||||||||||
|
||||||||||||||||||
Any Pipeline constructor parameter passed later when calling the decorated function will | ||||||||||||||||||
override the decorator-defined params:: | ||||||||||||||||||
|
||||||||||||||||||
@pipeline_def(batch_size=32, num_threads=3) | ||||||||||||||||||
def my_pipe(): | ||||||||||||||||||
data = fn.external_source(source=my_generator) | ||||||||||||||||||
return data | ||||||||||||||||||
|
||||||||||||||||||
pipe = my_pipe(batch_size=128) # batch_size=128 overrides batch_size=32 | ||||||||||||||||||
|
||||||||||||||||||
.. warning:: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we just disallow the |
||||||||||||||||||
|
||||||||||||||||||
The arguments of the function being decorated can shadow pipeline constructor arguments - | ||||||||||||||||||
in which case there's no way to alter their values. Be especially mindful about using | ||||||||||||||||||
``**kwargs``, since code written this way may break with future versions of DALI, when | ||||||||||||||||||
new parameters are added to the ``Pipeline`` constructor. | ||||||||||||||||||
""" | ||||||||||||||||||
def actual_decorator(func): | ||||||||||||||||||
@functools.wraps(func) | ||||||||||||||||||
def create_pipeline(*args, **kwargs): | ||||||||||||||||||
ctor_args, fn_kwargs = _discriminate_args(func, **kwargs) | ||||||||||||||||||
pipe = Pipeline(**{**pipeline_kwargs, **ctor_args}) # Merge and overwrite dict | ||||||||||||||||||
with pipe: | ||||||||||||||||||
pipe_outputs = func(*args, **fn_kwargs) | ||||||||||||||||||
if isinstance(pipe_outputs, tuple): | ||||||||||||||||||
po = pipe_outputs | ||||||||||||||||||
elif pipe_outputs is None: | ||||||||||||||||||
po = () | ||||||||||||||||||
else: | ||||||||||||||||||
po = (pipe_outputs,) | ||||||||||||||||||
pipe.set_outputs(*po) | ||||||||||||||||||
return pipe | ||||||||||||||||||
return create_pipeline | ||||||||||||||||||
return actual_decorator(fn) if fn else actual_decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder how it plays with political correctness ;P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned it before but the comment got lost.
def _separate_kwargs(kwargs):
in ops.py. Isn't this exactly what you need? (haven't read it)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a discriminator in electronics, mathematics, GANs also have a discriminator. I wouldn't like to be oversensitive - should we avoid to use "gap", just because there is a wage-gap?
(sacrasm warning)
After all, that's precisely what this function does - treats function args better than ctor args ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Achievement unlocked: troll ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the existing function I pointed out? Is it a duplicate or not?