Skip to content

Commit

Permalink
Add more APIs deprecated along with replacements
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Jun 21, 2022
1 parent b80f771 commit 189c9df
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 71 deletions.
142 changes: 121 additions & 21 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,22 @@ struct TypeInfo;
struct Value;
struct ModelMetadata;

// Light functor to release memory with OrtAllocator
namespace detail {
// Light functor to release memory with OrtAllocator
struct AllocatedFree {
OrtAllocator* allocator_;
explicit AllocatedFree(OrtAllocator* allocator)
: allocator_(allocator) {}
void operator()(void* ptr) const { allocator_->Free(allocator_, ptr); }
void operator()(void* ptr) const { if(ptr) allocator_->Free(allocator_, ptr); }
};
} // namespace detail

/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
* and release them at the end of the scope. The lifespan of the given allocator
* must eclipse the lifespan of AllocatedStringPtr instance
*/
using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;

/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.
Expand Down Expand Up @@ -395,13 +401,108 @@ struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API

char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
/** \deprecated use GetProducerNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName

/** \brief Returns a copy of the producer name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName

/** \deprecated use GetGraphNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName

/** \brief Returns a copy of the graph name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName

/** \deprecated use GetDomainAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain

/** \brief Returns a copy of the domain name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain

/** \deprecated use GetDescriptionAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription

/** \brief Returns a copy of the description.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription

/** \deprecated use GetGraphDescriptionAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription

/** \brief Returns a copy of the graph description.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription

/** \deprecated use GetCustomMetadataMapKeysAllocated()
* [[deprecated]]
* This interface produces multiple pointers that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys

std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys

/** \deprecated use LookupCustomMetadataMapAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap

/** \brief Looks up a value by a key in the Custom Metadata map
*
* \param zero terminated string key to lookup
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* maybe nullptr if key is not found.
*
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap

int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
};

Expand Down Expand Up @@ -450,18 +551,10 @@ struct Session : Base<OrtSession> {
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName
/** \deprecated use GetOutputNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName

using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
*/
char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName

/** \brief Returns a copy of input name at the specified index. Replaces GetInputName().
/** \brief Returns a copy of input name at the specified index.
*
* \param index must less than the value returned by GetInputCount()
* \param allocator to allocate memory for the copy of the name returned
Expand All @@ -470,6 +563,13 @@ struct Session : Base<OrtSession> {
*/
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;

/** \deprecated use GetOutputNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName

/** \brief Returns a copy of output name at then specified index.
*
* \param index must less than the value returned by GetOutputCount()
Expand All @@ -483,7 +583,7 @@ struct Session : Base<OrtSession> {
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
*/
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName

/** \brief Returns a copy of the overridable initializer name at then specified index.
Expand All @@ -500,17 +600,17 @@ struct Session : Base<OrtSession> {
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling

/** \brief Returns a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata

TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
Expand Down
71 changes: 66 additions & 5 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ inline void IoBinding::BindOutput(const char* name, const MemoryInfo& mem_info)

inline std::vector<std::string> IoBinding::GetOutputNamesHelper(OrtAllocator* allocator) const {
std::vector<std::string> result;
auto free_fn = [allocator](void* p) { if (p) allocator->Free(allocator, p); };
auto free_fn = detail::AllocatedFree(allocator);
using Ptr = std::unique_ptr<void, decltype(free_fn)>;

char* buffer = nullptr;
Expand Down Expand Up @@ -656,13 +656,13 @@ inline char* Session::GetOutputName(size_t index, OrtAllocator* allocator) const
return out;
}

inline Session::AllocatedStringPtr Session::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
inline AllocatedStringPtr Session::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionGetInputName(p_, index, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline Session::AllocatedStringPtr Session::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
inline AllocatedStringPtr Session::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionGetOutputName(p_, index, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
Expand All @@ -674,7 +674,7 @@ inline char* Session::GetOverridableInitializerName(size_t index, OrtAllocator*
return out;
}

inline Session::AllocatedStringPtr Session::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
inline AllocatedStringPtr Session::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionGetOverridableInitializerName(p_, index, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
Expand All @@ -686,7 +686,7 @@ inline char* Session::EndProfiling(OrtAllocator* allocator) const {
return out;
}

inline Session::AllocatedStringPtr Session::EndProfilingAllocated(OrtAllocator* allocator) const {
inline AllocatedStringPtr Session::EndProfilingAllocated(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().SessionEndProfiling(p_, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
Expand All @@ -710,42 +710,103 @@ inline char* ModelMetadata::GetProducerName(OrtAllocator* allocator) const {
return out;
}

inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline char* ModelMetadata::GetGraphName(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
return out;
}

inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline char* ModelMetadata::GetDomain(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
return out;
}

inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline char* ModelMetadata::GetDescription(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
return out;
}

inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline char* ModelMetadata::GetGraphDescription(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
return out;
}

inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
return out;
}

inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
char* out;
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
}

inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const {
char** out;
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
return out;
}

inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
auto deletor = detail::AllocatedFree(allocator);
std::vector<AllocatedStringPtr> result;

char** out = nullptr;
int64_t num_keys = 0;
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
if (num_keys <= 0) {
return result;
}

// array of pointers will be freed
std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
// reserve may throw
auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
result.reserve(static_cast<size_t>(num_keys));
strings_guard.release();
for (int64_t i = 0; i < num_keys; ++i) {
result.push_back(AllocatedStringPtr(out[i], deletor));
}

return result;
}

inline int64_t ModelMetadata::GetVersion() const {
int64_t out;
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
Expand Down
22 changes: 18 additions & 4 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1382,17 +1382,31 @@ ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetCustomMetadataMapKeys,
// To guard against overflow in the next step where we compute bytes to allocate
SafeInt<size_t> alloc_count(count);

InlinedVector<Ort::AllocatedStringPtr> string_holders;
string_holders.reserve(count);

auto deletor = Ort::detail::AllocatedFree(allocator);
// alloc_count * sizeof(...) will throw if there was an overflow which will be caught in API_IMPL_END
// and be returned to the user as a status
char** p = reinterpret_cast<char**>(allocator->Alloc(allocator, alloc_count * sizeof(char*)));
assert(p != nullptr);
auto map_iter = custom_metadata_map.cbegin();

// StrDup may throw
std::unique_ptr<void, decltype(deletor)> array_guard(p, deletor);

int64_t i = 0;
while (map_iter != custom_metadata_map.cend()) {
p[i++] = StrDup(map_iter->first, allocator);
++map_iter;
for (const auto& e : custom_metadata_map) {
auto* s = StrDup(e.first, allocator);
string_holders.push_back(Ort::AllocatedStringPtr(s, deletor));
p[i++] = s;
}

for (auto& s : string_holders) {
s.release();
}

*keys = p;
array_guard.release();
}

*num_keys = static_cast<int64_t>(count);
Expand Down
Loading

0 comments on commit 189c9df

Please sign in to comment.