-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pipeline decorator #2629
Pipeline decorator #2629
Changes from 18 commits
b1a4eb3
2843e1d
a1d7e57
484d370
7e0bd2d
3dfd20a
32374e7
22d499f
978f69f
f7a747d
9b47f82
a151663
c102afe
293658f
0a53386
3fe6694
9b687e1
12c01d8
a1ad829
932f1de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,4 +1,4 @@ | ||||||||||||||||||
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||
# Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||
# | ||||||||||||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||
# you may not use this file except in compliance with the License. | ||||||||||||||||||
|
@@ -12,20 +12,25 @@ | |||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||
# limitations under the License. | ||||||||||||||||||
|
||||||||||||||||||
#pylint: disable=no-member | ||||||||||||||||||
# pylint: disable=no-member | ||||||||||||||||||
from collections import deque | ||||||||||||||||||
from nvidia.dali import backend as b | ||||||||||||||||||
from nvidia.dali import tensors as Tensors | ||||||||||||||||||
from nvidia.dali import types | ||||||||||||||||||
from nvidia.dali.backend import CheckDLPackCapsule | ||||||||||||||||||
from threading import local as tls | ||||||||||||||||||
from . import data_node as _data_node | ||||||||||||||||||
import functools | ||||||||||||||||||
import inspect | ||||||||||||||||||
import warnings | ||||||||||||||||||
import ctypes | ||||||||||||||||||
|
||||||||||||||||||
pipeline_tls = tls() | ||||||||||||||||||
|
||||||||||||||||||
from .data_node import DataNode | ||||||||||||||||||
DataNode.__module__ = __name__ # move to pipeline | ||||||||||||||||||
|
||||||||||||||||||
DataNode.__module__ = __name__ # move to pipeline | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _show_deprecation_warning(deprecated, in_favor_of): | ||||||||||||||||||
# show only this warning | ||||||||||||||||||
|
@@ -34,6 +39,7 @@ def _show_deprecation_warning(deprecated, in_favor_of): | |||||||||||||||||
warnings.warn("{} is deprecated, please use {} instead".format(deprecated, in_favor_of), | ||||||||||||||||||
Warning, stacklevel=2) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _get_default_stream_for_array(array): | ||||||||||||||||||
if types._is_torch_tensor(array): | ||||||||||||||||||
import torch | ||||||||||||||||||
|
@@ -990,3 +996,113 @@ def iter_setup(self): | |||||||||||||||||
For example, one can use this function to feed the input | ||||||||||||||||||
data from NumPy arrays.""" | ||||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _discriminate_args(func, **func_kwargs): | ||||||||||||||||||
"""Split args on those applicable to Pipeline constructor and the decorated function.""" | ||||||||||||||||||
func_argspec = inspect.getfullargspec(func) | ||||||||||||||||||
ctor_argspec = inspect.getfullargspec(Pipeline.__init__) | ||||||||||||||||||
|
||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that this is probably the simplest and cleanest solution, that we can easily document - just say that all the variadic kwargs are used for forwarding arguments to the pipeline and the original function is not supposed to use it/have it as we may steal them with changes to Pipeline class. |
||||||||||||||||||
ctor_args = {} | ||||||||||||||||||
fn_args = {} | ||||||||||||||||||
|
||||||||||||||||||
for farg in func_kwargs.items(): | ||||||||||||||||||
is_ctor_arg = farg[0] in ctor_argspec.args or farg[0] in ctor_argspec.kwonlyargs | ||||||||||||||||||
is_fn_arg = farg[0] in func_argspec.args or farg[0] in func_argspec.kwonlyargs | ||||||||||||||||||
if is_fn_arg: | ||||||||||||||||||
fn_args[farg[0]] = farg[1] | ||||||||||||||||||
if is_ctor_arg: | ||||||||||||||||||
print( | ||||||||||||||||||
"Warning: the argument `{}` shadows a Pipeline constructor argument of the same name.".format( | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider
Suggested change
|
||||||||||||||||||
farg[0])) | ||||||||||||||||||
elif is_ctor_arg: | ||||||||||||||||||
ctor_args[farg[0]] = farg[1] | ||||||||||||||||||
else: | ||||||||||||||||||
fn_args[farg[0]] = farg[1] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just raise here? If we disable kwargs that would save us from guessing as it would be an explicit error. |
||||||||||||||||||
|
||||||||||||||||||
for farg in fn_args.items(): | ||||||||||||||||||
if farg[0] not in func_argspec.args and farg[0] not in func_argspec.kwonlyargs: | ||||||||||||||||||
raise TypeError( | ||||||||||||||||||
"Using non-explicitly declared arguments in graph-defining function is not allowed. " | ||||||||||||||||||
"Please remove `{}` argument or declare it explicitly in the function signature.".format( | ||||||||||||||||||
farg[0])) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can explicitly check **kwargs (see above), but we can still produce a nice error here - but with a slightly different descirption. |
||||||||||||||||||
|
||||||||||||||||||
return ctor_args, fn_args | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def pipeline_def(fn=None, **pipeline_kwargs): | ||||||||||||||||||
""" | ||||||||||||||||||
Decorator that converts a graph definition function into a DALI pipeline factory. | ||||||||||||||||||
|
||||||||||||||||||
A graph definition function is a function that returns intended pipeline outputs. | ||||||||||||||||||
You can decorate this function with ``@pipeline_def``:: | ||||||||||||||||||
|
||||||||||||||||||
@pipeline_def | ||||||||||||||||||
def my_pipe(flip_vertical, flip_horizontal): | ||||||||||||||||||
''' Creates a DALI pipeline, which returns flipped and original images ''' | ||||||||||||||||||
data, _ = fn.file_reader(file_root=images_dir) | ||||||||||||||||||
img = fn.image_decoder(data, device="mixed") | ||||||||||||||||||
flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical) | ||||||||||||||||||
return flipped, img | ||||||||||||||||||
|
||||||||||||||||||
The decorated function returns a DALI Pipeline object:: | ||||||||||||||||||
|
||||||||||||||||||
pipe = my_pipe(True, False) | ||||||||||||||||||
# pipe.build() # the pipeline is not configured properly yet | ||||||||||||||||||
|
||||||||||||||||||
A pipeline requires additional parameters such as batch size, number of worker threads, | ||||||||||||||||||
GPU device id and so on (see :meth:`Pipeline.__init__` for a complete list of pipeline parameters). | ||||||||||||||||||
These parameters can be supplied as additional keyword arguments, passed to the decorated function:: | ||||||||||||||||||
|
||||||||||||||||||
pipe = my_pipe(True, False, batch_size=32, num_threads=1, device_id=0) | ||||||||||||||||||
pipe.build() # the pipeline is properly configured, we can build it now | ||||||||||||||||||
|
||||||||||||||||||
The outputs from the original function became the outputs of the Pipeline:: | ||||||||||||||||||
|
||||||||||||||||||
flipped, img = pipe.run() | ||||||||||||||||||
|
||||||||||||||||||
When some of the pipeline parameters are fixed, they can be specified by name in the decorator:: | ||||||||||||||||||
|
||||||||||||||||||
@pipeline_def(batch_size=42, num_threads=3) | ||||||||||||||||||
def my_pipe(flip_vertical, flip_horizontal): | ||||||||||||||||||
... | ||||||||||||||||||
|
||||||||||||||||||
Any Pipeline constructor parameter passed later when calling the decorated function will | ||||||||||||||||||
override the decorator-defined params:: | ||||||||||||||||||
|
||||||||||||||||||
@pipeline_def(batch_size=32, num_threads=3) | ||||||||||||||||||
def my_pipe(): | ||||||||||||||||||
data = fn.external_source(source=my_generator) | ||||||||||||||||||
return data | ||||||||||||||||||
|
||||||||||||||||||
pipe = my_pipe(batch_size=128) # batch_size=128 overrides batch_size=32 | ||||||||||||||||||
|
||||||||||||||||||
.. warning:: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we just disallow the |
||||||||||||||||||
|
||||||||||||||||||
The arguments of the function being decorated can shadow pipeline constructor arguments - | ||||||||||||||||||
in which case there's no way to alter their values. | ||||||||||||||||||
|
||||||||||||||||||
.. note:: | ||||||||||||||||||
|
||||||||||||||||||
Using non-explicitly declared arguments in graph-defining function is not allowed. | ||||||||||||||||||
They may result in unwanted, silent hijacking of some arguments of the same name by | ||||||||||||||||||
Pipeline constructor. Code written this way may cease to work with future versions of DALI | ||||||||||||||||||
when new parameters are added to the Pipeline constructor. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
It's not allowed, so this explanation is describes a purely hypothetical scenario. |
||||||||||||||||||
""" | ||||||||||||||||||
def actual_decorator(func): | ||||||||||||||||||
@functools.wraps(func) | ||||||||||||||||||
def create_pipeline(*args, **kwargs): | ||||||||||||||||||
ctor_args, fn_kwargs = _discriminate_args(func, **kwargs) | ||||||||||||||||||
pipe = Pipeline(**{**pipeline_kwargs, **ctor_args}) # Merge and overwrite dict | ||||||||||||||||||
with pipe: | ||||||||||||||||||
pipe_outputs = func(*args, **fn_kwargs) | ||||||||||||||||||
if isinstance(pipe_outputs, tuple): | ||||||||||||||||||
po = pipe_outputs | ||||||||||||||||||
elif pipe_outputs is None: | ||||||||||||||||||
po = () | ||||||||||||||||||
else: | ||||||||||||||||||
po = (pipe_outputs,) | ||||||||||||||||||
pipe.set_outputs(*po) | ||||||||||||||||||
return pipe | ||||||||||||||||||
return create_pipeline | ||||||||||||||||||
return actual_decorator(fn) if fn else actual_decorator |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvidia.dali import Pipeline, pipeline_def | ||
from nose.tools import nottest, raises | ||
import nvidia.dali.fn as fn | ||
from test_utils import get_dali_extra_path, compare_pipelines | ||
import os | ||
|
||
data_root = get_dali_extra_path() | ||
images_dir = os.path.join(data_root, 'db', 'single', 'jpeg') | ||
|
||
N_ITER = 7 | ||
|
||
max_batch_size = 16 | ||
num_threads = 4 | ||
device_id = 0 | ||
|
||
|
||
def reference_pipeline(flip_vertical, flip_horizontal, ref_batch_size=max_batch_size): | ||
pipeline = Pipeline(ref_batch_size, num_threads, device_id) | ||
with pipeline: | ||
data, _ = fn.file_reader(file_root=images_dir) | ||
img = fn.image_decoder(data, device="mixed") | ||
flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical) | ||
pipeline.set_outputs(flipped, img) | ||
return pipeline | ||
|
||
|
||
@nottest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need nottest? The function won't be run automatically (not prefixed with test_). If you need that, why is it not there in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To test, that pipeline decorator works with other decorators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not obvious, maybe add a comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
@pipeline_def(batch_size=max_batch_size, num_threads=num_threads, device_id=device_id) | ||
def pipeline_static(flip_vertical, flip_horizontal): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're missing the tests that were discussed:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder what happens when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We agreed, that the decorator should pass everything to return
return None and how to deal with this. Added specific |
||
data, _ = fn.file_reader(file_root=images_dir) | ||
img = fn.image_decoder(data, device="mixed") | ||
flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical) | ||
return flipped, img | ||
|
||
|
||
@nottest | ||
@pipeline_def | ||
def pipeline_runtime(flip_vertical, flip_horizontal): | ||
data, _ = fn.file_reader(file_root=images_dir) | ||
img = fn.image_decoder(data, device="mixed") | ||
flipped = fn.flip(img, horizontal=flip_horizontal, vertical=flip_vertical) | ||
return flipped, img | ||
|
||
|
||
@nottest | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in other test suites we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def test_pipeline_static(flip_vertical, flip_horizontal): | ||
put_args = pipeline_static(flip_vertical, flip_horizontal) | ||
ref = reference_pipeline(flip_vertical, flip_horizontal) | ||
compare_pipelines(put_args, ref, batch_size=max_batch_size, N_iterations=N_ITER) | ||
|
||
|
||
@nottest | ||
def test_pipeline_runtime(flip_vertical, flip_horizontal): | ||
put_combined = pipeline_runtime(flip_vertical, flip_horizontal, batch_size=max_batch_size, | ||
num_threads=num_threads, device_id=device_id) | ||
ref = reference_pipeline(flip_vertical, flip_horizontal) | ||
compare_pipelines(put_combined, ref, batch_size=max_batch_size, N_iterations=N_ITER) | ||
|
||
|
||
@nottest | ||
def test_pipeline_override(flip_vertical, flip_horizontal, batch_size): | ||
put_combined = pipeline_static(flip_vertical, flip_horizontal, batch_size=batch_size, | ||
num_threads=num_threads, device_id=device_id) | ||
ref = reference_pipeline(flip_vertical, flip_horizontal, ref_batch_size=batch_size) | ||
compare_pipelines(put_combined, ref, batch_size=batch_size, N_iterations=N_ITER) | ||
|
||
|
||
def test_pipeline_decorator(): | ||
for vert in [0, 1]: | ||
for hori in [0, 1]: | ||
yield test_pipeline_static, vert, hori | ||
yield test_pipeline_runtime, vert, hori | ||
yield test_pipeline_override, vert, hori, 16 | ||
yield test_pipeline_runtime, fn.random.coin_flip(seed=123), fn.random.coin_flip(seed=234) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work? can we pass data nodes from outside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apparently yes... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can if they don't have side-effects. |
||
yield test_pipeline_static, fn.random.coin_flip(seed=123), fn.random.coin_flip(seed=234) | ||
|
||
|
||
def test_duplicated_argument(): | ||
@pipeline_def(batch_size=max_batch_size, num_threads=num_threads, device_id=device_id) | ||
def ref_pipeline(val): | ||
data, _ = fn.file_reader(file_root=images_dir) | ||
return data + val | ||
|
||
@pipeline_def(batch_size=max_batch_size, num_threads=num_threads, device_id=device_id) | ||
def pipeline_duplicated_arg(max_streams): | ||
data, _ = fn.file_reader(file_root=images_dir) | ||
return data + max_streams | ||
|
||
pipe = pipeline_duplicated_arg(max_streams=42) | ||
assert pipe._max_streams == -1 | ||
ref = ref_pipeline(42) | ||
compare_pipelines(pipe, ref, batch_size=max_batch_size, N_iterations=N_ITER) | ||
|
||
|
||
# test_kwargs_exception tests against user introducing arguments, | ||
# that are not explicitly declared in function signature | ||
|
||
@pipeline_def | ||
def pipeline_kwargs(arg1, arg2, *args, **kwargs): | ||
pass | ||
|
||
|
||
@pipeline_def | ||
def pipeline_kwonlyargs(arg1, *, arg2, **kwargs): | ||
pass | ||
|
||
|
||
def test_kwargs_exception_1(): | ||
pipeline_kwargs(1, arg2=2) | ||
|
||
|
||
def test_kwargs_exception_2(): | ||
pipeline_kwonlyargs(1, arg2=2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should raise an exception for mere presence of **kwargs. |
||
|
||
|
||
@raises(TypeError) | ||
def test_kwargs_exception_3(): | ||
pipeline_kwargs(arg1=1, arg2=2, arg3=3) | ||
|
||
|
||
def test_kwargs_exception_4(): | ||
pipeline_kwargs(1, 2, 3, 4, 5) | ||
|
||
|
||
@raises(TypeError) | ||
def test_kwargs_exception_5(): | ||
pipeline_kwonlyargs(1, arg2=2, arg3=3) | ||
|
||
|
||
@raises(TypeError) | ||
def test_kwargs_exception_6(): | ||
pipeline_kwargs(1, arg2=2, arg3=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder how it plays with political correctness ;P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned it before but the comment got lost.
def _separate_kwargs(kwargs):
in ops.py. Isn't this exactly what you need? (haven't read it)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a discriminator in electronics, mathematics, GANs also have a discriminator. I wouldn't like to be oversensitive - should we avoid to use "gap", just because there is a wage-gap?
(sacrasm warning)
After all, that's precisely what this function does - treats function args better than ctor args ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Achievement unlocked: troll ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the existing function I pointed out? Is it a duplicate or not?