Skip to content

Commit

Permalink
Reduce number of Workspace Input/Output APIs (NVIDIA#3446)
Browse files Browse the repository at this point in the history
Remove Input/Output API from all Workspaces
as they were non-uniform, sometimes being
sample-based and sometimes batch-based.
Keep InputRef/OutputRef which are
always batch-based.
Search and replace was used to adjust most
of the usages of the batch Input/Output
to InputRef/OutputRef.

InputRef/OutputRef have one common implementation.

Move `GetSample` which is a bit like
SampleWorkspace factory out of the HostWorkspace
as it's enough to have just the external access.

Make SampleWorkspace use regular pointers
instead of shared_ptr - it's just view to the data.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki authored and cyyever committed Jan 23, 2022
1 parent dd35b1b commit 66b4cf9
Show file tree
Hide file tree
Showing 69 changed files with 270 additions and 503 deletions.
6 changes: 3 additions & 3 deletions dali/benchmark/caffe2_alexnet_bench.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -127,7 +127,7 @@ BENCHMARK_DEFINE_F(C2Alexnet, Caffe2Pipe)(benchmark::State& st) { // NOLINT
}
}

WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down Expand Up @@ -258,7 +258,7 @@ BENCHMARK_DEFINE_F(C2Alexnet, HybridPipe)(benchmark::State& st) { // NOLINT
}
}

// WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
// WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down
6 changes: 3 additions & 3 deletions dali/benchmark/caffe_alexnet_bench.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -128,7 +128,7 @@ BENCHMARK_DEFINE_F(Alexnet, CaffePipe)(benchmark::State& st) { // NOLINT
}
}

WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down Expand Up @@ -259,7 +259,7 @@ BENCHMARK_DEFINE_F(Alexnet, HybridPipe)(benchmark::State& st) { // NOLINT
}
}

// WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
// WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down
2 changes: 1 addition & 1 deletion dali/benchmark/decoder_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class DecoderBench : public DALIBenchmark {
}
}

// WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
// WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + 1;
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down
4 changes: 2 additions & 2 deletions dali/benchmark/file_reader_alexnet_bench.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -128,7 +128,7 @@ BENCHMARK_DEFINE_F(FileReaderAlexnet, CaffePipe)(benchmark::State& st) { // NOLI
}
}

WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down
6 changes: 3 additions & 3 deletions dali/benchmark/resnet50_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ BENCHMARK_DEFINE_F(RN50, C2Pipe)(benchmark::State& st) { // NOLINT
}
}

// WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
// WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down Expand Up @@ -266,7 +266,7 @@ BENCHMARK_DEFINE_F(RN50, HybridPipe)(benchmark::State& st) { // NOLINT
}
}

// WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
// WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down Expand Up @@ -379,7 +379,7 @@ BENCHMARK_DEFINE_F(RN50, nvJPEGPipe)(benchmark::State& st) { // NOLINT
}
}

// WriteCHWBatch<float16>(ws.Output<GPUBackend>(0), 128, 1, "img");
// WriteCHWBatch<float16>(ws.OutputRef<GPUBackend>(0), 128, 1, "img");
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
benchmark::Counter::kIsRate);
Expand Down
4 changes: 2 additions & 2 deletions dali/benchmark/resnet50_nvjpeg_bench.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -112,7 +112,7 @@ BENCHMARK_DEFINE_F(RealRN50, nvjpegPipe)(benchmark::State& st) { // NOLINT
}

#if DALI_DEBUG
WriteHWCBatch(ws.Output<GPUBackend>(0), "img");
WriteHWCBatch(ws.OutputRef<GPUBackend>(0), "img");
#endif
int num_batches = st.iterations() + static_cast<int>(pipelined);
st.counters["FPS"] = benchmark::Counter(batch_size*num_batches,
Expand Down
16 changes: 8 additions & 8 deletions dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,17 +318,17 @@ void daliOutputRelease(daliPipelineHandle *pipe_handle) {
int64_t daliOutputHasUniformShape(daliPipelineHandle* pipe_handle, int i) {
dali::DeviceWorkspace* ws = reinterpret_cast<dali::DeviceWorkspace*>(pipe_handle->ws);
if (ws->OutputIsType<dali::CPUBackend>(i)) {
return is_uniform(ws->Output<dali::CPUBackend>(i).shape());
return is_uniform(ws->OutputRef<dali::CPUBackend>(i).shape());
} else {
return is_uniform(ws->Output<dali::GPUBackend>(i).shape());
return is_uniform(ws->OutputRef<dali::GPUBackend>(i).shape());
}
}

template<typename T>
static int64_t *daliShapeAtHelper(dali::DeviceWorkspace *ws, int n, int k) {
int64_t *c_shape = nullptr;
std::vector<dali::Index> shape;
const auto &out_tensor_list = ws->Output<T>(n);
const auto &out_tensor_list = ws->OutputRef<T>(n);
if (k >= 0) {
auto shape_span = out_tensor_list.tensor_shape_span(k);
shape = std::vector<dali::Index>(shape_span.begin(), shape_span.end());
Expand Down Expand Up @@ -366,7 +366,7 @@ int64_t* daliShapeAt(daliPipelineHandle* pipe_handle, int n) {

template <typename T>
static dali_data_type_t daliTypeAtHelper(dali::DeviceWorkspace* ws, int n) {
const auto &out_tensor_list = ws->Output<T>(n);
const auto &out_tensor_list = ws->OutputRef<T>(n);
auto type_id = out_tensor_list.type();
return static_cast<dali_data_type_t>(static_cast<int>(type_id));
}
Expand All @@ -383,7 +383,7 @@ dali_data_type_t daliTypeAt(daliPipelineHandle* pipe_handle, int n) {

template <typename T>
static size_t daliNumTensorsHelper(dali::DeviceWorkspace* ws, int n) {
return ws->Output<T>(n).num_samples();
return ws->OutputRef<T>(n).num_samples();
}

size_t daliNumTensors(daliPipelineHandle* pipe_handle, int n) {
Expand All @@ -397,7 +397,7 @@ size_t daliNumTensors(daliPipelineHandle* pipe_handle, int n) {

template <typename T>
static size_t daliNumElementsHelper(dali::DeviceWorkspace* ws, int n) {
return ws->Output<T>(n)._num_elements();
return ws->OutputRef<T>(n)._num_elements();
}

size_t daliNumElements(daliPipelineHandle* pipe_handle, int n) {
Expand All @@ -411,7 +411,7 @@ size_t daliNumElements(daliPipelineHandle* pipe_handle, int n) {

template <typename T>
static size_t daliTensorSizeHelper(dali::DeviceWorkspace* ws, int n) {
return ws->Output<T>(n).nbytes();
return ws->OutputRef<T>(n).nbytes();
}

size_t daliTensorSize(daliPipelineHandle* pipe_handle, int n) {
Expand All @@ -425,7 +425,7 @@ size_t daliTensorSize(daliPipelineHandle* pipe_handle, int n) {

template <typename T>
static size_t daliMaxDimTensorsHelper(dali::DeviceWorkspace* ws, int n) {
const auto &out_tensor_list = ws->Output<T>(n);
const auto &out_tensor_list = ws->OutputRef<T>(n);
size_t tensors_num = out_tensor_list.num_samples();
int max_num_dim = 0;
for (size_t i = 0; i < tensors_num; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions dali/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ComparePipelinesOutputs(daliPipelineHandle &handle, Pipeline &baseline,
EXPECT_EQ(daliNumTensors(&handle, output), batch_size);
for (int elem = 0; elem < batch_size; elem++) {
auto *shape = daliShapeAtSample(&handle, output, elem);
auto ref_shape = ws.Output<Backend>(output).shape()[elem];
auto ref_shape = ws.OutputRef<Backend>(output).shape()[elem];
int D = ref_shape.size();
for (int d = 0; d < D; d++)
EXPECT_EQ(shape[d], ref_shape[d]);
Expand All @@ -162,7 +162,7 @@ void ComparePipelinesOutputs(daliPipelineHandle &handle, Pipeline &baseline,

TensorList<CPUBackend> pipeline_output_cpu, c_api_output_cpu;
// Unnecessary copy in case of CPUBackend, makes the code generic across Backends
pipeline_output_cpu.Copy(ws.Output<Backend>(0), cuda_stream);
pipeline_output_cpu.Copy(ws.OutputRef<Backend>(0), cuda_stream);

auto num_elems = pipeline_output_cpu.shape().num_elements();
auto backend_buf = AllocBuffer<Backend>(num_elems * sizeof(uint8_t), false);
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/bbox/bbox_paste.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ canvas and ``(1,1)`` aligns it to bottom-right.

template<>
void BBoxPaste<CPUBackend>::RunImpl(Workspace<CPUBackend> &ws) {
const auto &input = ws.Input<CPUBackend>(0);
const auto &input = ws.InputRef<CPUBackend>(0);
const auto input_data = input.data<float>();

DALI_ENFORCE(input.type() == DALI_FLOAT, "Bounding box in wrong format");
DALI_ENFORCE(input.size() % 4 == 0, "Bounding box tensor size must be a multiple of 4."
"Got: " + std::to_string(input.size()));

auto &output = ws.Output<CPUBackend>(0);
auto &output = ws.OutputRef<CPUBackend>(0);
output.Resize(input.shape(), DALI_FLOAT);
auto *output_data = output.mutable_data<float>();

Expand Down
6 changes: 3 additions & 3 deletions dali/operators/debug/dump_image.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,8 @@ namespace dali {

template<>
void DumpImage<CPUBackend>::RunImpl(SampleWorkspace &ws) {
auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
auto &input = ws.InputRef<CPUBackend>(0);
auto &output = ws.OutputRef<CPUBackend>(0);

DALI_ENFORCE(input.ndim() == 3,
make_string("Input images must have three dimensions, got input with `",
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/debug/dump_image.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -20,8 +20,8 @@ namespace dali {

template<>
void DumpImage<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
auto &input = ws.InputRef<GPUBackend>(0);
auto &output = ws.OutputRef<GPUBackend>(0);


DALI_ENFORCE(input.shape().sample_dim() == 3,
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/decoder/host/host_decoder.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,8 @@
namespace dali {

void HostDecoder::RunImpl(SampleWorkspace &ws) {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
const auto &input = ws.InputRef<CPUBackend>(0);
auto &output = ws.OutputRef<CPUBackend>(0);
auto file_name = input.GetSourceInfo();

// Verify input
Expand Down
18 changes: 9 additions & 9 deletions dali/operators/decoder/nvjpeg/nvjpeg_decoder_decoupled_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
#endif // NVJPEG2K_ENABLED

for (int i = 0; i < curr_batch_size; i++) {
const auto &in = ws.Input<CPUBackend>(0, i);
const auto &in = ws.InputRef<CPUBackend>(0)[i];
const auto in_size = in.size();
thread_pool_.AddWork([this, i, &in, in_size](int tid) {
auto *input_data = in.data<uint8_t>();
Expand Down Expand Up @@ -683,7 +683,7 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
}

void ProcessImagesCache(MixedWorkspace &ws) {
auto& output = ws.Output<GPUBackend>(0);
auto& output = ws.OutputRef<GPUBackend>(0);
for (auto *sample : samples_cache_) {
assert(sample);
auto i = sample->sample_idx;
Expand All @@ -694,12 +694,12 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
}

void ProcessImagesCuda(MixedWorkspace &ws) {
auto& output = ws.Output<GPUBackend>(0);
auto& output = ws.OutputRef<GPUBackend>(0);
for (auto *sample : samples_single_) {
assert(sample);
auto i = sample->sample_idx;
auto *output_data = output.mutable_tensor<uint8_t>(i);
const auto &in = ws.Input<CPUBackend>(0, i);
const auto &in = ws.InputRef<CPUBackend>(0)[i];
thread_pool_.AddWork(
[this, sample, &in, output_data](int tid) {
SampleWorker(sample->sample_idx, sample->file_name, in.size(), tid,
Expand Down Expand Up @@ -799,11 +799,11 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
}

void ProcessImagesHost(MixedWorkspace &ws) {
auto& output = ws.Output<GPUBackend>(0);
auto& output = ws.OutputRef<GPUBackend>(0);
for (auto *sample : samples_host_) {
auto i = sample->sample_idx;
auto *output_data = output.mutable_tensor<uint8_t>(i);
const auto &in = ws.Input<CPUBackend>(0, i);
const auto &in = ws.InputRef<CPUBackend>(0)[i];
ImageCache::ImageShape shape = output_shape_[i].to_static<3>();
thread_pool_.AddWork(
[this, sample, &in, output_data, shape](int tid) {
Expand All @@ -816,7 +816,7 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {

void ProcessImagesHw(MixedWorkspace &ws) {
#if IS_HW_DECODER_COMPATIBLE
auto& output = ws.Output<GPUBackend>(0);
auto& output = ws.OutputRef<GPUBackend>(0);
if (!samples_hw_batched_.empty()) {
nvjpegJpegState_t &state = state_hw_batched_;
assert(state != nullptr);
Expand All @@ -839,7 +839,7 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {

for (auto *sample : samples_hw_batched_) {
int i = sample->sample_idx;
const auto &in = ws.Input<CPUBackend>(0, i);
const auto &in = ws.InputRef<CPUBackend>(0)[i];
const auto &out_shape = output_shape_.tensor_shape(i);

tv[j].ShareData(const_cast<Tensor<CPUBackend> &>(in));
Expand Down Expand Up @@ -891,7 +891,7 @@ class nvJPEGDecoder : public Operator<MixedBackend>, CachedDecoderImpl {
}

void ProcessImages(MixedWorkspace &ws) {
auto &output = ws.Output<GPUBackend>(0);
auto &output = ws.OutputRef<GPUBackend>(0);
assert(output_shape_.num_samples() ==
ws.GetInputBatchSize(0)); // If fails: Incorrect number of samples in shape
output.Resize(output_shape_, DALI_UINT8);
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/generic/flip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ void RunFlip(Tensor<CPUBackend> &output, const Tensor<CPUBackend> &input,

template <>
void Flip<CPUBackend>::RunImpl(Workspace<CPUBackend> &ws) {
const auto &input = ws.Input<CPUBackend>(0);
auto &output = ws.Output<CPUBackend>(0);
const auto &input = ws.InputRef<CPUBackend>(0);
auto &output = ws.OutputRef<CPUBackend>(0);
auto layout = input.GetLayout();
output.SetLayout(layout);
output.Resize(input.shape(), input.type());
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/generic/flip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void RunKernel(TensorList<GPUBackend> &output, const TensorList<GPUBackend> &inp

template <>
void Flip<GPUBackend>::RunImpl(Workspace<GPUBackend> &ws) {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
const auto &input = ws.InputRef<GPUBackend>(0);
auto &output = ws.OutputRef<GPUBackend>(0);
output.SetLayout(input.GetLayout());
output.Resize(input.shape(), input.type());
auto curr_batch_size = ws.GetInputBatchSize(0);
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/generic/pad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <>
bool Pad<GPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,
const workspace_t<GPUBackend> &ws) {
output_desc.resize(1);
const auto &input = ws.Input<GPUBackend>(0);
const auto &input = ws.InputRef<GPUBackend>(0);
auto in_shape = input.shape();
auto in_layout = input.GetLayout();
int ndim = in_shape.sample_dim();
Expand Down Expand Up @@ -57,8 +57,8 @@ bool Pad<GPUBackend>::SetupImpl(std::vector<OutputDesc> &output_desc,

template <>
void Pad<GPUBackend>::RunImpl(workspace_t<GPUBackend> &ws) {
const auto &input = ws.Input<GPUBackend>(0);
auto &output = ws.Output<GPUBackend>(0);
const auto &input = ws.InputRef<GPUBackend>(0);
auto &output = ws.OutputRef<GPUBackend>(0);
output.SetLayout(input.GetLayout());
int ndim = input.shape().sample_dim();
TYPE_SWITCH(input.type(), type2id, T, PAD_SUPPORTED_TYPES, (
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/image/color/brightness_contrast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ bool BrightnessContrastGpu::SetupImpl(std::vector<OutputDesc> &output_desc,


void BrightnessContrastGpu::RunImpl(workspace_t<GPUBackend> &ws) {
const auto &input = ws.template Input<GPUBackend>(0);
auto &output = ws.template Output<GPUBackend>(0);
const auto &input = ws.template InputRef<GPUBackend>(0);
auto &output = ws.template OutputRef<GPUBackend>(0);
output.SetLayout(input.GetLayout());
TYPE_SWITCH(input.type(), type2id, InputType, (uint8_t, int16_t, int32_t, float), (
TYPE_SWITCH(output_type_, type2id, OutputType, (uint8_t, int16_t, int32_t, float), (
Expand Down
Loading

0 comments on commit 66b4cf9

Please sign in to comment.