Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeTable/TypeInfo optimization #5634

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions dali/operators/math/expressions/expression_tree.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -84,9 +84,10 @@ class ExprNode {

virtual std::string GetOutputDesc() const {
const auto &op_type = TypeTable::GetTypeInfo(GetTypeId()).name();
std::string result = GetAbbreviation(GetNodeType());
result += IsScalarLike(GetShape()) ? "C:" : "T:";
return result + op_type;
return make_string(
GetAbbreviation(GetNodeType()),
IsScalarLike(GetShape()) ? "C:" : "T:",
op_type);
}

virtual NodeType GetNodeType() const = 0;
Expand Down Expand Up @@ -140,15 +141,18 @@ class ExprFunc : public ExprNode {

std::string GetNodeDesc() const override {
const auto &op_type = TypeTable::GetTypeInfo(GetTypeId()).name();
std::string result = func_name_ + (IsScalarLike(GetShape()) ? ":C:" : ":T:") + op_type + "(";
std::stringstream result;
result << func_name_ << (IsScalarLike(GetShape()) ? ":C:" : ":T:") << op_type << "(";

for (int i = 0; i < GetSubexpressionCount(); i++) {
result += (*this)[i].GetOutputDesc();
result << (*this)[i].GetOutputDesc();
if (i < GetSubexpressionCount() - 1) {
result += " ";
result << " ";
}
}
result += ")";
return result;

result << ")";
return result.str();
}

NodeType GetNodeType() const override {
Expand Down
4 changes: 2 additions & 2 deletions dali/pipeline/data/types.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -14,7 +14,7 @@

#define DALI_TYPENAME_REGISTERER(Type, dtype) \
{ \
return to_string(dtype); \
return dali::TypeName(dtype); \
}

#define DALI_TYPEID_REGISTERER(Type, dtype) \
Expand Down
170 changes: 109 additions & 61 deletions dali/pipeline/data/types.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,16 +15,20 @@
#ifndef DALI_PIPELINE_DATA_TYPES_H_
#define DALI_PIPELINE_DATA_TYPES_H_

#include <algorithm>
#include <atomic>
#include <cstdint>
#include <cstring>
#include <functional>
#include <list>
#include <mutex>
#include <string>
#include <type_traits>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include <vector>
#include "dali/core/util.h"
#include "dali/core/common.h"
#include "dali/core/spinlock.h"
#include "dali/core/float16.h"
Expand Down Expand Up @@ -123,7 +127,8 @@ enum DALIDataType : int {
DALI_PYTHON_OBJECT = 24,
DALI_TENSOR_LAYOUT_VEC = 25,
DALI_DATA_TYPE_VEC = 26,
DALI_DATATYPE_END = 1000
DALI_NUM_BUILTIN_TYPES,
DALI_CUSTOM_TYPE_START = 1001
};

inline const char *GetBuiltinTypeName(DALIDataType t) {
Expand Down Expand Up @@ -397,7 +402,7 @@ class DLL_PUBLIC TypeInfo {
return type_size_;
}

DLL_PUBLIC inline const string &name() const {
DLL_PUBLIC inline std::string_view name() const {
return name_;
}

Expand All @@ -410,12 +415,12 @@ class DLL_PUBLIC TypeInfo {

DALIDataType id_ = DALI_NO_TYPE;
size_t type_size_ = 0;
std::string name_ = dali::to_string(DALI_NO_TYPE);
std::string_view name_ = GetBuiltinTypeName(DALI_NO_TYPE);
};

template <typename T>
struct TypeNameHelper {
static string GetTypeName() {
static std::string_view GetTypeName() {
return typeid(T).name();
}
};
Expand All @@ -427,23 +432,23 @@ class DLL_PUBLIC TypeTable {
public:
template <typename T>
DLL_PUBLIC static DALIDataType GetTypeId() {
auto &inst = instance();
static DALIDataType type_id = inst.RegisterType<T>(static_cast<DALIDataType>(++inst.index_));
static DALIDataType type_id = instance().RegisterType<T>(
static_cast<DALIDataType>(instance().next_id_++));
return type_id;
}

template <typename T>
DLL_PUBLIC static string GetTypeName() {
DLL_PUBLIC static std::string_view GetTypeName() {
return TypeNameHelper<T>::GetTypeName();
}

DLL_PUBLIC static const TypeInfo *TryGetTypeInfo(DALIDataType dtype) {
auto &inst = instance();
std::lock_guard<spinlock> guard(inst.lock_);
auto id_it = inst.type_info_map_.find(dtype);
if (id_it == inst.type_info_map_.end())
auto *types = instance().type_info_map_;
assert(types);
size_t idx = dtype - DALI_NO_TYPE;
if (idx >= types->size())
return nullptr;
return &id_it->second;
return (*types)[idx];
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when RegisterType allocated new vector of TypeInfo pointers because of idx >= type_info_map_->size() and we're asking for some type registered earlier? Won't we end up with nullptr here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, there was as insert that copied the previous vector, but I must have lost it somewhere along the way. This also shows, that this is likely untested. I'll add tests that cover adding new types.


DLL_PUBLIC static const TypeInfo &GetTypeInfo(DALIDataType dtype) {
Expand All @@ -465,42 +470,77 @@ class DLL_PUBLIC TypeTable {

template <typename T>
DALIDataType RegisterType(DALIDataType dtype) {
std::lock_guard<spinlock> guard(lock_);
// Check the map for this types id
auto id_it = type_map_.find(typeid(T));

if (id_it == type_map_.end()) {
type_map_[typeid(T)] = dtype;
TypeInfo t;
t.SetType<T>(dtype);
type_info_map_[dtype] = t;
static DALIDataType id = [dtype, this]() {
std::lock_guard guard(insert_lock_);
size_t idx = dtype - DALI_NO_TYPE;
// We need the map because this function (and the static variable) may be instantiated
stiepan marked this conversation as resolved.
Show resolved Hide resolved
// in multiple shared objects whereas the map instance is tied to one well defined
// instance of the TypeTable returned by `instance()`.
auto [it, inserted] = type_map_.emplace(typeid(T), dtype);
if (!inserted)
return it->second;
if (!type_info_map_ || idx >= type_info_map_->size()) {
constexpr size_t kMinCapacity = next_pow2(DALI_CUSTOM_TYPE_START + 100);
// we don't need to look at the previous capacity to achieve std::vector-like growth
size_t capacity = next_pow2(idx + 1);
if (capacity < kMinCapacity)
capacity = kMinCapacity;
auto &m = type_info_maps_.emplace_back();
m.resize(capacity);
if (type_info_map_) // copy the old map into the new one
std::copy(type_info_map_->begin(), type_info_map_->end(), m.begin());
// The new map contains everything that the old map did - we can "publish" it.
// Make sure that the compiler doesn't reorder after the "publishing".
std::atomic_thread_fence(std::memory_order_release);
// Publish the new map.
type_info_map_ = &m;
}
TypeInfo &info = type_infos_.emplace_back();
info.SetType<T>(dtype);
if ((*type_info_map_)[idx] != nullptr)
DALI_FAIL("The type id ", idx, " is already taken by type ",
(*type_info_map_)[idx]->name());
(*type_info_map_)[idx] = &info;

return dtype;
} else {
return id_it->second;
}
}();
return id;
}


spinlock lock_;
using TypeInfoMap = std::vector<TypeInfo*>;
// The "current" type map - it's just a vector that maps type_id (adjusted and treated as index)
// to a TypeInfo pointer.
TypeInfoMap *type_info_map_ = nullptr;

std::mutex insert_lock_;
// All type info maps - old ones are never deleted to avoid locks when only read access is needed.
std::list<TypeInfoMap> type_info_maps_;
// The actual type info objects. Each type has exactly one TypeInfo - even if we need to grow
// the storage - hence, we need to store TypeInfo* in the pas (see typedef TypeInfoMap) and
// we need to store TypeInfo instances in a container that never invalidates pointers
// (e.g. a list).
std::list<TypeInfo> type_infos_;
// This is necessary because it turns out that static field in RegisterType has many instances
// in a program built with multiple shared libraries.
std::unordered_map<std::type_index, DALIDataType> type_map_;
// Unordered maps do not work with enums,
// so we need to use underlying type instead of DALIDataType
std::unordered_map<std::underlying_type_t<DALIDataType>, TypeInfo> type_info_map_;
int index_ = DALI_DATATYPE_END;

int next_id_ = DALI_CUSTOM_TYPE_START;
DLL_PUBLIC static TypeTable &instance();
};

template <typename T, typename A>
struct TypeNameHelper<std::vector<T, A> > {
static string GetTypeName() {
return "list of " + TypeTable::GetTypeName<T>();
static std::string_view GetTypeName() {
static const std::string name = "list of " + std::string(TypeTable::GetTypeName<T>());
return name;
}
};

template <typename T, size_t N>
struct TypeNameHelper<std::array<T, N> > {
static string GetTypeName() {
return "list of " + TypeTable::GetTypeName<T>();
static std::string_view GetTypeName() {
static const std::string name = "list of " + std::string(TypeTable::GetTypeName<T>());
return name;
}
};

Expand All @@ -513,8 +553,9 @@ template <typename T>
void TypeInfo::SetType(DALIDataType dtype) {
// Note: We enforce the fact that NoType is invalid by
// explicitly setting its type size as 0
type_size_ = std::is_same<T, NoType>::value ? 0 : sizeof(T);
if (!std::is_same<T, NoType>::value) {
constexpr bool is_no_type = std::is_same_v<T, NoType>;
type_size_ = is_no_type ? 0 : sizeof(T);
if constexpr (!is_no_type) {
id_ = dtype != DALI_NO_TYPE ? dtype : TypeTable::GetTypeId<T>();
} else {
id_ = DALI_NO_TYPE;
Expand Down Expand Up @@ -555,17 +596,34 @@ DLL_PUBLIC inline bool IsValidType(const TypeInfo &type) {
return !IsType<NoType>(type);
}

inline std::string_view TypeName(DALIDataType dtype) {
if (const char *builtin = GetBuiltinTypeName(dtype))
return builtin;
auto *info = TypeTable::TryGetTypeInfo(dtype);
if (info)
return info->name();
return "<unknown>";
}

inline std::string to_string(DALIDataType dtype) {
std::string_view name = TypeName(dtype);
if (name == "<unknown>")
return "unknown type: " + std::to_string(static_cast<int>(dtype));
else
return std::string(name);
}

// Used to define a type for use in dali. Inserts the type into the
// TypeTable w/ a unique id and creates a method to get the name of
// the type as a string. This does not work for non-fundamental types,
// as we do not have any mechanism for calling the constructor of the
// type when the buffer allocates the memory.
#define DALI_REGISTER_TYPE(Type, dtype) \
template <> DLL_PUBLIC string TypeTable::GetTypeName<Type>() \
DALI_TYPENAME_REGISTERER(Type, dtype); \
template <> DLL_PUBLIC DALIDataType TypeTable::GetTypeId<Type>() \
DALI_TYPEID_REGISTERER(Type, dtype); \
DALI_STATIC_TYPE_MAPPING(Type, dtype); \
#define DALI_REGISTER_TYPE(Type, dtype) \
template <> DLL_PUBLIC std::string_view TypeTable::GetTypeName<Type>() \
DALI_TYPENAME_REGISTERER(Type, dtype); \
template <> DLL_PUBLIC DALIDataType TypeTable::GetTypeId<Type>() \
DALI_TYPEID_REGISTERER(Type, dtype); \
DALI_STATIC_TYPE_MAPPING(Type, dtype); \
DALI_REGISTER_TYPE_IMPL(Type, dtype);

// Instantiate some basic types
Expand Down Expand Up @@ -600,25 +658,15 @@ DALI_REGISTER_TYPE(std::vector<float>, DALI_FLOAT_VEC);
DALI_REGISTER_TYPE(std::vector<TensorLayout>, DALI_TENSOR_LAYOUT_VEC);
DALI_REGISTER_TYPE(std::vector<DALIDataType>, DALI_DATA_TYPE_VEC);


inline std::string to_string(DALIDataType dtype) {
if (const char *builtin = GetBuiltinTypeName(dtype))
return builtin;
auto *info = TypeTable::TryGetTypeInfo(dtype);
if (info)
return info->name();
return "unknown type: " + std::to_string(static_cast<int>(dtype));
}

inline std::ostream &operator<<(std::ostream &os, DALIDataType dtype) {
if (const char *builtin = GetBuiltinTypeName(dtype))
return os << builtin;
auto *info = TypeTable::TryGetTypeInfo(dtype);
if (info)
return os << info->name();
// Use string concatenation so that the result is the same as in to_string, unaffected by
// formatting & other settings in `os`.
return os << ("unknown type: " + std::to_string(static_cast<int>(dtype)));
std::string_view name = TypeName(dtype);
if (name == "<unknown>") {
// Use string concatenation so that the result is the same as in to_string, unaffected by
// formatting & other settings in `os`.
return os << ("unknown type: " + std::to_string(static_cast<int>(dtype)));
} else {
return os << name;
}
}

#define DALI_INTEGRAL_TYPES uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t
Expand Down
Loading
Loading