From 3bf8a7b6b78f1269f59c4c9de3d4dbf798db4fea Mon Sep 17 00:00:00 2001 From: Dong Xu Date: Fri, 12 Apr 2024 16:58:06 +0800 Subject: [PATCH] [ATLAS] Refact TNN ATLAS for Multiple Model --- .../tnn/device/atlas/atlas_blob_converter.cc | 112 +- .../tnn/device/atlas/atlas_blob_converter.h | 8 +- source/tnn/device/atlas/atlas_common_types.cc | 5 - source/tnn/device/atlas/atlas_common_types.h | 68 +- source/tnn/device/atlas/atlas_context.cc | 110 ++ source/tnn/device/atlas/atlas_context.h | 80 ++ source/tnn/device/atlas/atlas_device.cc | 18 +- .../tnn/device/atlas/atlas_mat_converter.cc | 65 +- source/tnn/device/atlas/atlas_mat_converter.h | 5 +- .../device/atlas/atlas_model_interpreter.cc | 112 -- .../device/atlas/atlas_model_interpreter.h | 51 - source/tnn/device/atlas/atlas_network.cc | 1039 +++++++++-------- source/tnn/device/atlas/atlas_network.h | 103 +- .../atlas/atlas_om_model_interpreter.cc | 45 + .../device/atlas/atlas_om_model_interpreter.h | 50 + source/tnn/device/atlas/atlas_runtime.cc | 134 --- source/tnn/device/atlas/atlas_runtime.h | 57 - source/tnn/device/atlas/atlas_utils.cc | 41 - source/tnn/device/atlas/atlas_utils.h | 6 - source/tnn/device/atlas/tnn_impl_atlas.cc | 214 ++-- source/tnn/device/atlas/tnn_impl_atlas.h | 38 +- 21 files changed, 1219 insertions(+), 1142 deletions(-) delete mode 100644 source/tnn/device/atlas/atlas_common_types.cc create mode 100644 source/tnn/device/atlas/atlas_context.cc create mode 100644 source/tnn/device/atlas/atlas_context.h delete mode 100644 source/tnn/device/atlas/atlas_model_interpreter.cc delete mode 100644 source/tnn/device/atlas/atlas_model_interpreter.h create mode 100644 source/tnn/device/atlas/atlas_om_model_interpreter.cc create mode 100644 source/tnn/device/atlas/atlas_om_model_interpreter.h delete mode 100644 source/tnn/device/atlas/atlas_runtime.cc delete mode 100644 source/tnn/device/atlas/atlas_runtime.h diff --git a/source/tnn/device/atlas/atlas_blob_converter.cc b/source/tnn/device/atlas/atlas_blob_converter.cc index 303d77131..da0cae6e3 100644 --- a/source/tnn/device/atlas/atlas_blob_converter.cc +++ b/source/tnn/device/atlas/atlas_blob_converter.cc @@ -12,10 +12,10 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "tnn/device/atlas/atlas_blob_converter.h" + #include "tnn/core/macro.h" -#include "tnn/device/atlas/atlas_runtime.h" +#include "tnn/device/atlas/atlas_blob_converter.h" #include "tnn/device/atlas/atlas_utils.h" #include "tnn/memory_manager/blob_memory_size_info.h" #include "tnn/utils/blob_memory_size_utils.h" @@ -31,17 +31,16 @@ AtlasBlobConverterAcc::AtlasBlobConverterAcc(Blob *blob) : BlobConverterAcc(blob blob_bytesize_ = GetBlobMemoryBytesSize(size_info); LOGD("blob bytesize: %d\n", blob_bytesize_); - auto model_info_map = AtlasRuntime::GetInstance()->GetModleInfoMap(); // for input blob, need to find model info - if (model_info_map.find(blob) != model_info_map.end()) { - model_info_ = model_info_map[blob]; + if (global_blob_om_model_info_map.find(blob) != global_blob_om_model_info_map.end()) { + om_model_info_ = global_blob_om_model_info_map[blob]; aclError acl_ret = - aclmdlGetInputIndexByName(model_info_.model_desc, ACL_DYNAMIC_AIPP_NAME, &dynamic_aipp_index_); + aclmdlGetInputIndexByName(om_model_info_->model_desc, ACL_DYNAMIC_AIPP_NAME, &dynamic_aipp_index_); LOGD("acl ret: %d input_index: %d\n", acl_ret, dynamic_aipp_index_); if (ACL_ERROR_NONE == acl_ret) { aipp_type_ = AIPP_DYNAMIC; } else { - if (model_info_.has_aipp) { + if (!(om_model_info_->aipp_input_format_map.empty())) { aipp_type_ = AIPP_STATIC; } else { aipp_type_ = AIPP_NONE; @@ -73,13 +72,13 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, return Status(TNNERR_PARAM_ERR, "not support postprocess yet!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { + aclrtStream stream_ptr = static_cast(command_queue); + if (stream_ptr == nullptr) { LOGE("get atlas command queue failed!\n"); return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); } - acl_ret = aclrtSetCurrentContext(atlas_cmd_queue->context); + acl_ret = aclrtSetCurrentContext(global_stream_context_map[stream_ptr]); if (acl_ret != ACL_ERROR_NONE) { LOGE("set context failed\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set context failed"); @@ -95,7 +94,7 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, LOGD("Convert To Mat: mat type: %d, mat device type: %d, byte_size: %d.\n", mat.GetMatType(), mat.GetDeviceType(), blob_bytesize_); if (DATA_FORMAT_NCHW == blob_dataformat && DATA_TYPE_FLOAT == blob_datatype) { tnn_ret = AtlasMemoryCopyAsync(mat.GetData(), blob_->GetHandle().base, mat.GetDeviceType(), blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; } else if (DATA_FORMAT_NHWC == blob_dataformat && DATA_TYPE_FLOAT == blob_datatype) { @@ -105,12 +104,12 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, buffer_.reset(new char[blob_bytesize_], [](char *p) { delete[] p; }); } tnn_ret = AtlasMemoryCopyAsync(buffer_.get(), blob_->GetHandle().base, DEVICE_NAIVE, blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; // force sync LOGD("force sync to get buffer data\n"); - acl_ret = aclrtSynchronizeStream(atlas_cmd_queue->stream); + acl_ret = aclrtSynchronizeStream(stream_ptr); if (acl_ret != ACL_ERROR_NONE) { return Status(TNNERR_ATLAS_RUNTIME_ERROR, "stream sync failed"); } @@ -127,12 +126,12 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, buffer_.reset(new char[blob_bytesize_], [](char *p) { delete[] p; }); } tnn_ret = AtlasMemoryCopyAsync(buffer_.get(), blob_->GetHandle().base, DEVICE_NAIVE, blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; // force sync LOGD("force sync to get buffer data\n"); - acl_ret = aclrtSynchronizeStream(atlas_cmd_queue->stream); + acl_ret = aclrtSynchronizeStream(stream_ptr); if (acl_ret != ACL_ERROR_NONE) { return Status(TNNERR_ATLAS_RUNTIME_ERROR, "stream sync failed"); } @@ -159,14 +158,14 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, LOGD("Convert To NC_INT32 Mat: mat type: %d, mat device type: %d, byte_size: %d.\n", mat.GetMatType(), mat.GetDeviceType(), blob_bytesize_); if (DATA_TYPE_INT32 == blob_datatype) { tnn_ret = AtlasMemoryCopyAsync(mat.GetData(), blob_->GetHandle().base, mat.GetDeviceType(), blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; } else if (DATA_TYPE_FLOAT == blob_datatype) { LOGD("WARNING: Target Blob name is '%s', internally convert Blob DataType from FLOAT to INT32.", blob_->GetBlobDesc().name.c_str()); blob_->GetBlobDesc().data_type = DATA_TYPE_INT32; tnn_ret = AtlasMemoryCopyAsync(mat.GetData(), blob_->GetHandle().base, mat.GetDeviceType(), blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; } else { @@ -177,7 +176,7 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, LOGD("Convert To NC_INT64 Mat: mat type: %d, mat device type: %d, byte_size: %d.\n", mat.GetMatType(), mat.GetDeviceType(), blob_bytesize_); if (DATA_TYPE_INT64 == blob_datatype) { tnn_ret = AtlasMemoryCopyAsync(mat.GetData(), blob_->GetHandle().base, mat.GetDeviceType(), blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; } else if (DATA_TYPE_FLOAT == blob_datatype || DATA_TYPE_INT32 == blob_datatype) { @@ -186,7 +185,7 @@ Status AtlasBlobConverterAcc::ConvertToMatAsync(Mat &mat, MatConvertParam param, BlobMemorySizeInfo new_size_info = Calculate1DMemorySize(blob_->GetBlobDesc()); blob_bytesize_ = GetBlobMemoryBytesSize(new_size_info); // sizeof(int64_t) == 8, re-calculate ByteSize tnn_ret = AtlasMemoryCopyAsync(mat.GetData(), blob_->GetHandle().base, mat.GetDeviceType(), blob_bytesize_, - atlas_cmd_queue->stream, false); + stream_ptr, false); if (tnn_ret != TNN_OK) return tnn_ret; } else { @@ -210,27 +209,30 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsync(Mat &mat, MatConvertParam para Status tnn_ret = TNN_OK; aclError acl_ret = ACL_ERROR_NONE; - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { + aclrtStream stream_ptr = static_cast(command_queue); + if (stream_ptr == nullptr) { LOGE("get atlas command queue failed!\n"); return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); } - acl_ret = aclrtSetCurrentContext(atlas_cmd_queue->context); + aclrtContext aclrt_context = global_stream_context_map[stream_ptr]; + acl_ret = aclrtSetCurrentContext(aclrt_context); if (acl_ret != ACL_ERROR_NONE) { LOGE("set context failed\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set context failed"); } if (AIPP_DYNAMIC == aipp_type_) { - LOGD("run with dynamic aipp\n"); - tnn_ret = ConvertFromMatAsyncWithDynamicAipp(mat, param, atlas_cmd_queue); + //LOGD("run with dynamic aipp\n"); + //tnn_ret = ConvertFromMatAsyncWithDynamicAipp(mat, param, stream_ptr); + LOGE("Convert From Mat With Dynamic AIPP NOT SUPPORTED yet.\n"); + return Status(TNNERR_NULL_PARAM, "Convert From Mat With Dynamic AIPP NOT SUPPORTED yet"); } else if (AIPP_STATIC == aipp_type_) { LOGD("run with static aipp\n"); - tnn_ret = ConvertFromMatAsyncWithStaticAipp(mat, param, atlas_cmd_queue); + tnn_ret = ConvertFromMatAsyncWithStaticAipp(mat, param, stream_ptr); } else { LOGD("run without aipp\n"); - tnn_ret = ConvertFromMatAsyncWithoutAipp(mat, param, atlas_cmd_queue); + tnn_ret = ConvertFromMatAsyncWithoutAipp(mat, param, stream_ptr); } return tnn_ret; @@ -239,12 +241,13 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsync(Mat &mat, MatConvertParam para Status AtlasBlobConverterAcc::ConvertToMat(Mat &mat, MatConvertParam param, void *command_queue) { Status ret = ConvertToMatAsync(mat, param, command_queue); if (ret == TNN_OK) { - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { + aclrtStream stream_ptr = static_cast(command_queue); + if (stream_ptr == nullptr) { LOGE("get atlas command queue failed!\n"); return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); } - aclError acl_ret = aclrtSynchronizeStream(atlas_cmd_queue->stream); + + aclError acl_ret = aclrtSynchronizeStream(stream_ptr); if (acl_ret != ACL_ERROR_NONE) { return Status(TNNERR_ATLAS_RUNTIME_ERROR, "stream sync failed"); } @@ -260,12 +263,13 @@ Status AtlasBlobConverterAcc::ConvertFromMat(Mat &mat, MatConvertParam param, vo Status ret = ConvertFromMatAsync(mat, param, command_queue); if (ret == TNN_OK) { - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { + aclrtStream stream_ptr = static_cast(command_queue); + if (stream_ptr == nullptr) { LOGE("get atlas command queue failed!\n"); return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); } - aclError acl_ret = aclrtSynchronizeStream(atlas_cmd_queue->stream); + + aclError acl_ret = aclrtSynchronizeStream(stream_ptr); if (acl_ret != ACL_ERROR_NONE) { return Status(TNNERR_ATLAS_RUNTIME_ERROR, "stream sync failed"); } @@ -274,7 +278,7 @@ Status AtlasBlobConverterAcc::ConvertFromMat(Mat &mat, MatConvertParam param, vo } Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConvertParam param, - AtlasCommandQueue *atlas_cmd_queue) { + const aclrtStream& aclrt_stream) { Status tnn_ret = TNN_OK; aclError acl_ret = ACL_ERROR_NONE; @@ -303,7 +307,7 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConver if (NCHW_FLOAT == mat.GetMatType()) { if (DATA_FORMAT_NCHW == blob_dataformat && DATA_TYPE_FLOAT == blob_datatype) { tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else if (DATA_FORMAT_NHWC == blob_dataformat && DATA_TYPE_FLOAT == blob_datatype) { @@ -319,7 +323,7 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConver blob_dim[0], blob_dim[3], blob_dim[1], blob_dim[2]); tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, buffer_.get(), DEVICE_NAIVE, mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else { @@ -332,14 +336,14 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConver LOGD("Convert from NC_INT32 Mat: mat device type: %d\n", mat.GetDeviceType()); if (DATA_TYPE_INT32 == blob_datatype) { tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else if (DATA_TYPE_FLOAT == blob_datatype) { LOGD("WARNING: Target Blob name is '%s', internally convert Blob DataType from FLOAT to INT32.", blob_->GetBlobDesc().name.c_str()); blob_->GetBlobDesc().data_type = DATA_TYPE_INT32; tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else { @@ -350,7 +354,7 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConver LOGD("Convert from NC_INT64 Mat: mat device type: %d\n", mat.GetDeviceType()); if (DATA_TYPE_INT64 == blob_datatype) { tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else if (DATA_TYPE_FLOAT == blob_datatype || DATA_TYPE_INT32 == blob_datatype) { @@ -359,7 +363,7 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConver BlobMemorySizeInfo new_size_info = Calculate1DMemorySize(blob_->GetBlobDesc()); blob_bytesize_ = GetBlobMemoryBytesSize(new_size_info); // sizeof(int64_t) == 8, re-calculate ByteSize tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else { @@ -374,7 +378,7 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithoutAipp(Mat &mat, MatConver } Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithStaticAipp(Mat &mat, MatConvertParam param, - AtlasCommandQueue *atlas_cmd_queue) { + const aclrtStream& aclrt_stream) { Status tnn_ret = TNN_OK; aclError acl_ret = ACL_ERROR_NONE; @@ -391,14 +395,15 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithStaticAipp(Mat &mat, MatCon return tnn_ret; } + auto aipp_input_format = om_model_info_->aipp_input_format_map[blob_->GetBlobDesc().name]; + LOGD("Convert From Mat: mat type: %d mat device type: %d acl input format:%d\n", mat.GetMatType(), - mat.GetDeviceType(), model_info_.aipp_input_format); - if ((N8UC3 == mat.GetMatType() && ACL_RGB888_U8 == model_info_.aipp_input_format) || - (NGRAY == mat.GetMatType() && ACL_YUV400_U8 == model_info_.aipp_input_format) || - ((NNV12 == mat.GetMatType() || NNV21 == mat.GetMatType()) && - ACL_YUV420SP_U8 == model_info_.aipp_input_format)) { + mat.GetDeviceType(), aipp_input_format); + if ((N8UC3 == mat.GetMatType() && ACL_RGB888_U8 == aipp_input_format) || + (NGRAY == mat.GetMatType() && ACL_YUV400_U8 == aipp_input_format) || + ((NNV12 == mat.GetMatType() || NNV21 == mat.GetMatType()) && ACL_YUV420SP_U8 == aipp_input_format)) { tnn_ret = AtlasMemoryCopyAsync(blob_->GetHandle().base, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); if (tnn_ret != TNN_OK) return tnn_ret; } else { @@ -409,13 +414,14 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithStaticAipp(Mat &mat, MatCon } Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithDynamicAipp(Mat &mat, MatConvertParam param, - AtlasCommandQueue *atlas_cmd_queue) { + const aclrtStream& aclrt_stream) { + /* Status tnn_ret = SetDynamicAipp(mat, param); if (TNN_OK != tnn_ret) { LOGE("set dynamic aipp failed!\n"); return tnn_ret; } - auto data_buffer = aclmdlGetDatasetBuffer(model_info_.input_dataset, 0); + auto data_buffer = aclmdlGetDatasetBuffer(om_model_info_.input_dataset, 0); if (nullptr == data_buffer) { LOGE("get data buffer from dataset failed!\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get data buffer failed"); @@ -441,9 +447,11 @@ Status AtlasBlobConverterAcc::ConvertFromMatAsyncWithDynamicAipp(Mat &mat, MatCo } tnn_ret = AtlasMemoryCopyAsync(data_buffer_ptr, mat.GetData(), mat.GetDeviceType(), mat_bytesize, - atlas_cmd_queue->stream, true); + aclrt_stream, true); return tnn_ret; + */ + return TNN_OK; } bool AtlasBlobConverterAcc::NeedDoScaleBias(MatConvertParam ¶m) { @@ -492,11 +500,12 @@ Status AtlasBlobConverterAcc::AtlasMemoryCopyAsync(void *dst, void *src, DeviceT } Status AtlasBlobConverterAcc::SetDynamicAipp(Mat &mat, MatConvertParam ¶m) { + /* aclError acl_ret = ACL_ERROR_NONE; Status tnn_ret = TNN_OK; if (nullptr == aipp_dynamic_set_) { - aipp_mat_batchsize_ = GetMaxBatchSize(model_info_.model_desc, blob_->GetBlobDesc().dims[0]); + aipp_mat_batchsize_ = GetMaxBatchSize(om_model_info_->model_desc, blob_->GetBlobDesc().dims[0]); aipp_dynamic_set_ = aclmdlCreateAIPP(aipp_mat_batchsize_); if (nullptr == aipp_dynamic_set_) { LOGE("create aipp info failed\n"); @@ -593,10 +602,11 @@ Status AtlasBlobConverterAcc::SetDynamicAipp(Mat &mat, MatConvertParam ¶m) { // set input aipp acl_ret = - aclmdlSetInputAIPP(model_info_.model_id, model_info_.input_dataset, dynamic_aipp_index_, aipp_dynamic_set_); + aclmdlSetInputAIPP(om_model_info_->model_id, om_model_info_.input_dataset, dynamic_aipp_index_, aipp_dynamic_set_); if (ACL_ERROR_NONE != acl_ret) { return Status(TNNERR_ATLAS_RUNTIME_ERROR, "aipp set input failed!\n"); } + */ return TNN_OK; } diff --git a/source/tnn/device/atlas/atlas_blob_converter.h b/source/tnn/device/atlas/atlas_blob_converter.h index 760eb623f..d77254f8f 100644 --- a/source/tnn/device/atlas/atlas_blob_converter.h +++ b/source/tnn/device/atlas/atlas_blob_converter.h @@ -37,9 +37,9 @@ class AtlasBlobConverterAcc : public BlobConverterAcc { virtual Status ConvertFromMatAsync(Mat& mat, MatConvertParam param, void* command_queue = NULL); private: - Status ConvertFromMatAsyncWithoutAipp(Mat& mat, MatConvertParam param, AtlasCommandQueue* atlas_cmd_queue); - Status ConvertFromMatAsyncWithStaticAipp(Mat& mat, MatConvertParam param, AtlasCommandQueue* atlas_cmd_queue); - Status ConvertFromMatAsyncWithDynamicAipp(Mat& mat, MatConvertParam param, AtlasCommandQueue* atlas_cmd_queue); + Status ConvertFromMatAsyncWithoutAipp(Mat& mat, MatConvertParam param, const aclrtStream& aclrt_stream); + Status ConvertFromMatAsyncWithStaticAipp(Mat& mat, MatConvertParam param, const aclrtStream& aclrt_stream); + Status ConvertFromMatAsyncWithDynamicAipp(Mat& mat, MatConvertParam param, const aclrtStream& aclrt_stream); bool NeedDoScaleBias(MatConvertParam& param); Status AtlasMemoryCopyAsync(void* dst, void* src, DeviceType mat_device_type, int bytes, void* stream, @@ -55,7 +55,7 @@ class AtlasBlobConverterAcc : public BlobConverterAcc { AippType aipp_type_ = AIPP_NONE; int aipp_mat_batchsize_ = 0; size_t dynamic_aipp_index_ = 0; - AtlasModelInfo model_info_; + std::shared_ptr om_model_info_; }; } // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_common_types.cc b/source/tnn/device/atlas/atlas_common_types.cc deleted file mode 100644 index eadb4b594..000000000 --- a/source/tnn/device/atlas/atlas_common_types.cc +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2019 Tencent. All Rights Reserved - -#include "atlas_common_types.h" - -namespace TNN_NS {} // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_common_types.h b/source/tnn/device/atlas/atlas_common_types.h index 69036a763..50b6330bf 100644 --- a/source/tnn/device/atlas/atlas_common_types.h +++ b/source/tnn/device/atlas/atlas_common_types.h @@ -1,53 +1,63 @@ -// Copyright 2019 Tencent. All Rights Reserved +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. #ifndef TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_ #define TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_ +#include #include #include #include +#include #include "acl/acl.h" #include "tnn/core/blob.h" #include "tnn/core/macro.h" +/////////////////////// +#include +/////////////////////// + namespace TNN_NS { -enum ImageTypeT { - IMAGE_TYPE_RAW = -1, - IMAGE_TYPE_NV12 = 0, - IMAGE_TYPE_JPEG, - IMAGE_TYPE_PNG, - IMAGE_TYPE_BMP, - IMAGE_TYPE_TIFF, - IMAGE_TYPE_VIDEO = 100 +enum class AtlasOmModelDynamicMode { + Static = 0, + DynamicBatch = 1, + DynamicHW = 2, + GenericDynamic = 3, // New Dynamic Mode, convert by input_shape_range or input_shape without dynamic dim/hw specified. }; -struct AtlasModelConfig { - std::string om_str = ""; - bool is_path = false; -}; +struct AtlasOMModelInfo { + aclmdlDesc* model_desc = nullptr; + uint32_t model_id = INT_MAX; + aclmdlDataset* input_dataset = nullptr; + aclrtContext aclrt_context = nullptr; -struct DimInfo { - uint32_t batch = 0; - uint32_t channel = 0; - uint32_t height = 0; - uint32_t width = 0; -}; + size_t memory_size = 0; + size_t weight_size = 0; -struct AtlasCommandQueue { - void* context; - void* stream; -}; + // Dynamic Input + AtlasOmModelDynamicMode dynamic_mode; + std::unordered_set generic_dynamic_input_names; -struct AtlasModelInfo { - aclmdlDesc* model_desc = nullptr; - uint32_t model_id = 0; - aclmdlDataset* input_dataset = nullptr; - bool has_aipp = false; - aclAippInputFormat aipp_input_format = ACL_AIPP_RESERVED; + // AIPP Input + std::map aipp_input_format_map; }; +extern std::map> global_blob_om_model_info_map; +extern std::map global_stream_context_map; + } // namespace TNN_NS #endif // TNN_SOURCE_DEVICE_ATLAS_ATLAS_COMMON_TYPES_H_ diff --git a/source/tnn/device/atlas/atlas_context.cc b/source/tnn/device/atlas/atlas_context.cc new file mode 100644 index 000000000..145e5deec --- /dev/null +++ b/source/tnn/device/atlas/atlas_context.cc @@ -0,0 +1,110 @@ +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "tnn/device/atlas/atlas_context.h" + +namespace TNN_NS { + +AtlasContext::~AtlasContext() { + // Aclrt Stream is created and maintained by AtlasNetwork + // Do not Destroy aclrtStream HERE. + //if (this->aclrt_stream_ != nullptr) { + // ret = aclrtDestroyStream(this->aclrt_stream_); + // this->aclrt_stream_ = nullptr; + //} +} + +Status AtlasContext::Setup(int device_id) { + this->device_id_ = device_id; + return TNN_OK; +} + +Status AtlasContext::LoadLibrary(std::vector path) { + return TNN_OK; +} + +Status AtlasContext::GetCommandQueue(void** command_queue) { + // Reshape Model For different Model Types + if (this->model_type_ == MODEL_TYPE_TORCHSCRIPT) { + LOGE("Fail to GetCommandQueue, MODEL_TYPE_TORCHSCRIPT not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, MODEL_TYPE_TORCHSCRIPT not supported YET"); + } else if (this->model_type_ == MODEL_TYPE_TNN || this->model_type_ == MODEL_TYPE_RAPIDNET) { + LOGE("Fail to GetCommandQueue, MODEL_TYPE_TNN not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, MODEL_TYPE_TNN not supported YET"); + } else if (this->model_type_ == MODEL_TYPE_ATLAS) { + *command_queue = this->aclrt_stream_; + } else { + LOGE("Fail to GetCommandQueue, model type not supported.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to GetCommandQueue, model type not supported"); + } + + return TNN_OK; +} + +Status AtlasContext::SetCommandQueue(void* command_queue) { + return TNN_OK; +} + +Status AtlasContext::ShareCommandQueue(Context* context) { + return TNN_OK; +} + +Status AtlasContext::OnInstanceForwardBegin() { + return TNN_OK; +} + +Status AtlasContext::OnInstanceForwardEnd() { + return TNN_OK; +} + +Status AtlasContext::Synchronize() { + if (model_type_ == MODEL_TYPE_TNN || model_type_ == MODEL_TYPE_RAPIDNET || + model_type_ == MODEL_TYPE_ATLAS) { + aclError acl_ret = aclrtSynchronizeStream(this->aclrt_stream_); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("before forward synchronize stream failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "before forward synchronize stream failed"); + } + } + return TNN_OK; +} + +aclrtStream& AtlasContext::GetAclrtStream() { + return this->aclrt_stream_; +} + +void AtlasContext::SetAclrtStream(const aclrtStream& stream) { + this->aclrt_stream_ = stream; +} + +Status AtlasContext::CreateAclrtStream() { + // Create aclrt Stream + aclError acl_ret = aclrtCreateStream(&aclrt_stream_); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("acl create stream failed (acl error code: %d)\n", acl_ret); + } + return TNN_OK; +} + +ModelType& AtlasContext::GetModelType() { + return this->model_type_; +} + +void AtlasContext::SetModelType(ModelType model_type) { + this->model_type_ = model_type; +} + + + +} // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_context.h b/source/tnn/device/atlas/atlas_context.h new file mode 100644 index 000000000..f57cb7e0a --- /dev/null +++ b/source/tnn/device/atlas/atlas_context.h @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_ +#define TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_ + +#include "tnn/core/context.h" +#include "tnn/device/atlas/atlas_common_types.h" +#include "tnn/interpreter/raw_buffer.h" + +namespace TNN_NS { + +class AtlasContext : public Context { +public: + // @brief deconstructor + ~AtlasContext(); + + // @brief setup with specified device id + Status Setup(int device_id); + + // @brief load library + virtual Status LoadLibrary(std::vector path) override; + + // @brief get tnn command queue + // @param command_queue device command queue for forward + virtual Status GetCommandQueue(void** command_queue) override; + + // @brief set tnn command queue + // @param command_queue device command queue for forward + virtual Status SetCommandQueue(void* command_queue) override; + + // @brief share tnn command queue to another context + virtual Status ShareCommandQueue(Context* context); + + // @brief before instance forward + virtual Status OnInstanceForwardBegin() override; + + // @brief after instance forward + virtual Status OnInstanceForwardEnd() override; + + // @brief wait for jobs in the current context to complete + virtual Status Synchronize() override; + + // @brief get Atlas stream + aclrtStream& GetAclrtStream(); + + // @brief set Atlas stream + void SetAclrtStream(const aclrtStream& stream); + + // @brief create Atlas stream + Status CreateAclrtStream(); + + // @brief get ModelType + ModelType& GetModelType(); + + // @brief set ModelType + void SetModelType(ModelType model_type); + +private: + ModelType model_type_; + int device_id_ = INT_MAX; + + // ACL Runtime Related + aclrtStream aclrt_stream_ = nullptr; +}; + +} // namespace TNN_NS; + +#endif // TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_CONTEXT_H_ diff --git a/source/tnn/device/atlas/atlas_device.cc b/source/tnn/device/atlas/atlas_device.cc index 032546cae..4f57e6d46 100644 --- a/source/tnn/device/atlas/atlas_device.cc +++ b/source/tnn/device/atlas/atlas_device.cc @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making TNN available. // -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -12,8 +12,9 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -#include "tnn/device/atlas/atlas_device.h" #include "acl/ops/acl_dvpp.h" +#include "tnn/device/atlas/atlas_context.h" +#include "tnn/device/atlas/atlas_device.h" #include "tnn/utils/blob_memory_size_utils.h" #include "tnn/utils/dims_vector_utils.h" @@ -107,8 +108,17 @@ AbstractLayerAcc* AtlasDevice::CreateLayerAcc(LayerType type) { return nullptr; } -Context* AtlasDevice::CreateContext(int) { - return nullptr; +Context* AtlasDevice::CreateContext(int device_id) { + auto context = new AtlasContext(); + + Status ret = context->Setup(device_id); + if (ret != TNN_OK) { + LOGE("Cuda context setup failed."); + delete context; + return NULL; + } + + return context; } NetworkType AtlasDevice::ConvertAutoNetworkType() { diff --git a/source/tnn/device/atlas/atlas_mat_converter.cc b/source/tnn/device/atlas/atlas_mat_converter.cc index d51d9be6b..55510c48d 100644 --- a/source/tnn/device/atlas/atlas_mat_converter.cc +++ b/source/tnn/device/atlas/atlas_mat_converter.cc @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making TNN available. // -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -14,7 +14,6 @@ #include "tnn/device/atlas/atlas_mat_converter.h" #include "tnn/core/macro.h" -#include "tnn/device/atlas/atlas_runtime.h" #include "tnn/device/atlas/atlas_utils.h" #include "tnn/utils/data_format_converter.h" #include "tnn/utils/dims_vector_utils.h" @@ -106,11 +105,11 @@ Status AtlasMatConverterAcc::Copy(Mat& src, Mat& dst, void* command_queue) { return Status(TNNERR_NULL_PARAM, "init mat converter failed!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { - LOGE("get atlas command queue failed!\n"); - return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); - } + //auto atlas_cmd_queue = static_cast(command_queue); + //if (atlas_cmd_queue == nullptr) { + // LOGE("get atlas command queue failed!\n"); + // return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); + //} aclrtMemcpyKind memcpy_type; if (DEVICE_ATLAS == src.GetDeviceType() && DEVICE_ATLAS == dst.GetDeviceType()) { @@ -157,8 +156,13 @@ Status AtlasMatConverterAcc::Resize(Mat& src, Mat& dst, ResizeParam param, void* return Status(TNNERR_NULL_PARAM, "init mat converter failed!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { + //auto atlas_cmd_queue = static_cast(command_queue); + //if (atlas_cmd_queue == nullptr) { + // LOGE("get atlas command queue failed!\n"); + // return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); + //} + aclrtStream* stream_ptr = static_cast(command_queue); + if (stream_ptr == nullptr) { LOGE("get atlas command queue failed!\n"); return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); } @@ -196,7 +200,7 @@ Status AtlasMatConverterAcc::Resize(Mat& src, Mat& dst, ResizeParam param, void* } acl_ret = - acldvppVpcResizeAsync(dvpp_channel_desc_, input_desc_, output_desc_, resize_config, atlas_cmd_queue->stream); + acldvppVpcResizeAsync(dvpp_channel_desc_, input_desc_, output_desc_, resize_config, *stream_ptr); if (ACL_ERROR_NONE != acl_ret) { LOGE("acldvppVpcResizeAsync failed, ret = %d\n", acl_ret); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acldvppVpcResizeAsync failed"); @@ -210,7 +214,7 @@ Status AtlasMatConverterAcc::Resize(Mat& src, Mat& dst, ResizeParam param, void* resize_config = nullptr; } - aclrtSynchronizeStream(atlas_cmd_queue->stream); + aclrtSynchronizeStream(*stream_ptr); ret = ProcessOutput(dst); if (TNN_OK != ret) { @@ -226,8 +230,13 @@ Status AtlasMatConverterAcc::Crop(Mat& src, Mat& dst, CropParam param, void* com return Status(TNNERR_NULL_PARAM, "init mat converter failed!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { + //auto atlas_cmd_queue = static_cast(command_queue); + //if (atlas_cmd_queue == nullptr) { + // LOGE("get atlas command queue failed!\n"); + // return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); + //} + aclrtStream* stream_ptr = static_cast(command_queue); + if (stream_ptr == nullptr) { LOGE("get atlas command queue failed!\n"); return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); } @@ -257,13 +266,13 @@ Status AtlasMatConverterAcc::Crop(Mat& src, Mat& dst, CropParam param, void* com aclError acl_ret = ACL_ERROR_NONE; acl_ret = - acldvppVpcCropAsync(dvpp_channel_desc_, input_desc_, output_desc_, crop_roi_config, atlas_cmd_queue->stream); + acldvppVpcCropAsync(dvpp_channel_desc_, input_desc_, output_desc_, crop_roi_config, *stream_ptr); if (ACL_ERROR_NONE != acl_ret) { LOGE("acldvppVpcResizeAsync failed, ret = %d\n", acl_ret); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acldvppVpcResizeAsync failed"); } - aclrtSynchronizeStream(atlas_cmd_queue->stream); + aclrtSynchronizeStream(*stream_ptr); ret = ProcessOutput(dst); if (TNN_OK != ret) { @@ -287,12 +296,6 @@ Status AtlasMatConverterAcc::WarpAffine(Mat& src, Mat& dst, WarpAffineParam para return Status(TNNERR_NULL_PARAM, "init mat converter failed!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { - LOGE("get atlas command queue failed!\n"); - return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); - } - return Status(TNNERR_ATLAS_DVPP_NOT_SUPPORT, "atlas mat not support WarpAffine"); } @@ -302,12 +305,6 @@ Status AtlasMatConverterAcc::CvtColor(Mat& src, Mat& dst, ColorConversionType ty return Status(TNNERR_NULL_PARAM, "init mat converter failed!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { - LOGE("get atlas command queue failed!\n"); - return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); - } - return Status(TNNERR_ATLAS_DVPP_NOT_SUPPORT, "atlas mat not support CvtColor"); } @@ -317,12 +314,6 @@ Status AtlasMatConverterAcc::CopyMakeBorder(Mat& src, Mat& dst, CopyMakeBorderPa return Status(TNNERR_NULL_PARAM, "init mat converter failed!"); } - auto atlas_cmd_queue = static_cast(command_queue); - if (atlas_cmd_queue == nullptr) { - LOGE("get atlas command queue failed!\n"); - return Status(TNNERR_NULL_PARAM, "get atlas command queue failed!"); - } - return Status(TNNERR_ATLAS_DVPP_NOT_SUPPORT, "atlas mat not support CopyMakeBorder"); } @@ -650,6 +641,14 @@ Status AtlasMatConverterAcc::MatCopyAsync(Mat& dst, Mat& src, int dst_offset, vo return TNN_OK; } +Status AtlasMatConverterAcc::ResizeAndPaste(Mat& src, Mat& dst, ResizeParam param, PasteParam paste_param, void* command_queue) { + return TNN_OK; +} + +Status AtlasMatConverterAcc::ConcatMatWithBatch(std::vector& src_vec, Mat& dst, void* command_queue) { + return TNN_OK; +} + DECLARE_MAT_CONVERTER_CREATER(Atlas); REGISTER_MAT_CONVERTER(Atlas, DEVICE_ATLAS); diff --git a/source/tnn/device/atlas/atlas_mat_converter.h b/source/tnn/device/atlas/atlas_mat_converter.h index 9f0c1c4d9..dbbfa4ec6 100644 --- a/source/tnn/device/atlas/atlas_mat_converter.h +++ b/source/tnn/device/atlas/atlas_mat_converter.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making TNN available. // -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -32,6 +32,9 @@ class AtlasMatConverterAcc : public MatConverterAcc { virtual Status Crop(Mat& src, Mat& dst, CropParam param, void* command_queue = NULL) override; virtual Status WarpAffine(Mat& src, Mat& dst, WarpAffineParam param, void* command_queue = NULL) override; virtual Status CvtColor(Mat& src, Mat& dst, ColorConversionType type, void* command_queue = NULL) override; + virtual Status ResizeAndPaste(Mat& src, Mat& dst, ResizeParam param, PasteParam paste_param, + void* command_queue = NULL) override; + virtual Status ConcatMatWithBatch(std::vector& src_vec, Mat& dst, void* command_queue = NULL) override; virtual Status CopyMakeBorder(Mat& src, Mat& dst, CopyMakeBorderParam param, void* command_queue = NULL) override; private: diff --git a/source/tnn/device/atlas/atlas_model_interpreter.cc b/source/tnn/device/atlas/atlas_model_interpreter.cc deleted file mode 100644 index 24c834983..000000000 --- a/source/tnn/device/atlas/atlas_model_interpreter.cc +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2019 Tencent. All Rights Reserved - -#include "tnn/device/atlas/atlas_model_interpreter.h" - -#include -#include "tnn/device/atlas/atlas_utils.h" -#include "tnn/device/atlas/atlas_runtime.h" -#include "tnn/utils/split_utils.h" - -namespace TNN_NS { - -AtlasModelInterpreter::AtlasModelInterpreter() {} - -AtlasModelInterpreter::~AtlasModelInterpreter() { - LOGD("~AtlasModelInterpreter()\n"); - for (auto &item : model_weight_map_) { - if (nullptr != item.second.weight_mem_ptr && nullptr != item.second.context) { - aclError ret = aclrtSetCurrentContext(item.second.context); - if (ret != ACL_ERROR_NONE) { - LOGE("set context failed\n"); - } - - aclrtFree(item.second.weight_mem_ptr); - LOGD("acl free model weight ptr (device: %d)\n", item.first); - item.second.weight_mem_ptr = nullptr; - - ret = aclrtDestroyContext(item.second.context); - if (ret != ACL_ERROR_NONE) { - LOGE("destroy context failed\n"); - } - item.second.context = nullptr; - } - } - model_weight_map_.clear(); - model_weight_size_ = 0; - AtlasRuntime::DecreaseRef(); -} - -Status AtlasModelInterpreter::Interpret(std::vector ¶ms) { - model_config_.om_str = params[0]; - model_config_.is_path = false; - if (model_config_.om_str.length() < 1024) { - std::ifstream om_file(model_config_.om_str); - if (!om_file) { - LOGE("Invalied om file path! (param[0] : %s) take as memory content\n", model_config_.om_str.c_str()); - model_config_.is_path = false; - } else { - model_config_.is_path = true; - } - } - - // Init ACL - Status tnn_ret = AtlasRuntime::GetInstance()->Init(); - if (tnn_ret != TNN_OK) { - LOGE("acl init falied\n"); - return tnn_ret; - } - - size_t model_mem_size; - aclError acl_ret = ACL_ERROR_NONE; - if (model_config_.is_path) { - acl_ret = aclmdlQuerySize(model_config_.om_str.c_str(), &model_mem_size, &model_weight_size_); - } else { - acl_ret = aclmdlQuerySizeFromMem(model_config_.om_str.data(), model_config_.om_str.length(), &model_mem_size, - &model_weight_size_); - } - if (acl_ret != ACL_ERROR_NONE) { - LOGE("query model failed (%d)\n", acl_ret); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "query model failed"); - } - LOGD("atlas model weight size: %d model mem size: %d\n", model_weight_size_, model_mem_size); - - return TNN_OK; -} - -AtlasModelConfig& AtlasModelInterpreter::GetModelConfig() { - return model_config_; -} - -void* AtlasModelInterpreter::GetModelWeightsBufferPtr(int device_id) { - std::unique_lock lck(mutex_); - if (model_weight_map_.find(device_id) == model_weight_map_.end()) { - WeightPacket packet; - // create context related to device - aclError acl_ret = aclrtCreateContext(&packet.context, device_id); - if (acl_ret != ACL_ERROR_NONE) { - LOGE("acl create context failed (device %d) (acl error code: %d)\n", device_id, acl_ret); - return nullptr; - } - - // alloc device memory - acl_ret = aclrtMalloc(&packet.weight_mem_ptr, model_weight_size_, ACL_MEM_MALLOC_HUGE_FIRST); - if (acl_ret != ACL_ERROR_NONE) { - LOGE("malloc buffer for weight failed (ret=%d), require size is %zu\n", acl_ret, model_weight_size_); - return nullptr; - } - LOGD("malloc buffer for weight success (size %zu)\n", model_weight_size_); - - model_weight_map_[device_id] = packet; - } - - return model_weight_map_[device_id].weight_mem_ptr; -} - -size_t AtlasModelInterpreter::GetModelWeightsBufferSize() { - return model_weight_size_; -} - -TypeModelInterpreterRegister> g_atlas_model_interpreter_register( - MODEL_TYPE_ATLAS); - -} // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_model_interpreter.h b/source/tnn/device/atlas/atlas_model_interpreter.h deleted file mode 100644 index 1ca0d92fe..000000000 --- a/source/tnn/device/atlas/atlas_model_interpreter.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2019 Tencent. All Rights Reserved - -#ifndef TNN_SOURCE_DEVICE_ATLAS_ATLAS_MODEL_INTERPRETER_H_ -#define TNN_SOURCE_DEVICE_ATLAS_ATLAS_MODEL_INTERPRETER_H_ - -#include -#include -#include -#include -#include "atlas_common_types.h" -#include "tnn/core/macro.h" -#include "tnn/core/status.h" -#include "tnn/interpreter/abstract_model_interpreter.h" - -namespace TNN_NS { - -struct WeightPacket { - void* weight_mem_ptr = nullptr; - aclrtContext context = nullptr; -}; - -// @brief Atlas model interpreter interpret Atlas model -class AtlasModelInterpreter : public AbstractModelInterpreter { -public: - AtlasModelInterpreter(); - - // @brief virtual destructor - virtual ~AtlasModelInterpreter(); - - // @brief different interpreter has different order param - virtual Status Interpret(std::vector ¶ms); - - // @brief get model config info - AtlasModelConfig& GetModelConfig(); - - // @brief get buffer ptr for model weights - void* GetModelWeightsBufferPtr(int device_id); - - // @brief get buffer size for model weights - size_t GetModelWeightsBufferSize(); - -private: - AtlasModelConfig model_config_; - std::map model_weight_map_; - size_t model_weight_size_ = 0; - std::mutex mutex_; -}; - -} // namespace TNN_NS - -#endif // TNN_SOURCE_DEVICE_ATLAS_ATLAS_MODEL_INTERPRETER_H_ diff --git a/source/tnn/device/atlas/atlas_network.cc b/source/tnn/device/atlas/atlas_network.cc index da546a0fb..3f3f20772 100644 --- a/source/tnn/device/atlas/atlas_network.cc +++ b/source/tnn/device/atlas/atlas_network.cc @@ -1,137 +1,513 @@ -// Copyright 2019 Tencent. All Rights Reserved +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + -#include "tnn/device/atlas/atlas_network.h" #include #include #include "tnn/device/atlas/atlas_common_types.h" -#include "tnn/device/atlas/atlas_model_interpreter.h" -#include "tnn/device/atlas/atlas_runtime.h" +#include "tnn/device/atlas/atlas_network.h" +#include "tnn/device/atlas/atlas_om_model_interpreter.h" #include "tnn/device/atlas/atlas_utils.h" #include "tnn/utils/dims_vector_utils.h" #include "tnn/utils/dims_vector_utils.h" + namespace TNN_NS { NetworkImplFactoryRegister> g_network_impl_atlas_factory_register(NETWORK_TYPE_ATLAS); +// Default initialize global variable defined in "atlas_common_types.h" +std::map> global_blob_om_model_info_map; +std::map global_stream_context_map; + AtlasNetwork::~AtlasNetwork() { - if (need_to_deinit) { - DeInit(); + if (!this->network_init_called_) { + LOGD("TNN ATLAS Network DeInit() called without Inited, do nothing.\n"); + } + this->network_init_called_ = false; + + for (auto item : input_blob_map_) { + if (nullptr != item.second) { + delete item.second; + } + } + input_blob_map_.clear(); + + for (auto item : output_blob_map_) { + if (nullptr != item.second) { + delete item.second; + } + } + output_blob_map_.clear(); + + LOGD("TNN AtlasNetwork Destructor: aclmdl destroy input dataset\n"); + if (this->aclmdl_input_dataset_ != nullptr) { + DestroyDataset(this->aclmdl_input_dataset_); + } + LOGD("TNN AtlasNetwork Destructor: aclmdl destroy output dataset\n"); + if (this->aclmdl_output_dataset_ != nullptr) { + DestroyDataset(this->aclmdl_output_dataset_); + } + + if (this->model_type_ == MODEL_TYPE_ATLAS) { + // Release OM model related classes and resources. + aclError acl_ret; + if (this->om_model_info_->model_id != INT_MAX) { + LOGD("Unload ATLAS ACL Model id & Model Desc.\n"); + acl_ret = aclmdlUnload(this->om_model_info_->model_id); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("unload model failed, modelId is %u\n", this->om_model_info_->model_id); + } + } + + if (nullptr != this->om_model_info_->model_desc) { + (void)aclmdlDestroyDesc(this->om_model_info_->model_desc); + this->om_model_info_->model_desc = nullptr; + } + + AtlasContext* tnn_atlas_context = dynamic_cast(context_); + if(tnn_atlas_context == nullptr) { + LOGE("TNN ATLAS Network: fail to cast to tnn atlas context\n"); + } + if (tnn_atlas_context->GetAclrtStream() != nullptr) { + acl_ret = aclrtSetCurrentContext(om_model_info_->aclrt_context); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("TNN ATLAS Network: on destroy stream set context failed\n"); + } + acl_ret = aclrtDestroyStream(tnn_atlas_context->GetAclrtStream()); + LOGD("aclrt destroy stream\n"); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("TNN ATLAS Network: destroy stream failed\n"); + } + tnn_atlas_context->SetAclrtStream(nullptr); + } + + if (om_model_info_->aclrt_context != nullptr) { + acl_ret = aclrtDestroyContext(om_model_info_->aclrt_context); + LOGD("aclrt destroy aclrt context\n"); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("TNN ATLAS Network: destroy context failed\n"); + } + om_model_info_->aclrt_context = nullptr; + } + + if (nullptr != this->om_model_memory_ptr_) { + aclrtFree(this->om_model_memory_ptr_); + LOGD("Unload ATLAS ACL Model Memory.\n"); + this->om_model_memory_ptr_ = nullptr; + this->om_model_info_->memory_size = 0; + } + + if (nullptr != this->om_model_weight_ptr_) { + aclrtFree(this->om_model_weight_ptr_); + LOGD("Unload ATLAS ACL Model Weight.\n"); + this->om_model_weight_ptr_ = nullptr; + this->om_model_info_->weight_size = 0; + } } + + // Call DeInit() of DefaultNetwork + DeInit(); } -Status AtlasNetwork::Init(NetworkConfig &net_config, ModelConfig &model_config, AbstractModelInterpreter *interpreter, - InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, InputDataTypeMap inputs_data_type, - bool enable_const_folder) { - need_to_deinit = true; +Status AtlasNetwork::LoadOMModelFromFile(const std::string &om_file) { + // Step 1: Query Model Weight And Memory Size + aclError acl_ret = aclmdlQuerySize(om_file.c_str(), &(om_model_info_->memory_size), &(om_model_info_->weight_size)); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlQuerySize failed with Error Code: (%d)\n", acl_ret); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Atlas API: aclmdlQuerySize failed"); + } + LOGD("Load Atlas OM Model From FILE. Weight Size: %d, Memory Size: %d\n", om_model_info_->weight_size, om_model_info_->memory_size); - AtlasModelInterpreter *atlas_interpreter = dynamic_cast(interpreter); - model_weight_size_ = atlas_interpreter->GetModelWeightsBufferSize(); + // Step 2: Load Model & Alloc Model Memory + if (om_model_info_->memory_size > 0) { + acl_ret = aclrtMalloc(&om_model_memory_ptr_, om_model_info_->memory_size, ACL_MEM_MALLOC_HUGE_FIRST); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclrtMalloc for model memory failed, require size is %zu\n", om_model_info_->memory_size); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Atlas API: aclrtMalloc for model memory failed"); + } - // Init ACL - Status ret = AtlasRuntime::Init(); - if (ret != TNN_OK) { - LOGE("acl init falied\n"); - return ret; + acl_ret = aclmdlLoadFromFileWithMem(om_file.c_str(), &(om_model_info_->model_id), om_model_memory_ptr_, om_model_info_->memory_size, + om_model_weight_ptr_, om_model_info_->weight_size); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlLoadFromFileWithMem failed, model file is %s\n", om_file.c_str()); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Atlas API: aclmdlLoadFromFileWithMem failed"); + } + } else { + // Some model, e.g model Converted with atc config: --input_shape_range, + // Does not have model_mem_size, aclrtMalloc EMPTY mem is NOT ALLOWED. + acl_ret = aclmdlLoadFromFile(om_file.c_str(), &(om_model_info_->model_id)); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlLoadFromFile without memory failed, model file is %s\n", om_file.c_str()); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Atlas API: aclmdlLoadFromFile without memory failed"); + } } - // Set Device - ret = AtlasRuntime::GetInstance()->SetDevice(net_config.device_id); - if (ret != TNN_OK) { - LOGE("acl set device falied\n"); - return ret; + // Step 3: Create Model Desc to get Model Info + om_model_info_->model_desc = aclmdlCreateDesc(); + if (nullptr == om_model_info_->model_desc) { + LOGE("Atlas API: aclmdlCreateDesc failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "create model description failed"); } - // Get model weights buffer ptr - model_weight_ptr_ = atlas_interpreter->GetModelWeightsBufferPtr(net_config.device_id); - if (model_weight_ptr_ == nullptr) { - LOGE("get model weight buffer ptr falied\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get model weight buffer ptr falied"); + acl_ret = aclmdlGetDesc(om_model_info_->model_desc, om_model_info_->model_id); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlGetDesc failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get model description failed"); } - // Create Context - aclError acl_ret = aclrtCreateContext(&context_, net_config.device_id); + return TNN_OK; +} + +Status AtlasNetwork::LoadOMModelFromMemory(const std::string &om_content) { + // Step 1: Query Model Weight And Memory Size + aclError acl_ret = aclmdlQuerySizeFromMem(om_content.data(), om_content.length(), &(om_model_info_->memory_size), &(om_model_info_->weight_size)); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlQuerySizeFromMem failed with Error Code: (%d)\n", acl_ret); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Atlas API: aclmdlQuerySizeFromMem failed"); + } + LOGD("Load Atlas OM Model From MEMORY. Weight Size: %d, Memory Size: %d\n", om_model_info_->weight_size, om_model_info_->memory_size); + + // Step 2: Load Model & Alloc Model Memory + if (om_model_info_->memory_size > 0) { + acl_ret = aclrtMalloc(&om_model_memory_ptr_, om_model_info_->memory_size, ACL_MEM_MALLOC_HUGE_FIRST); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclrtMalloc for model memory failed, require size is %zu\n", om_model_info_->memory_size); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Atlas API: aclrtMalloc for model memory failed"); + } + + acl_ret = aclmdlLoadFromMemWithMem(om_content.data(), om_content.length(), &(om_model_info_->model_id), om_model_memory_ptr_, + om_model_info_->memory_size, om_model_weight_ptr_, om_model_info_->weight_size); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlLoadFromMemWithMem, load om content from memory with model memory failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Load om content from memory with model memory"); + } + } else { + // Some model, e.g model Converted with atc config: --input_shape_range, + // Does not need model_mem_size, + acl_ret = aclmdlLoadFromMem(om_content.data(), om_content.length(), &(om_model_info_->model_id)); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlLoadFromMem, load model from file without model memory failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Load model from file without model memory failed"); + } + } + + // Step 3: Create Model Desc to get Model Info + om_model_info_->model_desc = aclmdlCreateDesc(); + if (nullptr == om_model_info_->model_desc) { + LOGE("Atlas API: aclmdlCreateDesc failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "create model description failed"); + } + + acl_ret = aclmdlGetDesc(om_model_info_->model_desc, om_model_info_->model_id); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("Atlas API: aclmdlGetDesc failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get model description failed"); + } + + return TNN_OK; +} + +Status AtlasNetwork::DeduceOMModelDynamicMode() { + // ATC Converted HUAWEI atlas .om dynamic Models are devided into: + // + // 1. Traditional dynamic models with only 1 dynamic inputs. + // Min/Max value of the dynamic dim has been explicitly defined in ATC Conversion. + // ---- 1.1: + // dynmaic batch + // --input_shape="img:-1,2,224,224;img_info:-1,4" + // --dynamic_batch_size="1,2,4,8" + // ---- 1.2: + // dynamic hw size + // --input_shape="data:8,3,-1,-1;img_info:8,4,-1,-1" + // --dynamic_image_size="416,416;832,832" + // ---- 1.3 + // dynamic dims + // --input_shape="data:-1,1,256,256", --dynamic_dims="1,2" + // + // 2. More flexible dynamic input models. + // Min/Max Value is not explictly defined in ATC Conversion. + // ---- 2.1: + // input_shape_range + // --input_shape_range="input1:[8~20,3,5,-1];input2:[5,3~9,10,-1]" + // ---- 2.1: + // input_shape (without "dynamic_batch_size" or "dynamic_image_size") + // --input_shape="input1:[8~20,3,5,-1];input2:[5,3~9,10,-1]" + + // Get Number of Inputs by Calling ACL API + int count = aclmdlGetNumInputs(this->om_model_info_->model_desc); + LOGD("TNN Atlas Loaded OM Model have %d inputs.\n", count); + + // Type 1 OM model has an extra input called "ascend_mbatch_shape_data" + // Check if the input exists. + bool is_om_model_dynamic = false; + + for (int i = 0; i < count; i++) { + std::string input_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, i); + if (input_name.find(ACL_DYNAMIC_TENSOR_NAME) != std::string::npos) { + LOGD("Network is converted with dynamic batch/hw/dims.\n"); + is_om_model_dynamic = true; + } + } + + // Traditional Type 1 Dynamic + if (is_om_model_dynamic) { + if (count != 2) { + // TODO: SUPPORT Type 1 Model with more than ONE input in the future. + LOGD("Dynamic batch/hw/dims ATLAS with more than ONE input not supported yet.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, + "Dynamic batch/hw/dims ATLAS with more than ONE input not supported yet."); + } + + // TODO: Update this part for multiple inputs + for (int i = 0; i < count; i++) { + std::string input_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, i); + if (input_name.find(ACL_DYNAMIC_TENSOR_NAME) == std::string::npos) { + aclmdlIODims acl_dims; + aclError acl_ret = aclmdlGetInputDims(this->om_model_info_->model_desc, i, &acl_dims); + if (ACL_ERROR_NONE != acl_ret) { + LOGE("ACL API Call aclmdlGetInputDims falied!\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "ACL API Call aclmdlGetInputDims falied."); + } + + int minus_one_count = 0; + for (int d = 0; d < acl_dims.dimCount; d++) { + if (acl_dims.dims[d] == -1) { + minus_one_count++; + } + } + if (minus_one_count == 0) { + LOGE("The Only Input %s is not dynamic But Model is dynamic. Not Supported.\n", input_name.c_str()); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, + "The Only Input is not dynamic But Model is dynamic. Not Supported.."); + } + + if (minus_one_count == 1 && acl_dims.dims[0] == -1) { + LOGD("Deduced Dynamic Batch Mode from input: %s.\n", input_name.c_str()); + this->om_model_info_->dynamic_mode = AtlasOmModelDynamicMode::DynamicBatch; + return TNN_OK; + } + if (minus_one_count == 2 && acl_dims.dimCount == 4 && acl_dims.dims[2] == -1 && + acl_dims.dims[3] == -1) { + LOGD("Deduced Dynamic HW Mode from input: %s.\n", input_name.c_str()); + this->om_model_info_->dynamic_mode = AtlasOmModelDynamicMode::DynamicHW; + return TNN_OK; + } + // ELSE + LOGD("Deduced Generic Dynamic Dim Mode from input: %s.\n", input_name.c_str()); + this->om_model_info_->dynamic_mode = AtlasOmModelDynamicMode::GenericDynamic; + return TNN_OK; + } + } + } + + // No Dynamic Or Type 2 Dynamic Input by --input_shape_range + for (int i = 0; i < count; i++) { + aclmdlIODims acl_dims; + aclError acl_ret = aclmdlGetInputDims(this->om_model_info_->model_desc, i, &acl_dims); + if (ACL_ERROR_NONE != acl_ret) { + LOGE("ACL API Call aclmdlGetInputDims falied!\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "ACL API Call aclmdlGetInputDims falied."); + } + + int minus_one_count = 0; + for (int d = 0; d < acl_dims.dimCount; d++) { + if (acl_dims.dims[d] == -1) { + minus_one_count++; + } + } + + if (minus_one_count > 0) { + std::string input_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, i); + LOGD("Input: '%s' is dynamic by --input_shape_range.\n", input_name.c_str()); + this->om_model_info_->generic_dynamic_input_names.insert(input_name); + } + } + + if (this->om_model_info_->generic_dynamic_input_names.empty()) { + LOGD("No Dynamic Input.\n"); + } + return TNN_OK; +} + +Status AtlasNetwork::DeduceOMModelAIPPInputFormat() { + // Get Number of Inputs by Calling ACL API + int count = aclmdlGetNumInputs(this->om_model_info_->model_desc); + + for (int i = 0; i < count; i++) { + std::string input_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, i); + aclAippInfo aipp_info; + aclError acl_ret = aclmdlGetFirstAippInfo(this->om_model_info_->model_id, i, &aipp_info); + if (acl_ret == ACL_ERROR_NONE) { + LOGD("Found AIPP Input, shapeCount: %d srcDimNum: %d\n", aipp_info.shapeCount, aipp_info.srcDimNum); + this->om_model_info_->aipp_input_format_map[input_name] = aipp_info.inputFormat; + } + } +} + +Status AtlasNetwork::InitOMModel(ModelConfig &model_config, AbstractModelInterpreter *interpreter, + InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, + InputDataTypeMap inputs_data_type, bool enable_const_folder) { + AtlasOMModelInterpreter *om_interpreter = dynamic_cast(interpreter); + CHECK_PARAM_NULL(om_interpreter); + AtlasContext* atlas_context = dynamic_cast(context_); + CHECK_PARAM_NULL(atlas_context); + + std::string& om_str = om_interpreter->GetOmString(); + + // Part 1: Load(Interpret) Model. Aclrt load OM model will directly load model onto Device + // So it can only be called in AtlasNetwork, not ModelInterpreter + // Step 1: Create OM Model Info, aclrt load_model_context & load model stream + this->om_model_info_ = std::make_shared(); + + aclError acl_ret = aclrtSetDevice(this->config_.device_id); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("acl open device %d failed (acl error code: %d)\n", this->config_.device_id, acl_ret); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl open device falied"); + } + acl_ret = aclrtCreateContext(&(om_model_info_->aclrt_context), this->config_.device_id); if (acl_ret != ACL_ERROR_NONE) { LOGE("acl create context failed (acl error code: %d)\n", acl_ret); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl create context falied"); } - - // Create Stream - acl_ret = aclrtCreateStream(&stream_); + acl_ret = aclrtSetCurrentContext(om_model_info_->aclrt_context); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("TNN ATLAS OM Model Interpreter: on destroy stream set context failed\n"); + } + aclrtStream aclrt_stream; + acl_ret = aclrtCreateStream(&aclrt_stream); if (acl_ret != ACL_ERROR_NONE) { LOGE("acl create stream failed (acl error code: %d)\n", acl_ret); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl create stream falied"); } - - command_queue_.reset(new AtlasCommandQueue()); - command_queue_->context = context_; - command_queue_->stream = stream_; - - // Load model - if (atlas_interpreter->GetModelConfig().is_path) { - LOGD("load model form file\n"); - ret = LoadModelFromFile(atlas_interpreter->GetModelConfig().om_str); + atlas_context->SetAclrtStream(aclrt_stream); + global_stream_context_map[atlas_context->GetAclrtStream()] = om_model_info_->aclrt_context; + + // Step 2: Load Model From Path or From Memory + // Determine OM string is model path or model content + Status tnn_ret; + if (om_str.length() < 1024) { + std::ifstream om_file(om_str); + if (!om_file) { + LOGE("Invalied om file path! (om_str : %s) maybe as memory content\n", om_str.c_str()); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Invalied om file Path, cannot determine if om_str is Path or Model Content."); + } + tnn_ret = LoadOMModelFromFile(om_str); + if (tnn_ret != TNN_OK) { + LOGE("TNN Atlas Load OM Model from File Failed.\n"); + return tnn_ret; + } } else { - LOGD("load model form memory\n"); - ret = LoadModelFromMemory(atlas_interpreter->GetModelConfig().om_str); + tnn_ret = LoadOMModelFromMemory(om_str); + if (tnn_ret != TNN_OK) { + LOGE("TNN Atlas Load OM Model from Model Content Failed.\n"); + return tnn_ret; + } + } + // Synchronize Device and Destroy Model Load Stream + acl_ret = aclrtSynchronizeDevice(); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("acl device synchronize failed (acl error code: %d)\n", acl_ret); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl device synchronize falied"); } - if (ret != TNN_OK) - return ret; - // deduce if dynamic input exists - // get type of dynamic input if exists - // type 1: Traditional Types - // --dynamic_batch_size - // --dynamic_image_size (hw) - // --dynamic_dims - // type 2: Flexible Dynamic - // --input_shape_range - ret = DeduceDynamicInputType(); - if (ret != TNN_OK) - return ret; + // Step 3: Deduce Atlas OM Model Dynamic Type + tnn_ret = DeduceOMModelDynamicMode(); + if (tnn_ret != TNN_OK) { + LOGE("TNN Atlas Deduce Model Dynamic Mode Failed.\n"); + return tnn_ret; + } - // allocate input and output - ret = AllocateDatasetCreateBlob(&input_, max_inputs_shape, true); - if (ret != TNN_OK) - return ret; - ret = AllocateDatasetCreateBlob(&output_, max_inputs_shape, false); - if (ret != TNN_OK) - return ret; - - // add model info - AtlasModelInfo model_info; - model_info.model_desc = model_desc_; - model_info.model_id = model_id_; - model_info.input_dataset = input_; - model_info.has_aipp = has_aipp_; - for (auto item : input_blob_map_) { - if (aipp_input_format_map_.find(item.first) != aipp_input_format_map_.end()) - model_info.aipp_input_format = aipp_input_format_map_[item.first]; - else - model_info.aipp_input_format = ACL_AIPP_RESERVED; - AtlasRuntime::GetInstance()->AddModelInfo(item.second, model_info); + // Step 4: Deduce Atlas OM Model AIPP input format if input is AIPP Mode. + tnn_ret = DeduceOMModelAIPPInputFormat(); + if (tnn_ret != TNN_OK) { + LOGE("TNN Atlas Deduce Model AIPP input format Failed.\n"); + return tnn_ret; } - // set dynamic batch size - // must do if input is dynamic batch - if (this->atc_mode_dynamic_batch_hw_dim_) { + + // Part 2: Allocate Input/Output, Reshape etc. + // Step 5: allocate input and output + tnn_ret = AllocateDatasetCreateBlob(&aclmdl_input_dataset_, max_inputs_shape, true); + if (tnn_ret != TNN_OK) + return tnn_ret; + tnn_ret = AllocateDatasetCreateBlob(&aclmdl_output_dataset_, max_inputs_shape, false); + if (tnn_ret != TNN_OK) + return tnn_ret; + + // Step 6: set dynamic batch size + // must do if input is dynamic batch + if (this->om_model_info_->dynamic_mode != AtlasOmModelDynamicMode::Static) { for (auto item : input_blob_map_) { - ret = SetDynamicBatchSize(item.first, item.second->GetBlobDesc().dims[0]); - if (ret != TNN_OK) - return ret; + tnn_ret = SetDynamicBatchSize(item.first, item.second->GetBlobDesc().dims[0]); + if (tnn_ret != TNN_OK) + return tnn_ret; } } - // reshape if needed - ret = Reshape(max_inputs_shape); - if (ret != TNN_OK) - return ret; + // Step 7: reshape if needed + tnn_ret = Reshape(max_inputs_shape); + if (tnn_ret != TNN_OK) + return tnn_ret; return TNN_OK; } +Status AtlasNetwork::Init(NetworkConfig &net_config, ModelConfig &model_config, AbstractModelInterpreter *interpreter, + InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, InputDataTypeMap inputs_data_type, + bool enable_const_folder) { + this->network_init_called_ = true; + this->config_ = net_config; + this->model_type_ = model_config.model_type; + + // GetDevice and Context + this->device_ = GetDevice(net_config.device_type); + CHECK_PARAM_NULL(this->device_); + this->context_ = device_->CreateContext(net_config.device_id); + CHECK_PARAM_NULL(this->context_); + + // Set AtlasContext model type + AtlasContext* atlas_context = dynamic_cast(context_); + CHECK_PARAM_NULL(atlas_context); + atlas_context->SetModelType(model_config.model_type); + + // Init Model For different Model Types + if (model_config.model_type == MODEL_TYPE_TORCHSCRIPT) { + LOGE("Fail to init AtlasNetwork, MODEL_TYPE_TORCHSCRIPT not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to init AtlasNetwork, MODEL_TYPE_TORCHSCRIPT not supported YET"); + } else if (model_config.model_type == MODEL_TYPE_TNN || + model_config.model_type == MODEL_TYPE_RAPIDNET) { + LOGE("Fail to init AtlasNetwork, MODEL_TYPE_TNN not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to init AtlasNetwork, MODEL_TYPE_TNN not supported YET"); + } else if (model_config.model_type == MODEL_TYPE_ATLAS) { + return InitOMModel(model_config, interpreter, min_inputs_shape, max_inputs_shape, + inputs_data_type, enable_const_folder); + } else { + LOGE("Fail to init AtlasNetwork, model type not supported.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to init AtlasNetwork, model type not supported"); + } +} + + Status AtlasNetwork::GetForwardMemorySize(size_t &memory_size) { - memory_size = model_mem_size_; + if (model_type_ == MODEL_TYPE_ATLAS) { + if (!om_model_info_) { + LOGE("Unable to Get ForwardMemorySize, ATLAS om ModelInfo Missing.\n"); + return Status(TNNERR_DEVICE_NOT_SUPPORT, "Unable to Get ForwardMemorySize, ATLAS om ModelInfo Missing."); + } + memory_size = om_model_info_->memory_size + om_model_info_->weight_size; + } return TNN_OK; } @@ -154,22 +530,19 @@ Status AtlasNetwork::GetAllOutputBlobs(BlobMap &blobs) { blobs = output_blob_map_; return TNN_OK; } - -// @brief get atlas model id of current network -uint32_t AtlasNetwork::GetModelId() const { - return this->model_id_; -} - -// @brief get atlas model desc of current network -aclmdlDesc* AtlasNetwork::GetModelDesc() const { - return this->model_desc_; + +std::shared_ptr AtlasNetwork::GetOMModelInfo() { + return this->om_model_info_; } -Status AtlasNetwork::Reshape(const InputShapesMap &inputs) { - aclError ret = aclrtSetCurrentContext(context_); - if (ret != ACL_ERROR_NONE) { - LOGE("set context failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set context failed"); +Status AtlasNetwork::ReshapeOMModel(const InputShapesMap &inputs) { + AtlasContext* atlas_context = dynamic_cast(context_); + CHECK_PARAM_NULL(atlas_context); + + aclError acl_ret = aclrtSetCurrentContext(om_model_info_->aclrt_context); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("ReshapeOMModel set context failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "ReshapeOMModel set context failed"); } for (auto item : inputs) { @@ -194,7 +567,7 @@ Status AtlasNetwork::Reshape(const InputShapesMap &inputs) { } // Traditional Dynamic Batch, Set Input/Output Blob Shape. - if (this->atc_mode_dynamic_batch_) { + if (this->om_model_info_->dynamic_mode == AtlasOmModelDynamicMode::DynamicBatch) { Status tnn_ret = SetDynamicBatchSize(item.first, dims[0]); if (TNN_OK != tnn_ret) return tnn_ret; @@ -203,8 +576,8 @@ Status AtlasNetwork::Reshape(const InputShapesMap &inputs) { // Range input for Model Converted with --input_shape_range // Range input output shape cannot be infered from input shape. // Output Shape will be deduced after ACL Forward() API is called. - if (this->dynamic_input_shape_range_names_.find(item.first) != - this->dynamic_input_shape_range_names_.end()) { + if (this->om_model_info_->generic_dynamic_input_names.find(item.first) != + this->om_model_info_->generic_dynamic_input_names.end()) { Status tnn_ret = SetRangeDynamicInputDim(item.first, dims); if (TNN_OK != tnn_ret) return tnn_ret; @@ -215,230 +588,84 @@ Status AtlasNetwork::Reshape(const InputShapesMap &inputs) { return TNN_OK; } -Status AtlasNetwork::DeInit() { - aclError ret = ACL_ERROR_NONE; - if (nullptr != context_) { - ret = aclrtSetCurrentContext(context_); - if (ret != ACL_ERROR_NONE) { - LOGE("set context failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set context failed"); - } - } - for (auto item : input_blob_map_) { - if (nullptr != item.second) { - // delete model info - AtlasRuntime::GetInstance()->DelModelInfo(item.second); - delete item.second; - } - } - input_blob_map_.clear(); - for (auto item : output_blob_map_) { - if (nullptr != item.second) { - delete item.second; - } - } - output_blob_map_.clear(); - - LOGD("acl destroy input dataset\n"); - if (nullptr != input_) { - DestroyDataset(input_); - } - LOGD("acl destroy output dataset\n"); - if (nullptr != output_) { - DestroyDataset(output_); - } - - UnloadModel(); - - if (nullptr != stream_) { - ret = aclrtDestroyStream(stream_); - LOGD("acl destroy stream\n"); - if (ret != ACL_ERROR_NONE) { - LOGE("destroy stream failed\n"); - } - stream_ = nullptr; - } - - if (nullptr != context_) { - ret = aclrtDestroyContext(context_); - LOGD("acl destroy context\n"); - if (ret != ACL_ERROR_NONE) { - LOGE("destroy context failed\n"); - } - context_ = nullptr; +Status AtlasNetwork::Reshape(const InputShapesMap &inputs) { + // Reshape Model For different Model Types + if (this->model_type_ == MODEL_TYPE_TORCHSCRIPT) { + LOGE("Fail to reshape AtlasNetwork, MODEL_TYPE_TORCHSCRIPT not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to reshape AtlasNetwork, MODEL_TYPE_TORCHSCRIPT not supported YET"); + } else if (this->model_type_ == MODEL_TYPE_TNN || this->model_type_ == MODEL_TYPE_RAPIDNET) { + LOGE("Fail to reshape AtlasNetwork, MODEL_TYPE_TNN not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to reshape AtlasNetwork, MODEL_TYPE_TNN not supported YET"); + } else if (this->model_type_ == MODEL_TYPE_ATLAS) { + return ReshapeOMModel(inputs); + } else { + LOGE("Fail to reshape AtlasNetwork, model type not supported.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to reshape AtlasNetwork, model type not supported"); } - - AtlasRuntime::DecreaseRef(); - return TNN_OK; } Status AtlasNetwork::GetCommandQueue(void **command_queue) { - *command_queue = command_queue_.get(); - return TNN_OK; + return context_->GetCommandQueue(command_queue); } Status AtlasNetwork::Forward() { - LOGD("Atlas Forward!\n"); - - aclError ret = aclrtSetCurrentContext(context_); - if (ret != ACL_ERROR_NONE) { - LOGE("set context & synchronize stream failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set context & synchronized failed"); - } - - ret = aclrtSynchronizeStream(stream_); - if (ret != ACL_ERROR_NONE) { - LOGE("before forward synchronize stream failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "before forward synchronize stream failed"); - } - - ret = aclmdlExecute(model_id_, input_, output_); - if (ret != ACL_ERROR_NONE) { - LOGE("execute model failed, modelId is %u\n", model_id_); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "execute model failed"); - } - - // For Range Dynamic Models with --input_shape_range - // Update Output Blob Shapes here. - if (!this->dynamic_input_shape_range_names_.empty()) { - Status tnn_ret = UpdateRangeDynamicOutputDims(); - if (TNN_OK != tnn_ret) { - return tnn_ret; - } - } - - return TNN_OK; -} - -Status AtlasNetwork::ForwardAsync(Callback call_back) { - LOGD("Atlas Async Forward! (as same as Forward by now)\n"); - return Forward(); -} - -Status AtlasNetwork::LoadModelFromFile(const std::string &om_file) { - size_t temp_size; - aclError ret = aclmdlQuerySize(om_file.c_str(), &model_mem_size_, &temp_size); - if (ret != ACL_ERROR_NONE) { - LOGE("query model failed (ret=%d), model file is %s\n", ret, om_file.c_str()); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "query model failed"); - } - LOGD("atlas model mem size: %d\n", model_mem_size_); - - // Some model, e.g model Converted with atc config: --input_shape_range, - // Does not have model_mem_size, aclrtMalloc EMPTY mem is NOT ALLOWED. - if (model_mem_size_) { - ret = aclrtMalloc(&model_mem_ptr_, model_mem_size_, ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_ERROR_NONE) { - LOGE("malloc buffer for mem failed, require size is %zu\n", model_mem_size_); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "malloc buffer for mem failed"); + // Reshape Model For different Model Types + if (this->model_type_ == MODEL_TYPE_TORCHSCRIPT) { + LOGE("Fail to execute AtlasNetwork, MODEL_TYPE_TORCHSCRIPT not supported YET.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to execute AtlasNetwork, MODEL_TYPE_TORCHSCRIPT not supported YET"); + } else if (this->model_type_ == MODEL_TYPE_TNN || this->model_type_ == MODEL_TYPE_RAPIDNET || + this->model_type_ == MODEL_TYPE_ATLAS) { + LOGD("Atlas Forward!\n"); + AtlasContext* atlas_context = dynamic_cast(context_); + CHECK_PARAM_NULL(atlas_context); + + aclError acl_ret = aclrtSetCurrentContext(om_model_info_->aclrt_context); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("ReshapeOMModel set context failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "ReshapeOMModel set context failed"); } - ret = aclmdlLoadFromFileWithMem(om_file.c_str(), &model_id_, model_mem_ptr_, model_mem_size_, model_weight_ptr_, - model_weight_size_); - if (ret != ACL_ERROR_NONE) { - LOGE("load model from file with mem failed, model file is %s\n", om_file.c_str()); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "load model from file with mem failed"); - } - } else { - ret = aclmdlLoadFromFile(om_file.c_str(), &model_id_); - if (ret != ACL_ERROR_NONE) { - LOGE("load model from file without mem failed, model file is %s\n", om_file.c_str()); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "load model from file without mem failed"); + acl_ret = aclrtSynchronizeStream(atlas_context->GetAclrtStream()); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("before forward synchronize stream failed\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "before forward synchronize stream failed"); } - } - // create model desc to get model info - model_desc_ = aclmdlCreateDesc(); - if (nullptr == model_desc_) { - LOGE("create model description failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "create model description failed"); - } - - ret = aclmdlGetDesc(model_desc_, model_id_); - if (ret != ACL_ERROR_NONE) { - LOGE("get model description failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get model description failed"); - } - - return TNN_OK; -} - -Status AtlasNetwork::LoadModelFromMemory(const std::string &om_content) { - size_t temp_size; - aclError ret = aclmdlQuerySizeFromMem(om_content.data(), om_content.length(), &model_mem_size_, &temp_size); - if (ret != ACL_ERROR_NONE) { - LOGE("query model failed (ret=%d)\n", ret); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "query model failed"); - } - LOGD("atlas model mem size: %d\n", model_mem_size_); - - // Some model, e.g model Converted with atc config: --input_shape_range, - // Does not need model_mem_size, - if (model_mem_size_) { - ret = aclrtMalloc(&model_mem_ptr_, model_mem_size_, ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_ERROR_NONE) { - LOGE("malloc buffer for mem failed, require size is %zu\n", model_mem_size_); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "malloc buffer for mem failed"); + acl_ret = aclmdlExecute(this->om_model_info_->model_id, aclmdl_input_dataset_, aclmdl_output_dataset_); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("execute model failed, modelId is %u\n", this->om_model_info_->model_id); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "execute model failed"); } - ret = aclmdlLoadFromMemWithMem(om_content.data(), om_content.length(), &model_id_, model_mem_ptr_, - model_mem_size_, model_weight_ptr_, model_weight_size_); - if (ret != ACL_ERROR_NONE) { - LOGE("load model from file with mem failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "load model from file with mem failed"); + // For Range Dynamic Models with --input_shape_range + // Update Output Blob Shapes here. + if (!this->om_model_info_->generic_dynamic_input_names.empty()) { + Status tnn_ret = UpdateRangeDynamicOutputDims(); + if (TNN_OK != tnn_ret) { + return tnn_ret; + } } } else { - ret = aclmdlLoadFromMem(om_content.data(), om_content.length(), &model_id_); - if (ret != ACL_ERROR_NONE) { - LOGE("load model from file without mem failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "load model from file without mem failed"); - } - } - - // create model desc to get model info - model_desc_ = aclmdlCreateDesc(); - if (nullptr == model_desc_) { - LOGE("create model description failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "create model description failed"); - } - - ret = aclmdlGetDesc(model_desc_, model_id_); - if (ret != ACL_ERROR_NONE) { - LOGE("get model description failed\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get model description failed"); + LOGE("Fail to reshape AtlasNetwork, model type not supported.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to reshape AtlasNetwork, model type not supported"); } return TNN_OK; } -void AtlasNetwork::UnloadModel() { - aclError ret = aclmdlUnload(model_id_); - LOGD("acl unload model\n"); - if (ret != ACL_ERROR_NONE) { - LOGE("unload model failed, modelId is %u\n", model_id_); - } - - if (nullptr != model_desc_) { - (void)aclmdlDestroyDesc(model_desc_); - LOGD("acl destroy model desc\n"); - model_desc_ = nullptr; - } - - if (nullptr != model_mem_ptr_) { - aclrtFree(model_mem_ptr_); - LOGD("acl free model mem ptr\n"); - model_mem_ptr_ = nullptr; - model_mem_size_ = 0; - } +Status AtlasNetwork::ForwardAsync(Callback call_back) { + LOGD("Atlas Async Forward! (as same as Forward by now)\n"); + return Forward(); } + Status AtlasNetwork::AllocateDatasetCreateBlob(aclmdlDataset **data_set, const InputShapesMap &max_input_shapes_map, bool is_input) { // This Function should be called twice. // Input should be called first, then output should also be called. - if (nullptr == model_desc_) { + if (nullptr == om_model_info_->model_desc) { LOGE("no model description, create ouput failed\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "no model description, create ouput failed"); } @@ -453,27 +680,22 @@ Status AtlasNetwork::AllocateDatasetCreateBlob(aclmdlDataset **data_set, const I size_t count = 0; if (is_input) { - count = aclmdlGetNumInputs(model_desc_); + count = aclmdlGetNumInputs(om_model_info_->model_desc); LOGD("AllocateDataset for input (count=%d)\n", count); } else { - count = aclmdlGetNumOutputs(model_desc_); + count = aclmdlGetNumOutputs(om_model_info_->model_desc); LOGD("AllocateDataset for output (count=%d)\n", count); - - //ret = InferOutputShapeIfNecessery(); - //if (ret != TNN_OK) - // return ret; } - for (size_t i = 0; i < count; ++i) { size_t buffer_size = 0; // OM Model Converted with atc config "--input_shape_range" // does not have buffer_size info. buffer_size should be provided externally // from MAX_INPUTS_SHAPE in "tnn::CreateInst() API" if (is_input) { - buffer_size = aclmdlGetInputSizeByIndex(model_desc_, i); + buffer_size = aclmdlGetInputSizeByIndex(om_model_info_->model_desc, i); if (buffer_size == 0) { - std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i); + std::string input_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, i); auto iter = max_input_shapes_map.find(input_name); if (iter == max_input_shapes_map.end()) { LOGE("Shape of dynamic input: %s, not found in max_input_shapes_map.\n", input_name.c_str()); @@ -483,9 +705,9 @@ Status AtlasNetwork::AllocateDatasetCreateBlob(aclmdlDataset **data_set, const I buffer_size = sizeof(int64_t)*DimsVectorUtils::Count(iter->second); } } else { - buffer_size = aclmdlGetOutputSizeByIndex(model_desc_, i); + buffer_size = aclmdlGetOutputSizeByIndex(om_model_info_->model_desc, i); if (buffer_size == 0) { - std::string output_name = aclmdlGetOutputNameByIndex(model_desc_, i); + std::string output_name = aclmdlGetOutputNameByIndex(om_model_info_->model_desc, i); auto iter = max_input_shapes_map.find(output_name); if (iter == max_input_shapes_map.end()) { LOGE("Shape of dynamic output: %s, not found in max_input_shapes_map.\n", output_name.c_str()); @@ -530,9 +752,9 @@ Status AtlasNetwork::AllocateDatasetCreateBlob(aclmdlDataset **data_set, const I // Create Tensor Desc for dynamic Input // https://www.hiascend.com/document/detail/zh/canncommercial/601/inferapplicationdev/atctool/atctool_0053.html if (is_input) { - std::string input_name = aclmdlGetInputNameByIndex(this->model_desc_, i); - if (this->dynamic_input_shape_range_names_.find(input_name) != - this->dynamic_input_shape_range_names_.end()) { + std::string input_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, i); + if (om_model_info_->generic_dynamic_input_names.find(input_name) != + om_model_info_->generic_dynamic_input_names.end()) { auto iter = max_input_shapes_map.find(input_name); if (iter == max_input_shapes_map.end()) { LOGE("MAX shape of Dynamic Input Range input '%s' not found.\n", input_name.c_str()); @@ -546,8 +768,8 @@ Status AtlasNetwork::AllocateDatasetCreateBlob(aclmdlDataset **data_set, const I // Input TensorDesc should only be created ONCE. // It will be destroyed in DeInit() aclTensorDesc *input_desc = - aclCreateTensorDesc(aclmdlGetInputDataType(this->model_desc_, i), iter->second.size(), dim_arr, - aclmdlGetInputFormat(this->model_desc_, i)); + aclCreateTensorDesc(aclmdlGetInputDataType(om_model_info_->model_desc, i), iter->second.size(), dim_arr, + aclmdlGetInputFormat(om_model_info_->model_desc, i)); acl_ret = aclmdlSetDatasetTensorDesc(*data_set, input_desc, i); if (acl_ret != ACL_ERROR_NONE) { LOGE("API aclmdlSetDatasetTensorDesc failed for input '%s'.\n", input_name.c_str()); @@ -561,12 +783,12 @@ Status AtlasNetwork::AllocateDatasetCreateBlob(aclmdlDataset **data_set, const I } Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, size_t index, void *data, bool is_input) { - if (nullptr == model_desc_) { + if (om_model_info_->model_desc == nullptr) { LOGE("no model description\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "no model description"); } - Status ret = TNN_OK; + Status ret = TNN_OK; std::string blob_name = ""; std::vector io_dims; aclDataType data_type; @@ -575,7 +797,7 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si io_dims.clear(); if (is_input) { // get blob name - blob_name = aclmdlGetInputNameByIndex(model_desc_, index); + blob_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, index); // skip dynamic aipp input if (blob_name.find(ACL_DYNAMIC_AIPP_NAME) != std::string::npos) { LOGD("find dynamic aipp input (%s) and skip...\n", blob_name.c_str()); @@ -584,7 +806,7 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si // skip dynamic batch input if (blob_name.find(ACL_DYNAMIC_TENSOR_NAME) != std::string::npos) { LOGD("find dynamic batch/hw/dims input (%s) and skip...\n", blob_name.c_str()); - atc_mode_dynamic_batch_hw_dim_ = true; + //atc_mode_dynamic_batch_hw_dim_ = true; //dynamic_batch_name_.push_back(blob_name); return TNN_OK; } @@ -597,8 +819,8 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si // If "max_input_shapes" is externally provided. // Set io_dims to max_input_shape. auto max_input_shape_iter = max_input_shapes_map.find(blob_name); - auto max_input_range_iter = this->dynamic_input_shape_range_names_.find(blob_name); - if (max_input_range_iter != this->dynamic_input_shape_range_names_.end() && + auto max_input_range_iter = om_model_info_->generic_dynamic_input_names.find(blob_name); + if (max_input_range_iter != om_model_info_->generic_dynamic_input_names.end() && max_input_shape_iter == max_input_shapes_map.end()) { // For MODELS with '--input_shape_range' type dynamic input, // external "max_input_shapes" is REQUIRED. @@ -619,18 +841,18 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si } } else { // get blob name - blob_name = aclmdlGetOutputNameByIndex(model_desc_, index); + blob_name = aclmdlGetOutputNameByIndex(om_model_info_->model_desc, index); // get dims info aclmdlIODims acl_dims; - aclError acl_ret = aclmdlGetOutputDims(model_desc_, index, &acl_dims); + aclError acl_ret = aclmdlGetOutputDims(om_model_info_->model_desc, index, &acl_dims); if (acl_ret != ACL_ERROR_NONE) { LOGE("can't get output dims\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "can't get output dims"); } - if (this->atc_mode_dynamic_batch_) { + if (om_model_info_->dynamic_mode == AtlasOmModelDynamicMode::DynamicBatch) { // get dims0 - int max_batch = GetMaxBatchSize(model_desc_, 1); + int max_batch = GetMaxBatchSize(om_model_info_->model_desc, 1); if (0 == max_batch) { LOGE("get batch size failed\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get batch size failed"); @@ -638,9 +860,9 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si output_dim0_map_[blob_name] = std::max(1, (int)acl_dims.dims[0] / max_batch); } // get data type - data_type = aclmdlGetOutputDataType(model_desc_, index); + data_type = aclmdlGetOutputDataType(om_model_info_->model_desc, index); // get data format - data_format = aclmdlGetOutputFormat(model_desc_, index); + data_format = aclmdlGetOutputFormat(om_model_info_->model_desc, index); for (int i = 0; i < acl_dims.dimCount; ++i) { io_dims.push_back((int)acl_dims.dims[i]); } @@ -691,6 +913,10 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si blob_handle.base = data; Blob *blob = new Blob(blob_desc, blob_handle); + + // Add Blob To global_blob_om_model_map + global_blob_om_model_info_map[blob] = om_model_info_; + LOGD("Added Blob to global_blob_model_info_map, map.size = %d\n", global_blob_om_model_info_map.size()); if (is_input) { input_blob_map_[blob_name] = blob; @@ -703,17 +929,14 @@ Status AtlasNetwork::AddBlobToMap(const InputShapesMap &max_input_shapes_map, si Status AtlasNetwork::GetInputInfo(size_t index, std::vector &input_dims, aclFormat &input_format, aclDataType &input_data_type) { - std::string blob_name = aclmdlGetInputNameByIndex(model_desc_, index); + std::string blob_name = aclmdlGetInputNameByIndex(om_model_info_->model_desc, index); aclAippInfo aipp_info; - aclError acl_ret = aclmdlGetFirstAippInfo(model_id_, index, &aipp_info); + aclError acl_ret = aclmdlGetFirstAippInfo(om_model_info_->model_id, index, &aipp_info); input_dims.clear(); if (ACL_ERROR_NONE == acl_ret) { // has static aipp - has_aipp_ = true; LOGD("shapeCount: %d srcDimNum: %d\n", aipp_info.shapeCount, aipp_info.srcDimNum); - // get aipp input format - aipp_input_format_map_[blob_name] = aipp_info.inputFormat; // get data format input_format = aipp_info.srcFormat; @@ -739,25 +962,23 @@ Status AtlasNetwork::GetInputInfo(size_t index, std::vector &input_dims, ac } } else { LOGD("get aipp info failed (ret=%d), use input info directly\n", acl_ret); - // get aipp input format - aipp_input_format_map_[blob_name] = ACL_AIPP_RESERVED; // get data format - input_format = aclmdlGetInputFormat(model_desc_, index); + input_format = aclmdlGetInputFormat(om_model_info_->model_desc, index); // get data type - input_data_type = aclmdlGetInputDataType(model_desc_, index); + input_data_type = aclmdlGetInputDataType(om_model_info_->model_desc, index); // get dims info aclmdlIODims acl_dims; - aclError acl_ret = aclmdlGetInputDims(model_desc_, index, &acl_dims); + aclError acl_ret = aclmdlGetInputDims(om_model_info_->model_desc, index, &acl_dims); if (acl_ret != ACL_ERROR_NONE) { LOGE("can't get input dims\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "can't get input dims"); } // in dynamic batch input, reset batch if (-1 == acl_dims.dims[0]) { - auto buffer_size = aclmdlGetInputSizeByIndex(model_desc_, index); + auto buffer_size = aclmdlGetInputSizeByIndex(om_model_info_->model_desc, index); int chw_size = aclDataTypeSize(input_data_type); for (int i = 1; i < acl_dims.dimCount; ++i) { chw_size *= acl_dims.dims[i]; @@ -776,14 +997,14 @@ Status AtlasNetwork::GetInputInfo(size_t index, std::vector &input_dims, ac Status AtlasNetwork::SetRangeDynamicInputDim(std::string input_name, const DimsVector& target_input_shape) { size_t index = 0; - aclError acl_ret = aclmdlGetInputIndexByName(model_desc_, input_name.c_str(), &index); + aclError acl_ret = aclmdlGetInputIndexByName(om_model_info_->model_desc, input_name.c_str(), &index); if (acl_ret != ACL_ERROR_NONE) { LOGE("get dynamic batch input index falied!\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get dynamic batch input index falied"); } // Get & Destroy Old Output TensorDesc - aclTensorDesc* old_input_desc = aclmdlGetDatasetTensorDesc(this->input_, index); + aclTensorDesc* old_input_desc = aclmdlGetDatasetTensorDesc(this->aclmdl_input_dataset_, index); if (old_input_desc == nullptr) { LOGE("failed to get existing TensorDesc for input '%s'.\n", input_name.c_str()); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "failed to get existing TensorDesc for dynamic input."); @@ -796,9 +1017,9 @@ Status AtlasNetwork::SetRangeDynamicInputDim(std::string input_name, const DimsV dim_arr[d] = target_input_shape[d]; } aclTensorDesc *new_input_desc = - aclCreateTensorDesc(aclmdlGetInputDataType(this->model_desc_, index), target_input_shape.size(), dim_arr, - aclmdlGetInputFormat(this->model_desc_, index)); - acl_ret = aclmdlSetDatasetTensorDesc(this->input_, new_input_desc, index); + aclCreateTensorDesc(aclmdlGetInputDataType(om_model_info_->model_desc, index), target_input_shape.size(), dim_arr, + aclmdlGetInputFormat(om_model_info_->model_desc, index)); + acl_ret = aclmdlSetDatasetTensorDesc(this->aclmdl_input_dataset_, new_input_desc, index); if (acl_ret != ACL_ERROR_NONE) { LOGE("API aclmdlSetDatasetTensorDesc failed for input '%s'.\n", input_name.c_str()); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "API aclmdlSetDatasetTensorDesc failed."); @@ -808,10 +1029,10 @@ Status AtlasNetwork::SetRangeDynamicInputDim(std::string input_name, const DimsV } Status AtlasNetwork::UpdateRangeDynamicOutputDims() { - int out_count = aclmdlGetNumOutputs(model_desc_); + int out_count = aclmdlGetNumOutputs(this->om_model_info_->model_desc); for (int i=0; ioutput_, i); - std::string output_name = aclmdlGetOutputNameByIndex(this->model_desc_, i); + aclTensorDesc* desc_i = aclmdlGetDatasetTensorDesc(this->aclmdl_output_dataset_, i); + std::string output_name = aclmdlGetOutputNameByIndex(this->om_model_info_->model_desc, i); if (output_blob_map_.find(output_name) == output_blob_map_.end()) { LOGE("Unable to find output '%s' in output blob map.\n", output_name.c_str()); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Unable to find output in output blob map."); @@ -844,15 +1065,16 @@ Status AtlasNetwork::UpdateRangeDynamicOutputDims() { } Status AtlasNetwork::SetDynamicBatchSize(std::string blob_name, int batch_size) { - if (IsDynamicBatch(model_desc_, blob_name) && atc_mode_dynamic_batch_hw_dim_) { + if (IsDynamicBatch(this->om_model_info_->model_desc, blob_name) && + om_model_info_->dynamic_mode != AtlasOmModelDynamicMode::Static) { // set dynamic batch size_t index = 0; - aclError acl_ret = aclmdlGetInputIndexByName(model_desc_, ACL_DYNAMIC_TENSOR_NAME, &index); + aclError acl_ret = aclmdlGetInputIndexByName(om_model_info_->model_desc, ACL_DYNAMIC_TENSOR_NAME, &index); if (acl_ret != ACL_ERROR_NONE) { LOGE("get dynamic batch input index falied!\n"); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "get dynamic batch input index falied"); } - acl_ret = aclmdlSetDynamicBatchSize(model_id_, input_, index, batch_size); + acl_ret = aclmdlSetDynamicBatchSize(om_model_info_->model_id, aclmdl_input_dataset_, index, batch_size); if (acl_ret != ACL_ERROR_NONE) { LOGE("set batch size (%s) in reshape failed\n", blob_name.c_str()); return Status(TNNERR_ATLAS_RUNTIME_ERROR, "set batch size in reshape failed"); @@ -871,123 +1093,6 @@ Status AtlasNetwork::SetDynamicBatchSize(std::string blob_name, int batch_size) } -Status AtlasNetwork::DeduceDynamicInputType() { - // ATC Converted HUAWEI atlas .om dynamic Models are devided into: - // - // 1. Traditional dynamic models with only 1 dynamic inputs. - // Min/Max value of the dynamic dim has been explicitly defined in ATC Conversion. - // ---- 1.1: - // dynmaic batch - // --input_shape="img:-1,2,224,224;img_info:-1,4" - // --dynamic_batch_size="1,2,4,8" - // ---- 1.2: - // dynamic hw size - // --input_shape="data:8,3,-1,-1;img_info:8,4,-1,-1" - // --dynamic_image_size="416,416;832,832" - // ---- 1.3 - // dynamic dims - // --input_shape="data:-1,1,256,256", --dynamic_dims="1,2" - // - // 2. More flexible dynamic input models. - // Min/Max Value is not explictly defined in ATC Conversion. - // ---- 2.1: - // input_shape_range - // --input_shape_range="input1:[8~20,3,5,-1];input2:[5,3~9,10,-1]" - - // Get Number of Inputs by Calling ACL API - int count = aclmdlGetNumInputs(this->model_desc_); - LOGD("Network have %d inputs.\n", count); - - // Type 1 OM model has an extra input called "ascend_mbatch_shape_data" - // Check if the input exists. - for (int i = 0; i < count; i++) { - std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i); - if (input_name.find(ACL_DYNAMIC_TENSOR_NAME) != std::string::npos) { - LOGD("Network is converted with dynamic batch/hw/dims.\n"); - atc_mode_dynamic_batch_hw_dim_ = true; - } - } - - // Traditional Type 1 Dynamic - if (atc_mode_dynamic_batch_hw_dim_) { - if (count != 2) { - // TODO: SUPPORT Type 1 Model with more than ONE input in the future. - LOGD("Dynamic batch/hw/dims ATLAS with more than ONE input not supported yet.\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, - "Dynamic batch/hw/dims ATLAS with more than ONE input not supported yet."); - } - - // TODO: Update this part for multiple inputs - for (int i = 0; i < count; i++) { - std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i); - if (input_name.find(ACL_DYNAMIC_TENSOR_NAME) == std::string::npos) { - aclmdlIODims acl_dims; - aclError acl_ret = aclmdlGetInputDims(this->model_desc_, i, &acl_dims); - if (ACL_ERROR_NONE != acl_ret) { - LOGE("ACL API Call aclmdlGetInputDims falied!\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "ACL API Call aclmdlGetInputDims falied."); - } - - int minus_one_count = 0; - for (int d = 0; d < acl_dims.dimCount; d++) { - if (acl_dims.dims[d] == -1) { - minus_one_count++; - } - } - if (minus_one_count == 0) { - LOGE("The Only Input %s is not dynamic But Model is dynamic. Not Supported.\n", input_name.c_str()); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, - "The Only Input is not dynamic But Model is dynamic. Not Supported.."); - } - - if (minus_one_count == 1 && acl_dims.dims[0] == -1) { - LOGD("Deduced Dynamic Batch Mode from input: %s.\n", input_name.c_str()); - this->atc_mode_dynamic_batch_ = true; - return TNN_OK; - } - if (minus_one_count == 2 && acl_dims.dimCount == 4 && acl_dims.dims[2] == -1 && - acl_dims.dims[3] == -1) { - LOGD("Deduced Dynamic HW Mode from input: %s.\n", input_name.c_str()); - this->atc_mode_dynamic_hw_ = true; - return TNN_OK; - } - // ELSE - LOGD("Deduced Dynamic Dim Mode from input: %s.\n", input_name.c_str()); - this->atc_mode_dynamic_dim_ = true; - return TNN_OK; - } - } - } - - // No Dynamic Or Type 2 Dynamic Input by --input_shape_range - for (int i = 0; i < count; i++) { - aclmdlIODims acl_dims; - aclError acl_ret = aclmdlGetInputDims(this->model_desc_, i, &acl_dims); - if (ACL_ERROR_NONE != acl_ret) { - LOGE("ACL API Call aclmdlGetInputDims falied!\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "ACL API Call aclmdlGetInputDims falied."); - } - - int minus_one_count = 0; - for (int d = 0; d < acl_dims.dimCount; d++) { - if (acl_dims.dims[d] == -1) { - minus_one_count++; - } - } - - if (minus_one_count > 0) { - std::string input_name = aclmdlGetInputNameByIndex(model_desc_, i); - LOGD("Input: '%s' is dynamic by --input_shape_range.\n", input_name.c_str()); - this->dynamic_input_shape_range_names_.insert(input_name); - } - } - - if (this->dynamic_input_shape_range_names_.empty()) { - LOGD("No Dynamic Input.\n"); - } - return TNN_OK; -} - void AtlasNetwork::DestroyDataset(aclmdlDataset *&data_set) { if (nullptr == data_set) { return; diff --git a/source/tnn/device/atlas/atlas_network.h b/source/tnn/device/atlas/atlas_network.h index e4fb4738e..535edd4bd 100644 --- a/source/tnn/device/atlas/atlas_network.h +++ b/source/tnn/device/atlas/atlas_network.h @@ -1,4 +1,16 @@ -// Copyright 2019 Tencent. All Rights Reserved +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. #ifndef TNN_SOURCE_DEVICE_ATLAS_ATLAS_NETWORK_H_ #define TNN_SOURCE_DEVICE_ATLAS_ATLAS_NETWORK_H_ @@ -9,13 +21,14 @@ #include #include #include "acl/acl.h" -#include "tnn/core/abstract_network.h" +#include "tnn/core/default_network.h" #include "tnn/core/macro.h" #include "tnn/device/atlas/atlas_common_types.h" +#include "tnn/device/atlas/atlas_context.h" namespace TNN_NS { -class AtlasNetwork : public AbstractNetwork { +class AtlasNetwork : public DefaultNetwork { public: // @brief virtual default destructor virtual ~AtlasNetwork(); @@ -27,9 +40,6 @@ class AtlasNetwork : public AbstractNetwork { InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, InputDataTypeMap inputs_data_type, bool enable_const_folder = true); - // @brief deinit release init create resource - virtual Status DeInit(); - // @brief return the amount of memory required for forward // @param memory_size: the memory size used by tnn layers for // forward @@ -48,7 +58,7 @@ class AtlasNetwork : public AbstractNetwork { // virtual Status SetForwardMemory(void *memory); - // @brief network infer + // @brief reshape network virtual Status Reshape(const InputShapesMap &inputs); // @brief get tnn command queue @@ -72,38 +82,35 @@ class AtlasNetwork : public AbstractNetwork { // @param blobs output blobs name map virtual Status GetAllOutputBlobs(BlobMap &blobs); - // @brief get atlas model id of current network - uint32_t GetModelId() const; - - // @brief get atlas model desc of current network - aclmdlDesc* GetModelDesc() const; + // @brief get OM info of ATLAS OM model + std::shared_ptr GetOMModelInfo(); private: + // OM RELATED + // @brief load model from om file - Status LoadModelFromFile(const std::string &om_file); + Status LoadOMModelFromFile(const std::string &om_file); // @brief load model from memory - Status LoadModelFromMemory(const std::string &om_file); + Status LoadOMModelFromMemory(const std::string &om_content); - // @brief unload model - void UnloadModel(); + // @brief deduce model dynamic input mode + Status DeduceOMModelDynamicMode(); + + // @brief deduce model AIPP input format + Status DeduceOMModelAIPPInputFormat(); - // @brief allocate data set and create Blob - Status AllocateDatasetCreateBlob(aclmdlDataset **data_set, const InputShapesMap &max_input_shapes_map, - bool is_input); + // @brief internal init network for OM model + Status InitOMModel(ModelConfig &model_config, AbstractModelInterpreter *interpreter, + InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, + InputDataTypeMap inputs_data_type, bool enable_const_folder); - // @brief add blob into map - Status AddBlobToMap(const InputShapesMap &max_input_shapes_map, size_t index, void *data, bool is_input); + // @brief internal reshape network for OM model + virtual Status ReshapeOMModel(const InputShapesMap &inputs); // @brief get input dims Status GetInputInfo(size_t index, std::vector &input_dims, aclFormat &input_format, aclDataType &input_data_type); - - // @brief deduce dynamic input type - Status DeduceDynamicInputType(); - - // @brief get output shape from max input shape if output shape is missing. - //Status InferOutputShapeIfNecessery(); // @brief set dynamic input dims for OM models converted with --input_shape_range Status SetRangeDynamicInputDim(std::string input_name, const DimsVector& target_input_shape); @@ -114,36 +121,30 @@ class AtlasNetwork : public AbstractNetwork { // @brief set dynmaic batch size Status SetDynamicBatchSize(std::string blob_name, int batch_size); + std::map output_dim0_map_; + void* om_model_memory_ptr_ = nullptr; + void* om_model_weight_ptr_ = nullptr; + std::shared_ptr om_model_info_ = nullptr; + + + + // @brief add blob into map + Status AddBlobToMap(const InputShapesMap &max_input_shapes_map, size_t index, void *data, bool is_input); + + // @brief allocate data set and create Blob + Status AllocateDatasetCreateBlob(aclmdlDataset **data_set, const InputShapesMap &max_input_shapes_map, + bool is_input); + // @brief destory dataset void DestroyDataset(aclmdlDataset *&data_set); + ModelType model_type_; BlobMap input_blob_map_; BlobMap output_blob_map_; - bool need_to_deinit = false; - std::shared_ptr command_queue_ = nullptr; - aclrtContext context_ = nullptr; - aclrtStream stream_ = nullptr; - size_t model_mem_size_ = 0; - size_t model_weight_size_ = 0; - void *model_mem_ptr_ = nullptr; - void *model_weight_ptr_ = nullptr; - uint32_t model_id_ = 0; - aclmdlDesc *model_desc_ = nullptr; - aclmdlDataset *input_ = nullptr; - aclmdlDataset *output_ = nullptr; - - // Traditional Type 1 Dynamic Input - bool atc_mode_dynamic_batch_hw_dim_ = false; // Be one of dynamic batch, hw, or dim - bool atc_mode_dynamic_batch_ = false; - bool atc_mode_dynamic_hw_ = false; - bool atc_mode_dynamic_dim_ = false; - // More Flexible Type 2 Dynamic Input (--input_shape_range) - std::unordered_set dynamic_input_shape_range_names_; - - bool has_aipp_ = false; - std::map aipp_input_format_map_; - std::map output_dim0_map_; + bool network_init_called_ = false; + aclmdlDataset* aclmdl_input_dataset_ = nullptr; + aclmdlDataset* aclmdl_output_dataset_ = nullptr; }; } // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_om_model_interpreter.cc b/source/tnn/device/atlas/atlas_om_model_interpreter.cc new file mode 100644 index 000000000..3fe1bf669 --- /dev/null +++ b/source/tnn/device/atlas/atlas_om_model_interpreter.cc @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include +#include "tnn/device/atlas/atlas_om_model_interpreter.h" +#include "tnn/device/atlas/atlas_utils.h" +#include "tnn/utils/split_utils.h" + +namespace TNN_NS { + +AtlasOMModelInterpreter::AtlasOMModelInterpreter() {} + +AtlasOMModelInterpreter::~AtlasOMModelInterpreter() {} + +Status AtlasOMModelInterpreter::Interpret(std::vector ¶ms) { + // OM Model Load API only support LOAD model directly ONTO device (Card) + // So the real model interpret path is in AtlasNetwork instead. + + // The only thing we need to do here is to store om_string locally, + // we USE MOVE here to save memory for large OM model. + this->om_str_ = std::move(params[0]); + //this->om_str_ = params[0]; + + return TNN_OK; +} + +std::string& AtlasOMModelInterpreter::GetOmString() { + return this->om_str_; +} + +TypeModelInterpreterRegister> g_atlas_model_interpreter_register( + MODEL_TYPE_ATLAS); + +} // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_om_model_interpreter.h b/source/tnn/device/atlas/atlas_om_model_interpreter.h new file mode 100644 index 000000000..773900681 --- /dev/null +++ b/source/tnn/device/atlas/atlas_om_model_interpreter.h @@ -0,0 +1,50 @@ +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef TNN_SOURCE_DEVICE_ATLAS_ATLAS_OM_MODEL_INTERPRETER_H_ +#define TNN_SOURCE_DEVICE_ATLAS_ATLAS_OM_MODEL_INTERPRETER_H_ + +#include +#include +#include +#include +#include +#include "tnn/core/macro.h" +#include "tnn/core/status.h" +#include "tnn/device/atlas/atlas_common_types.h" +#include "tnn/interpreter/abstract_model_interpreter.h" + +namespace TNN_NS { + +// @brief Atlas OM model interpreter that interprets Atlas OM Model +class AtlasOMModelInterpreter : public AbstractModelInterpreter { +public: + AtlasOMModelInterpreter(); + + // @brief virtual destructor + virtual ~AtlasOMModelInterpreter(); + + // @brief different interpreter has different order param + virtual Status Interpret(std::vector ¶ms); + + // @brief get model om string + std::string& GetOmString(); + +private: + std::string om_str_; +}; + +} // namespace TNN_NS + +#endif // TNN_SOURCE_DEVICE_ATLAS_ATLAS_OM_MODEL_INTERPRETER_H_ diff --git a/source/tnn/device/atlas/atlas_runtime.cc b/source/tnn/device/atlas/atlas_runtime.cc deleted file mode 100644 index 490897839..000000000 --- a/source/tnn/device/atlas/atlas_runtime.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Tencent is pleased to support the open source community by making TNN available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#include "tnn/device/atlas/atlas_runtime.h" -#include "acl/acl.h" -#include "tnn/core/macro.h" - -namespace TNN_NS { - -static std::mutex g_mtx; - -std::shared_ptr AtlasRuntime::atlas_runtime_singleton_ = nullptr; -int AtlasRuntime::ref_count_ = 0; -bool AtlasRuntime::init_done_ = false; - -AtlasRuntime* AtlasRuntime::GetInstance() { - if (nullptr == atlas_runtime_singleton_.get()) { - std::unique_lock lck(g_mtx); - if (nullptr == atlas_runtime_singleton_.get()) { - atlas_runtime_singleton_.reset(new AtlasRuntime()); - } - } - - return atlas_runtime_singleton_.get(); -} - -void AtlasRuntime::DecreaseRef() { - std::unique_lock lck(g_mtx); - ref_count_--; - LOGD("AtlasRuntime::DecreaseRef() count=%d\n", ref_count_); - if (ref_count_ <= 0) { - atlas_runtime_singleton_.reset(); - ref_count_ = 0; - } -} - -AtlasRuntime::AtlasRuntime() { - device_list_.clear(); -} - -// Init Atlas Runtime and increase reference count -Status AtlasRuntime::Init() { - std::unique_lock lck(g_mtx); - - ref_count_++; - LOGD("AtlasRuntime::Init() reference count=%d\n", ref_count_); - - // only init once. - if (!init_done_) { - LOGD("Init Atlas Acl\n"); - - LOGD("acl begin init...\n"); - aclError ret = aclInit(nullptr); - if (ret != ACL_ERROR_NONE && ret != ACL_ERROR_REPEAT_INITIALIZE) { - LOGE("acl init failed!\n"); - return TNNERR_ATLAS_RUNTIME_ERROR; - } - LOGD("acl init done!\n"); - - init_done_ = true; - } - - return TNN_OK; -} - -Status AtlasRuntime::SetDevice(int device_id) { - std::unique_lock lck(g_mtx); - if (device_list_.find(device_id) == device_list_.end()) { - LOGD("set device: %d\n", device_id); - aclError acl_ret = aclrtSetDevice(device_id); - if (acl_ret != ACL_ERROR_NONE) { - LOGE("acl open device %d failed (acl error code: %d)\n", device_id, acl_ret); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl open device falied"); - } - LOGD("set device done!\n"); - device_list_.emplace(device_id); - } - - return TNN_OK; -} - -Status AtlasRuntime::AddModelInfo(Blob* blob, AtlasModelInfo model_info) { - std::unique_lock lck(g_mtx); - model_info_map_[blob] = model_info; - return TNN_OK; -} - -Status AtlasRuntime::DelModelInfo(Blob* blob) { - std::unique_lock lck(g_mtx); - auto blob_it = model_info_map_.find(blob); - if (blob_it != model_info_map_.end()) { - model_info_map_.erase(blob_it); - } - return TNN_OK; -} - -std::map& AtlasRuntime::GetModleInfoMap() { - return model_info_map_; -} - -AtlasRuntime::~AtlasRuntime() { - LOGD("~AtlasRuntime() begin \n"); - - aclError ret; - for (auto id : device_list_) { - LOGD("reset device: %d\n", id); - ret = aclrtResetDevice(id); - if (ret != ACL_ERROR_NONE) { - LOGE("acl reset device failed!\n"); - } - } - device_list_.clear(); - - LOGD("aclFinalize()\n"); - ret = aclFinalize(); - if (ret != ACL_ERROR_NONE) { - LOGD("acl finalize failed!\n"); - } - - LOGD("~AtlasRuntime() end \n"); -} - -} // namespace TNN_NS diff --git a/source/tnn/device/atlas/atlas_runtime.h b/source/tnn/device/atlas/atlas_runtime.h deleted file mode 100644 index c2981a203..000000000 --- a/source/tnn/device/atlas/atlas_runtime.h +++ /dev/null @@ -1,57 +0,0 @@ -// Tencent is pleased to support the open source community by making TNN available. -// -// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software distributed -// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, either express or implied. See the License for the -// specific language governing permissions and limitations under the License. - -#ifndef TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_RUNTIME_H_ -#define TNN_SOURCE_TNN_DEVICE_ATLAS_ATLAS_RUNTIME_H_ - -#include -#include -#include -#include -#include -#include "tnn/core/status.h" -#include "tnn/device/atlas/atlas_common_types.h" - -namespace TNN_NS { - -class AtlasRuntime { -public: - static AtlasRuntime *GetInstance(); - static Status Init(); - static void DecreaseRef(); - - ~AtlasRuntime(); - AtlasRuntime(const AtlasRuntime &) = delete; - AtlasRuntime &operator=(const AtlasRuntime &) = delete; - - Status SetDevice(int device_id); - Status AddModelInfo(Blob *blob, AtlasModelInfo model_info); - Status DelModelInfo(Blob *blob); - std::map &GetModleInfoMap(); - -private: - AtlasRuntime(); - -private: - std::set device_list_; - std::map model_info_map_; - - static std::shared_ptr atlas_runtime_singleton_; - static int ref_count_; - static bool init_done_; -}; - -} // namespace TNN_NS - -#endif // TNN_SOURCE_TNN_DEVICE_OPENCL_OPENCL_RUNTIME_H_ diff --git a/source/tnn/device/atlas/atlas_utils.cc b/source/tnn/device/atlas/atlas_utils.cc index d79655459..12c9da093 100644 --- a/source/tnn/device/atlas/atlas_utils.cc +++ b/source/tnn/device/atlas/atlas_utils.cc @@ -7,47 +7,6 @@ namespace TNN_NS { -std::vector SplitPath(const std::string& str, const std::set delimiters) { - std::vector result; - char const* pch = str.c_str(); - char const* start = pch; - for (; *pch; ++pch) { - if (delimiters.find(*pch) != delimiters.end()) { - if (start != pch) { - std::string str(start, pch); - result.push_back(str); - } else { - result.push_back(""); - } - start = pch + 1; - } - } - result.push_back(start); - return result; -} - -long GetCurentTime() { - struct timeval tv; - gettimeofday(&tv, NULL); - return tv.tv_sec * 1000 + tv.tv_usec / 1000; -} - -int SaveMemToFile(std::string file_name, void* data, int size) { - FILE* fd = fopen(file_name.c_str(), "wb"); - if (fd == nullptr) { - return -1; - } - - int ret = fwrite(data, 1, size, fd); - if (ret != size) { - fclose(fd); - return -1; - } - - fclose(fd); - return 0; -} - Status ConvertFromAclDataTypeToTnnDataType(aclDataType acl_datatype, DataType& tnn_datatype) { if (ACL_FLOAT == acl_datatype) { tnn_datatype = DATA_TYPE_FLOAT; diff --git a/source/tnn/device/atlas/atlas_utils.h b/source/tnn/device/atlas/atlas_utils.h index 5571039a1..add2203c8 100644 --- a/source/tnn/device/atlas/atlas_utils.h +++ b/source/tnn/device/atlas/atlas_utils.h @@ -17,12 +17,6 @@ namespace TNN_NS { -std::vector SplitPath(const std::string& str, const std::set delimiters); - -long GetCurentTime(); - -int SaveMemToFile(std::string file_name, void* data, int size); - Status ConvertFromAclDataTypeToTnnDataType(aclDataType acl_datatype, DataType& tnn_datatype); Status ConvertAclDataFormatToTnnDataFormat(aclFormat acl_format, DataFormat& tnn_dataformat); diff --git a/source/tnn/device/atlas/tnn_impl_atlas.cc b/source/tnn/device/atlas/tnn_impl_atlas.cc index d2f5e5421..2dc8b6adc 100644 --- a/source/tnn/device/atlas/tnn_impl_atlas.cc +++ b/source/tnn/device/atlas/tnn_impl_atlas.cc @@ -1,10 +1,22 @@ -// Copyright 2019 Tencent. All Rights Reserved +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. -#include "tnn_impl_atlas.h" -#include "atlas_network.h" -#include "atlas_utils.h" #include #include "tnn/core/instance.h" +#include "tnn/device/atlas/atlas_network.h" +#include "tnn/device/atlas/atlas_utils.h" +#include "tnn/device/atlas/tnn_impl_atlas.h" #include "tnn/interpreter/abstract_model_interpreter.h" namespace TNN_NS { @@ -17,6 +29,22 @@ TNNImplAtlas::~TNNImplAtlas() {} Status TNNImplAtlas::Init(ModelConfig& config) { TNNImpl::Init(config); + + this->model_type_ = config.model_type; + + if (config.model_type == TNN_NS::MODEL_TYPE_TNN || + config.model_type == TNN_NS::MODEL_TYPE_RAPIDNET || + config.model_type == TNN_NS::MODEL_TYPE_ATLAS) { + LOGD("Model Type is TNN or ATLAS OM, ACL API Required. Call aclInit() ...\n"); + aclError acl_ret = aclInit(nullptr); + if (acl_ret != ACL_ERROR_NONE && acl_ret != ACL_ERROR_REPEAT_INITIALIZE) { + LOGE("Atlas API: aclInit failed!\n"); + return TNNERR_ATLAS_RUNTIME_ERROR; + } + LOGD("Model Type is TNN or ATLAS OM, ACL API Required. Call aclInit() ... done.\n"); + this->acl_init_called_ = true; + } + auto interpreter = CreateModelInterpreter(config.model_type); if (!interpreter) { return Status(TNNERR_NET_ERR, "interpreter is nil"); @@ -26,6 +54,14 @@ Status TNNImplAtlas::Init(ModelConfig& config) { } Status TNNImplAtlas::DeInit() { + if (this->acl_init_called_) { + LOGD("TNNImplAtlas DeInit: to call aclFinalize().\n"); + aclError ret = aclFinalize(); + if (ret != ACL_ERROR_NONE) { + LOGD("TNNImplAtlas DeInit: ATLAS API: aclFinalize failed!\n"); + } + } + return TNN_OK; } @@ -35,95 +71,109 @@ Status TNNImplAtlas::AddOutput(const std::string& layer_name, int output_index) } Status TNNImplAtlas::GetModelInputNames(std::vector& input_names) { - if (this->model_desc_of_the_first_instance_ == nullptr) { - LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing."); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing."); - } - - size_t num_inputs = aclmdlGetNumInputs(this->model_desc_of_the_first_instance_); - std::vector in_names_vec; - for (size_t i=0; imodel_desc_of_the_first_instance_, i)); - in_names_vec.emplace_back(input_name); + if (model_type_ == MODEL_TYPE_ATLAS) { + if (this->om_model_desc_of_the_first_instance_ == nullptr) { + LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing."); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing."); + } + + size_t num_inputs = aclmdlGetNumInputs(this->om_model_desc_of_the_first_instance_); + std::vector in_names_vec; + for (size_t i=0; iom_model_desc_of_the_first_instance_, i)); + in_names_vec.emplace_back(input_name); + } + input_names = in_names_vec; + } else { + LOGE("API not supported for current MODEL TYPE.\n"); } - input_names = in_names_vec; return TNN_OK; } Status TNNImplAtlas::GetModelOutputNames(std::vector& output_names) { - if (this->model_desc_of_the_first_instance_ == nullptr) { - LOGE("Fail to Get TNN Atlas ModelOutputNames, model desc missing.\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelOutputNames, model desc missing."); - } - - size_t num_outputs = aclmdlGetNumOutputs(this->model_desc_of_the_first_instance_); - std::vector out_names_vec; - for (size_t i=0; imodel_desc_of_the_first_instance_, i)); - out_names_vec.emplace_back(output_name); + if (model_type_ == MODEL_TYPE_ATLAS) { + if (this->om_model_desc_of_the_first_instance_ == nullptr) { + LOGE("Fail to Get TNN Atlas ModelOutputNames, model desc missing.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelOutputNames, model desc missing."); + } + + size_t num_outputs = aclmdlGetNumOutputs(this->om_model_desc_of_the_first_instance_); + std::vector out_names_vec; + for (size_t i=0; iom_model_desc_of_the_first_instance_, i)); + out_names_vec.emplace_back(output_name); + } + output_names = out_names_vec; + } else { + LOGE("API not supported for current MODEL TYPE.\n"); } - output_names = out_names_vec; return TNN_OK; } Status TNNImplAtlas::GetModelInputShapesMap(InputShapesMap& shapes_map) { - if (this->model_desc_of_the_first_instance_ == nullptr) { - LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing."); - } - - size_t num_inputs = aclmdlGetNumInputs(this->model_desc_of_the_first_instance_); - InputShapesMap in_shapes_map; - for (size_t i=0; imodel_desc_of_the_first_instance_, i, &acl_dims); - if (acl_ret != ACL_ERROR_NONE) { - LOGE("acl get input dim failed (acl error code: %d)\n", acl_ret); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl get input dim falied"); + if (model_type_ == MODEL_TYPE_ATLAS) { + if (this->om_model_desc_of_the_first_instance_ == nullptr) { + LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing."); } - std::string input_name; - input_name.assign(aclmdlGetInputNameByIndex(this->model_desc_of_the_first_instance_, i)); - std::vector in_dims; - for (int d=0; dom_model_desc_of_the_first_instance_); + InputShapesMap in_shapes_map; + for (size_t i=0; iom_model_desc_of_the_first_instance_, i, &acl_dims); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("acl get input dim failed (acl error code: %d)\n", acl_ret); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl get input dim falied"); } + std::string input_name; + input_name.assign(aclmdlGetInputNameByIndex(this->om_model_desc_of_the_first_instance_, i)); + std::vector in_dims; + for (int d=0; dmodel_desc_of_the_first_instance_ == nullptr) { - LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.\n"); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing."); - } - - size_t num_inputs = aclmdlGetNumInputs(this->model_desc_of_the_first_instance_); - InputDataTypeMap in_dtype_map; - for (size_t i=0; imodel_desc_of_the_first_instance_, i)); - aclDataType acl_dtype = aclmdlGetInputDataType(this->model_desc_of_the_first_instance_, i); - DataType tnn_dtype; - aclError acl_ret = ConvertFromAclDataTypeToTnnDataType(acl_dtype, tnn_dtype); - if (acl_ret != ACL_ERROR_NONE) { - LOGE("acl get input data type failed, maybe unsupported data type (acl error code: %d)\n", acl_ret); - return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl get input data type failed"); + if (model_type_ == MODEL_TYPE_ATLAS) { + if (this->om_model_desc_of_the_first_instance_ == nullptr) { + LOGE("Fail to Get TNN Atlas ModelInputNames, model desc missing.\n"); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "Fail to Get TNN Atlas ModelInputNames, model desc missing."); } - in_dtype_map[input_name] = tnn_dtype; + + size_t num_inputs = aclmdlGetNumInputs(this->om_model_desc_of_the_first_instance_); + InputDataTypeMap in_dtype_map; + for (size_t i=0; iom_model_desc_of_the_first_instance_, i)); + aclDataType acl_dtype = aclmdlGetInputDataType(this->om_model_desc_of_the_first_instance_, i); + DataType tnn_dtype; + aclError acl_ret = ConvertFromAclDataTypeToTnnDataType(acl_dtype, tnn_dtype); + if (acl_ret != ACL_ERROR_NONE) { + LOGE("acl get input data type failed, maybe unsupported data type (acl error code: %d)\n", acl_ret); + return Status(TNNERR_ATLAS_RUNTIME_ERROR, "acl get input data type failed"); + } + in_dtype_map[input_name] = tnn_dtype; + } + data_type_map = in_dtype_map; + } else { + LOGE("API not supported for current MODEL TYPE.\n"); } - data_type_map = in_dtype_map; - return TNN_OK; } @@ -135,19 +185,21 @@ std::shared_ptr TNNImplAtlas::CreateInst(NetworkConfig& net_config, St } std::shared_ptr TNNImplAtlas::CreateInst(NetworkConfig& net_config, Status& status, - InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, InputDataTypeMap inputs_data_type) { + InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, + InputDataTypeMap inputs_data_type) { auto instance = std::make_shared(net_config, model_config_); status = instance->Init(interpreter_, min_inputs_shape, max_inputs_shape, inputs_data_type); - AtlasNetwork* atlas_net = reinterpret_cast(instance->GetNetwork()); - if (this->model_id_of_the_first_instance_ == 0) { - this->model_id_of_the_first_instance_ = atlas_net->GetModelId(); - LOGD("TNNImplAtlas init the first Instance, get model id.\n"); - } - - if (this->model_desc_of_the_first_instance_ == nullptr) { - this->model_desc_of_the_first_instance_ = atlas_net->GetModelDesc(); - LOGD("TNNImplAtlas init the first Instance, get model desc.\n"); + if (model_type_ == MODEL_TYPE_ATLAS) { + AtlasNetwork* atlas_net = reinterpret_cast(instance->GetNetwork()); + if (this->om_model_id_of_the_first_instance_ == 0) { + this->om_model_id_of_the_first_instance_ = atlas_net->GetOMModelInfo()->model_id; + LOGD("TNNImplAtlas init the first Instance, get model id.\n"); + } + if (this->om_model_desc_of_the_first_instance_ == nullptr) { + this->om_model_desc_of_the_first_instance_ = atlas_net->GetOMModelInfo()->model_desc; + LOGD("TNNImplAtlas init the first Instance, get model desc.\n"); + } } return instance; diff --git a/source/tnn/device/atlas/tnn_impl_atlas.h b/source/tnn/device/atlas/tnn_impl_atlas.h index 38bbdf12d..ff4023cb6 100644 --- a/source/tnn/device/atlas/tnn_impl_atlas.h +++ b/source/tnn/device/atlas/tnn_impl_atlas.h @@ -1,4 +1,16 @@ -// Copyright 2019 Tencent. All Rights Reserved +// Tencent is pleased to support the open source community by making TNN available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. #ifndef TNN_SOURCE_DEVICE_ATLAS_TNN_IMPL_ATLAS_H_ #define TNN_SOURCE_DEVICE_ATLAS_TNN_IMPL_ATLAS_H_ @@ -20,8 +32,7 @@ class TNNImplAtlas : public TNNImpl { // @brief init the tnn, contruct model interpreter // @param config config model type and params - // @return status code: Successful, returns zero. Otherwise, returns - // error code. + // @return status code: 0 if succeed elsewise error codes virtual Status Init(ModelConfig& config); // @brief release model interpreter @@ -31,9 +42,7 @@ class TNNImplAtlas : public TNNImpl { // outputIndex. //@param output_name Name of the output blob //@param output_index Index of the output layer - //@return status code: If successful, returns zero. Otherwise, returns - // error - // code. + //@return status code: 0 if succeed elsewise error codes virtual Status AddOutput(const std::string& output_name, int output_index = 0); //@brief get input shapes map from model @@ -50,10 +59,8 @@ class TNNImplAtlas : public TNNImpl { // @brief create an instance // @param instance: The instance to be created. - // @param inputs_shape: modify input shape, or it will use the shape in the - // proto - // @param status code: If successful, returns zero. Otherwise, returns - // error code. + // @param inputs_shape: modify input shape, or it will use shape in the proto + // @param status code: 0 if succeed elsewise error codes virtual std::shared_ptr CreateInst(NetworkConfig& config, Status& status, InputShapesMap inputs_shape = InputShapesMap(), InputDataTypeMap inputs_data_type = InputDataTypeMap()); @@ -62,19 +69,20 @@ class TNNImplAtlas : public TNNImpl { // @param instance: The instance to be created. // @param min_inputs_shape: support min shape // @param max_inputs_shape: support max shape - // @param status code: If successful, returns zero. Otherwise, returns - // error code. + // @param status code: 0 if succeed elsewise error codes virtual std::shared_ptr CreateInst(NetworkConfig& config, Status& status, InputShapesMap min_inputs_shape, InputShapesMap max_inputs_shape, InputDataTypeMap inputs_data_type = InputDataTypeMap()); private: std::shared_ptr interpreter_; + ModelType model_type_; + bool acl_init_called_ = false; - // Model Desc and Model id for the first instance. + // OM Model Desc and OM Model id for the first instance. // Set when the first Effective CreateInst is called. // Usage: Get input/output names, shapes, datatypes ... etc. - uint32_t model_id_of_the_first_instance_ = 0; - aclmdlDesc* model_desc_of_the_first_instance_ = nullptr; + uint32_t om_model_id_of_the_first_instance_ = 0; + aclmdlDesc* om_model_desc_of_the_first_instance_ = nullptr; }; } // namespace TNN_NS