diff --git a/dali/operators/util/get_property.cc b/dali/operators/util/get_property.cc index 409f6f86e4..f6abb90cf7 100644 --- a/dali/operators/util/get_property.cc +++ b/dali/operators/util/get_property.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ DALI_SCHEMA(GetProperty) The type of the output will depend on the ``key`` of the requested property.)code") .NumInput(1) + .InputDevice(0, InputDevice::Metadata) .NumOutput(1) .AddArg("key", R"code(Specifies, which property is requested. @@ -38,6 +39,110 @@ The following properties are supported: )code", DALI_STRING); +template +void GetPerSample(TensorList &out, const TensorList &in, + SampleShapeFunc &&sample_shape, CopySampleFunc &©_sample) { + int N = in.num_samples(); + TensorListShape<> tls; + for (int i = 0; i < N; i++) { + auto shape = sample_shape(in, i); + if (i == 0) + tls.resize(N, shape.sample_dim()); + tls.set_tensor_shape(i, shape); + } + out.Resize(tls, DALI_UINT8); + for (int i = 0; i < N; i++) { + copy_sample(out, in, i); + } +} + +template +void SourceInfoToTL(TensorList &out, const TensorList &in) { + GetPerSample(out, in, + [](auto &in, int idx) { + auto &info = in.GetMeta(idx).GetSourceInfo(); + return TensorShape<1>(info.length()); + }, + [](auto &out, auto &in, int idx) { + auto &info = in.GetMeta(idx).GetSourceInfo(); + std::memcpy(out.raw_mutable_tensor(idx), info.c_str(), info.length()); + }); +} + +template +void SourceInfoToTL(TensorList &out, const TensorList &in) { + TensorList tmp; + tmp.set_pinned(true); + SourceInfoToTL(tmp, in); + tmp.set_order(out.order()); + out.Copy(tmp); +} + +template +void SourceInfoToTL(TensorList &out, const Workspace &ws) { + ws.Output(0).set_order(ws.output_order()); + if (ws.InputIsType(0)) + return SourceInfoToTL(out, ws.Input(0)); + else if (ws.InputIsType(0)) + return SourceInfoToTL(out, ws.Input(0)); + else + DALI_FAIL("Internal error - input 0 is neither CPU nor GPU."); +} + +template +void RepeatTensor(TensorList &tl, const Tensor &t, int N) { + tl.Reset(); + tl.set_device_id(t.device_id()); + tl.SetSize(N); + tl.set_sample_dim(t.ndim()); + tl.set_type(t.type()); + tl.SetLayout(t.GetLayout()); + for (int i = 0; i < N; i++) + tl.SetSample(i, t); +} + +template +void RepeatFirstSample(TensorList &tl, int N) { + Tensor t; + TensorShape<> shape = tl[0].shape(); + t.ShareData(unsafe_sample_owner(tl, 0), shape.num_elements(), tl.is_pinned(), + shape, tl.type(), tl.device_id(), tl.order()); + t.SetMeta(tl.GetMeta(0)); + RepeatTensor(tl, t, N); +} + +void LayoutToTL(TensorList &out, const Workspace &ws) { + TensorLayout l = ws.GetInputLayout(0); + out.Resize(uniform_list_shape(1, { l.size() }), DALI_UINT8); + memcpy(out.raw_mutable_tensor(0), l.data(), l.size()); + RepeatFirstSample(out, ws.GetInputBatchSize(0)); +} + +void LayoutToTL(TensorList &out, const Workspace &ws) { + TensorLayout l = ws.GetInputLayout(0); + Tensor tmp_cpu; + Tensor tmp_gpu; + tmp_cpu.Resize(TensorShape<1>(l.size()), DALI_UINT8); + memcpy(tmp_cpu.raw_mutable_data(), l.data(), l.size()); + tmp_cpu.set_order(ws.output_order()); + tmp_gpu.set_order(ws.output_order()); + tmp_gpu.Copy(tmp_cpu); + + RepeatTensor(out, tmp_gpu, ws.GetInputBatchSize(0)); +} + +template +auto GetProperty::GetPropertyReader(std::string_view key) -> PropertyReader { + if (key == "source_info") { + return static_cast(SourceInfoToTL); + } else if (key == "layout") { + return static_cast(LayoutToTL); + } else { + DALI_FAIL(make_string("Unsupported property key: ", key)); + } +} + + DALI_REGISTER_OPERATOR(GetProperty, GetProperty, CPU) DALI_REGISTER_OPERATOR(GetProperty, GetProperty, GPU) diff --git a/dali/operators/util/get_property.h b/dali/operators/util/get_property.h index 02ff7c1bd5..59c0a03ef1 100644 --- a/dali/operators/util/get_property.h +++ b/dali/operators/util/get_property.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ #include #include +#include #include -#include "dali/operators/util/property.h" #include "dali/pipeline/data/type_traits.h" #include "dali/pipeline/operator/common.h" #include "dali/pipeline/operator/checkpointing/stateless_operator.h" @@ -32,41 +32,29 @@ class GetProperty : public StatelessOperator { explicit GetProperty(const OpSpec &spec) : StatelessOperator(spec), property_key_(spec.template GetArgument("key")), - property_(PropertyFactory()) {} - - ~GetProperty() override = default; - DISABLE_COPY_MOVE_ASSIGN(GetProperty); + property_reader_(GetPropertyReader(property_key_)) {} protected: bool CanInferOutputs() const override { - return true; + return false; // we may broadcast a common value to all samples } bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { - const auto &input = ws.Input(0); - output_desc.resize(1); - output_desc[0].shape = property_->GetShape(input); - output_desc[0].type = property_->GetType(input); - return true; + return false; } void RunImpl(Workspace &ws) override { - property_->FillOutput(ws); + property_reader_(ws.Output(0), ws); } private: - std::unique_ptr> PropertyFactory() { - if (property_key_ == "source_info") { - return std::make_unique>(); - } else if (property_key_ == "layout") { - return std::make_unique>(); - } else { - DALI_FAIL(make_string("Unknown property key: ", property_key_)); - } - } + using PropertyReaderFunc = void(TensorList &, const Workspace &); + using PropertyReader = std::function; + + std::string property_key_; + PropertyReader property_reader_; - const std::string property_key_; - std::unique_ptr> property_; + static PropertyReader GetPropertyReader(std::string_view key); }; } // namespace dali diff --git a/dali/operators/util/property.cc b/dali/operators/util/property.cc deleted file mode 100644 index eb62213cd9..0000000000 --- a/dali/operators/util/property.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dali/operators/util/property.h" -#include "dali/pipeline/data/backend.h" - -namespace dali { -namespace tensor_property { - -template <> -void SourceInfo::FillOutput(Workspace &ws) { - const auto& input = ws.Input(0); - auto& output = ws.Output(0); - for (int sample_id = 0; sample_id < input.num_samples(); sample_id++) { - auto si = GetSourceInfo(input, sample_id); - std::memcpy(output.mutable_tensor(sample_id), si.c_str(), si.length()); - } -} - -template <> -void Layout::FillOutput(Workspace &ws) { - const auto& input = ws.Input(0); - auto& output = ws.Output(0); - for (int sample_id = 0; sample_id < input.num_samples(); sample_id++) { - auto layout = GetLayout(input, sample_id); - std::memcpy(output.mutable_tensor(sample_id), layout.c_str(), layout.size()); - } -} - -template <> -void SourceInfo::FillOutput(Workspace &ws) { - const auto& input = ws.Input(0); - auto& output = ws.Output(0); - for (int sample_id = 0; sample_id < input.num_samples(); sample_id++) { - auto si = GetSourceInfo(input, sample_id); - auto output_ptr = output.raw_mutable_tensor(sample_id); - cudaMemcpyAsync(output_ptr, si.c_str(), si.length(), cudaMemcpyDefault, ws.stream()); - } -} - -template <> -void Layout::FillOutput(Workspace &ws) { - const auto& input = ws.Input(0); - auto& output = ws.Output(0); - for (int sample_id = 0; sample_id < input.num_samples(); sample_id++) { - auto layout = GetLayout(input, sample_id); - auto output_ptr = output.raw_mutable_tensor(sample_id); - cudaMemcpyAsync(output_ptr, layout.c_str(), layout.size(), cudaMemcpyDefault, ws.stream()); - } -} - -} // namespace tensor_property -} // namespace dali diff --git a/dali/operators/util/property.h b/dali/operators/util/property.h deleted file mode 100644 index 7c46a7b50c..0000000000 --- a/dali/operators/util/property.h +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DALI_OPERATORS_UTIL_PROPERTY_H_ -#define DALI_OPERATORS_UTIL_PROPERTY_H_ - -#include -#include "dali/pipeline/data/type_traits.h" -#include "dali/pipeline/operator/common.h" -#include "dali/pipeline/operator/operator.h" - -namespace dali { -namespace tensor_property { - -/** - * Base class for a property of the Tensor. - * @tparam Backend Backend of the operator. - */ -template -struct Property { - Property() = default; - virtual ~Property() = default; - - /** - * @return The shape of the tensor containing the property, based on the input to the operator. - */ - virtual TensorListShape<> GetShape(const TensorList& input) = 0; - - /** - * @return The type of the tensor containing the property, based on the input to the operator. - */ - virtual DALIDataType GetType(const TensorList& input) = 0; - - /** - * This function implements filling the output of the operator. Its implementation should - * be similar to any RunImpl function of the operator. - */ - virtual void FillOutput(Workspace&) = 0; -}; - - -template -struct SourceInfo : public Property { - TensorListShape<> GetShape(const TensorList& input) override { - TensorListShape<> ret{static_cast(input.num_samples()), 1}; - for (int i = 0; i < ret.size(); i++) { - ret.set_tensor_shape(i, {static_cast(GetSourceInfo(input, i).length())}); - } - return ret; - } - - DALIDataType GetType(const TensorList&) override { - return DALI_UINT8; - } - - void FillOutput(Workspace &ws) override; - - private: - const std::string& GetSourceInfo(const TensorList& input, size_t idx) { - return input.GetMeta(idx).GetSourceInfo(); - } -}; - - -template -struct Layout : public Property { - TensorListShape<> GetShape(const TensorList& input) override { - // Every tensor in the output has the same number of dimensions - return uniform_list_shape(input.num_samples(), {GetLayout(input, 0).size()}); - } - - DALIDataType GetType(const TensorList&) override { - return DALI_UINT8; - } - - void FillOutput(Workspace &ws) override; - - private: - const TensorLayout& GetLayout(const TensorList& input, int idx) { - return input.GetMeta(idx).GetLayout(); - } -}; - - -} // namespace tensor_property -} // namespace dali - -#endif // DALI_OPERATORS_UTIL_PROPERTY_H_ diff --git a/dali/python/nvidia/dali/data_node.py b/dali/python/nvidia/dali/data_node.py index 3892a1ce5c..1265501ae7 100644 --- a/dali/python/nvidia/dali/data_node.py +++ b/dali/python/nvidia/dali/data_node.py @@ -276,6 +276,32 @@ def shape(self, *, dtype=None, device="cpu"): self._check_gpu2cpu() return fn.shapes(self, dtype=dtype, device=device) + def property(self, key, *, device="cpu"): + """Returns a metadata property associated with a DataNode + + Parameters + ---------- + key : str + The name of the metadata item. Currently supported: + "source_info" - the file name or location in the dataset where the data originated + (each sample is a 1D uint8 tensor) + "layout" - the layout string + (each sample is a 1D uint8 tensor) + device : str, optional + The device, where the value is returned; defaults to CPU. + """ + + from . import fn + + if device == "cpu": + self._check_gpu2cpu() + + return fn.get_property(self, key=key, device=device) + + def source_info(self, *, device="cpu"): + """Returns the "source_info" property. Equivalent to self.meta("source_info").""" + return self.property("source_info", device=device) + def _check_gpu2cpu(self): if self.device == "gpu" and self.source and self.source.pipeline: if not self.source.pipeline._exec_dynamic: diff --git a/dali/test/python/operator_1/test_get_property.py b/dali/test/python/operator_1/test_get_property.py index 2be70a8379..da718300ca 100644 --- a/dali/test/python/operator_1/test_get_property.py +++ b/dali/test/python/operator_1/test_get_property.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,8 @@ from nvidia.dali import pipeline_def import os import numpy as np -from nvidia.dali import fn +import nvidia.dali as dali +import nvidia.dali.fn as fn from test_utils import get_dali_extra_path from nose_utils import raises import tempfile @@ -52,12 +53,12 @@ def test_file_properties(): yield _test_file_properties, dev -@pipeline_def -def wds_properties(root_path, device, idx_paths): +@pipeline_def(experimental_exec_dynamic=True) +def wds_source_info(root_path, device, idx_paths): read = fn.readers.webdataset(paths=[root_path], index_paths=idx_paths, ext=["jpg"]) if device == "gpu": read = read.gpu() - return fn.get_property(read, key="source_info") + return read.source_info() def generate_wds_index(root_path, index_path): @@ -67,7 +68,7 @@ def generate_wds_index(root_path, index_path): ic.create_index() -def _test_wds_properties(device, generate_index): +def _test_wds_source_info(device, generate_index): root_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") ref_filenames = [ "2000.jpg", @@ -84,25 +85,24 @@ def _test_wds_properties(device, generate_index): with tempfile.TemporaryDirectory() as idx_dir: index_paths = [os.path.join(idx_dir, os.path.basename(root_path) + ".idx")] generate_wds_index(root_path, index_paths[0]) - p = wds_properties( + p = wds_source_info( root_path, device, index_paths, batch_size=8, num_threads=4, device_id=0 ) p.build() output = p.run() else: - p = wds_properties(root_path, device, None, batch_size=8, num_threads=4, device_id=0) + p = wds_source_info(root_path, device, None, batch_size=8, num_threads=4, device_id=0) p.build() output = p.run() for out in output: - out = out if device == "cpu" else out.as_cpu() for source_info, ref_fname, ref_idx in zip(out, ref_filenames, ref_indices): assert _uint8_tensor_to_string(source_info) == f"{root_path}:{ref_idx}:{ref_fname}" -def test_wds_properties(): +def test_wds_source_info(): for dev in ["cpu", "gpu"]: for gen_idx in [True, False]: - yield _test_wds_properties, dev, gen_idx + yield _test_wds_source_info, dev, gen_idx @pipeline_def @@ -180,7 +180,7 @@ def improper_property(root_path, device): return fn.get_property(read, key=["this key doesn't exist"]) -@raises(RuntimeError, glob="Unknown property key*") +@raises(RuntimeError, glob="Unsupported property key*") def _test_improper_property(device): root_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") p = improper_property(root_path, device, batch_size=8, num_threads=4, device_id=0) @@ -191,3 +191,16 @@ def _test_improper_property(device): def test_improper_property(): for dev in ["cpu", "gpu"]: yield _test_improper_property, dev + + +def test_get_property_gpu2cpu(): + @pipeline_def(batch_size=2, device_id=0, num_threads=1, experimental_exec_dynamic=True) + def test_pipe(): + data = dali.types.Constant(np.array([[[42]]]), device="gpu", layout="abc") + return fn.get_property(data, key="layout", device="cpu") + + pipe = test_pipe() + pipe.build() + (out,) = pipe.run() + assert _uint8_tensor_to_string(out[0]) == "abc" + assert _uint8_tensor_to_string(out[1]) == "abc"