Skip to content

Commit

Permalink
Deprecate APIs returning raw ptrs and provide replacements (#11922)
Browse files Browse the repository at this point in the history
Provider better documentation
  • Loading branch information
yuslepukhin authored and RandyShuai committed Jul 6, 2022
1 parent 1e0c4f1 commit 82fc95f
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ inline std::vector<std::string> Session::GetInputNames() const {
size_t node_count = GetInputCount();
std::vector<std::string> out(node_count);
for (size_t i = 0; i < node_count; i++) {
char* tmp = GetInputName(i, allocator);
out[i] = tmp;
allocator.Free(tmp); // prevent memory leak
auto tmp = GetInputNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
Expand All @@ -47,9 +46,8 @@ inline std::vector<std::string> Session::GetOutputNames() const {
size_t node_count = GetOutputCount();
std::vector<std::string> out(node_count);
for (size_t i = 0; i < node_count; i++) {
char* tmp = GetOutputName(i, allocator);
out[i] = tmp;
allocator.Free(tmp); // prevent memory leak
auto tmp = GetOutputNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
Expand All @@ -59,9 +57,8 @@ inline std::vector<std::string> Session::GetOverridableInitializerNames() const
size_t init_count = GetOverridableInitializerCount();
std::vector<std::string> out(init_count);
for (size_t i = 0; i < init_count; i++) {
char* tmp = GetOverridableInitializerName(i, allocator);
out[i] = tmp;
allocator.Free(tmp); // prevent memory leak
auto tmp = GetOverridableInitializerNameAllocated(i, allocator);
out[i] = tmp.get();
}
return out;
}
Expand Down
185 changes: 177 additions & 8 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ struct TypeInfo;
struct Value;
struct ModelMetadata;

namespace detail {
// Light functor to release memory with OrtAllocator
struct AllocatedFree {
OrtAllocator* allocator_;
explicit AllocatedFree(OrtAllocator* allocator)
: allocator_(allocator) {}
void operator()(void* ptr) const { if(ptr) allocator_->Free(allocator_, ptr); }
};
} // namespace detail

/** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
* and release them at the end of the scope. The lifespan of the given allocator
* must eclipse the lifespan of AllocatedStringPtr instance
*/
using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;

/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.
Expand Down Expand Up @@ -385,13 +401,108 @@ struct ModelMetadata : Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API

char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
/** \deprecated use GetProducerNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetProducerName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName

/** \brief Returns a copy of the producer name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName

/** \deprecated use GetGraphNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetGraphName(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName

/** \brief Returns a copy of the graph name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName

/** \deprecated use GetDomainAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetDomain(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain

/** \brief Returns a copy of the domain name.
*
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain

/** \deprecated use GetDescriptionAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription

/** \brief Returns a copy of the description.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription

/** \deprecated use GetGraphDescriptionAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetGraphDescription(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription

/** \brief Returns a copy of the graph description.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription

/** \deprecated use GetCustomMetadataMapKeysAllocated()
* [[deprecated]]
* This interface produces multiple pointers that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys

std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys

/** \deprecated use LookupCustomMetadataMapAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap

/** \brief Looks up a value by a key in the Custom Metadata map
*
* \param zero terminated string key to lookup
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* maybe nullptr if key is not found.
*
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap

int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
};

Expand Down Expand Up @@ -436,12 +547,70 @@ struct Session : Base<OrtSession> {
size_t GetOutputCount() const; ///< Returns the number of model outputs
size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden

char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName
/** \deprecated use GetInputNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetInputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetInputName

/** \brief Returns a copy of input name at the specified index.
*
* \param index must less than the value returned by GetInputCount()
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;

/** \deprecated use GetOutputNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetOutputName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOutputName

/** \brief Returns a copy of output name at then specified index.
*
* \param index must less than the value returned by GetOutputCount()
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;

/** \deprecated use GetOverridableInitializerNameAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata

/** \brief Returns a copy of the overridable initializer name at then specified index.
*
* \param index must less than the value returned by GetOverridableInitializerCount()
* \param allocator to allocate memory for the copy of the name returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName

/** \deprecated use EndProfilingAllocated()
* [[deprecated]]
* This interface produces a pointer that must be released
* by the specified allocator and is often leaked. Not exception safe.
*/
char* EndProfiling(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling

/** \brief Returns a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
* \return a instance of smart pointer that would deallocate the buffer when out of scope.
* The OrtAllocator instances must be valid at the point of memory release.
*/
AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionEndProfiling
uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata

TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
Expand Down
Loading

0 comments on commit 82fc95f

Please sign in to comment.