Skip to content

Commit

Permalink
Move tfrecord reader to readers module (#2722)
Browse files Browse the repository at this point in the history
Added placeholders/alias schemas in the same fashion as other ops (to facilitate grepping with readers__).
Reworked the wrapper class to be internal only. Use it to create two proper wrapper classes inserted into appropriate modules and wrapped into corresponding fn api functions.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Feb 25, 2021
1 parent 18118b0 commit 3d521a8
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 87 deletions.
61 changes: 49 additions & 12 deletions dali/operators/reader/tfrecord_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,20 @@

namespace dali {

DALI_REGISTER_OPERATOR(_TFRecordReader, TFRecordReader, CPU);
namespace {

int TFRecordReaderOutputFn(const OpSpec &spec) {
std::vector<std::string> v = spec.GetRepeatedArgument<std::string>("feature_names");
return v.size();
}

} // namespace

// Internal readers._tfrecord, note the triple underscore.
DALI_REGISTER_OPERATOR(readers___TFRecord, TFRecordReader, CPU);

DALI_SCHEMA(_TFRecordReaderBase)
// Common part of schema for internal readers._tfrecord and public readers.tfrecord schema.
DALI_SCHEMA(readers___TFRecordBase)
.DocStr(R"code(Read sample data from a TensorFlow TFRecord file.)code")
.AddArg("path",
R"code(List of paths to TFRecord files.)code",
Expand All @@ -36,24 +47,22 @@ The index files can be obtained from TFRecord files by using the ``tfrecord2idx`
that is distributed with DALI.)code",
DALI_STRING_VEC);

DALI_SCHEMA(_TFRecordReader)
// Internal readers._tfrecord schema.
DALI_SCHEMA(readers___TFRecord)
.DocStr(R"code(Reads samples from a TensorFlow TFRecord file.)code")
.OutputFn([](const OpSpec &spec) {
std::vector<std::string> v = spec.GetRepeatedArgument<std::string>("feature_names");
return v.size();
})
.OutputFn(TFRecordReaderOutputFn)
.NumInput(0)
.AddArg("feature_names", "Names of the features in TFRecord.",
DALI_STRING_VEC)
.AddArg("features", "List of features.",
DALI_TF_FEATURE_VEC)
.AddParent("_TFRecordReaderBase")
.AddParent("readers___TFRecordBase")
.AddParent("LoaderBase")
.MakeInternal();

// Schema for the actual TFRecordReader op exposed
// in Python. It is here for proper docstring generation
DALI_SCHEMA(TFRecordReader)
// Schema for the actual readers.tfrecord op expose in Python.
// It is here for proper docstring generation. Note the double underscore.
DALI_SCHEMA(readers__TFRecord)
.DocStr(R"code(Reads samples from a TensorFlow TFRecord file.)code")
.AddArg("features",
R"code(A dictionary that maps names of the TFRecord features to extract to the feature type.
Expand All @@ -65,9 +74,37 @@ Typically obtained by using the ``dali.tfrecord.FixedLenFeature`` and
the data will be reshaped to match its value, and the first dimension will be inferred from
the data size.)code",
DALI_TF_FEATURE_DICT)
.AddParent("_TFRecordReaderBase")
.AddParent("readers___TFRecordBase")
.AddParent("LoaderBase");


// Deprecated alias for internal op. Necessary for deprecation warning.
DALI_REGISTER_OPERATOR(_TFRecordReader, TFRecordReader, CPU);

DALI_SCHEMA(_TFRecordReader)
.DocStr("Legacy alias for :meth:`readers.tfrecord`.")
.OutputFn(TFRecordReaderOutputFn)
.NumInput(0)
.AddParent("readers___TFRecord")
.MakeInternal()
.Deprecate(
"readers__TFRecord",
R"code(In DALI 1.0 all readers were moved into a dedicated :mod:`~nvidia.dali.fn.readers`
submodule and renamed to follow a common pattern. This is a placeholder operator with identical
functionality to allow for backward compatibility.)code"); // Deprecated in 1.0;


// Deprecated alias
DALI_SCHEMA(TFRecordReader)
.DocStr("Legacy alias for :meth:`readers.tfrecord`.")
.AddParent("readers__TFRecord")
.MakeDocPartiallyHidden()
.Deprecate(
"readers__TFRecord",
R"code(In DALI 1.0 all readers were moved into a dedicated :mod:`~nvidia.dali.fn.readers`
submodule and renamed to follow a common pattern. This is a placeholder operator with identical
functionality to allow for backward compatibility.)code"); // Deprecated in 1.0;

} // namespace dali

#endif // DALI_BUILD_PROTO3
32 changes: 25 additions & 7 deletions dali/python/nvidia/dali/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,7 @@ def Reload():
_load_ops()

# custom wrappers around ops
class TFRecordReader(metaclass=_DaliOperatorMeta):
global _cpu_ops
_cpu_ops = _cpu_ops.union({'TFRecordReader'})
class _TFRecordReaderImpl():

def __init__(self, path, index_path, features, **kwargs):
if isinstance(path, list):
Expand All @@ -748,8 +746,8 @@ def __init__(self, path, index_path, features, **kwargs):
self._index_path = index_path
else:
self._index_path = [index_path]
self._schema = _b.GetSchema("_TFRecordReader")
self._spec = _b.OpSpec("_TFRecordReader")
self._schema = _b.GetSchema(self._internal_schema_name)
self._spec = _b.OpSpec(self._internal_schema_name)
self._device = "cpu"

self._spec.AddArg("path", self._path)
Expand Down Expand Up @@ -806,7 +804,26 @@ def __call__(self, *inputs, **kwargs):
op_instance.spec.AddArg("features", features)
return outputs

TFRecordReader.__call__.__doc__ = _docstring_generator_call("TFRecordReader")
_TFRecordReaderImpl.__call__.__doc__ = _docstring_generator_call("readers__TFRecord")

def _load_readers_tfrecord():
global _cpu_ops
_cpu_ops = _cpu_ops.union({'readers__TFRecord', 'TFRecordReader'})

ops_module = sys.modules[__name__]
class TFRecordReader(_TFRecordReaderImpl, metaclass=_DaliOperatorMeta): pass
class TFRecord(_TFRecordReaderImpl, metaclass=_DaliOperatorMeta): pass
for op_reg_name, internal_schema, op_class in [('readers__TFRecord', 'readers___TFRecord', TFRecord),
('TFRecordReader', '_TFRecordReader', TFRecordReader)]:
op_class.schema_name = op_reg_name
op_class._internal_schema_name = internal_schema
op_full_name, submodule, op_name = _process_op_name(op_reg_name)
module = _internal.get_submodule(ops_module, submodule)
if not hasattr(module, op_name):
op_class.__module__ = module.__name__
setattr(module, op_name, op_class)
_wrap_op(op_class, submodule)


class PythonFunctionBase(metaclass=_DaliOperatorMeta):
def __init__(self, impl_name, function, num_outputs=1, device='cpu', **kwargs):
Expand Down Expand Up @@ -988,7 +1005,6 @@ def __init__(self, function, num_outputs=1, device='cpu', synchronize_stream=Tru

_wrap_op(PythonFunction)
_wrap_op(DLTensorPythonFunction)
_wrap_op(TFRecordReader)


def _choose_device(inputs):
Expand Down Expand Up @@ -1230,4 +1246,6 @@ def Compose(op_list):
_cpu_ops = _cpu_ops.union({"Compose"})
_gpu_ops = _gpu_ops.union({"Compose"})


_load_ops()
_load_readers_tfrecord()
22 changes: 11 additions & 11 deletions dali/test/python/test_RN50_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ def __init__(self, **kwargs):
tfrecord = sorted(glob.glob(kwargs['data_paths'][0]))
tfrecord_idx = sorted(glob.glob(kwargs['data_paths'][1]))
cache_enabled = kwargs['decoder_cache_params']['cache_enabled']
self.input = ops.TFRecordReader(path = tfrecord,
index_path = tfrecord_idx,
shard_id = kwargs['shard_id'],
num_shards = kwargs['num_shards'],
random_shuffle = kwargs['random_shuffle'],
dont_use_mmap = kwargs['dont_use_mmap'],
stick_to_shard = cache_enabled,
#skip_cached_images = cache_enabled,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)
})
self.input = ops.readers.TFRecord(path = tfrecord,
index_path = tfrecord_idx,
shard_id = kwargs['shard_id'],
num_shards = kwargs['num_shards'],
random_shuffle = kwargs['random_shuffle'],
dont_use_mmap = kwargs['dont_use_mmap'],
stick_to_shard = cache_enabled,
#skip_cached_images = cache_enabled,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)
})

def define_graph(self):
inputs = self.input(name="Reader")
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/test_coco_tfrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class TFRecordDetectionPipeline(Pipeline):
def __init__(self, args):
super(TFRecordDetectionPipeline, self).__init__(
args.batch_size, args.num_workers, 0, 0)
self.input = ops.TFRecordReader(
self.input = ops.readers.TFRecord(
path = os.path.join(test_dummy_data_path, 'small_coco.tfrecord'),
index_path = os.path.join(test_dummy_data_path, 'small_coco_index.idx'),
features = {
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/test_dali_cpu_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_tfrecord_reader_cpu():
pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=None)
tfrecord = sorted(glob.glob(os.path.join(tfrecord_dir, '*[!i][!d][!x]')))
tfrecord_idx = sorted(glob.glob(os.path.join(tfrecord_dir, '*idx')))
input = fn.tfrecord_reader(path = tfrecord,
input = fn.readers.tfrecord(path = tfrecord,
index_path = tfrecord_idx,
shard_id=0, num_shards=1,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
Expand Down
14 changes: 7 additions & 7 deletions dali/test/python/test_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ def __init__(self, batch_size, num_threads, device_id, num_gpus, data_paths, don
super(TFRecordPipeline, self).__init__(batch_size, num_threads, device_id)
tfrecord = sorted(glob.glob(data_paths[0]))
tfrecord_idx = sorted(glob.glob(data_paths[1]))
self.input = ops.TFRecordReader(path = tfrecord,
index_path = tfrecord_idx,
shard_id = device_id,
num_shards = num_gpus,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)
}, dont_use_mmap=dont_use_mmap)
self.input = ops.readers.TFRecord(path = tfrecord,
index_path = tfrecord_idx,
shard_id = device_id,
num_shards = num_gpus,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)
}, dont_use_mmap=dont_use_mmap)

def define_graph(self):
inputs = self.input(name="Reader")
Expand Down
31 changes: 23 additions & 8 deletions dali/test/python/test_operator_readers_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def test_tfrecord():
class TFRecordPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, num_gpus, data, data_idx):
super(TFRecordPipeline, self).__init__(batch_size, num_threads, device_id)
self.input = ops.TFRecordReader(path = data,
index_path = data_idx,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)
})
self.input = ops.readers.TFRecord(path = data,
index_path = data_idx,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)}
)

def define_graph(self):
inputs = self.input(name="Reader")
Expand Down Expand Up @@ -95,9 +95,9 @@ def test_wrong_feature_shape():
test_dummy_data_path = os.path.join(get_dali_extra_path(), 'db', 'coco_dummy')
pipe = Pipeline(1, 1, 0)
with pipe:
input = fn.tfrecord_reader(path = os.path.join(test_dummy_data_path, 'small_coco.tfrecord'),
index_path = os.path.join(test_dummy_data_path, 'small_coco_index.idx'),
features = features)
input = fn.readers.tfrecord(path = os.path.join(test_dummy_data_path, 'small_coco.tfrecord'),
index_path = os.path.join(test_dummy_data_path, 'small_coco_index.idx'),
features = features)
pipe.set_outputs(input['image/encoded'], input['image/object/class/label'], input['image/object/bbox'])
pipe.build()
# the error is raised because FixedLenFeature is used with insufficient shape to house the input
Expand All @@ -117,3 +117,18 @@ def test_mxnet_reader_alias():
new_pipe = mxnet_pipe(fn.readers.mxnet, recordio, recordio_idx)
legacy_pipe = mxnet_pipe(fn.mxnet_reader, recordio, recordio_idx)
compare_pipelines(new_pipe, legacy_pipe, batch_size_alias_test, 50)


@pipeline_def(batch_size=batch_size_alias_test, device_id=0, num_threads=4)
def tfrecord_pipe(tfrecord_op, path, index_path):
inputs = tfrecord_op(path=path, index_path=index_path,
features={"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)})
return inputs["image/encoded"]

def test_tfrecord_reader_alias():
tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train')
tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx')
new_pipe = tfrecord_pipe(fn.readers.tfrecord, tfrecord, tfrecord_idx)
legacy_pipe = tfrecord_pipe(fn.tfrecord_reader, tfrecord, tfrecord_idx)
compare_pipelines(new_pipe, legacy_pipe, batch_size_alias_test, 50)
24 changes: 12 additions & 12 deletions dali/test/python/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,17 +1049,17 @@ def __init__(self, reader_type, batch_size, is_cached=False, is_cached_batch_cop
skip_cached_images = skip_cached_images,
prefetch_queue_depth = 1)

elif reader_type == "TFRecordReader":
elif reader_type == "readers.TFRecord":
tfrecord = sorted(glob.glob(os.path.join(tfrecord_db_folder, '*[!i][!d][!x]')))
tfrecord_idx = sorted(glob.glob(os.path.join(tfrecord_db_folder, '*idx')))
self.input = ops.TFRecordReader(path = tfrecord,
index_path = tfrecord_idx,
shard_id = 0,
num_shards = num_shards,
stick_to_shard = True,
skip_cached_images = skip_cached_images,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)})
self.input = ops.readers.TFRecord(path = tfrecord,
index_path = tfrecord_idx,
shard_id = 0,
num_shards = num_shards,
stick_to_shard = True,
skip_cached_images = skip_cached_images,
features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""),
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)})

if is_cached:
self.decode = ops.ImageDecoder(device = "mixed", output_type = types.RGB,
Expand All @@ -1073,7 +1073,7 @@ def __init__(self, reader_type, batch_size, is_cached=False, is_cached_batch_cop
# hw_decoder_load=0.0 for deterministic results
self.decode = ops.ImageDecoder(device = "mixed", output_type = types.RGB, hw_decoder_load = 0.0)
def define_graph(self):
if self.reader_type == "TFRecordReader":
if self.reader_type == "readers.TFRecord":
inputs = self.input()
jpegs = inputs["image/encoded"]
labels = inputs["image/class/label"]
Expand All @@ -1085,14 +1085,14 @@ def define_graph(self):

def test_nvjpeg_cached_batch_copy_pipelines():
batch_size = 26
for reader_type in {"MXNetReader", "CaffeReader", "Caffe2Reader", "FileReader", "TFRecordReader"}:
for reader_type in {"MXNetReader", "CaffeReader", "Caffe2Reader", "FileReader", "readers.TFRecord"}:
compare_pipelines(CachedPipeline(reader_type, batch_size, is_cached=True, is_cached_batch_copy=True),
CachedPipeline(reader_type, batch_size, is_cached=True, is_cached_batch_copy=False),
batch_size=batch_size, N_iterations=20)

def test_nvjpeg_cached_pipelines():
batch_size = 26
for reader_type in {"MXNetReader", "CaffeReader", "Caffe2Reader", "FileReader", "TFRecordReader"}:
for reader_type in {"MXNetReader", "CaffeReader", "Caffe2Reader", "FileReader", "readers.TFRecord"}:
compare_pipelines(CachedPipeline(reader_type, batch_size, is_cached=False),
CachedPipeline(reader_type, batch_size, is_cached=True),
batch_size=batch_size, N_iterations=20)
Expand Down
2 changes: 1 addition & 1 deletion dali/test/python/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_pipeline_including_custom_plugin(self):

def test_python_operator_and_custom_plugin(self):
plugin_manager.load_library( test_bin_dir + "/libcustomdummyplugin.so")
ops.TFRecordReader(path="dummy", index_path="dummy", features={})
ops.readers.TFRecord(path="dummy", index_path="dummy", features={})

if __name__ == '__main__':
unittest.main()
12 changes: 6 additions & 6 deletions docs/examples/frameworks/mxnet/mxnet-various-readers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"- MXNetReader\n",
"- CaffeReader\n",
"- FileReader\n",
"- TFRecordReader\n",
"- readers.TFRecord\n",
"\n",
"For details on how to use them please see other [examples](../../index.rst)."
]
Expand Down Expand Up @@ -216,11 +216,11 @@
"class TFRecordPipeline(CommonPipeline):\n",
" def __init__(self, batch_size, num_threads, device_id, num_gpus):\n",
" super(TFRecordPipeline, self).__init__(batch_size, num_threads, device_id)\n",
" self.input = ops.TFRecordReader(path = tfrecord, \n",
" index_path = tfrecord_idx,\n",
" features = {\"image/encoded\" : tfrec.FixedLenFeature((), tfrec.string, \"\"),\n",
" \"image/class/label\": tfrec.FixedLenFeature([1], tfrec.int64, -1)\n",
" })\n",
" self.input = ops.readers.TFRecord(path = tfrecord, \n",
" index_path = tfrecord_idx,\n",
" features = {\"image/encoded\" : tfrec.FixedLenFeature((), tfrec.string, \"\"),\n",
" \"image/class/label\": tfrec.FixedLenFeature([1], tfrec.int64, -1)\n",
" })\n",
"\n",
" def define_graph(self):\n",
" inputs = self.input(name=\"Reader\")\n",
Expand Down
Loading

0 comments on commit 3d521a8

Please sign in to comment.