From 89042095c46285ee4a9fa93920411fc7045310d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Tue, 17 Sep 2024 17:19:59 +0200 Subject: [PATCH] TypeTable/TypeInfo optimization (#5634) * TypeTable/TypeInfo optimization - TypeInfo uses string_view for type name - TypeTable stores types in an array - TypeTable read access is lockless --------- Signed-off-by: Michal Zientkiewicz --- .../math/expressions/expression_tree.h | 22 ++- dali/pipeline/data/types.cc | 4 +- dali/pipeline/data/types.h | 170 +++++++++++------- dali/pipeline/data/types_test.cc | 55 +++++- 4 files changed, 174 insertions(+), 77 deletions(-) diff --git a/dali/operators/math/expressions/expression_tree.h b/dali/operators/math/expressions/expression_tree.h index 3dc73adbeb..f345332a38 100644 --- a/dali/operators/math/expressions/expression_tree.h +++ b/dali/operators/math/expressions/expression_tree.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -84,9 +84,10 @@ class ExprNode { virtual std::string GetOutputDesc() const { const auto &op_type = TypeTable::GetTypeInfo(GetTypeId()).name(); - std::string result = GetAbbreviation(GetNodeType()); - result += IsScalarLike(GetShape()) ? "C:" : "T:"; - return result + op_type; + return make_string( + GetAbbreviation(GetNodeType()), + IsScalarLike(GetShape()) ? "C:" : "T:", + op_type); } virtual NodeType GetNodeType() const = 0; @@ -140,15 +141,18 @@ class ExprFunc : public ExprNode { std::string GetNodeDesc() const override { const auto &op_type = TypeTable::GetTypeInfo(GetTypeId()).name(); - std::string result = func_name_ + (IsScalarLike(GetShape()) ? ":C:" : ":T:") + op_type + "("; + std::stringstream result; + result << func_name_ << (IsScalarLike(GetShape()) ? ":C:" : ":T:") << op_type << "("; + for (int i = 0; i < GetSubexpressionCount(); i++) { - result += (*this)[i].GetOutputDesc(); + result << (*this)[i].GetOutputDesc(); if (i < GetSubexpressionCount() - 1) { - result += " "; + result << " "; } } - result += ")"; - return result; + + result << ")"; + return result.str(); } NodeType GetNodeType() const override { diff --git a/dali/pipeline/data/types.cc b/dali/pipeline/data/types.cc index 1ce5fd654c..51fd207d2a 100644 --- a/dali/pipeline/data/types.cc +++ b/dali/pipeline/data/types.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ #define DALI_TYPENAME_REGISTERER(Type, dtype) \ { \ - return to_string(dtype); \ + return dali::TypeName(dtype); \ } #define DALI_TYPEID_REGISTERER(Type, dtype) \ diff --git a/dali/pipeline/data/types.h b/dali/pipeline/data/types.h index 40e429e3d5..cad79254b3 100644 --- a/dali/pipeline/data/types.h +++ b/dali/pipeline/data/types.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,9 +15,12 @@ #ifndef DALI_PIPELINE_DATA_TYPES_H_ #define DALI_PIPELINE_DATA_TYPES_H_ +#include +#include #include #include #include +#include #include #include #include @@ -25,6 +28,7 @@ #include #include #include +#include "dali/core/util.h" #include "dali/core/common.h" #include "dali/core/spinlock.h" #include "dali/core/float16.h" @@ -123,7 +127,8 @@ enum DALIDataType : int { DALI_PYTHON_OBJECT = 24, DALI_TENSOR_LAYOUT_VEC = 25, DALI_DATA_TYPE_VEC = 26, - DALI_DATATYPE_END = 1000 + DALI_NUM_BUILTIN_TYPES, + DALI_CUSTOM_TYPE_START = 1001 }; inline const char *GetBuiltinTypeName(DALIDataType t) { @@ -397,7 +402,7 @@ class DLL_PUBLIC TypeInfo { return type_size_; } - DLL_PUBLIC inline const string &name() const { + DLL_PUBLIC inline std::string_view name() const { return name_; } @@ -410,12 +415,12 @@ class DLL_PUBLIC TypeInfo { DALIDataType id_ = DALI_NO_TYPE; size_t type_size_ = 0; - std::string name_ = dali::to_string(DALI_NO_TYPE); + std::string_view name_ = GetBuiltinTypeName(DALI_NO_TYPE); }; template struct TypeNameHelper { - static string GetTypeName() { + static std::string_view GetTypeName() { return typeid(T).name(); } }; @@ -427,23 +432,23 @@ class DLL_PUBLIC TypeTable { public: template DLL_PUBLIC static DALIDataType GetTypeId() { - auto &inst = instance(); - static DALIDataType type_id = inst.RegisterType(static_cast(++inst.index_)); + static DALIDataType type_id = instance().RegisterType( + static_cast(instance().next_id_++)); return type_id; } template - DLL_PUBLIC static string GetTypeName() { + DLL_PUBLIC static std::string_view GetTypeName() { return TypeNameHelper::GetTypeName(); } DLL_PUBLIC static const TypeInfo *TryGetTypeInfo(DALIDataType dtype) { - auto &inst = instance(); - std::lock_guard guard(inst.lock_); - auto id_it = inst.type_info_map_.find(dtype); - if (id_it == inst.type_info_map_.end()) + auto *types = instance().type_info_map_; + assert(types); + size_t idx = dtype - DALI_NO_TYPE; + if (idx >= types->size()) return nullptr; - return &id_it->second; + return (*types)[idx]; } DLL_PUBLIC static const TypeInfo &GetTypeInfo(DALIDataType dtype) { @@ -465,42 +470,77 @@ class DLL_PUBLIC TypeTable { template DALIDataType RegisterType(DALIDataType dtype) { - std::lock_guard guard(lock_); - // Check the map for this types id - auto id_it = type_map_.find(typeid(T)); - - if (id_it == type_map_.end()) { - type_map_[typeid(T)] = dtype; - TypeInfo t; - t.SetType(dtype); - type_info_map_[dtype] = t; + static DALIDataType id = [dtype, this]() { + std::lock_guard guard(insert_lock_); + size_t idx = dtype - DALI_NO_TYPE; + // We need the map because this function (and the static variable) may be instantiated + // in multiple shared objects whereas the map instance is tied to one well defined + // instance of the TypeTable returned by `instance()`. + auto [it, inserted] = type_map_.emplace(typeid(T), dtype); + if (!inserted) + return it->second; + if (!type_info_map_ || idx >= type_info_map_->size()) { + constexpr size_t kMinCapacity = next_pow2(DALI_CUSTOM_TYPE_START + 100); + // we don't need to look at the previous capacity to achieve std::vector-like growth + size_t capacity = next_pow2(idx + 1); + if (capacity < kMinCapacity) + capacity = kMinCapacity; + auto &m = type_info_maps_.emplace_back(); + m.resize(capacity); + if (type_info_map_) // copy the old map into the new one + std::copy(type_info_map_->begin(), type_info_map_->end(), m.begin()); + // The new map contains everything that the old map did - we can "publish" it. + // Make sure that the compiler doesn't reorder after the "publishing". + std::atomic_thread_fence(std::memory_order_release); + // Publish the new map. + type_info_map_ = &m; + } + TypeInfo &info = type_infos_.emplace_back(); + info.SetType(dtype); + if ((*type_info_map_)[idx] != nullptr) + DALI_FAIL("The type id ", idx, " is already taken by type ", + (*type_info_map_)[idx]->name()); + (*type_info_map_)[idx] = &info; + return dtype; - } else { - return id_it->second; - } + }(); + return id; } - - spinlock lock_; + using TypeInfoMap = std::vector; + // The "current" type map - it's just a vector that maps type_id (adjusted and treated as index) + // to a TypeInfo pointer. + TypeInfoMap *type_info_map_ = nullptr; + + std::mutex insert_lock_; + // All type info maps - old ones are never deleted to avoid locks when only read access is needed. + std::list type_info_maps_; + // The actual type info objects. Each type has exactly one TypeInfo - even if we need to grow + // the storage - hence, we need to store TypeInfo* in the pas (see typedef TypeInfoMap) and + // we need to store TypeInfo instances in a container that never invalidates pointers + // (e.g. a list). + std::list type_infos_; + // This is necessary because it turns out that static field in RegisterType has many instances + // in a program built with multiple shared libraries. std::unordered_map type_map_; - // Unordered maps do not work with enums, - // so we need to use underlying type instead of DALIDataType - std::unordered_map, TypeInfo> type_info_map_; - int index_ = DALI_DATATYPE_END; + + int next_id_ = DALI_CUSTOM_TYPE_START; DLL_PUBLIC static TypeTable &instance(); }; template struct TypeNameHelper > { - static string GetTypeName() { - return "list of " + TypeTable::GetTypeName(); + static std::string_view GetTypeName() { + static const std::string name = "list of " + std::string(TypeTable::GetTypeName()); + return name; } }; template struct TypeNameHelper > { - static string GetTypeName() { - return "list of " + TypeTable::GetTypeName(); + static std::string_view GetTypeName() { + static const std::string name = "list of " + std::string(TypeTable::GetTypeName()); + return name; } }; @@ -513,8 +553,9 @@ template void TypeInfo::SetType(DALIDataType dtype) { // Note: We enforce the fact that NoType is invalid by // explicitly setting its type size as 0 - type_size_ = std::is_same::value ? 0 : sizeof(T); - if (!std::is_same::value) { + constexpr bool is_no_type = std::is_same_v; + type_size_ = is_no_type ? 0 : sizeof(T); + if constexpr (!is_no_type) { id_ = dtype != DALI_NO_TYPE ? dtype : TypeTable::GetTypeId(); } else { id_ = DALI_NO_TYPE; @@ -555,17 +596,34 @@ DLL_PUBLIC inline bool IsValidType(const TypeInfo &type) { return !IsType(type); } +inline std::string_view TypeName(DALIDataType dtype) { + if (const char *builtin = GetBuiltinTypeName(dtype)) + return builtin; + auto *info = TypeTable::TryGetTypeInfo(dtype); + if (info) + return info->name(); + return ""; +} + +inline std::string to_string(DALIDataType dtype) { + std::string_view name = TypeName(dtype); + if (name == "") + return "unknown type: " + std::to_string(static_cast(dtype)); + else + return std::string(name); +} + // Used to define a type for use in dali. Inserts the type into the // TypeTable w/ a unique id and creates a method to get the name of // the type as a string. This does not work for non-fundamental types, // as we do not have any mechanism for calling the constructor of the // type when the buffer allocates the memory. -#define DALI_REGISTER_TYPE(Type, dtype) \ - template <> DLL_PUBLIC string TypeTable::GetTypeName() \ - DALI_TYPENAME_REGISTERER(Type, dtype); \ - template <> DLL_PUBLIC DALIDataType TypeTable::GetTypeId() \ - DALI_TYPEID_REGISTERER(Type, dtype); \ - DALI_STATIC_TYPE_MAPPING(Type, dtype); \ +#define DALI_REGISTER_TYPE(Type, dtype) \ + template <> DLL_PUBLIC std::string_view TypeTable::GetTypeName() \ + DALI_TYPENAME_REGISTERER(Type, dtype); \ + template <> DLL_PUBLIC DALIDataType TypeTable::GetTypeId() \ + DALI_TYPEID_REGISTERER(Type, dtype); \ + DALI_STATIC_TYPE_MAPPING(Type, dtype); \ DALI_REGISTER_TYPE_IMPL(Type, dtype); // Instantiate some basic types @@ -600,25 +658,15 @@ DALI_REGISTER_TYPE(std::vector, DALI_FLOAT_VEC); DALI_REGISTER_TYPE(std::vector, DALI_TENSOR_LAYOUT_VEC); DALI_REGISTER_TYPE(std::vector, DALI_DATA_TYPE_VEC); - -inline std::string to_string(DALIDataType dtype) { - if (const char *builtin = GetBuiltinTypeName(dtype)) - return builtin; - auto *info = TypeTable::TryGetTypeInfo(dtype); - if (info) - return info->name(); - return "unknown type: " + std::to_string(static_cast(dtype)); -} - inline std::ostream &operator<<(std::ostream &os, DALIDataType dtype) { - if (const char *builtin = GetBuiltinTypeName(dtype)) - return os << builtin; - auto *info = TypeTable::TryGetTypeInfo(dtype); - if (info) - return os << info->name(); - // Use string concatenation so that the result is the same as in to_string, unaffected by - // formatting & other settings in `os`. - return os << ("unknown type: " + std::to_string(static_cast(dtype))); + std::string_view name = TypeName(dtype); + if (name == "") { + // Use string concatenation so that the result is the same as in to_string, unaffected by + // formatting & other settings in `os`. + return os << ("unknown type: " + std::to_string(static_cast(dtype))); + } else { + return os << name; + } } #define DALI_INTEGRAL_TYPES uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, uint64_t, int64_t diff --git a/dali/pipeline/data/types_test.cc b/dali/pipeline/data/types_test.cc index 387dfda5ae..0786da3f7d 100644 --- a/dali/pipeline/data/types_test.cc +++ b/dali/pipeline/data/types_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -97,19 +97,32 @@ typedef ::testing::Types::value) = type2id::value> +void CheckBuiltinType(const TypeInfo *t) { + EXPECT_EQ(t->id(), type2id::value); + EXPECT_EQ(TypeTable::GetTypeId(), type2id::value); +} + +template +void CheckBuiltinType(...) {} + + TYPED_TEST(TypesTest, TestRegisteredType) { typedef TypeParam T; TypeInfo type; // Verify we start with no type - ASSERT_EQ(type.name(), ""); - ASSERT_EQ(type.size(), 0); + EXPECT_EQ(type.name(), ""); + EXPECT_EQ(type.size(), 0); type.SetType(); - ASSERT_EQ(type.size(), sizeof(T)); - ASSERT_EQ(type.name(), this->TypeName()); + EXPECT_EQ(type.size(), sizeof(T)); + EXPECT_EQ(type.name(), this->TypeName()); + CheckBuiltinType(&type); } struct CustomTestType {}; @@ -142,4 +155,36 @@ TEST(ListTypeNames, ListTypeNames) { ASSERT_EQ(str1, expected_str1); } +// The following disabled code tests the scenario when we need to grow the type table +// - for which we need an inordinate number of artifical types. The compilation is extremely +// slow and setting a breakpoint in types.h becomes a nightmare. +// Uncomment to test this particular scenario, leave commented otherwise. +#if 0 + +template +void TestTypeTableGrowth(std::integral_constant = {}) { + if constexpr (n < 3000) { + const TypeInfo &ti = TypeTable::GetTypeInfo>(); + std::cout << ti.name() << std::endl; + EXPECT_NE(ti.name().find(std::to_string(n)), std::string::npos); + + TestTypeTableGrowth(std::integral_constant()); + TestTypeTableGrowth(std::integral_constant()); + + const TypeInfo &ti2 = TypeTable::GetTypeInfo>(); + EXPECT_EQ(&ti, &ti2); + } + if constexpr (n == 0) + exit(0); +} + + +TEST(TypeTable, TestTableGrowth) { + TestTypeTableGrowth(); + GTEST_FLAG_SET(death_test_style, "threadsafe"); + // This avoids polluting the global type table with dummy types + EXPECT_EXIT(TestTypeTableGrowth(), ::testing::ExitedWithCode(0), ""); +} +#endif + } // namespace dali