Skip to content

Commit

Permalink
Schema-based input device check (#5631)
Browse files Browse the repository at this point in the history
* Check input device in OpSpec::AddInput.
* Add `Any` and `MatchBackendOrCPU` input devices
* Fix InputDevice in operators. Add Any device capability to "shape_like" inputs.
* Add python-side backend validation in logical expressions.

---------

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Sep 16, 2024
1 parent 3399b74 commit 869afb3
Show file tree
Hide file tree
Showing 24 changed files with 160 additions and 120 deletions.
3 changes: 2 additions & 1 deletion dali/operators/generic/cast.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,6 +76,7 @@ DALI_SCHEMA(Cast)
DALI_SCHEMA(CastLike)
.DocStr("Cast the first tensor to the type of the second tensor.")
.NumInput(2)
.InputDevice(1, InputDevice::Any)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric();
Expand Down
5 changes: 4 additions & 1 deletion dali/operators/generic/constant_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void ConstantValue<CPUBackend>::RunImpl(Workspace &ws) {
auto dtype = output.type();
auto &tp = ws.GetThreadPool();
if (has_fill_value_) {
auto &fill_value = ws.Input<CPUBackend>(kValueInputIdx);
auto &fill_value = ws.Input<CPUBackend>(value_input_idx_);
const auto &fill_value_sh = fill_value.shape();
TYPE_SWITCH(fill_value.type(), type2id, FillValueType, (DALI_CONSTANT_VALUE_TYPES), (
TYPE_SWITCH(dtype, type2id, OutputType, (DALI_CONSTANT_VALUE_TYPES), (
Expand Down Expand Up @@ -84,6 +84,7 @@ DALI_SCHEMA(FullLike)
.DocStr(R"code(Returns new data with the same shape and type as the input data, filled with a `fill_value`.)code")
.NumInput(2)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDevice(0, InputDevice::Any)
.InputDox(1, "fill_value", "TensorList", R"code(The fill value.)code")
.NumOutput(1);
DALI_REGISTER_OPERATOR(FullLike, FullLike<CPUBackend>, CPU);
Expand All @@ -101,6 +102,7 @@ DALI_SCHEMA(ZerosLike)
.DocStr(R"code(Returns new data with the same shape and type as the input array, filled with zeros.)code")
.NumInput(1)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDevice(0, InputDevice::Any)
.NumOutput(1)
.AddOptionalTypeArg("dtype", R"code(Overrides the output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(ZerosLike, ZerosLike<CPUBackend>, CPU);
Expand All @@ -118,6 +120,7 @@ DALI_SCHEMA(OnesLike)
.DocStr(R"code(Returns new data with the same shape and type as the input array, filled with ones.)code")
.NumInput(1)
.InputDox(0, "data_like", "TensorList", R"code(The input data value to copy the shape and type from.)code")
.InputDevice(0, InputDevice::Any)
.NumOutput(1)
.AddOptionalTypeArg("dtype", R"code(Overrides the output data type.)code", DALI_INT32);
DALI_REGISTER_OPERATOR(OnesLike, OnesLike<CPUBackend>, CPU);
Expand Down
16 changes: 8 additions & 8 deletions dali/operators/generic/constant_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ConstantValue : public StatelessOperator<Backend> {

int GetBatchSize(const Workspace &ws) const {
if (is_shape_like_)
return ws.Input<Backend>(kShapeLikeInputIdx).shape().size();
return ws.GetInputBatchSize(shape_like_input_idx_);
else
return ws.GetRequestedBatchSize(0);
}
Expand Down Expand Up @@ -70,18 +70,18 @@ class ConstantValue : public StatelessOperator<Backend> {
output_desc.resize(1);
auto &dtype = output_desc[0].type;
auto &shape = output_desc[0].shape;
dtype = is_shape_like_ && !has_dtype_ ? ws.Input<Backend>(kShapeLikeInputIdx).type() : dtype_;
dtype = is_shape_like_ && !has_dtype_ ? ws.GetInputDataType(shape_like_input_idx_) : dtype_;

if (is_shape_like_) {
shape = ws.Input<Backend>(kShapeLikeInputIdx).shape();
shape = ws.GetInputShape(shape_like_input_idx_);
} else if (has_shape_) {
GetShapeArgument(shape, spec_, "shape", ws, nsamples);
} else {
shape = uniform_list_shape(nsamples, TensorShape<0>{});
}

if (has_fill_value_) {
auto& fill_value = ws.Input<Backend>(kValueInputIdx);
auto& fill_value = ws.Input<Backend>(value_input_idx_);
auto fill_value_shape = fill_value.shape();
auto fill_value_dtype = fill_value.type();
int new_ndim = shape.sample_dim() + fill_value_shape.sample_dim();
Expand Down Expand Up @@ -110,16 +110,16 @@ class ConstantValue : public StatelessOperator<Backend> {
protected:
using Operator<Backend>::spec_;
using Operator<Backend>::max_batch_size_;
const bool has_fill_value_;
const bool is_shape_like_;
bool has_fill_value_;
bool is_shape_like_;
bool has_shape_, has_dtype_;
DALIDataType dtype_;

bool has_const_value_ = false;
int const_value_ = 0;

const int kShapeLikeInputIdx = is_shape_like_ ? 0 : -1;
const int kValueInputIdx = is_shape_like_ ? 1 : 0;
int shape_like_input_idx_ = is_shape_like_ ? 0 : -1;
int value_input_idx_ = is_shape_like_ ? 1 : 0;
};

template <typename Backend>
Expand Down
2 changes: 2 additions & 0 deletions dali/operators/generic/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The buffer contents are not copied.)code")
.NumOutput(1)
.InputDox(0, "data", "TensorList", "Data to be reshaped")
.InputDox(1, "shape_input", "1D TensorList of integers", "Same as ``shape`` keyword argument")
.InputDevice(1, InputDevice::CPU)
.PassThrough({{0, 0}})
.AllowSequences()
.SupportVolumetric()
Expand Down Expand Up @@ -105,6 +106,7 @@ The buffer contents are not copied.)")
.NumOutput(1)
.InputDox(0, "data", "TensorList", "Data to be reshaped")
.InputDox(1, "shape_input", "1D TensorList of integers", "Same as ``shape`` keyword argument")
.InputDevice(1, InputDevice::CPU)
.PassThrough({{0, 0}})
.AllowSequences()
.SupportVolumetric()
Expand Down
3 changes: 2 additions & 1 deletion dali/operators/generic/shapes.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@ namespace dali {
DALI_SCHEMA(Shapes)
.DocStr(R"code(Returns the shapes of inputs.)code")
.NumInput(1)
.InputDevice(0, InputDevice::Any)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric()
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/generic/slice/slice.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,7 +59,7 @@ arguments.
By default, the :meth:`nvidia.dali.fn.slice` operator uses normalized coordinates and ``WH``
order for the slice arguments.)code")
.NumInput(1, 3)
.InputDevice(1, 3, InputDevice::CPU)
.InputDevice(1, 3, InputDevice::MatchBackendOrCPU)
.NumOutput(1)
.InputDox(0, "data", "TensorList", R"code(Batch that contains the input data.)code")
.InputDox(1, "anchor", "1D TensorList of float or int",
Expand Down
3 changes: 2 additions & 1 deletion dali/operators/image/convolution/filter.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,6 +66,7 @@ the corresponding scalars when convolved with the filter.
The scalars must be of the same type as the input samples.
For video/sequence input, an array of scalars can be specified to be applied
:func:`per-frame<nvidia.dali.fn.per_frame>`.)code")
.InputDevice(1, 3, InputDevice::MatchBackendOrCPU)
.AddOptionalArg("anchor",
R"code(Specifies the position of the filter over the input.
Expand Down
3 changes: 2 additions & 1 deletion dali/operators/image/remap/warp_affine.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,6 +26,7 @@ DALI_SCHEMA(WarpAffine)
"Like ``matrix`` argument, but can be placed in GPU memory")
.NumOutput(1)
.InputLayout(0, { "HWC", "FHWC", "DHWC", "FDHWC" })
.InputDevice(1, InputDevice::MatchBackendOrCPU)
.SupportVolumetric()
.AddOptionalArg<float>("matrix",
R"code(Transform matrix.
Expand Down
2 changes: 2 additions & 0 deletions dali/operators/random/normal_distribution_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ a single value per sample is generated.
.NumInput(0, 1)
.InputDox(0, "shape_like", "TensorList",
"Shape of this input will be used to infer the shape of the output, if provided.")
.InputDevice(0, InputDevice::Any)
.NumOutput(1)
.AddOptionalArg<float>("mean",
R"code(Mean of the distribution.)code",
Expand All @@ -50,6 +51,7 @@ a single value per sample is generated.
.NumInput(0, 1)
.InputDox(0, "shape_like", "TensorList",
"Shape of this input will be used to infer the shape of the output, if provided.")
.InputDevice(0, InputDevice::Any)
.NumOutput(1)
.AddParent("random__Normal")
.Deprecate("random__Normal"); // Deprecated in 0.30
Expand Down
10 changes: 1 addition & 9 deletions dali/operators/random/rng_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,7 @@ class RNGBase : public OperatorWithRng<Backend> {
shape_ = ws.Input<Backend>(0).shape();
} else if (has_shape_like) {
int shape_like_idx = spec_.GetSchema().MinNumInput();
if (ws.InputIsType<Backend>(shape_like_idx)) {
shape_ = ws.Input<Backend>(shape_like_idx).shape();
} else if (std::is_same<GPUBackend, Backend>::value &&
ws.InputIsType<CPUBackend>(shape_like_idx)) {
shape_ = ws.Input<CPUBackend>(shape_like_idx).shape();
} else {
DALI_FAIL(
"Shape-like input can be either CPUBackend or GPUBackend for case of GPU operators.");
}
shape_ = ws.GetInputShape(shape_like_idx);
} else if (has_shape) {
GetShapeArgument(shape_, spec_, "shape", ws, nsamples);
} else {
Expand Down
1 change: 1 addition & 0 deletions dali/pipeline/executor/executor2/stream_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class StreamAssignmentDummyOp : public Operator<Backend> {

DALI_SCHEMA(StreamAssignmentDummyOp)
.NumInput(0, 999)
.InputDevice(0, 999, InputDevice::Any)
.NumOutput(0)
.AdditionalOutputsFn([](const OpSpec &spec) {
return spec.NumOutput();
Expand Down
18 changes: 0 additions & 18 deletions dali/pipeline/executor/lowered_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,6 @@ void CheckOpConstraints(const OpSpec &spec) {
" outputs, but was passed ", spec.NumOutput(), "."));
}

OpType ParseOpType(const std::string &device) {
if (device == "gpu") {
return OpType::GPU;
} else if (device == "cpu") {
return OpType::CPU;
} else if (device == "mixed") {
return OpType::MIXED;
}
DALI_FAIL("Unsupported device type: " + device + ".");
}

StorageDevice ParseStorageDevice(const std::string &io_device) {
if (io_device == "cpu") {
return StorageDevice::CPU;
}
return StorageDevice::GPU;
}

} // namespace

void OpGraph::Lower(const graph::OpGraph &definition) {
Expand Down
22 changes: 0 additions & 22 deletions dali/pipeline/graph/op_graph2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,6 @@
namespace dali {
namespace graph {

namespace {

OpType ParseOpType(const std::string &device) {
if (device == "gpu") {
return OpType::GPU;
} else if (device == "cpu") {
return OpType::CPU;
} else if (device == "mixed") {
return OpType::MIXED;
}
DALI_FAIL("Unsupported device type: " + device + ".");
}

StorageDevice ParseStorageDevice(const std::string &io_device) {
if (io_device == "cpu") {
return StorageDevice::CPU;
}
return StorageDevice::GPU;
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
// OpGraph

Expand Down
3 changes: 2 additions & 1 deletion dali/pipeline/operator/builtin/copy.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,6 +62,7 @@ DALI_REGISTER_OPERATOR(Copy, Copy<GPUBackend>, GPU);
DALI_SCHEMA(Copy)
.DocStr("Creates a copy of the input tensor.")
.NumInput(1)
.InputDevice(0, InputDevice::Any)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric();
Expand Down
3 changes: 2 additions & 1 deletion dali/pipeline/operator/builtin/make_contiguous.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,6 +80,7 @@ bool IsPassThrough(const OperatorBase &op) {
DALI_SCHEMA(MakeContiguous)
.DocStr(R"code(Move input batch to a contiguous representation, more suitable for execution on the GPU)code")
.NumInput(1)
.InputDevice(0, InputDevice::MatchBackendOrCPU)
.NumOutput(1)
.MakeInternal();

Expand Down
2 changes: 2 additions & 0 deletions dali/pipeline/operator/op_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ enum class InputDevice : uint8_t {
MatchBackend = 0,
CPU = 1,
GPU = 2,
Any = 3,
MatchBackendOrCPU = 4
};

class DLL_PUBLIC OpSchema {
Expand Down
63 changes: 59 additions & 4 deletions dali/pipeline/operator/op_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,70 @@

namespace dali {

inline bool IsCompatibleDevice(StorageDevice provided, InputDevice required, OpType op_device) {
switch (required) {
case InputDevice::CPU:
return provided == StorageDevice::CPU;
case InputDevice::GPU:
return provided == StorageDevice::GPU;
case InputDevice::MatchBackend:
return op_device == OpType::GPU
? provided == StorageDevice::GPU
: provided == StorageDevice::CPU;
case InputDevice::MatchBackendOrCPU:
return op_device == OpType::CPU ? provided == StorageDevice::CPU : true;
case InputDevice::Any:
return true;
default:
return false;
}
}

inline std::string ValidDevices(InputDevice required, OpType op_device) {
switch (required) {
case InputDevice::CPU:
return "\"cpu\"";
case InputDevice::GPU:
return "\"gpu\"";
case InputDevice::MatchBackend:
return op_device == OpType::GPU ? "\"gpu\"" : "\"cpu\"";
case InputDevice::MatchBackendOrCPU:
return op_device == OpType::GPU ? "\"gpu\" or \"cpu\"" : "\"cpu\"";
case InputDevice::Any:
return "\"gpu\" or \"cpu\"";
default:
assert(!"Unrechable");
return "<invalid>";
}
}

OpSpec& OpSpec::AddInput(const string &name, const string &device, bool regular_input) {
DALI_ENFORCE(device == "gpu" || device == "cpu", "Invalid device "
"specifier \"" + device + "\" for input \"" + name + "\". "
"Valid options are \"cpu\" or \"gpu\"");
auto dev = ParseStorageDevice(device);
if (regular_input) {
// We rely on the fact that regular inputs are first in inputs_ vector
DALI_ENFORCE(NumArgumentInput() == 0,
"All regular inputs (particularly, `" + name + "`) need to be added to the op `" +
"All regular inputs (particularly, \"" + name + "\") need to be added to the op `" +
GetOpDisplayName(*this, true) + "` before argument inputs.");

if (schema_) {
int idx = inputs_.size();
DALI_ENFORCE(idx < schema_->MaxNumInput(), make_string(
"The operator `", GetOpDisplayName(*this, true), "` takes up to ", schema_->MaxNumInput(),
" inputs. The input \"", name , "\" is out of range."));

if (HasArgument("device")) {
auto op_type_str = GetArgument<std::string>("device");
OpType op_type = ParseOpType(op_type_str);
auto inp_dev = schema_->GetInputDevice(idx);
DALI_ENFORCE(IsCompatibleDevice(dev, inp_dev, op_type),
make_string("The input ", idx, " for ", op_type_str, " operator `",
GetOpDisplayName(*this, true), "` is stored on incompatible device \"", device,
"\". Valid device is ", ValidDevices(inp_dev, op_type), "."));
}
}
} else {
DALI_ENFORCE(dev == StorageDevice::CPU, make_string("Invalid storage device \"", device,
"\" for a named input \"", name, "\". All named inputs must be on CPU."));
}

inputs_.push_back({name, device});
Expand Down
Loading

0 comments on commit 869afb3

Please sign in to comment.