diff --git a/.github/workflows/build-macos-10.15-cpu.yml b/.github/workflows/build-macos-10.15-cpu.yml new file mode 100644 index 000000000..824915b81 --- /dev/null +++ b/.github/workflows/build-macos-10.15-cpu.yml @@ -0,0 +1,51 @@ +name: macos-10.5-cpu + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: macos-10.15 + + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: recursive + + - name: Install dependencies + run: brew install openblas protobuf + + # Openblas location is exported explicitly because openblas is keg-only, + # which means it was not symlinked into /usr/local/. + # CMake cannot find BLAS on GitHub runners if Marian is being compiled + # statically, hence USE_STATIC_LIBS=off + - name: Configure CMake + run: | + export LDFLAGS="-L/usr/local/opt/openblas/lib" + export CPPFLAGS="-I/usr/local/opt/openblas/include" + mkdir -p build + cd build + cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \ + -DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on -DUSE_STATIC_LIBS=off + + - name: Compile + working-directory: build + run: make -j2 + + - name: Run unit tests + working-directory: build + run: make test + + - name: Print versions + working-directory: build + run: | + ./marian --version + ./marian-decoder --version + ./marian-scorer --version + ./spm_encode --version + diff --git a/.github/workflows/build-ubuntu-18.04-cpu.yml b/.github/workflows/build-ubuntu-18.04-cpu.yml new file mode 100644 index 000000000..0b37e4f4b --- /dev/null +++ b/.github/workflows/build-ubuntu-18.04-cpu.yml @@ -0,0 +1,64 @@ +name: ubuntu-18.04-cpu + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-18.04 + + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: recursive + + # The following packages are already installed on GitHub-hosted runners: build-essential openssl libssl-dev + - name: Install dependencies + run: sudo apt-get install --no-install-recommends libgoogle-perftools-dev libprotobuf10 libprotobuf-dev protobuf-compiler + + # https://software.intel.com/content/www/us/en/develop/articles/installing-intel-free-libs-and-python-apt-repo.html + - name: Install MKL + run: | + wget -qO- "https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2019.PUB" | sudo apt-key add - + sudo sh -c "echo deb https://apt.repos.intel.com/mkl all main > /etc/apt/sources.list.d/intel-mkl.list" + sudo apt-get update -o Dir::Etc::sourcelist="/etc/apt/sources.list.d/intel-mkl.list" + sudo apt-get install --no-install-recommends intel-mkl-64bit-2020.0-088 + + - name: Print Boost paths + run: | + ls $BOOST_ROOT_1_69_0 + ls $BOOST_ROOT_1_69_0/include + ls $BOOST_ROOT_1_69_0/lib + + # Boost is already installed on GitHub-hosted runners in a non-standard location + # https://github.com/actions/virtual-environments/issues/687#issuecomment-610471671 + - name: Configure CMake + run: | + mkdir -p build + cd build + cmake .. -DCOMPILE_CPU=on -DCOMPILE_CUDA=off -DCOMPILE_EXAMPLES=on -DCOMPILE_SERVER=on -DCOMPILE_TESTS=on \ + -DUSE_FBGEMM=on -DUSE_SENTENCEPIECE=on \ + -DBOOST_ROOT=$BOOST_ROOT_1_69_0 -DBOOST_INCLUDEDIR=$BOOST_ROOT_1_69_0/include -DBOOST_LIBRARYDIR=$BOOST_ROOT_1_69_0/lib \ + -DBoost_ARCHITECTURE=-x64 + + - name: Compile + working-directory: build + run: make -j2 + + - name: Run unit tests + working-directory: build + run: make test + + - name: Print versions + working-directory: build + run: | + ./marian --version + ./marian-decoder --version + ./marian-scorer --version + ./spm_encode --version + diff --git a/.github/workflows/build-windows-2019-cpu.yml b/.github/workflows/build-windows-2019-cpu.yml new file mode 100644 index 000000000..13ef66141 --- /dev/null +++ b/.github/workflows/build-windows-2019-cpu.yml @@ -0,0 +1,49 @@ +name: windows-2019-cpu + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: windows-2019 + + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + submodules: recursive + + - name: Prepare vcpkg + uses: lukka/run-vcpkg@v3 + with: + vcpkgArguments: protobuf + vcpkgGitCommitId: 6185aa76504a5025f36754324abf307cc776f3da + vcpkgDirectory: ${{ github.workspace }}/vcpkg/ + vcpkgTriplet: x64-windows-static + + # Note that we build with a simplified CMake settings JSON file + - name: Run CMake + uses: lukka/run-cmake@v2 + with: + buildDirectory: ${{ github.workspace }}/build/ + cmakeAppendedArgs: -G Ninja + cmakeListsOrSettingsJson: CMakeSettingsJson + cmakeSettingsJsonPath: ${{ github.workspace }}/CMakeSettingsCI.json + useVcpkgToolchainFile: true + + - name: Run unit tests + working-directory: build/Debug/ + run: ctest + + - name: Print versions + working-directory: build/Debug/ + run: | + .\marian.exe --version + .\marian-decoder.exe --version + .\marian-scorer.exe --version + .\spm_encode.exe --version + diff --git a/CHANGELOG.md b/CHANGELOG.md index 7926c17ee..6433dfec9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,18 +9,35 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Optimize LayerNormalization on CPU by 6x through vectorization (ffast-math) and fixing performance regression introduced with strides in 77a420 +- Decoding multi-source models in marian-server with --tsv +- GitHub workflows on Ubuntu, Windows, and MacOS +- LSH indexing to replace short list +- ONNX support for transformer models +- Add topk operator like PyTorch's topk - Use *cblas_sgemm_batch* instead of a for loop of *cblas_sgemm* on CPU as the batched_gemm implementation - Supporting relative paths in shortlist and sqlite options - Training and scoring from STDIN - Support for reading from TSV files from STDIN and other sources during training and translation with options --tsv and --tsv-fields n. +- Internal optional parameter in n-best list generation that skips empty hypotheses. ### Fixed +- Fix compilation without BLAS installed +- Providing a single value to vector-like options using the equals sign, e.g. --models=model.npz +- Fix quiet-translation in marian-server +- CMake-based compilation on Windows +- Fix minor issues with compilation on MacOS +- Fix warnings in Windows MSVC builds using CMake - Fix building server with Boost 1.72 - Make mini-batch scaling depend on mini-batch-words and not on mini-batch-words-ref - In concatenation make sure that we do not multiply 0 with nan (which results in nan) - Change Approx.epsilon(0.01) to Approx.margin(0.001) in unit tests. Tolerance is now absolute and not relative. We assumed incorrectly that epsilon is absolute tolerance. +- Fixed bug in finding .git/logs/HEAD when Marian is a submodule in another project. +- Properly record cmake variables in the cmake build directory instead of the source tree. +- Added default "none" for option shuffle in BatchGenerator, so that it works in executables where shuffle is not an option. +- Added a few missing header files in shortlist.h and beam_search.h. - Improved handling for graceful shutdown upon receiving SIGTERM. SIGTERM now also interrupts batch prefetching, which runs in a separate thread. diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cf8bf922..6c85259ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}") execute_process(COMMAND git submodule update --init --recursive --no-fetch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +# Note that with CMake MSVC build, the option CMAKE_BUILD_TYPE is automatically derived from the key +# 'configurationType' in CMakeSettings.json configurations if(NOT CMAKE_BUILD_TYPE) message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release") set(CMAKE_BUILD_TYPE "Release") @@ -62,15 +64,17 @@ if(MSVC) # These are used in src/CMakeLists.txt on a per-target basis list(APPEND ALL_WARNINGS /WX; /W4;) - # Disabled bogus warnings for CPU intrincics: + # Disabled bogus warnings for CPU intrinsics: # C4310: cast truncates constant value # C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier - set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"") + # C4702: unreachable code; note it is also disabled globally in the VS project file + set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\" /wd\"4702\"") - set(INTRINSICS "/arch:AVX") + # set(INTRINSICS "/arch:AVX") + add_definitions(-DUSE_SSE2=1) # Or maybe use these? - # set(INTRINSICS "/arch:AVX2") + set(INTRINSICS "/arch:AVX2") # set(INTRINSICS "/arch:AVX512") set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS ${DISABLE_GLOBALLY}") @@ -78,7 +82,9 @@ if(MSVC) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG") # ignores warning LNK4049: locally defined symbol free imported - this comes from zlib - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /ignore:4049") + set(CMAKE_EXE_LINKER_FLAGS_RELEASE "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRT") + set(CMAKE_EXE_LINKER_FLAGS_DEBUG "${CMAKE_EXE_LINKER_FLAGS} /NODEFAULTLIB:MSVCRTD") set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental") find_library(SHLWAPI Shlwapi.lib) @@ -195,6 +201,12 @@ if(USE_SENTENCEPIECE) set(EXT_LIBS ${EXT_LIBS} sentencepiece sentencepiece_train) endif() +if(USE_ONNX) + message(STATUS "Enabling experimental ONNX support") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_ONNX") + set(EXT_LIBS ${EXT_LIBS} protobuf) +endif() + # Find packages set(EXT_LIBS ${EXT_LIBS} ${CMAKE_DL_LIBS}) @@ -203,7 +215,9 @@ if(COMPILE_CUDA) if(USE_STATIC_LIBS) # link statically to stdlib libraries - set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++") + if(NOT MSVC) + set(CMAKE_EXE_LINKER_FLAGS "-static-libgcc -static-libstdc++") + endif() # look for libraries that have .a suffix set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES}) @@ -235,12 +249,22 @@ if(CUDA_FOUND) endif(COMPILE_CUDA_SM70) if(USE_STATIC_LIBS) - find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64) - set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) - set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) + set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) + set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}) + find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64) + # The cuLIBOS library does not seem to exist in Windows CUDA toolkit installs + if(CUDA_culibos_LIBRARY) + set(EXT_LIBS ${EXT_LIBS} ${CUDA_culibos_LIBRARY}) + set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_culibos_LIBRARY}) + elseif(NOT WIN32) + message(FATAL_ERROR "cuLIBOS library not found") + endif() # CUDA 10.1 introduces cublasLt library that is required on static build if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1")) - find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64) + find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64) + if(NOT CUDA_cublasLt_LIBRARY) + message(FATAL_ERROR "cuBLASLt library not found") + endif() set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY}) set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY}) endif() @@ -304,7 +328,7 @@ if(NOT MSVC) list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;) list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC}) else() - list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; ) + list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; -Xcompiler\ /MT$<$:d>; ) endif() list(REMOVE_DUPLICATES CUDA_NVCC_FLAGS) @@ -351,6 +375,7 @@ if(COMPILE_CPU) if(MKL_FOUND) include_directories(${MKL_INCLUDE_DIR}) set(EXT_LIBS ${EXT_LIBS} ${MKL_LIBRARIES}) + set(BLAS_FOUND TRUE) add_definitions(-DBLAS_FOUND=1 -DMKL_FOUND=1) else(MKL_FOUND) set(BLAS_VENDOR "OpenBLAS") @@ -420,8 +445,11 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h.in # Generate build_info.cpp with CMake cache variables include(GetCacheVariables) +# make sure src/common/build_info.cpp has been removed +execute_process(COMMAND rm ${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp + OUTPUT_QUIET ERROR_QUIET) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in - ${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp @ONLY) + ${CMAKE_CURRENT_BINARY_DIR}/src/common/build_info.cpp @ONLY) # Compile source files include_directories(${marian_SOURCE_DIR}/src) diff --git a/CMakeSettingsCI.json b/CMakeSettingsCI.json new file mode 100644 index 000000000..00cdb9183 --- /dev/null +++ b/CMakeSettingsCI.json @@ -0,0 +1,52 @@ +{ + "configurations": [ + { + "name": "Release", + "generator": "Ninja", + "configurationType": "Release", + "inheritEnvironments": [ "msvc_x64" ], + "cmakeCommandArgs": "", + "buildCommandArgs": "-v", + "ctestCommandArgs": "", + "variables": [ + { "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" }, + { "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" }, + + { "name": "COMPILE_CUDA:BOOL", "value": "FALSE" }, + { "name": "COMPILE_CPU:BOOL", "value": "TRUE" }, + { "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" }, + { "name": "COMPILE_SERVER:BOOL", "value": "FALSE" }, + { "name": "COMPILE_TESTS:BOOL", "value": "TRUE" }, + + { "name": "USE_FBGEMM:BOOL", "value": "TRUE" }, + { "name": "USE_MPI:BOOL", "value": "FALSE" }, + { "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" }, + { "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" } + ] + }, + { + "name": "Debug", + "generator": "Ninja", + "configurationType": "Debug", + "inheritEnvironments": [ "msvc_x64" ], + "cmakeCommandArgs": "", + "buildCommandArgs": "-v", + "ctestCommandArgs": "", + "variables": [ + { "name": "OPENSSL_MSVC_STATIC_RT:BOOL", "value": "TRUE" }, + { "name": "OPENSSL_USE_STATIC_LIBS:BOOL", "value": "TRUE" }, + + { "name": "COMPILE_CUDA:BOOL", "value": "FALSE" }, + { "name": "COMPILE_CPU:BOOL", "value": "TRUE" }, + { "name": "COMPILE_EXAMPLES:BOOL", "value": "FALSE" }, + { "name": "COMPILE_SERVER:BOOL", "value": "FALSE" }, + { "name": "COMPILE_TESTS:BOOL", "value": "TRUE" }, + + { "name": "USE_FBGEMM:BOOL", "value": "TRUE" }, + { "name": "USE_MPI:BOOL", "value": "FALSE" }, + { "name": "USE_SENTENCEPIECE:BOOL", "value": "TRUE" }, + { "name": "USE_STATIC_LIBS:BOOL", "value": "TRUE" } + ] + } + ] +} diff --git a/VERSION b/VERSION index 32d68684f..79e20fd31 100644 --- a/VERSION +++ b/VERSION @@ -1 +1,2 @@ -v1.9.10 +v1.9.31 + diff --git a/cmake/FindMKL.cmake b/cmake/FindMKL.cmake index 4e8a99eee..be5e0106e 100644 --- a/cmake/FindMKL.cmake +++ b/cmake/FindMKL.cmake @@ -89,10 +89,17 @@ find_library(MKL_CORE_LIBRARY NO_DEFAULT_PATH) set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR}) -# Added -Wl block to avoid circular dependencies. -# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options -# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor -set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group) +set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY}) + +if(NOT WIN32 AND NOT APPLE) + # Added -Wl block to avoid circular dependencies. + # https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options + # https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor + set(MKL_LIBRARIES -Wl,--start-group ${MKL_LIBRARIES} -Wl,--end-group) +elseif(APPLE) + # MacOS does not support --start-group and --end-group + set(MKL_LIBRARIES -Wl,${MKL_LIBRARIES} -Wl,) +endif() # message("1 ${MKL_INCLUDE_DIR}") # message("2 ${MKL_INTERFACE_LIBRARY}") @@ -130,4 +137,4 @@ endif() INCLUDE(FindPackageHandleStandardArgs) FIND_PACKAGE_HANDLE_STANDARD_ARGS(MKL DEFAULT_MSG MKL_LIBRARIES MKL_INCLUDE_DIRS MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY) -MARK_AS_ADVANCED(MKL_INCLUDE_DIRS MKL_LIBRARIES MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY) \ No newline at end of file +MARK_AS_ADVANCED(MKL_INCLUDE_DIRS MKL_LIBRARIES MKL_INTERFACE_LIBRARY MKL_SEQUENTIAL_LAYER_LIBRARY MKL_CORE_LIBRARY) diff --git a/cmake/GetCacheVariables.cmake b/cmake/GetCacheVariables.cmake index 563ade79e..c2fe376cc 100644 --- a/cmake/GetCacheVariables.cmake +++ b/cmake/GetCacheVariables.cmake @@ -34,17 +34,18 @@ foreach(_variableName ${_variableNames}) (NOT "${_variableType}" STREQUAL "INTERNAL") AND (NOT "${_variableValue}" STREQUAL "") ) - - set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValue}\\n\"\n") + string(REPLACE "\"" " " _variableValueEscapedQuotes ${_variableValue}) + string(REPLACE "\\" "/" _variableValueEscaped ${_variableValueEscapedQuotes}) + set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValueEscaped}\\n\"\n") # Get the variable's advanced flag get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET) if(NOT _isAdvanced) - set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValue}\\n\"\n") + set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValueEscaped}\\n\"\n") endif() # Print variables for debugging - #message(STATUS "${_variableName}=${${_variableName}}") + #message(STATUS "${_variableName}=${_variableValueEscaped}") #message(STATUS " Type=${_variableType}") #message(STATUS " Advanced=${_isAdvanced}") endif() diff --git a/cmake/GetVersionFromFile.cmake b/cmake/GetVersionFromFile.cmake index 29b9d2bca..0e9c54fe8 100644 --- a/cmake/GetVersionFromFile.cmake +++ b/cmake/GetVersionFromFile.cmake @@ -18,7 +18,7 @@ if(PROJECT_VERSION_FILE) file(STRINGS ${PROJECT_VERSION_FILE} PROJECT_VERSION_STRING) else() - file(STRINGS ${CMAKE_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING) + file(STRINGS ${CMAKE_CURRENT_SOURCE_DIR}/VERSION PROJECT_VERSION_STRING) endif() # Get current commit SHA from git diff --git a/regression-tests b/regression-tests index 0f8cabf13..7b8f6ee5b 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit 0f8cabf13ec362d50544d33490024e00c3a763be +Subproject commit 7b8f6ee5b6ff7779fd993df7f77adf1e2d9adbe5 diff --git a/scripts/onnx/example-greedy.py b/scripts/onnx/example-greedy.py new file mode 100644 index 000000000..382e8a85a --- /dev/null +++ b/scripts/onnx/example-greedy.py @@ -0,0 +1,81 @@ +import onnxruntime as ort +import numpy as np +import onnx +import os, sys, time + +os.environ['OMP_NUM_THREADS'] = '1' +sess_options = ort.SessionOptions() +sess_options.intra_op_num_threads = 1 +sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL +sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + +def get_function(path, output_vars): + print("Reading ONNX function from", path) + #model = onnx.load(path) + #print("Done", flush=True) + #print(model) + ort_sess = ort.InferenceSession(path, sess_options) + output_defs = ort_sess.get_outputs() + for input in ort_sess.get_inputs(): + print(" input: ", input.name, input.shape, input.type) + for output in output_defs: + print(" output: ", output.name, output.shape, output.type) + def invoke_model(**kwargs): + def to_numpy(val): + arr = np.array(val) + if arr.dtype == np.double: + arr = arr.astype(np.float32) + elif arr.dtype == np.int64: + arr = arr.astype(np.int32) + return arr + kwargs = { name: to_numpy(val) for name, val in kwargs.items() } + output_vals = ort_sess.run(None, kwargs) + output_dict = { output_def.name : output_val for output_val, output_def in zip(output_vals, output_defs) } + return [output_dict[output_var] for output_var in output_vars] + return invoke_model + +id2word = { id : word.rstrip() for id, word in enumerate(open('c:/work/marian-dev/local/model/vocab_v1.wl', encoding='utf-8').readlines()) } +word2id = { word : id for id, word in id2word.items() } +unk_id = word2id[""] + +model_path_prefix = "c:/work/marian-dev/local/model/model.npz.best-ce-mean-words-debug-sin-uniq-notrans-nounk" +encode_source = get_function(model_path_prefix + '.encode_source.onnx', + ['encoder_context_0']) +decode_first = get_function(model_path_prefix + '.decode_first.onnx', + ['first_logits', 'first_decoder_state_0', 'first_decoder_state_1', 'first_decoder_state_2', 'first_decoder_state_3', 'first_decoder_state_4', 'first_decoder_state_5']) +decode_next = get_function(model_path_prefix + '.decode_next.onnx', + ['next_logits', 'next_decoder_state_0', 'next_decoder_state_1', 'next_decoder_state_2', 'next_decoder_state_3', 'next_decoder_state_4', 'next_decoder_state_5']) + +def greedy_decode(data_0): + if len(data_0) == 1: # special handling for the empty sentence, like Marian + return data_0 + data_0_mask = [[[1.]]] * len(data_0) + data_0_index_range = [[[float(t)]] for t in range(len(data_0))] + #print(data_0, data_0_mask, data_0_index_range) + + max_len = len(data_0) * 3 + Y = [] + encoder_context_0, *_ = encode_source(data_0=data_0, data_0_mask=data_0_mask, data_0_posrange=data_0_index_range) + logp, *out_decoder_states = decode_first(data_1_posrange=[[[float(0)]]], + encoder_context_0=encoder_context_0, data_0_mask=data_0_mask) + logp[:,:,:,unk_id] = -1e8 # suppress , like Marian + Y.append(np.argmax(logp[0][0])) + while Y[-1] != 0 and len(Y) < max_len: + logp, *out_decoder_states = decode_next(prev_word=[Y[-1]], data_1_posrange=[[[float(len(Y))]]], + encoder_context_0=encoder_context_0, data_0_mask=data_0_mask, + decoder_state_0=out_decoder_states[0], decoder_state_1=out_decoder_states[1], + decoder_state_2=out_decoder_states[2], decoder_state_3=out_decoder_states[3], + decoder_state_4=out_decoder_states[4], decoder_state_5=out_decoder_states[5]) + logp[:,:,:,unk_id] = -1e8 + Y.append(np.argmax(logp[0][0])) + return Y + +start_time = time.time() +with open("C:/work/marian-dev/local/model/predictions.out-onnx-debug-sin-notrans-first100-d.tok", 'wt', encoding='utf-8') as out_f: + for line in open("C:/work/marian-dev/local/model/predictions.in-first100.tok", encoding='utf-8').readlines(): + data = [word2id.get(w, unk_id) for w in (line.rstrip() + " ").split(' ') if w] + Y = greedy_decode(data) + print("input: ", ' '.join(id2word[x] for x in data)) + print("output:", ' '.join(id2word[y] for y in Y)) + print(' '.join(id2word[y] for y in Y[:-1]), file=out_f, flush=True) # strip for output to file +print("--- %s seconds ---" % (time.time() - start_time)) diff --git a/src/3rd_party/CLI/App.hpp b/src/3rd_party/CLI/App.hpp index 14ddd1e7f..bf493959b 100644 --- a/src/3rd_party/CLI/App.hpp +++ b/src/3rd_party/CLI/App.hpp @@ -1595,7 +1595,8 @@ class App { if(num < 0) { // RG: We need to keep track if the vector option is empty and handle this separately as // otherwise the parser will mark the command-line option as not set - bool emptyVectorArgs = true; + // RG: An option value after '=' was already collected + bool emptyVectorArgs = (collected <= 0); while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) { if(collected >= -num) { // We could break here for allow extras, but we don't diff --git a/src/3rd_party/CMakeLists.txt b/src/3rd_party/CMakeLists.txt index 9f3981af2..6f465e5d1 100644 --- a/src/3rd_party/CMakeLists.txt +++ b/src/3rd_party/CMakeLists.txt @@ -5,6 +5,8 @@ add_subdirectory(./yaml-cpp) add_subdirectory(./SQLiteCpp) add_subdirectory(./pathie-cpp) add_subdirectory(./zlib) +add_subdirectory(./faiss) +include_directories(./faiss) if(USE_FBGEMM) # @TODO: find out if this is somehow harmful. This is supppressing CMake warnings for CMAKE_SUPPRESS_DEVELOPER_WARNINGS @@ -17,6 +19,9 @@ if(USE_FBGEMM) # only locally disabled for the 3rd_party folder # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused") + else() + # Do not compile cpuinfo executables due to a linker error, and they are not needed + set(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "Build command-line tools") endif() set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests") @@ -43,7 +48,7 @@ if(USE_SENTENCEPIECE) endif() endif() - set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available." FORCE) + set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available.") if(USE_STATIC_LIBS) message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \ @@ -53,8 +58,8 @@ if(USE_SENTENCEPIECE) set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE) set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE) else(USE_STATIC_LIBS) - set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE) - set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC." FORCE) + set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries.") + set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC.") endif(USE_STATIC_LIBS) add_subdirectory(./sentencepiece) @@ -64,7 +69,7 @@ if(USE_SENTENCEPIECE) PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") - foreach(t sentencepiece sentencepiece_train sentencepiece_train-static + foreach(t sentencepiece-static sentencepiece_train-static spm_decode spm_encode spm_export_vocab spm_normalize spm_train) set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused") if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0) @@ -98,8 +103,6 @@ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") " -fPIC -Wno-unused-value") endif() - - include_directories(./zlib) include(ExternalProject) diff --git a/src/3rd_party/cnpy/cnpy.h b/src/3rd_party/cnpy/cnpy.h index 15053ed04..89e607cab 100644 --- a/src/3rd_party/cnpy/cnpy.h +++ b/src/3rd_party/cnpy/cnpy.h @@ -5,6 +5,8 @@ #ifndef LIBCNPY_H_ #define LIBCNPY_H_ +#include "3rd_party/zlib/zlib.h" + #include #include #include @@ -13,10 +15,13 @@ #include #include #include -#include #include #include +#ifdef __APPLE__ +#include +#endif + namespace cnpy { struct NpyArray { @@ -260,8 +265,8 @@ namespace cnpy { static inline void npz_save(std::string zipname, const std::vector& items) { - auto tmpname = zipname + "$$"; // TODO: add thread id or something - unlink(tmpname.c_str()); // when saving to HDFS, we cannot overwrite an existing file + auto tmpname = zipname + "$$"; // TODO: add thread id or something + unlink(tmpname.c_str()); // when saving to HDFS, we cannot overwrite an existing file FILE* fp = fopen(tmpname.c_str(),"wb"); if (!fp) throw std::runtime_error("npz_save: error opening file for writing: " + tmpname); @@ -351,15 +356,15 @@ namespace cnpy { fclose(fp); // move to final location (atomically) -#ifdef _MSC_VER - unlink(zipname.c_str()); // needed for Windows -#endif - bad = bad || (rename(tmpname.c_str(), zipname.c_str()) == -1); - - if (bad) - { - unlink(tmpname.c_str()); - throw std::runtime_error("npz_save: error saving to file: " + zipname); +#ifdef _MSC_VER + unlink(zipname.c_str()); // needed for Windows +#endif + bad = bad || (rename(tmpname.c_str(), zipname.c_str()) == -1); + + if (bad) + { + unlink(tmpname.c_str()); + throw std::runtime_error("npz_save: error saving to file: " + zipname); } } diff --git a/src/3rd_party/faiss/CMakeLists.txt b/src/3rd_party/faiss/CMakeLists.txt new file mode 100644 index 000000000..dbc9fbcc2 --- /dev/null +++ b/src/3rd_party/faiss/CMakeLists.txt @@ -0,0 +1,7 @@ +# adding a new file require explicittly modifing the CMakeLists.txt + +add_definitions(-DFINTEGER=uint64_t) + +include_directories("impl") +FILE(GLOB FaissCppSources *.cpp impl/*.cpp utils/*.cpp) +add_library(faiss OBJECT ${FaissCppSources}) diff --git a/src/3rd_party/faiss/Index.cpp b/src/3rd_party/faiss/Index.cpp new file mode 100644 index 000000000..eac5f3d93 --- /dev/null +++ b/src/3rd_party/faiss/Index.cpp @@ -0,0 +1,119 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include "Index.h" +#include "common/logging.h" +#include + +namespace faiss { + +Index::~Index () +{ +} + + +void Index::train(idx_t /*n*/, const float* /*x*/) { + // does nothing by default +} + + +void Index::range_search (idx_t , const float *, float, + RangeSearchResult *) const +{ + ABORT ("range search not implemented"); +} + +void Index::assign (idx_t n, const float * x, idx_t * labels, idx_t k) +{ + float * distances = new float[n * k]; + ScopeDeleter del(distances); + search (n, x, k, distances, labels); +} + +void Index::add_with_ids( + idx_t /*n*/, + const float* /*x*/, + const idx_t* /*xids*/) { + ABORT ("add_with_ids not implemented for this type of index"); +} + +size_t Index::remove_ids(const IDSelector& /*sel*/) { + ABORT ("remove_ids not implemented for this type of index"); + return -1; +} + + +void Index::reconstruct (idx_t, float * ) const { + ABORT ("reconstruct not implemented for this type of index"); +} + + +void Index::reconstruct_n (idx_t i0, idx_t ni, float *recons) const { + for (idx_t i = 0; i < ni; i++) { + reconstruct (i0 + i, recons + i * d); + } +} + + +void Index::search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const { + search (n, x, k, distances, labels); + for (idx_t i = 0; i < n; ++i) { + for (idx_t j = 0; j < k; ++j) { + idx_t ij = i * k + j; + idx_t key = labels[ij]; + float* reconstructed = recons + ij * d; + if (key < 0) { + // Fill with NaNs + memset(reconstructed, -1, sizeof(*reconstructed) * d); + } else { + reconstruct (key, reconstructed); + } + } + } +} + +void Index::compute_residual (const float * x, + float * residual, idx_t key) const { + reconstruct (key, residual); + for (size_t i = 0; i < d; i++) { + residual[i] = x[i] - residual[i]; + } +} + +void Index::compute_residual_n (idx_t n, const float* xs, + float* residuals, + const idx_t* keys) const { +//#pragma omp parallel for + for (idx_t i = 0; i < n; ++i) { + compute_residual(&xs[i * d], &residuals[i * d], keys[i]); + } +} + + + +size_t Index::sa_code_size () const +{ + ABORT ("standalone codec not implemented for this type of index"); +} + +void Index::sa_encode (idx_t, const float *, + uint8_t *) const +{ + ABORT ("standalone codec not implemented for this type of index"); +} + +void Index::sa_decode (idx_t, const uint8_t *, + float *) const +{ + ABORT ("standalone codec not implemented for this type of index"); +} + +} diff --git a/src/3rd_party/faiss/Index.h b/src/3rd_party/faiss/Index.h new file mode 100644 index 000000000..deaabcaad --- /dev/null +++ b/src/3rd_party/faiss/Index.h @@ -0,0 +1,233 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_INDEX_H +#define FAISS_INDEX_H + +#include "utils/misc.h" +#include +#include +#include +#include + +#define FAISS_VERSION_MAJOR 1 +#define FAISS_VERSION_MINOR 6 +#define FAISS_VERSION_PATCH 3 + +/** + * @namespace faiss + * + * Throughout the library, vectors are provided as float * pointers. + * Most algorithms can be optimized when several vectors are processed + * (added/searched) together in a batch. In this case, they are passed + * in as a matrix. When n vectors of size d are provided as float * x, + * component j of vector i is + * + * x[ i * d + j ] + * + * where 0 <= i < n and 0 <= j < d. In other words, matrices are + * always compact. When specifying the size of the matrix, we call it + * an n*d matrix, which implies a row-major storage. + */ + + +namespace faiss { + +/// Forward declarations see AuxIndexStructures.h +struct IDSelector; +struct RangeSearchResult; +struct DistanceComputer; + +/** Abstract structure for an index, supports adding vectors and searching them. + * + * All vectors provided at add or search time are 32-bit float arrays, + * although the internal representation may vary. + */ +struct Index { + using idx_t = int64_t; ///< all indices are this type + using component_t = float; + using distance_t = float; + + int d; ///< vector dimension + idx_t ntotal; ///< total nb of indexed vectors + bool verbose; ///< verbosity level + + /// set if the Index does not require training, or if training is + /// done already + bool is_trained; + + /// type of metric this index uses for search + MetricType metric_type; + float metric_arg; ///< argument of the metric type + + explicit Index (idx_t d = 0, MetricType metric = METRIC_L2): + d((int)d), + ntotal(0), + verbose(false), + is_trained(true), + metric_type (metric), + metric_arg(0) {} + + virtual ~Index (); + + + /** Perform training on a representative set of vectors + * + * @param n nb of training vectors + * @param x training vecors, size n * d + */ + virtual void train(idx_t n, const float* x); + + /** Add n vectors of dimension d to the index. + * + * Vectors are implicitly assigned labels ntotal .. ntotal + n - 1 + * This function slices the input vectors in chuncks smaller than + * blocksize_add and calls add_core. + * @param x input matrix, size n * d + */ + virtual void add (idx_t n, const float *x) = 0; + + /** Same as add, but stores xids instead of sequential ids. + * + * The default implementation fails with an assertion, as it is + * not supported by all indexes. + * + * @param xids if non-null, ids to store for the vectors (size n) + */ + virtual void add_with_ids (idx_t n, const float * x, const idx_t *xids); + + /** query n vectors of dimension d to the index. + * + * return at most k vectors. If there are not enough results for a + * query, the result array is padded with -1s. + * + * @param x input vectors to search, size n * d + * @param labels output labels of the NNs, size n*k + * @param distances output pairwise distances, size n*k + */ + virtual void search (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels) const = 0; + + /** query n vectors of dimension d to the index. + * + * return all vectors with distance < radius. Note that many + * indexes do not implement the range_search (only the k-NN search + * is mandatory). + * + * @param x input vectors to search, size n * d + * @param radius search radius + * @param result result table + */ + virtual void range_search (idx_t n, const float *x, float radius, + RangeSearchResult *result) const; + + /** return the indexes of the k vectors closest to the query x. + * + * This function is identical as search but only return labels of neighbors. + * @param x input vectors to search, size n * d + * @param labels output labels of the NNs, size n*k + */ + void assign (idx_t n, const float * x, idx_t * labels, idx_t k = 1); + + /// removes all elements from the database. + virtual void reset() = 0; + + /** removes IDs from the index. Not supported by all + * indexes. Returns the number of elements removed. + */ + virtual size_t remove_ids (const IDSelector & sel); + + /** Reconstruct a stored vector (or an approximation if lossy coding) + * + * this function may not be defined for some indexes + * @param key id of the vector to reconstruct + * @param recons reconstucted vector (size d) + */ + virtual void reconstruct (idx_t key, float * recons) const; + + /** Reconstruct vectors i0 to i0 + ni - 1 + * + * this function may not be defined for some indexes + * @param recons reconstucted vector (size ni * d) + */ + virtual void reconstruct_n (idx_t i0, idx_t ni, float *recons) const; + + /** Similar to search, but also reconstructs the stored vectors (or an + * approximation in the case of lossy coding) for the search results. + * + * If there are not enough results for a query, the resulting arrays + * is padded with -1s. + * + * @param recons reconstructed vectors size (n, k, d) + **/ + virtual void search_and_reconstruct (idx_t n, const float *x, idx_t k, + float *distances, idx_t *labels, + float *recons) const; + + /** Computes a residual vector after indexing encoding. + * + * The residual vector is the difference between a vector and the + * reconstruction that can be decoded from its representation in + * the index. The residual can be used for multiple-stage indexing + * methods, like IndexIVF's methods. + * + * @param x input vector, size d + * @param residual output residual vector, size d + * @param key encoded index, as returned by search and assign + */ + virtual void compute_residual (const float * x, + float * residual, idx_t key) const; + + /** Computes a residual vector after indexing encoding (batch form). + * Equivalent to calling compute_residual for each vector. + * + * The residual vector is the difference between a vector and the + * reconstruction that can be decoded from its representation in + * the index. The residual can be used for multiple-stage indexing + * methods, like IndexIVF's methods. + * + * @param n number of vectors + * @param xs input vectors, size (n x d) + * @param residuals output residual vectors, size (n x d) + * @param keys encoded index, as returned by search and assign + */ + virtual void compute_residual_n (idx_t n, const float* xs, + float* residuals, + const idx_t* keys) const; + + /* The standalone codec interface */ + + /** size of the produced codes in bytes */ + virtual size_t sa_code_size () const; + + /** encode a set of vectors + * + * @param n number of vectors + * @param x input vectors, size n * d + * @param bytes output encoded vectors, size n * sa_code_size() + */ + virtual void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const; + + /** encode a set of vectors + * + * @param n number of vectors + * @param bytes input encoded vectors, size n * sa_code_size() + * @param x output vectors, size n * d + */ + virtual void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const; + + +}; + +} + + +#endif diff --git a/src/3rd_party/faiss/IndexLSH.cpp b/src/3rd_party/faiss/IndexLSH.cpp new file mode 100644 index 000000000..6df843312 --- /dev/null +++ b/src/3rd_party/faiss/IndexLSH.cpp @@ -0,0 +1,224 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include + +#include + +#include +#include "common/logging.h" + + +namespace faiss { + +/*************************************************************** + * IndexLSH + ***************************************************************/ + + +IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds): + Index(d), nbits(nbits), rotate_data(rotate_data), + train_thresholds (train_thresholds), rrot(d, nbits) +{ + is_trained = !train_thresholds; + + bytes_per_vec = (nbits + 7) / 8; + + if (rotate_data) { + rrot.init(5); + } else { + ABORT_UNLESS(d >= nbits, "d >= nbits"); + } +} + +IndexLSH::IndexLSH (): + nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false) +{ +} + + +const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const +{ + + float *xt = nullptr; + if (rotate_data) { + // also applies bias if exists + xt = rrot.apply (n, x); + } else if (d != nbits) { + assert (nbits < d); + xt = new float [nbits * n]; + float *xp = xt; + for (idx_t i = 0; i < n; i++) { + const float *xl = x + i * d; + for (int j = 0; j < nbits; j++) + *xp++ = xl [j]; + } + } + + if (train_thresholds) { + + if (xt == NULL) { + xt = new float [nbits * n]; + memcpy (xt, x, sizeof(*x) * n * nbits); + } + + float *xp = xt; + for (idx_t i = 0; i < n; i++) + for (int j = 0; j < nbits; j++) + *xp++ -= thresholds [j]; + } + + return xt ? xt : x; +} + + + +void IndexLSH::train (idx_t n, const float *x) +{ + if (train_thresholds) { + thresholds.resize (nbits); + train_thresholds = false; + const float *xt = apply_preprocess (n, x); + ScopeDeleter del (xt == x ? nullptr : xt); + train_thresholds = true; + + float * transposed_x = new float [n * nbits]; + ScopeDeleter del2 (transposed_x); + + for (idx_t i = 0; i < n; i++) + for (idx_t j = 0; j < nbits; j++) + transposed_x [j * n + i] = xt [i * nbits + j]; + + for (idx_t i = 0; i < nbits; i++) { + float *xi = transposed_x + i * n; + // std::nth_element + std::sort (xi, xi + n); + if (n % 2 == 1) + thresholds [i] = xi [n / 2]; + else + thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2; + + } + } + is_trained = true; +} + + +void IndexLSH::add (idx_t n, const float *x) +{ + ABORT_UNLESS (is_trained, "is_trained"); + codes.resize ((ntotal + n) * bytes_per_vec); + + sa_encode (n, x, &codes[ntotal * bytes_per_vec]); + + ntotal += n; +} + + +void IndexLSH::search ( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels) const +{ + ABORT_UNLESS (is_trained, "is_trained"); + const float *xt = apply_preprocess (n, x); + ScopeDeleter del (xt == x ? nullptr : xt); + + uint8_t * qcodes = new uint8_t [n * bytes_per_vec]; + ScopeDeleter del2 (qcodes); + + fvecs2bitvecs (xt, qcodes, nbits, n); + + int * idistances = new int [n * k]; + ScopeDeleter del3 (idistances); + + int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances}; + + hammings_knn_hc (&res, qcodes, codes.data(), + ntotal, bytes_per_vec, true); + + + // convert distances to floats + for (int i = 0; i < k * n; i++) + distances[i] = idistances[i]; + +} + + +void IndexLSH::transfer_thresholds (LinearTransform *vt) { + if (!train_thresholds) return; + ABORT_UNLESS (nbits == vt->d_out, "nbits == vt->d_out"); + if (!vt->have_bias) { + vt->b.resize (nbits, 0); + vt->have_bias = true; + } + for (int i = 0; i < nbits; i++) + vt->b[i] -= thresholds[i]; + train_thresholds = false; + thresholds.clear(); +} + +void IndexLSH::reset() { + codes.clear(); + ntotal = 0; +} + + +size_t IndexLSH::sa_code_size () const +{ + return bytes_per_vec; +} + +void IndexLSH::sa_encode (idx_t n, const float *x, + uint8_t *bytes) const +{ + ABORT_UNLESS (is_trained, "is_trained"); + const float *xt = apply_preprocess (n, x); + ScopeDeleter del (xt == x ? nullptr : xt); + fvecs2bitvecs (xt, bytes, nbits, n); +} + +void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes, + float *x) const +{ + float *xt = x; + ScopeDeleter del; + if (rotate_data || nbits != d) { + xt = new float [n * nbits]; + del.set(xt); + } + bitvecs2fvecs (bytes, xt, nbits, n); + + if (train_thresholds) { + float *xp = xt; + for (idx_t i = 0; i < n; i++) { + for (int j = 0; j < nbits; j++) { + *xp++ += thresholds [j]; + } + } + } + + if (rotate_data) { + rrot.reverse_transform (n, xt, x); + } else if (nbits != d) { + for (idx_t i = 0; i < n; i++) { + memcpy (x + i * d, xt + i * nbits, + nbits * sizeof(xt[0])); + } + } +} + + + +} // namespace faiss diff --git a/src/3rd_party/faiss/IndexLSH.h b/src/3rd_party/faiss/IndexLSH.h new file mode 100644 index 000000000..66435363a --- /dev/null +++ b/src/3rd_party/faiss/IndexLSH.h @@ -0,0 +1,90 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef INDEX_LSH_H +#define INDEX_LSH_H + +#include + +#include +#include + +namespace faiss { + + +/** The sign of each vector component is put in a binary signature */ +struct IndexLSH:Index { + typedef unsigned char uint8_t; + + int nbits; ///< nb of bits per vector + int bytes_per_vec; ///< nb of 8-bits per encoded vector + bool rotate_data; ///< whether to apply a random rotation to input + bool train_thresholds; ///< whether we train thresholds or use 0 + + RandomRotationMatrix rrot; ///< optional random rotation + + std::vector thresholds; ///< thresholds to compare with + + /// encoded dataset + std::vector codes; + + IndexLSH ( + idx_t d, int nbits, + bool rotate_data = true, + bool train_thresholds = false); + + /** Preprocesses and resizes the input to the size required to + * binarize the data + * + * @param x input vectors, size n * d + * @return output vectors, size n * bits. May be the same pointer + * as x, otherwise it should be deleted by the caller + */ + const float *apply_preprocess (idx_t n, const float *x) const; + + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels) const override; + + void reset() override; + + /// transfer the thresholds to a pre-processing stage (and unset + /// train_thresholds) + void transfer_thresholds (LinearTransform * vt); + + ~IndexLSH() override {} + + IndexLSH (); + + /* standalone codec interface. + * + * The vectors are decoded to +/- 1 (not 0, 1) */ + + size_t sa_code_size () const override; + + void sa_encode (idx_t n, const float *x, + uint8_t *bytes) const override; + + void sa_decode (idx_t n, const uint8_t *bytes, + float *x) const override; + +}; + + +} + + +#endif diff --git a/src/3rd_party/faiss/LICENSE b/src/3rd_party/faiss/LICENSE new file mode 100644 index 000000000..b96dcb048 --- /dev/null +++ b/src/3rd_party/faiss/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/3rd_party/faiss/README b/src/3rd_party/faiss/README new file mode 100644 index 000000000..0c97ad606 --- /dev/null +++ b/src/3rd_party/faiss/README @@ -0,0 +1 @@ +This is code extracted from the original FAISS repository: https://github.com/facebookresearch/faiss \ No newline at end of file diff --git a/src/3rd_party/faiss/VectorTransform.cpp b/src/3rd_party/faiss/VectorTransform.cpp new file mode 100644 index 000000000..103b0910e --- /dev/null +++ b/src/3rd_party/faiss/VectorTransform.cpp @@ -0,0 +1,731 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +#include +#include +#include +#include + +#include +#include "common/logging.h" + +using namespace faiss; + + +extern "C" { + +// this is to keep the clang syntax checker happy +#ifndef FINTEGER +#define FINTEGER uint64_t // MJD: only really safe type for use between Linux and Windows and different MKL versions. Not tested with non-MKL CBLAS +#endif + + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +int sgemm_ ( + const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const float *alpha, const float *a, + FINTEGER *lda, const float *b, + FINTEGER *ldb, float *beta, + float *c, FINTEGER *ldc); + +int dgemm_ ( + const char *transa, const char *transb, FINTEGER *m, FINTEGER * + n, FINTEGER *k, const double *alpha, const double *a, + FINTEGER *lda, const double *b, + FINTEGER *ldb, double *beta, + double *c, FINTEGER *ldc); + +int ssyrk_ ( + const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k, + float *alpha, float *a, FINTEGER *lda, + float *beta, float *c, FINTEGER *ldc); + +/* Lapack functions from http://www.netlib.org/clapack/old/single/ */ + +int ssyev_ ( + const char *jobz, const char *uplo, FINTEGER *n, float *a, + FINTEGER *lda, float *w, float *work, FINTEGER *lwork, + FINTEGER *info); + +int dsyev_ ( + const char *jobz, const char *uplo, FINTEGER *n, double *a, + FINTEGER *lda, double *w, double *work, FINTEGER *lwork, + FINTEGER *info); + +int sgesvd_( + const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n, + float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt, + FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info); + + +int dgesvd_( + const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n, + double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt, + FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info); + +} + +/////////////////////////////////////////////// +extern "C" { + /* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */ + + int sgeqrf_(FINTEGER *m, FINTEGER *n, float *a, FINTEGER *lda, + float *tau, float *work, FINTEGER *lwork, FINTEGER *info); + + int sorgqr_(FINTEGER *m, FINTEGER *n, FINTEGER *k, float *a, + FINTEGER *lda, float *tau, float *work, + FINTEGER *lwork, FINTEGER *info); + +} + +void matrix_qr(int m, int n, float *a) +{ + ABORT_UNLESS(m >= n, "m >= n"); + FINTEGER mi = m, ni = n, ki = mi < ni ? mi : ni; + std::vector tau(ki); + FINTEGER lwork = -1, info; + float work_size; + + sgeqrf_(&mi, &ni, a, &mi, tau.data(), + &work_size, &lwork, &info); + lwork = size_t(work_size); + std::vector work(lwork); + + sgeqrf_(&mi, &ni, a, &mi, + tau.data(), work.data(), &lwork, &info); + + sorgqr_(&mi, &ni, &ki, a, &mi, tau.data(), + work.data(), &lwork, &info); + +} + +/////////////////////////////////////////////// +const float *fvecs_maybe_subsample( + size_t d, size_t *n, size_t nmax, const float *x, + bool verbose = false, int64_t seed = 1234) +{ + + if (*n <= nmax) return x; // nothing to do + + size_t n2 = nmax; + if (verbose) { + printf(" Input training set too big (max size is %zu), sampling " + "%zu / %zu vectors\n", nmax, n2, *n); + } + std::vector subset(*n); + rand_perm(subset.data(), *n, seed); + float *x_subset = new float[n2 * d]; + for (int64_t i = 0; i < n2; i++) + memcpy(&x_subset[i * d], + &x[subset[i] * size_t(d)], + sizeof(x[0]) * d); + *n = n2; + return x_subset; +} + +#if 1 // def __SSE__ +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read(int d, const float *x) +{ + assert(0 <= d && d < 4); +#ifdef _MSC_VER + __declspec(align(16)) +#else + __attribute__((__aligned__(16))) +#endif + float buf[4] = { 0, 0, 0, 0 }; + switch (d) { + case 3: + buf[2] = x[2]; + case 2: + buf[1] = x[1]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} + +float fvec_norm_L2sqr(const float * x, + size_t d) +{ + __m128 mx; + __m128 msum1 = _mm_setzero_ps(); + + while (d >= 4) { + mx = _mm_loadu_ps(x); x += 4; + msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx)); + d -= 4; + } + + mx = masked_read(d, x); + msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx)); + + msum1 = _mm_hadd_ps(msum1, msum1); + msum1 = _mm_hadd_ps(msum1, msum1); + return _mm_cvtss_f32(msum1); +} +#else +// scalar implementation +float fvec_norm_L2sqr(const float *x, size_t d) +{ + return fvec_norm_L2sqr_ref(x, d); +} +#endif + +void fvec_renorm_L2(size_t d, size_t nx, float * __restrict x) +{ +//#pragma omp parallel for + for (size_t i = 0; i < nx; i++) { + float * __restrict xi = x + i * d; + + float nr = fvec_norm_L2sqr(xi, d); + + if (nr > 0) { + size_t j; + const float inv_nr = 1.0 / sqrtf(nr); + for (j = 0; j < d; j++) + xi[j] *= inv_nr; + } + } +} + +/********************************************* + * VectorTransform + *********************************************/ + + + +float * VectorTransform::apply (Index::idx_t n, const float * x) const +{ + float * xt = new float[n * d_out]; + apply_noalloc (n, x, xt); + return xt; +} + + +void VectorTransform::train (idx_t, const float *) { + // does nothing by default +} + + +void VectorTransform::reverse_transform ( + idx_t , const float *, + float *) const +{ + ABORT ("reverse transform not implemented"); +} + + + + +/********************************************* + * LinearTransform + *********************************************/ + +/// both d_in > d_out and d_out < d_in are supported +LinearTransform::LinearTransform (int d_in, int d_out, + bool have_bias): + VectorTransform (d_in, d_out), have_bias (have_bias), + is_orthonormal (false), verbose (false) +{ + is_trained = false; // will be trained when A and b are initialized +} + +void LinearTransform::apply_noalloc (Index::idx_t n, const float * x, + float * xt) const +{ + ABORT_UNLESS(is_trained, "Transformation not trained yet"); + + float c_factor; + if (have_bias) { + ABORT_UNLESS(b.size() == d_out, "Bias not initialized"); + float * xi = xt; + for (int i = 0; i < n; i++) + for(int j = 0; j < d_out; j++) + *xi++ = b[j]; + c_factor = 1.0; + } else { + c_factor = 0.0; + } + + ABORT_UNLESS(A.size() == d_out * d_in, + "Transformation matrix not initialized"); + + float one = 1; + FINTEGER nbiti = d_out, ni = n, di = d_in; + sgemm_ ("Transposed", "Not transposed", + &nbiti, &ni, &di, + &one, A.data(), &di, x, &di, &c_factor, xt, &nbiti); + +} + + +void LinearTransform::transform_transpose (idx_t n, const float * y, + float *x) const +{ + if (have_bias) { // allocate buffer to store bias-corrected data + float *y_new = new float [n * d_out]; + const float *yr = y; + float *yw = y_new; + for (idx_t i = 0; i < n; i++) { + for (int j = 0; j < d_out; j++) { + *yw++ = *yr++ - b [j]; + } + } + y = y_new; + } + + { + FINTEGER dii = d_in, doi = d_out, ni = n; + float one = 1.0, zero = 0.0; + sgemm_ ("Not", "Not", &dii, &ni, &doi, + &one, A.data (), &dii, y, &doi, &zero, x, &dii); + } + + if (have_bias) delete [] y; +} + +void LinearTransform::set_is_orthonormal () +{ + if (d_out > d_in) { + // not clear what we should do in this case + is_orthonormal = false; + return; + } + if (d_out == 0) { // borderline case, unnormalized matrix + is_orthonormal = true; + return; + } + + double eps = 4e-5; + ABORT_UNLESS(A.size() >= d_out * d_in, "A.size() >= d_out * d_in"); + { + std::vector ATA(d_out * d_out); + FINTEGER dii = d_in, doi = d_out; + float one = 1.0, zero = 0.0; + + sgemm_ ("Transposed", "Not", &doi, &doi, &dii, + &one, A.data (), &dii, + A.data(), &dii, + &zero, ATA.data(), &doi); + + is_orthonormal = true; + for (long i = 0; i < d_out; i++) { + for (long j = 0; j < d_out; j++) { + float v = ATA[i + j * d_out]; + if (i == j) v-= 1; + if (fabs(v) > eps) { + is_orthonormal = false; + } + } + } + } + +} + + +void LinearTransform::reverse_transform (idx_t n, const float * xt, + float *x) const +{ + if (is_orthonormal) { + transform_transpose (n, xt, x); + } else { + ABORT("reverse transform not implemented for non-orthonormal matrices"); + } +} + + +void LinearTransform::print_if_verbose ( + const char*name, const std::vector &mat, + int n, int d) const +{ + if (!verbose) return; + printf("matrix %s: %d*%d [\n", name, n, d); + ABORT_UNLESS(mat.size() >= n * d, "mat.size() >= n * d"); + for (int i = 0; i < n; i++) { + for (int j = 0; j < d; j++) { + printf("%10.5g ", mat[i * d + j]); + } + printf("\n"); + } + printf("]\n"); +} + +/********************************************* + * RandomRotationMatrix + *********************************************/ + +void RandomRotationMatrix::init (int seed) +{ + + if(d_out <= d_in) { + A.resize (d_out * d_in); + float *q = A.data(); + float_randn(q, d_out * d_in, seed); + matrix_qr(d_in, d_out, q); + } else { + // use tight-frame transformation + A.resize (d_out * d_out); + float *q = A.data(); + float_randn(q, d_out * d_out, seed); + matrix_qr(d_out, d_out, q); + // remove columns + int i, j; + for (i = 0; i < d_out; i++) { + for(j = 0; j < d_in; j++) { + q[i * d_in + j] = q[i * d_out + j]; + } + } + A.resize(d_in * d_out); + } + is_orthonormal = true; + is_trained = true; +} + +void RandomRotationMatrix::train (Index::idx_t /*n*/, const float* /*x*/) +{ + // initialize with some arbitrary seed + init (12345); +} + + +/********************************************* + * PCAMatrix + *********************************************/ + +PCAMatrix::PCAMatrix (int d_in, int d_out, + float eigen_power, bool random_rotation): + LinearTransform(d_in, d_out, true), + eigen_power(eigen_power), random_rotation(random_rotation) +{ + is_trained = false; + max_points_per_d = 1000; + balanced_bins = 0; +} + + +namespace { + +/// Compute the eigenvalue decomposition of symmetric matrix cov, +/// dimensions d_in-by-d_in. Output eigenvectors in cov. + +void eig(size_t d_in, double *cov, double *eigenvalues, int verbose) +{ + { // compute eigenvalues and vectors + FINTEGER info = 0, lwork = -1, di = d_in; + double workq; + + dsyev_ ("Vectors as well", "Upper", + &di, cov, &di, eigenvalues, &workq, &lwork, &info); + lwork = FINTEGER(workq); + double *work = new double[lwork]; + + dsyev_ ("Vectors as well", "Upper", + &di, cov, &di, eigenvalues, work, &lwork, &info); + + delete [] work; + + if (info != 0) { + fprintf (stderr, "WARN ssyev info returns %d, " + "a very bad PCA matrix is learnt\n", + int(info)); + // do not throw exception, as the matrix could still be useful + } + + + if(verbose && d_in <= 10) { + printf("info=%ld new eigvals=[", long(info)); + for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]); + printf("]\n"); + + double *ci = cov; + printf("eigenvecs=\n"); + for(int i = 0; i < d_in; i++) { + for(int j = 0; j < d_in; j++) + printf("%10.4g ", *ci++); + printf("\n"); + } + } + + } + + // revert order of eigenvectors & values + + for(int i = 0; i < d_in / 2; i++) { + + std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]); + double *v1 = cov + i * d_in; + double *v2 = cov + (d_in - 1 - i) * d_in; + for(int j = 0; j < d_in; j++) + std::swap(v1[j], v2[j]); + } + +} + + +} + +void PCAMatrix::train (Index::idx_t n, const float *x) +{ + const float * x_in = x; + + x = fvecs_maybe_subsample (d_in, (size_t*)&n, + max_points_per_d * d_in, x, verbose); + + ScopeDeleter del_x (x != x_in ? x : nullptr); + + // compute mean + mean.clear(); mean.resize(d_in, 0.0); + if (have_bias) { // we may want to skip the bias + const float *xi = x; + for (int i = 0; i < n; i++) { + for(int j = 0; j < d_in; j++) + mean[j] += *xi++; + } + for(int j = 0; j < d_in; j++) + mean[j] /= n; + } + if(verbose) { + printf("mean=["); + for(int j = 0; j < d_in; j++) printf("%g ", mean[j]); + printf("]\n"); + } + + if(n >= d_in) { + // compute covariance matrix, store it in PCA matrix + PCAMat.resize(d_in * d_in); + float * cov = PCAMat.data(); + { // initialize with mean * mean^T term + float *ci = cov; + for(int i = 0; i < d_in; i++) { + for(int j = 0; j < d_in; j++) + *ci++ = - n * mean[i] * mean[j]; + } + } + { + FINTEGER di = d_in, ni = n; + float one = 1.0; + ssyrk_ ("Up", "Non transposed", + &di, &ni, &one, (float*)x, &di, &one, cov, &di); + + } + if(verbose && d_in <= 10) { + float *ci = cov; + printf("cov=\n"); + for(int i = 0; i < d_in; i++) { + for(int j = 0; j < d_in; j++) + printf("%10g ", *ci++); + printf("\n"); + } + } + + std::vector covd (d_in * d_in); + for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i]; + + std::vector eigenvaluesd (d_in); + + eig (d_in, covd.data (), eigenvaluesd.data (), verbose); + + for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i]; + eigenvalues.resize (d_in); + + for (size_t i = 0; i < d_in; i++) + eigenvalues [i] = eigenvaluesd [i]; + + + } else { + + std::vector xc (n * d_in); + + for (size_t i = 0; i < n; i++) + for(size_t j = 0; j < d_in; j++) + xc [i * d_in + j] = x [i * d_in + j] - mean[j]; + + // compute Gram matrix + std::vector gram (n * n); + { + FINTEGER di = d_in, ni = n; + float one = 1.0, zero = 0.0; + ssyrk_ ("Up", "Transposed", + &ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni); + } + + if(verbose && d_in <= 10) { + float *ci = gram.data(); + printf("gram=\n"); + for(int i = 0; i < n; i++) { + for(int j = 0; j < n; j++) + printf("%10g ", *ci++); + printf("\n"); + } + } + + std::vector gramd (n * n); + for (size_t i = 0; i < n * n; i++) + gramd [i] = gram [i]; + + std::vector eigenvaluesd (n); + + // eig will fill in only the n first eigenvals + + eig (n, gramd.data (), eigenvaluesd.data (), verbose); + + PCAMat.resize(d_in * n); + + for (size_t i = 0; i < n * n; i++) + gram [i] = gramd [i]; + + eigenvalues.resize (d_in); + // fill in only the n first ones + for (size_t i = 0; i < n; i++) + eigenvalues [i] = eigenvaluesd [i]; + + { // compute PCAMat = x' * v + FINTEGER di = d_in, ni = n; + float one = 1.0; + + sgemm_ ("Non", "Non Trans", + &di, &ni, &ni, + &one, xc.data(), &di, gram.data(), &ni, + &one, PCAMat.data(), &di); + } + + if(verbose && d_in <= 10) { + float *ci = PCAMat.data(); + printf("PCAMat=\n"); + for(int i = 0; i < n; i++) { + for(int j = 0; j < d_in; j++) + printf("%10g ", *ci++); + printf("\n"); + } + } + fvec_renorm_L2 (d_in, n, PCAMat.data()); + + } + + prepare_Ab(); + is_trained = true; +} + +void PCAMatrix::copy_from (const PCAMatrix & other) +{ + ABORT_UNLESS(other.is_trained, "other.is_trained"); + mean = other.mean; + eigenvalues = other.eigenvalues; + PCAMat = other.PCAMat; + prepare_Ab (); + is_trained = true; +} + +void PCAMatrix::prepare_Ab () +{ + ABORT_UNLESS( + d_out * d_in <= PCAMat.size(), + "PCA matrix cannot output %d dimensions from %d ", + d_out, d_in); + + if (!random_rotation) { + A = PCAMat; + A.resize(d_out * d_in); // strip off useless dimensions + + // first scale the components + if (eigen_power != 0) { + float *ai = A.data(); + for (int i = 0; i < d_out; i++) { + float factor = pow(eigenvalues[i], eigen_power); + for(int j = 0; j < d_in; j++) + *ai++ *= factor; + } + } + + if (balanced_bins != 0) { + ABORT_UNLESS(d_out % balanced_bins == 0, "d_out % balanced_bins == 0"); + int dsub = d_out / balanced_bins; + std::vector Ain; + std::swap(A, Ain); + A.resize(d_out * d_in); + + std::vector accu(balanced_bins); + std::vector counter(balanced_bins); + + // greedy assignment + for (int i = 0; i < d_out; i++) { + // find best bin + int best_j = -1; + float min_w = 1e30; + for (int j = 0; j < balanced_bins; j++) { + if (counter[j] < dsub && accu[j] < min_w) { + min_w = accu[j]; + best_j = j; + } + } + int row_dst = best_j * dsub + counter[best_j]; + accu[best_j] += eigenvalues[i]; + counter[best_j] ++; + memcpy (&A[row_dst * d_in], &Ain[i * d_in], + d_in * sizeof (A[0])); + } + + if (verbose) { + printf(" bin accu=["); + for (int i = 0; i < balanced_bins; i++) + printf("%g ", accu[i]); + printf("]\n"); + } + } + + + } else { + ABORT_UNLESS(balanced_bins == 0, + "both balancing bins and applying a random rotation " + "does not make sense"); + RandomRotationMatrix rr(d_out, d_out); + + rr.init(5); + + // apply scaling on the rotation matrix (right multiplication) + if (eigen_power != 0) { + for (int i = 0; i < d_out; i++) { + float factor = pow(eigenvalues[i], eigen_power); + for(int j = 0; j < d_out; j++) + rr.A[j * d_out + i] *= factor; + } + } + + A.resize(d_in * d_out); + { + FINTEGER dii = d_in, doo = d_out; + float one = 1.0, zero = 0.0; + + sgemm_ ("Not", "Not", &dii, &doo, &doo, + &one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero, + A.data(), &dii); + + } + + } + + b.clear(); b.resize(d_out); + + for (int i = 0; i < d_out; i++) { + float accu = 0; + for (int j = 0; j < d_in; j++) + accu -= mean[j] * A[j + i * d_in]; + b[i] = accu; + } + + is_orthonormal = eigen_power == 0; + +} + + diff --git a/src/3rd_party/faiss/VectorTransform.h b/src/3rd_party/faiss/VectorTransform.h new file mode 100644 index 000000000..5fc96bc46 --- /dev/null +++ b/src/3rd_party/faiss/VectorTransform.h @@ -0,0 +1,187 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#ifndef FAISS_VECTOR_TRANSFORM_H +#define FAISS_VECTOR_TRANSFORM_H + +/** Defines a few objects that apply transformations to a set of + * vectors Often these are pre-processing steps. + */ + +#include +#include + +#include +#ifdef __APPLE__ +#include +#endif + + +namespace faiss { + + +/** Any transformation applied on a set of vectors */ +struct VectorTransform { + + typedef Index::idx_t idx_t; + + int d_in; ///! input dimension + int d_out; ///! output dimension + + explicit VectorTransform (int d_in = 0, int d_out = 0): + d_in(d_in), d_out(d_out), is_trained(true) + {} + + + /// set if the VectorTransform does not require training, or if + /// training is done already + bool is_trained; + + + /** Perform training on a representative set of vectors. Does + * nothing by default. + * + * @param n nb of training vectors + * @param x training vecors, size n * d + */ + virtual void train (idx_t n, const float *x); + + /** apply the random roation, return new allocated matrix + * @param x size n * d_in + * @return size n * d_out + */ + float *apply (idx_t n, const float * x) const; + + /// same as apply, but result is pre-allocated + virtual void apply_noalloc (idx_t n, const float * x, + float *xt) const = 0; + + /// reverse transformation. May not be implemented or may return + /// approximate result + virtual void reverse_transform (idx_t n, const float * xt, + float *x) const; + + virtual ~VectorTransform () {} + +}; + + + +/** Generic linear transformation, with bias term applied on output + * y = A * x + b + */ +struct LinearTransform: VectorTransform { + + bool have_bias; ///! whether to use the bias term + + /// check if matrix A is orthonormal (enables reverse_transform) + bool is_orthonormal; + + /// Transformation matrix, size d_out * d_in + std::vector A; + + /// bias vector, size d_out + std::vector b; + + /// both d_in > d_out and d_out < d_in are supported + explicit LinearTransform (int d_in = 0, int d_out = 0, + bool have_bias = false); + + /// same as apply, but result is pre-allocated + void apply_noalloc(idx_t n, const float* x, float* xt) const override; + + /// compute x = A^T * (x - b) + /// is reverse transform if A has orthonormal lines + void transform_transpose (idx_t n, const float * y, + float *x) const; + + /// works only if is_orthonormal + void reverse_transform (idx_t n, const float * xt, + float *x) const override; + + /// compute A^T * A to set the is_orthonormal flag + void set_is_orthonormal (); + + bool verbose; + void print_if_verbose (const char*name, const std::vector &mat, + int n, int d) const; + + ~LinearTransform() override {} +}; + + + +/// Randomly rotate a set of vectors +struct RandomRotationMatrix: LinearTransform { + + /// both d_in > d_out and d_out < d_in are supported + RandomRotationMatrix (int d_in, int d_out): + LinearTransform(d_in, d_out, false) {} + + /// must be called before the transform is used + void init(int seed); + + // intializes with an arbitrary seed + void train(idx_t n, const float* x) override; + + RandomRotationMatrix () {} +}; + + +/** Applies a principal component analysis on a set of vectors, + * with optionally whitening and random rotation. */ +struct PCAMatrix: LinearTransform { + + /** after transformation the components are multiplied by + * eigenvalues^eigen_power + * + * =0: no whitening + * =-0.5: full whitening + */ + float eigen_power; + + /// random rotation after PCA + bool random_rotation; + + /// ratio between # training vectors and dimension + size_t max_points_per_d; + + /// try to distribute output eigenvectors in this many bins + int balanced_bins; + + /// Mean, size d_in + std::vector mean; + + /// eigenvalues of covariance matrix (= squared singular values) + std::vector eigenvalues; + + /// PCA matrix, size d_in * d_in + std::vector PCAMat; + + // the final matrix is computed after random rotation and/or whitening + explicit PCAMatrix (int d_in = 0, int d_out = 0, + float eigen_power = 0, bool random_rotation = false); + + /// train on n vectors. If n < d_in then the eigenvector matrix + /// will be completed with 0s + void train(idx_t n, const float* x) override; + + /// copy pre-trained PCA matrix + void copy_from (const PCAMatrix & other); + + /// called after mean, PCAMat and eigenvalues are computed + void prepare_Ab(); + +}; + + +} // namespace faiss + + +#endif diff --git a/src/3rd_party/faiss/utils/Heap.cpp b/src/3rd_party/faiss/utils/Heap.cpp new file mode 100644 index 000000000..5f0e3be29 --- /dev/null +++ b/src/3rd_party/faiss/utils/Heap.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* Function for soft heap */ + +#include + + +namespace faiss { + + +template +void HeapArray::heapify () +{ +//#pragma omp parallel for + for (size_t j = 0; j < nh; j++) + heap_heapify (k, val + j * k, ids + j * k); +} + +template +void HeapArray::reorder () +{ +//#pragma omp parallel for + for (size_t j = 0; j < nh; j++) + heap_reorder (k, val + j * k, ids + j * k); +} + +template +void HeapArray::addn (size_t nj, const T *vin, TI j0, + size_t i0, int64_t ni) +{ + if (ni == -1) ni = nh; + assert (i0 >= 0 && i0 + ni <= nh); +//#pragma omp parallel for + for (size_t i = i0; i < i0 + ni; i++) { + T * __restrict simi = get_val(i); + TI * __restrict idxi = get_ids (i); + const T *ip_line = vin + (i - i0) * nj; + + for (size_t j = 0; j < nj; j++) { + T ip = ip_line [j]; + if (C::cmp(simi[0], ip)) { + heap_pop (k, simi, idxi); + heap_push (k, simi, idxi, ip, j + j0); + } + } + } +} + +template +void HeapArray::addn_with_ids ( + size_t nj, const T *vin, const TI *id_in, + int64_t id_stride, size_t i0, int64_t ni) +{ + if (id_in == nullptr) { + addn (nj, vin, 0, i0, ni); + return; + } + if (ni == -1) ni = nh; + assert (i0 >= 0 && i0 + ni <= nh); +//#pragma omp parallel for + for (size_t i = i0; i < i0 + ni; i++) { + T * __restrict simi = get_val(i); + TI * __restrict idxi = get_ids (i); + const T *ip_line = vin + (i - i0) * nj; + const TI *id_line = id_in + (i - i0) * id_stride; + + for (size_t j = 0; j < nj; j++) { + T ip = ip_line [j]; + if (C::cmp(simi[0], ip)) { + heap_pop (k, simi, idxi); + heap_push (k, simi, idxi, ip, id_line [j]); + } + } + } +} + +template +void HeapArray::per_line_extrema ( + T * out_val, + TI * out_ids) const +{ +//#pragma omp parallel for + for (size_t j = 0; j < nh; j++) { + int64_t imin = -1; + typename C::T xval = C::Crev::neutral (); + const typename C::T * x_ = val + j * k; + for (size_t i = 0; i < k; i++) + if (C::cmp (x_[i], xval)) { + xval = x_[i]; + imin = i; + } + if (out_val) + out_val[j] = xval; + + if (out_ids) { + if (ids && imin != -1) + out_ids[j] = ids [j * k + imin]; + else + out_ids[j] = imin; + } + } +} + + + + +// explicit instanciations + +template struct HeapArray >; +template struct HeapArray >; +template struct HeapArray >; +template struct HeapArray >; + + +} // END namespace fasis diff --git a/src/3rd_party/faiss/utils/Heap.h b/src/3rd_party/faiss/utils/Heap.h new file mode 100644 index 000000000..e691c36c7 --- /dev/null +++ b/src/3rd_party/faiss/utils/Heap.h @@ -0,0 +1,495 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * C++ support for heaps. The set of functions is tailored for + * efficient similarity search. + * + * There is no specific object for a heap, and the functions that + * operate on a signle heap are inlined, because heaps are often + * small. More complex functions are implemented in Heaps.cpp + * + */ + + +#ifndef FAISS_Heap_h +#define FAISS_Heap_h + +#include +#include +#include + +#include +#include +#include + +#include + + +namespace faiss { + +/******************************************************************* + * C object: uniform handling of min and max heap + *******************************************************************/ + +/** The C object gives the type T of the values in the heap, the type + * of the keys, TI and the comparison that is done: > for the minheap + * and < for the maxheap. The neutral value will always be dropped in + * favor of any other value in the heap. + */ + +template +struct CMax; + +// traits of minheaps = heaps where the minimum value is stored on top +// useful to find the *max* values of an array +template +struct CMin { + typedef T_ T; + typedef TI_ TI; + typedef CMax Crev; + inline static bool cmp (T a, T b) { + return a < b; + } + // value that will be popped first -> must be smaller than all others + // for int types this is not strictly the smallest val (-max - 1) + inline static T neutral () { + return -std::numeric_limits::max(); + } +}; + + +template +struct CMax { + typedef T_ T; + typedef TI_ TI; + typedef CMin Crev; + inline static bool cmp (T a, T b) { + return a > b; + } + inline static T neutral () { + return std::numeric_limits::max(); + } +}; + + +/******************************************************************* + * Basic heap ops: push and pop + *******************************************************************/ + +/** Pops the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. on output the element at k-1 is undefined. + */ +template inline +void heap_pop (size_t k, typename C::T * bh_val, typename C::TI * bh_ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + typename C::T val = bh_val[k]; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) { + if (C::cmp(val, bh_val[i1])) + break; + bh_val[i] = bh_val[i1]; + bh_ids[i] = bh_ids[i1]; + i = i1; + } + else { + if (C::cmp(val, bh_val[i2])) + break; + bh_val[i] = bh_val[i2]; + bh_ids[i] = bh_ids[i2]; + i = i2; + } + } + bh_val[i] = bh_val[k]; + bh_ids[i] = bh_ids[k]; +} + + + +/** Pushes the element (val, ids) into the heap bh_val[0..k-2] and + * bh_ids[0..k-2]. on output the element at k-1 is defined. + */ +template inline +void heap_push (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + typename C::T val, typename C::TI ids) +{ + bh_val--; /* Use 1-based indexing for easier node->child translation */ + bh_ids--; + size_t i = k, i_father; + while (i > 1) { + i_father = i >> 1; + if (!C::cmp (val, bh_val[i_father])) /* the heap structure is ok */ + break; + bh_val[i] = bh_val[i_father]; + bh_ids[i] = bh_ids[i_father]; + i = i_father; + } + bh_val[i] = val; + bh_ids[i] = ids; +} + + + +/* Partial instanciation for heaps with TI = int64_t */ + +template inline +void minheap_pop (size_t k, T * bh_val, int64_t * bh_ids) +{ + heap_pop > (k, bh_val, bh_ids); +} + + +template inline +void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_push > (k, bh_val, bh_ids, val, ids); +} + + +template inline +void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids) +{ + heap_pop > (k, bh_val, bh_ids); +} + + +template inline +void maxheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids) +{ + heap_push > (k, bh_val, bh_ids, val, ids); +} + + + +/******************************************************************* + * Heap initialization + *******************************************************************/ + +/* Initialization phase for the heap (with unconditionnal pushes). + * Store k0 elements in a heap containing up to k values. Note that + * (bh_val, bh_ids) can be the same as (x, ids) */ +template inline +void heap_heapify ( + size_t k, + typename C::T * bh_val, + typename C::TI * bh_ids, + const typename C::T * x = nullptr, + const typename C::TI * ids = nullptr, + size_t k0 = 0) +{ + if (k0 > 0) assert (x); + + if (ids) { + for (size_t i = 0; i < k0; i++) + heap_push (i+1, bh_val, bh_ids, x[i], ids[i]); + } else { + for (size_t i = 0; i < k0; i++) + heap_push (i+1, bh_val, bh_ids, x[i], i); + } + + for (size_t i = k0; i < k; i++) { + bh_val[i] = C::neutral(); + bh_ids[i] = -1; + } + +} + +template inline +void minheap_heapify ( + size_t k, T * bh_val, + int64_t * bh_ids, + const T * x = nullptr, + const int64_t * ids = nullptr, + size_t k0 = 0) +{ + heap_heapify< CMin > (k, bh_val, bh_ids, x, ids, k0); +} + + +template inline +void maxheap_heapify ( + size_t k, + T * bh_val, + int64_t * bh_ids, + const T * x = nullptr, + const int64_t * ids = nullptr, + size_t k0 = 0) +{ + heap_heapify< CMax > (k, bh_val, bh_ids, x, ids, k0); +} + + + +/******************************************************************* + * Add n elements to the heap + *******************************************************************/ + + +/* Add some elements to the heap */ +template inline +void heap_addn (size_t k, + typename C::T * bh_val, typename C::TI * bh_ids, + const typename C::T * x, + const typename C::TI * ids, + size_t n) +{ + size_t i; + if (ids) + for (i = 0; i < n; i++) { + if (C::cmp (bh_val[0], x[i])) { + heap_pop (k, bh_val, bh_ids); + heap_push (k, bh_val, bh_ids, x[i], ids[i]); + } + } + else + for (i = 0; i < n; i++) { + if (C::cmp (bh_val[0], x[i])) { + heap_pop (k, bh_val, bh_ids); + heap_push (k, bh_val, bh_ids, x[i], i); + } + } +} + + +/* Partial instanciation for heaps with TI = int64_t */ + +template inline +void minheap_addn (size_t k, T * bh_val, int64_t * bh_ids, + const T * x, const int64_t * ids, size_t n) +{ + heap_addn > (k, bh_val, bh_ids, x, ids, n); +} + +template inline +void maxheap_addn (size_t k, T * bh_val, int64_t * bh_ids, + const T * x, const int64_t * ids, size_t n) +{ + heap_addn > (k, bh_val, bh_ids, x, ids, n); +} + + + + + + +/******************************************************************* + * Heap finalization (reorder elements) + *******************************************************************/ + + +/* This function maps a binary heap into an sorted structure. + It returns the number */ +template inline +size_t heap_reorder (size_t k, typename C::T * bh_val, typename C::TI * bh_ids) +{ + size_t i, ii; + + for (i = 0, ii = 0; i < k; i++) { + /* top element should be put at the end of the list */ + typename C::T val = bh_val[0]; + typename C::TI id = bh_ids[0]; + + /* boundary case: we will over-ride this value if not a true element */ + heap_pop (k-i, bh_val, bh_ids); + bh_val[k-ii-1] = val; + bh_ids[k-ii-1] = id; + if (id != -1) ii++; + } + /* Count the number of elements which are effectively returned */ + size_t nel = ii; + + memmove (bh_val, bh_val+k-ii, ii * sizeof(*bh_val)); + memmove (bh_ids, bh_ids+k-ii, ii * sizeof(*bh_ids)); + + for (; ii < k; ii++) { + bh_val[ii] = C::neutral(); + bh_ids[ii] = -1; + } + return nel; +} + +template inline +size_t minheap_reorder (size_t k, T * bh_val, int64_t * bh_ids) +{ + return heap_reorder< CMin > (k, bh_val, bh_ids); +} + +template inline +size_t maxheap_reorder (size_t k, T * bh_val, int64_t * bh_ids) +{ + return heap_reorder< CMax > (k, bh_val, bh_ids); +} + + + + + +/******************************************************************* + * Operations on heap arrays + *******************************************************************/ + +/** a template structure for a set of [min|max]-heaps it is tailored + * so that the actual data of the heaps can just live in compact + * arrays. + */ +template +struct HeapArray { + typedef typename C::TI TI; + typedef typename C::T T; + + size_t nh; ///< number of heaps + size_t k; ///< allocated size per heap + TI * ids; ///< identifiers (size nh * k) + T * val; ///< values (distances or similarities), size nh * k + + /// Return the list of values for a heap + T * get_val (size_t key) { return val + key * k; } + + /// Correspponding identifiers + TI * get_ids (size_t key) { return ids + key * k; } + + /// prepare all the heaps before adding + void heapify (); + + /** add nj elements to heaps i0:i0+ni, with sequential ids + * + * @param nj nb of elements to add to each heap + * @param vin elements to add, size ni * nj + * @param j0 add this to the ids that are added + * @param i0 first heap to update + * @param ni nb of elements to update (-1 = use nh) + */ + void addn (size_t nj, const T *vin, TI j0 = 0, + size_t i0 = 0, int64_t ni = -1); + + /** same as addn + * + * @param id_in ids of the elements to add, size ni * nj + * @param id_stride stride for id_in + */ + void addn_with_ids ( + size_t nj, const T *vin, const TI *id_in = nullptr, + int64_t id_stride = 0, size_t i0 = 0, int64_t ni = -1); + + /// reorder all the heaps + void reorder (); + + /** this is not really a heap function. It just finds the per-line + * extrema of each line of array D + * @param vals_out extreme value of each line (size nh, or NULL) + * @param idx_out index of extreme value (size nh or NULL) + */ + void per_line_extrema (T *vals_out, TI *idx_out) const; + +}; + + +/* Define useful heaps */ +typedef HeapArray > float_minheap_array_t; +typedef HeapArray > int_minheap_array_t; + +typedef HeapArray > float_maxheap_array_t; +typedef HeapArray > int_maxheap_array_t; + +// The heap templates are instanciated explicitly in Heap.cpp + + + + + + + + + + + + + + + + + + + +/********************************************************************* + * Indirect heaps: instead of having + * + * node i = (bh_ids[i], bh_val[i]), + * + * in indirect heaps, + * + * node i = (bh_ids[i], bh_val[bh_ids[i]]), + * + *********************************************************************/ + + +template +inline +void indirect_heap_pop ( + size_t k, + const typename C::T * bh_val, + typename C::TI * bh_ids) +{ + bh_ids--; /* Use 1-based indexing for easier node->child translation */ + typename C::T val = bh_val[bh_ids[k]]; + size_t i = 1; + while (1) { + size_t i1 = i << 1; + size_t i2 = i1 + 1; + if (i1 > k) + break; + typename C::TI id1 = bh_ids[i1], id2 = bh_ids[i2]; + if (i2 == k + 1 || C::cmp(bh_val[id1], bh_val[id2])) { + if (C::cmp(val, bh_val[id1])) + break; + bh_ids[i] = id1; + i = i1; + } else { + if (C::cmp(val, bh_val[id2])) + break; + bh_ids[i] = id2; + i = i2; + } + } + bh_ids[i] = bh_ids[k]; +} + + + +template +inline +void indirect_heap_push (size_t k, + const typename C::T * bh_val, typename C::TI * bh_ids, + typename C::TI id) +{ + bh_ids--; /* Use 1-based indexing for easier node->child translation */ + typename C::T val = bh_val[id]; + size_t i = k; + while (i > 1) { + size_t i_father = i >> 1; + if (!C::cmp (val, bh_val[bh_ids[i_father]])) + break; + bh_ids[i] = bh_ids[i_father]; + i = i_father; + } + bh_ids[i] = id; +} + + +} // namespace faiss + +#endif /* FAISS_Heap_h */ diff --git a/src/3rd_party/faiss/utils/hamming-inl.h b/src/3rd_party/faiss/utils/hamming-inl.h new file mode 100644 index 000000000..d32da7580 --- /dev/null +++ b/src/3rd_party/faiss/utils/hamming-inl.h @@ -0,0 +1,475 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + + +namespace faiss { + + +#ifdef _MSC_VER +#define bzero(p,n) (memset((p),0,(n))) +#endif +inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size): + code (code), code_size (code_size), i(0) +{ + bzero (code, code_size); +} + +inline void BitstringWriter::write(uint64_t x, int nbit) { + assert (code_size * 8 >= nbit + i); + // nb of available bits in i / 8 + int na = 8 - (i & 7); + + if (nbit <= na) { + code[i >> 3] |= x << (i & 7); + i += nbit; + return; + } else { + int j = i >> 3; + code[j++] |= x << (i & 7); + i += nbit; + x >>= na; + while (x != 0) { + code[j++] |= x; + x >>= 8; + } + } +} + + +inline BitstringReader::BitstringReader(const uint8_t *code, int code_size): + code (code), code_size (code_size), i(0) +{} + +inline uint64_t BitstringReader::read(int nbit) { + assert (code_size * 8 >= nbit + i); + // nb of available bits in i / 8 + int na = 8 - (i & 7); + // get available bits in current byte + uint64_t res = code[i >> 3] >> (i & 7); + if (nbit <= na) { + res &= (1 << nbit) - 1; + i += nbit; + return res; + } else { + int ofs = na; + int j = (i >> 3) + 1; + i += nbit; + nbit -= na; + while (nbit > 8) { + res |= ((uint64_t)code[j++]) << ofs; + ofs += 8; + nbit -= 8; // TODO remove nbit + } + uint64_t last_byte = code[j]; + last_byte &= (1 << nbit) - 1; + res |= last_byte << ofs; + return res; + } +} + + +/****************************************************************** + * The HammingComputer series of classes compares a single code of + * size 4 to 32 to incoming codes. They are intended for use as a + * template class where it would be inefficient to switch on the code + * size in the inner loop. Hopefully the compiler will inline the + * hamming() functions and put the a0, a1, ... in registers. + ******************************************************************/ + + +struct HammingComputer4 { + uint32_t a0; + + HammingComputer4 () {} + + HammingComputer4 (const uint8_t *a, int code_size) { + set (a, code_size); + } + + void set (const uint8_t *a, int code_size) { + assert (code_size == 4); + a0 = *(uint32_t *)a; + } + + inline int hamming (const uint8_t *b) const { + return popcount64 (*(uint32_t *)b ^ a0); + } + +}; + +struct HammingComputer8 { + uint64_t a0; + + HammingComputer8 () {} + + HammingComputer8 (const uint8_t *a, int code_size) { + set (a, code_size); + } + + void set (const uint8_t *a, int code_size) { + assert (code_size == 8); + a0 = *(uint64_t *)a; + } + + inline int hamming (const uint8_t *b) const { + return popcount64 (*(uint64_t *)b ^ a0); + } + +}; + + +struct HammingComputer16 { + uint64_t a0, a1; + + HammingComputer16 () {} + + HammingComputer16 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1); + } + +}; + +// when applied to an array, 1/2 of the 64-bit accesses are unaligned. +// This incurs a penalty of ~10% wrt. fully aligned accesses. +struct HammingComputer20 { + uint64_t a0, a1; + uint32_t a2; + + HammingComputer20 () {} + + HammingComputer20 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 20); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) + + popcount64 (*(uint32_t*)(b + 2) ^ a2); + } +}; + +struct HammingComputer32 { + uint64_t a0, a1, a2, a3; + + HammingComputer32 () {} + + HammingComputer32 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) + + popcount64 (b[2] ^ a2) + popcount64 (b[3] ^ a3); + } + +}; + +struct HammingComputer64 { + uint64_t a0, a1, a2, a3, a4, a5, a6, a7; + + HammingComputer64 () {} + + HammingComputer64 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size == 64); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + a4 = a[4]; a5 = a[5]; a6 = a[6]; a7 = a[7]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return popcount64 (b[0] ^ a0) + popcount64 (b[1] ^ a1) + + popcount64 (b[2] ^ a2) + popcount64 (b[3] ^ a3) + + popcount64 (b[4] ^ a4) + popcount64 (b[5] ^ a5) + + popcount64 (b[6] ^ a6) + popcount64 (b[7] ^ a7); + } + +}; + +// very inefficient... +struct HammingComputerDefault { + const uint8_t *a; + int n; + + HammingComputerDefault () {} + + HammingComputerDefault (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + a = a8; + n = code_size; + } + + int hamming (const uint8_t *b8) const { + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64 (a[i] ^ b8[i]); + return accu; + } + +}; + +struct HammingComputerM8 { + const uint64_t *a; + int n; + + HammingComputerM8 () {} + + HammingComputerM8 (const uint8_t *a8, int code_size) { + set (a8, code_size); + } + + void set (const uint8_t *a8, int code_size) { + assert (code_size % 8 == 0); + a = (uint64_t *)a8; + n = code_size / 8; + } + + int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64 (a[i] ^ b[i]); + return accu; + } + +}; + +// even more inefficient! +struct HammingComputerM4 { + const uint32_t *a; + int n; + + HammingComputerM4 () {} + + HammingComputerM4 (const uint8_t *a4, int code_size) { + set (a4, code_size); + } + + void set (const uint8_t *a4, int code_size) { + assert (code_size % 4 == 0); + a = (uint32_t *)a4; + n = code_size / 4; + } + + int hamming (const uint8_t *b8) const { + const uint32_t *b = (uint32_t *)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += popcount64 (a[i] ^ b[i]); + return accu; + } + +}; + +/*************************************************************************** + * Equivalence with a template class when code size is known at compile time + **************************************************************************/ + +// default template +template +struct HammingComputer: HammingComputerM8 { + HammingComputer (const uint8_t *a, int code_size): + HammingComputerM8(a, code_size) {} +}; + +#define SPECIALIZED_HC(CODE_SIZE) \ + template<> struct HammingComputer: \ + HammingComputer ## CODE_SIZE { \ + HammingComputer (const uint8_t *a): \ + HammingComputer ## CODE_SIZE(a, CODE_SIZE) {} \ + } + +SPECIALIZED_HC(4); +SPECIALIZED_HC(8); +SPECIALIZED_HC(16); +SPECIALIZED_HC(20); +SPECIALIZED_HC(32); +SPECIALIZED_HC(64); + +#undef SPECIALIZED_HC + + +/*************************************************************************** + * generalized Hamming = number of bytes that are different between + * two codes. + ***************************************************************************/ + + +inline int generalized_hamming_64 (uint64_t a) { + a |= a >> 1; + a |= a >> 2; + a |= a >> 4; + a &= 0x0101010101010101UL; + return popcount64 (a); +} + + +struct GenHammingComputer8 { + uint64_t a0; + + GenHammingComputer8 (const uint8_t *a, int code_size) { + assert (code_size == 8); + a0 = *(uint64_t *)a; + } + + inline int hamming (const uint8_t *b) const { + return generalized_hamming_64 (*(uint64_t *)b ^ a0); + } + +}; + + +struct GenHammingComputer16 { + uint64_t a0, a1; + GenHammingComputer16 (const uint8_t *a8, int code_size) { + assert (code_size == 16); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return generalized_hamming_64 (b[0] ^ a0) + + generalized_hamming_64 (b[1] ^ a1); + } + +}; + +struct GenHammingComputer32 { + uint64_t a0, a1, a2, a3; + + GenHammingComputer32 (const uint8_t *a8, int code_size) { + assert (code_size == 32); + const uint64_t *a = (uint64_t *)a8; + a0 = a[0]; a1 = a[1]; a2 = a[2]; a3 = a[3]; + } + + inline int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + return generalized_hamming_64 (b[0] ^ a0) + + generalized_hamming_64 (b[1] ^ a1) + + generalized_hamming_64 (b[2] ^ a2) + + generalized_hamming_64 (b[3] ^ a3); + } + +}; + +struct GenHammingComputerM8 { + const uint64_t *a; + int n; + + GenHammingComputerM8 (const uint8_t *a8, int code_size) { + assert (code_size % 8 == 0); + a = (uint64_t *)a8; + n = code_size / 8; + } + + int hamming (const uint8_t *b8) const { + const uint64_t *b = (uint64_t *)b8; + int accu = 0; + for (int i = 0; i < n; i++) + accu += generalized_hamming_64 (a[i] ^ b[i]); + return accu; + } + +}; + + +/** generalized Hamming distances (= count number of code bytes that + are the same) */ +void generalized_hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t code_size, + int ordered = true); + + + +/** This class maintains a list of best distances seen so far. + * + * Since the distances are in a limited range (0 to nbit), the + * object maintains one list per possible distance, and fills + * in only the n-first lists, such that the sum of sizes of the + * n lists is below k. + */ +template +struct HCounterState { + int *counters; + int64_t *ids_per_dis; + + HammingComputer hc; + int thres; + int count_lt; + int count_eq; + int k; + + HCounterState(int *counters, int64_t *ids_per_dis, + const uint8_t *x, int d, int k) + : counters(counters), + ids_per_dis(ids_per_dis), + hc(x, d / 8), + thres(d + 1), + count_lt(0), + count_eq(0), + k(k) {} + + void update_counter(const uint8_t *y, size_t j) { + int32_t dis = hc.hamming(y); + + if (dis <= thres) { + if (dis < thres) { + ids_per_dis[dis * k + counters[dis]++] = j; + ++count_lt; + while (count_lt == k && thres > 0) { + --thres; + count_eq = counters[thres]; + count_lt -= count_eq; + } + } else if (count_eq < k) { + ids_per_dis[dis * k + count_eq++] = j; + counters[dis] = count_eq; + } + } + } +}; + + +} // namespace faiss diff --git a/src/3rd_party/faiss/utils/hamming.cpp b/src/3rd_party/faiss/utils/hamming.cpp new file mode 100644 index 000000000..cc0df2750 --- /dev/null +++ b/src/3rd_party/faiss/utils/hamming.cpp @@ -0,0 +1,879 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * Implementation of Hamming related functions (distances, smallest distance + * selection with regular heap|radix and probabilistic heap|radix. + * + * IMPLEMENTATION NOTES + * Bitvectors are generally assumed to be multiples of 64 bits. + * + * hamdis_t is used for distances because at this time + * it is not clear how we will need to balance + * - flexibility in vector size (unclear more than 2^16 or even 2^8 bitvectors) + * - memory usage + * - cache-misses when dealing with large volumes of data (lower bits is better) + * + * The hamdis_t should optimally be compatibe with one of the Torch Storage + * (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes +*/ + +#include + +#include +#include +#include +#include + +#include +#include "common/logging.h" +#include "misc.h" + +static const size_t BLOCKSIZE_QUERY = 8192; + + +namespace faiss { + void binary_to_real(size_t d, const uint8_t *x_in, float *x_out) { + for (size_t i = 0; i < d; ++i) { + x_out[i] = 2 * ((x_in[i >> 3] >> (i & 7)) & 1) - 1; + } + } + + ////////////////////////////////////////// +size_t hamming_batch_size = 65536; + +static const uint8_t hamdis_tab_ham_bytes[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8 +}; + + +/* Elementary Hamming distance computation: unoptimized */ +template +T hamming (const uint8_t *bs1, + const uint8_t *bs2) +{ + const size_t nbytes = nbits / 8; + size_t i; + T h = 0; + for (i = 0; i < nbytes; i++) + h += (T) hamdis_tab_ham_bytes[bs1[i]^bs2[i]]; + return h; +} + + +/* Hamming distances for multiples of 64 bits */ +template +hamdis_t hamming (const uint64_t * bs1, const uint64_t * bs2) +{ + const size_t nwords = nbits / 64; + size_t i; + hamdis_t h = 0; + for (i = 0; i < nwords; i++) + h += popcount64 (bs1[i] ^ bs2[i]); + return h; +} + + + +/* specialized (optimized) functions */ +template <> +hamdis_t hamming<64> (const uint64_t * pa, const uint64_t * pb) +{ + return popcount64 (pa[0] ^ pb[0]); +} + + +template <> +hamdis_t hamming<128> (const uint64_t *pa, const uint64_t *pb) +{ + return popcount64 (pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]); +} + + +template <> +hamdis_t hamming<256> (const uint64_t * pa, const uint64_t * pb) +{ + return popcount64 (pa[0] ^ pb[0]) + + popcount64 (pa[1] ^ pb[1]) + + popcount64 (pa[2] ^ pb[2]) + + popcount64 (pa[3] ^ pb[3]); +} + + +/* Hamming distances for multiple of 64 bits */ +hamdis_t hamming ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t nwords) +{ + size_t i; + hamdis_t h = 0; + for (i = 0; i < nwords; i++) + h += popcount64 (bs1[i] ^ bs2[i]); + return h; +} + + + +template +void hammings ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, size_t n2, + hamdis_t * dis) + +{ + size_t i, j; + const size_t nwords = nbits / 64; + for (i = 0; i < n1; i++) { + const uint64_t * __restrict bs1_ = bs1 + i * nwords; + hamdis_t * __restrict dis_ = dis + i * n2; + for (j = 0; j < n2; j++) + dis_[j] = hamming(bs1_, bs2 + j * nwords); + } +} + + + +void hammings ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, + size_t n2, + size_t nwords, + hamdis_t * __restrict dis) +{ + size_t i, j; + n1 *= nwords; + n2 *= nwords; + for (i = 0; i < n1; i+=nwords) { + const uint64_t * bs1_ = bs1+i; + for (j = 0; j < n2; j+=nwords) + dis[j] = hamming (bs1_, bs2+j, nwords); + } +} + + + + +/* Count number of matches given a max threshold */ +template +void hamming_count_thres ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t * nptr) +{ + const size_t nwords = nbits / 64; + size_t i, j, posm = 0; + const uint64_t * bs2_ = bs2; + + for (i = 0; i < n1; i++) { + bs2 = bs2_; + for (j = 0; j < n2; j++) { + /* collect the match only if this satisfies the threshold */ + if (hamming (bs1, bs2) <= ht) + posm++; + bs2 += nwords; + } + bs1 += nwords; /* next signature */ + } + *nptr = posm; +} + + +template +void crosshamming_count_thres ( + const uint64_t * dbs, + size_t n, + int ht, + size_t * nptr) +{ + const size_t nwords = nbits / 64; + size_t i, j, posm = 0; + const uint64_t * bs1 = dbs; + for (i = 0; i < n; i++) { + const uint64_t * bs2 = bs1 + 2; + for (j = i + 1; j < n; j++) { + /* collect the match only if this satisfies the threshold */ + if (hamming (bs1, bs2) <= ht) + posm++; + bs2 += nwords; + } + bs1 += nwords; + } + *nptr = posm; +} + + +template +size_t match_hamming_thres ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t n1, + size_t n2, + int ht, + int64_t * idx, + hamdis_t * hams) +{ + const size_t nwords = nbits / 64; + size_t i, j, posm = 0; + hamdis_t h; + const uint64_t * bs2_ = bs2; + for (i = 0; i < n1; i++) { + bs2 = bs2_; + for (j = 0; j < n2; j++) { + /* Here perform the real work of computing the distance */ + h = hamming (bs1, bs2); + + /* collect the match only if this satisfies the threshold */ + if (h <= ht) { + /* Enough space to store another match ? */ + *idx = i; idx++; + *idx = j; idx++; + *hams = h; + hams++; + posm++; + } + bs2+=nwords; /* next signature */ + } + bs1+=nwords; + } + return posm; +} + + +/* Return closest neighbors w.r.t Hamming distance, using a heap. */ +template +static +void hammings_knn_hc ( + int bytes_per_code, + int_maxheap_array_t * ha, + const uint8_t * bs1, + const uint8_t * bs2, + size_t n2, + bool order = true, + bool init_heap = true) +{ + size_t k = ha->k; + if (init_heap) ha->heapify (); + + const size_t block_size = hamming_batch_size; + for (size_t j0 = 0; j0 < n2; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, n2); +//#pragma omp parallel for + for (size_t i = 0; i < ha->nh; i++) { + HammingComputer hc (bs1 + i * bytes_per_code, bytes_per_code); + + const uint8_t * bs2_ = bs2 + j0 * bytes_per_code; + hamdis_t dis; + hamdis_t * __restrict bh_val_ = ha->val + i * k; + int64_t * __restrict bh_ids_ = ha->ids + i * k; + size_t j; + for (j = j0; j < j1; j++, bs2_+= bytes_per_code) { + dis = hc.hamming (bs2_); + if (dis < bh_val_[0]) { + faiss::maxheap_pop (k, bh_val_, bh_ids_); + faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + } + } + } + } + if (order) ha->reorder (); + } + +/* Return closest neighbors w.r.t Hamming distance, using max count. */ +template +static +void hammings_knn_mc ( + int bytes_per_code, + const uint8_t *a, + const uint8_t *b, + size_t na, + size_t nb, + size_t k, + int32_t *distances, + int64_t *labels) +{ + const int nBuckets = bytes_per_code * 8 + 1; + std::vector all_counters(na * nBuckets, 0); + std::unique_ptr all_ids_per_dis(new int64_t[na * nBuckets * k]); + + std::vector> cs; + for (size_t i = 0; i < na; ++i) { + cs.push_back(HCounterState( + all_counters.data() + i * nBuckets, + all_ids_per_dis.get() + i * nBuckets * k, + a + i * bytes_per_code, + 8 * bytes_per_code, + k + )); + } + + const size_t block_size = hamming_batch_size; + for (size_t j0 = 0; j0 < nb; j0 += block_size) { + const size_t j1 = std::min(j0 + block_size, nb); +//#pragma omp parallel for + for (size_t i = 0; i < na; ++i) { + for (size_t j = j0; j < j1; ++j) { + cs[i].update_counter(b + j * bytes_per_code, j); + } + } + } + + for (size_t i = 0; i < na; ++i) { + HCounterState& csi = cs[i]; + + int nres = 0; + for (int b = 0; b < nBuckets && nres < k; b++) { + for (int l = 0; l < csi.counters[b] && nres < k; l++) { + labels[i * k + nres] = csi.ids_per_dis[b * k + l]; + distances[i * k + nres] = b; + nres++; + } + } + while (nres < k) { + labels[i * k + nres] = -1; + distances[i * k + nres] = std::numeric_limits::max(); + ++nres; + } + } +} + + + +// works faster than the template version +static +void hammings_knn_hc_1 ( + int_maxheap_array_t * ha, + const uint64_t * bs1, + const uint64_t * bs2, + size_t n2, + bool order = true, + bool init_heap = true) +{ + const size_t nwords = 1; + size_t k = ha->k; + + + if (init_heap) { + ha->heapify (); + } + +//#pragma omp parallel for + for (size_t i = 0; i < ha->nh; i++) { + const uint64_t bs1_ = bs1 [i]; + const uint64_t * bs2_ = bs2; + hamdis_t dis; + hamdis_t * bh_val_ = ha->val + i * k; + hamdis_t bh_val_0 = bh_val_[0]; + int64_t * bh_ids_ = ha->ids + i * k; + size_t j; + for (j = 0; j < n2; j++, bs2_+= nwords) { + dis = popcount64 (bs1_ ^ *bs2_); + if (dis < bh_val_0) { + faiss::maxheap_pop (k, bh_val_, bh_ids_); + faiss::maxheap_push (k, bh_val_, bh_ids_, dis, j); + bh_val_0 = bh_val_[0]; + } + } + } + if (order) { + ha->reorder (); + } +} + + + + +/* Functions to maps vectors to bits. Assume proper allocation done beforehand, + meaning that b should be be able to receive as many bits as x may produce. */ + +/* + * dimension 0 corresponds to the least significant bit of b[0], or + * equivalently to the lsb of the first byte that is stored. + */ +void fvec2bitvec (const float * x, uint8_t * b, size_t d) +{ + for (int i = 0; i < d; i += 8) { + uint8_t w = 0; + uint8_t mask = 1; + int nj = i + 8 <= d ? 8 : d - i; + for (int j = 0; j < nj; j++) { + if (x[i + j] >= 0) + w |= mask; + mask <<= 1; + } + *b = w; + b++; + } +} + + + +/* Same but for n vectors. + Ensure that the ouptut b is byte-aligned (pad with 0s). */ +void fvecs2bitvecs (const float * x, uint8_t * b, size_t d, size_t n) +{ + const int64_t ncodes = ((d + 7) / 8); +//#pragma omp parallel for if(n > 100000) + for (size_t i = 0; i < n; i++) + fvec2bitvec (x + i * d, b + i * ncodes, d); +} + + + +void bitvecs2fvecs ( + const uint8_t * b, + float * x, + size_t d, + size_t n) { + + const int64_t ncodes = ((d + 7) / 8); +//#pragma omp parallel for if(n > 100000) + for (size_t i = 0; i < n; i++) { + binary_to_real (d, b + i * ncodes, x + i * d); + } +} + + +/* Reverse bit (NOT a optimized function, only used for print purpose) */ +static uint64_t uint64_reverse_bits (uint64_t b) +{ + int i; + uint64_t revb = 0; + for (i = 0; i < 64; i++) { + revb <<= 1; + revb |= b & 1; + b >>= 1; + } + return revb; +} + + +/* print the bit vector */ +void bitvec_print (const uint8_t * b, size_t d) +{ + size_t i, j; + for (i = 0; i < d; ) { + uint64_t brev = uint64_reverse_bits (* (uint64_t *) b); + for (j = 0; j < 64 && i < d; j++, i++) { + printf ("%d", (int) (brev & 1)); + brev >>= 1; + } + b += 8; + printf (" "); + } +} + + +void bitvec_shuffle (size_t n, size_t da, size_t db, + const int *order, + const uint8_t *a, + uint8_t *b) +{ + for(size_t i = 0; i < db; i++) { + ABORT_UNLESS (order[i] >= 0 && order[i] < da, "order[i] >= 0 && order[i] < da"); + } + size_t lda = (da + 7) / 8; + size_t ldb = (db + 7) / 8; + +//#pragma omp parallel for if(n > 10000) + for (size_t i = 0; i < n; i++) { + const uint8_t *ai = a + i * lda; + uint8_t *bi = b + i * ldb; + memset (bi, 0, ldb); + for(size_t i = 0; i < db; i++) { + int o = order[i]; + uint8_t the_bit = (ai[o >> 3] >> (o & 7)) & 1; + bi[i >> 3] |= the_bit << (i & 7); + } + } + +} + + + +/*----------------------------------------*/ +/* Hamming distance computation and k-nn */ + + +#define C64(x) ((uint64_t *)x) + + +/* Compute a set of Hamming distances */ +void hammings ( + const uint8_t * a, + const uint8_t * b, + size_t na, size_t nb, + size_t ncodes, + hamdis_t * __restrict dis) +{ + ABORT_UNLESS (ncodes % 8 == 0, "ncodes % 8 == 0"); + switch (ncodes) { + case 8: + faiss::hammings <64> (C64(a), C64(b), na, nb, dis); return; + case 16: + faiss::hammings <128> (C64(a), C64(b), na, nb, dis); return; + case 32: + faiss::hammings <256> (C64(a), C64(b), na, nb, dis); return; + case 64: + faiss::hammings <512> (C64(a), C64(b), na, nb, dis); return; + default: + faiss::hammings (C64(a), C64(b), na, nb, ncodes * 8, dis); return; + } +} + +void hammings_knn( + int_maxheap_array_t *ha, + const uint8_t *a, + const uint8_t *b, + size_t nb, + size_t ncodes, + int order) +{ + hammings_knn_hc(ha, a, b, nb, ncodes, order); +} + +void hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int order) +{ + switch (ncodes) { + case 4: + hammings_knn_hc + (4, ha, a, b, nb, order, true); + break; + case 8: + hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true); + // hammings_knn_hc + // (8, ha, a, b, nb, order, true); + break; + case 16: + hammings_knn_hc + (16, ha, a, b, nb, order, true); + break; + case 32: + hammings_knn_hc + (32, ha, a, b, nb, order, true); + break; + default: + if(ncodes % 8 == 0) { + hammings_knn_hc + (ncodes, ha, a, b, nb, order, true); + } else { + hammings_knn_hc + (ncodes, ha, a, b, nb, order, true); + + } + } +} + +void hammings_knn_mc( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + int32_t *distances, + int64_t *labels) +{ + switch (ncodes) { + case 4: + hammings_knn_mc( + 4, a, b, na, nb, k, distances, labels + ); + break; + case 8: + // TODO(hoss): Write analog to hammings_knn_hc_1 + // hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true); + hammings_knn_mc( + 8, a, b, na, nb, k, distances, labels + ); + break; + case 16: + hammings_knn_mc( + 16, a, b, na, nb, k, distances, labels + ); + break; + case 32: + hammings_knn_mc( + 32, a, b, na, nb, k, distances, labels + ); + break; + default: + if(ncodes % 8 == 0) { + hammings_knn_mc( + ncodes, a, b, na, nb, k, distances, labels + ); + } else { + hammings_knn_mc( + ncodes, a, b, na, nb, k, distances, labels + ); + } + } +} +template +static +void hamming_range_search_template ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + int radius, + size_t code_size, + RangeSearchResult *res) +{ + +//#pragma omp parallel + { + RangeSearchPartialResult pres (res); + +//#pragma omp for + for (size_t i = 0; i < na; i++) { + HammingComputer hc (a + i * code_size, code_size); + const uint8_t * yi = b; + RangeQueryResult & qres = pres.new_result (i); + + for (size_t j = 0; j < nb; j++) { + int dis = hc.hamming (yi); + if (dis < radius) { + qres.add(dis, j); + } + yi += code_size; + } + } + pres.finalize (); + } +} + +void hamming_range_search ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + int radius, + size_t code_size, + RangeSearchResult *result) +{ + +#define HC(name) hamming_range_search_template (a, b, na, nb, radius, code_size, result) + + switch(code_size) { + case 4: HC(HammingComputer4); break; + case 8: HC(HammingComputer8); break; + case 16: HC(HammingComputer16); break; + case 32: HC(HammingComputer32); break; + default: + if (code_size % 8 == 0) { + HC(HammingComputerM8); + } else { + HC(HammingComputerDefault); + } + } +#undef HC +} + + + +/* Count number of matches given a max threshold */ +void hamming_count_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + size_t * nptr) +{ + switch (ncodes) { + case 8: + faiss::hamming_count_thres <64> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + case 16: + faiss::hamming_count_thres <128> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + case 32: + faiss::hamming_count_thres <256> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + case 64: + faiss::hamming_count_thres <512> (C64(bs1), C64(bs2), + n1, n2, ht, nptr); + return; + default: + ABORT ("not implemented for %zu bits", ncodes); + } +} + + +/* Count number of cross-matches given a threshold */ +void crosshamming_count_thres ( + const uint8_t * dbs, + size_t n, + hamdis_t ht, + size_t ncodes, + size_t * nptr) +{ + switch (ncodes) { + case 8: + faiss::crosshamming_count_thres <64> (C64(dbs), n, ht, nptr); + return; + case 16: + faiss::crosshamming_count_thres <128> (C64(dbs), n, ht, nptr); + return; + case 32: + faiss::crosshamming_count_thres <256> (C64(dbs), n, ht, nptr); + return; + case 64: + faiss::crosshamming_count_thres <512> (C64(dbs), n, ht, nptr); + return; + default: + ABORT ("not implemented for %zu bits", ncodes); + } +} + + +/* Returns all matches given a threshold */ +size_t match_hamming_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + int64_t * idx, + hamdis_t * dis) +{ + switch (ncodes) { + case 8: + return faiss::match_hamming_thres <64> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + case 16: + return faiss::match_hamming_thres <128> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + case 32: + return faiss::match_hamming_thres <256> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + case 64: + return faiss::match_hamming_thres <512> (C64(bs1), C64(bs2), + n1, n2, ht, idx, dis); + default: + ABORT ("not implemented for %zu bits", ncodes); + return 0; + } +} + + +#undef C64 + + + +/************************************* + * generalized Hamming distances + ************************************/ + + + +template +static void hamming_dis_inner_loop ( + const uint8_t *ca, + const uint8_t *cb, + size_t nb, + size_t code_size, + int k, + hamdis_t * bh_val_, + int64_t * bh_ids_) +{ + + HammingComputer hc (ca, code_size); + + for (size_t j = 0; j < nb; j++) { + int ndiff = hc.hamming (cb); + cb += code_size; + if (ndiff < bh_val_[0]) { + maxheap_pop (k, bh_val_, bh_ids_); + maxheap_push (k, bh_val_, bh_ids_, ndiff, j); + } + } +} + +void generalized_hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t code_size, + int ordered) +{ + int na = ha->nh; + int k = ha->k; + + if (ordered) + ha->heapify (); + +//#pragma omp parallel for + for (int i = 0; i < na; i++) { + const uint8_t *ca = a + i * code_size; + const uint8_t *cb = b; + + hamdis_t * bh_val_ = ha->val + i * k; + int64_t * bh_ids_ = ha->ids + i * k; + + switch (code_size) { + case 8: + hamming_dis_inner_loop + (ca, cb, nb, 8, k, bh_val_, bh_ids_); + break; + case 16: + hamming_dis_inner_loop + (ca, cb, nb, 16, k, bh_val_, bh_ids_); + break; + case 32: + hamming_dis_inner_loop + (ca, cb, nb, 32, k, bh_val_, bh_ids_); + break; + default: + hamming_dis_inner_loop + (ca, cb, nb, code_size, k, bh_val_, bh_ids_); + break; + } + } + + if (ordered) + ha->reorder (); + +} + + +} // namespace faiss diff --git a/src/3rd_party/faiss/utils/hamming.h b/src/3rd_party/faiss/utils/hamming.h new file mode 100644 index 000000000..762d3773c --- /dev/null +++ b/src/3rd_party/faiss/utils/hamming.h @@ -0,0 +1,244 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* + * Hamming distances. The binary vector dimensionality should be a + * multiple of 8, as the elementary operations operate on bytes. If + * you need other sizes, just pad with 0s (this is done by function + * fvecs2bitvecs). + * + * User-defined type hamdis_t is used for distances because at this time + * it is still uncler clear how we will need to balance + * - flexibility in vector size (may need 16- or even 8-bit vectors) + * - memory usage + * - cache-misses when dealing with large volumes of data (fewer bits is better) + * + */ + +#ifndef FAISS_hamming_h +#define FAISS_hamming_h + + +#include + +#include + +#ifdef _MSC_VER +#include // needed for some intrinsics in +#define __builtin_popcountl __popcnt64 +#endif + +/* The Hamming distance type */ +typedef int32_t hamdis_t; + +namespace faiss { + +/************************************************** + * General bit vector functions + **************************************************/ + +struct RangeSearchResult; + +void bitvec_print (const uint8_t * b, size_t d); + + +/* Functions for casting vectors of regular types to compact bits. + They assume proper allocation done beforehand, meaning that b + should be be able to receive as many bits as x may produce. */ + +/* Makes an array of bits from the signs of a float array. The length + of the output array b is rounded up to byte size (allocate + accordingly) */ +void fvecs2bitvecs ( + const float * x, + uint8_t * b, + size_t d, + size_t n); + +void bitvecs2fvecs ( + const uint8_t * b, + float * x, + size_t d, + size_t n); + + +void fvec2bitvec (const float * x, uint8_t * b, size_t d); + +/** Shuffle the bits from b(i, j) := a(i, order[j]) + */ +void bitvec_shuffle (size_t n, size_t da, size_t db, + const int *order, + const uint8_t *a, + uint8_t *b); + + +/*********************************************** + * Generic reader/writer for bit strings + ***********************************************/ + + +struct BitstringWriter { + uint8_t *code; + size_t code_size; + size_t i; // current bit offset + + // code_size in bytes + BitstringWriter(uint8_t *code, int code_size); + + // write the nbit low bits of x + void write(uint64_t x, int nbit); +}; + +struct BitstringReader { + const uint8_t *code; + size_t code_size; + size_t i; + + // code_size in bytes + BitstringReader(const uint8_t *code, int code_size); + + // read nbit bits from the code + uint64_t read(int nbit); +}; + +/************************************************** + * Hamming distance computation functions + **************************************************/ + + + +extern size_t hamming_batch_size; + +static inline int popcount64(uint64_t x) { + return __builtin_popcountl(x); +} + + +/** Compute a set of Hamming distances between na and nb binary vectors + * + * @param a size na * nbytespercode + * @param b size nb * nbytespercode + * @param nbytespercode should be multiple of 8 + * @param dis output distances, size na * nb + */ +void hammings ( + const uint8_t * a, + const uint8_t * b, + size_t na, size_t nb, + size_t nbytespercode, + hamdis_t * dis); + + + + +/** Return the k smallest Hamming distances for a set of binary query vectors, + * using a max heap. + * @param a queries, size ha->nh * ncodes + * @param b database, size nb * ncodes + * @param nb number of database vectors + * @param ncodes size of the binary codes (bytes) + * @param ordered if != 0: order the results by decreasing distance + * (may be bottleneck for k/n > 0.01) */ +void hammings_knn_hc ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int ordered); + +/* Legacy alias to hammings_knn_hc. */ +void hammings_knn ( + int_maxheap_array_t * ha, + const uint8_t * a, + const uint8_t * b, + size_t nb, + size_t ncodes, + int ordered); + +/** Return the k smallest Hamming distances for a set of binary query vectors, + * using counting max. + * @param a queries, size na * ncodes + * @param b database, size nb * ncodes + * @param na number of query vectors + * @param nb number of database vectors + * @param k number of vectors/distances to return + * @param ncodes size of the binary codes (bytes) + * @param distances output distances from each query vector to its k nearest + * neighbors + * @param labels output ids of the k nearest neighbors to each query vector + */ +void hammings_knn_mc ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + size_t k, + size_t ncodes, + int32_t *distances, + int64_t *labels); + +/** same as hammings_knn except we are doing a range search with radius */ +void hamming_range_search ( + const uint8_t * a, + const uint8_t * b, + size_t na, + size_t nb, + int radius, + size_t ncodes, + RangeSearchResult *result); + + +/* Counting the number of matches or of cross-matches (without returning them) + For use with function that assume pre-allocated memory */ +void hamming_count_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + size_t * nptr); + +/* Return all Hamming distances/index passing a thres. Pre-allocation of output + is required. Use hamming_count_thres to determine the proper size. */ +size_t match_hamming_thres ( + const uint8_t * bs1, + const uint8_t * bs2, + size_t n1, + size_t n2, + hamdis_t ht, + size_t ncodes, + int64_t * idx, + hamdis_t * dis); + +/* Cross-matching in a set of vectors */ +void crosshamming_count_thres ( + const uint8_t * dbs, + size_t n, + hamdis_t ht, + size_t ncodes, + size_t * nptr); + + +/* compute the Hamming distances between two codewords of nwords*64 bits */ +hamdis_t hamming ( + const uint64_t * bs1, + const uint64_t * bs2, + size_t nwords); + + + +} // namespace faiss + +// inlined definitions of HammingComputerXX and GenHammingComputerXX + +#include + +#endif /* FAISS_hamming_h */ diff --git a/src/3rd_party/faiss/utils/misc.cpp b/src/3rd_party/faiss/utils/misc.cpp new file mode 100644 index 000000000..b9a1c92da --- /dev/null +++ b/src/3rd_party/faiss/utils/misc.cpp @@ -0,0 +1,162 @@ +#include "misc.h" +#include + +namespace faiss { + /*********************************************************************** + * BufferList + ***********************************************************************/ + + + BufferList::BufferList(size_t buffer_size) : + buffer_size(buffer_size) + { + wp = buffer_size; + } + + BufferList::~BufferList() + { + for (int i = 0; i < buffers.size(); i++) { + delete[] buffers[i].ids; + delete[] buffers[i].dis; + } + } + + void BufferList::add(idx_t id, float dis) { + if (wp == buffer_size) { // need new buffer + append_buffer(); + } + Buffer & buf = buffers.back(); + buf.ids[wp] = id; + buf.dis[wp] = dis; + wp++; + } + + + void BufferList::append_buffer() + { + Buffer buf = { new idx_t[buffer_size], new float[buffer_size] }; + buffers.push_back(buf); + wp = 0; + } + + /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to + /// tables dest_ids, dest_dis + void BufferList::copy_range(size_t ofs, size_t n, + idx_t * dest_ids, float *dest_dis) + { + size_t bno = ofs / buffer_size; + ofs -= bno * buffer_size; + while (n > 0) { + size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs; + Buffer buf = buffers[bno]; + memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids)); + memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis)); + dest_ids += ncopy; + dest_dis += ncopy; + ofs = 0; + bno++; + n -= ncopy; + } + } + + /*********************************************************************** + * RangeSearchPartialResult + ***********************************************************************/ + + void RangeQueryResult::add(float dis, idx_t id) { + nres++; + pres->add(id, dis); + } + + + + RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult * res_in) : + BufferList(res_in->buffer_size), + res(res_in) + {} + + + /// begin a new result + RangeQueryResult & + RangeSearchPartialResult::new_result(idx_t qno) + { + RangeQueryResult qres = { qno, 0, this }; + queries.push_back(qres); + return queries.back(); + } + + + void RangeSearchPartialResult::finalize() + { + set_lims(); +//#pragma omp barrier + +//#pragma omp single + res->do_allocation(); + +//#pragma omp barrier + copy_result(); + } + + + /// called by range_search before do_allocation + void RangeSearchPartialResult::set_lims() + { + for (int i = 0; i < queries.size(); i++) { + RangeQueryResult & qres = queries[i]; + res->lims[qres.qno] = qres.nres; + } + } + + /// called by range_search after do_allocation + void RangeSearchPartialResult::copy_result(bool incremental) + { + size_t ofs = 0; + for (int i = 0; i < queries.size(); i++) { + RangeQueryResult & qres = queries[i]; + + copy_range(ofs, qres.nres, + res->labels + res->lims[qres.qno], + res->distances + res->lims[qres.qno]); + if (incremental) { + res->lims[qres.qno] += qres.nres; + } + ofs += qres.nres; + } + } + + void RangeSearchPartialResult::merge(std::vector & + partial_results, bool do_delete) + { + + int npres = partial_results.size(); + if (npres == 0) return; + RangeSearchResult *result = partial_results[0]->res; + size_t nx = result->nq; + + // count + for (const RangeSearchPartialResult * pres : partial_results) { + if (!pres) continue; + for (const RangeQueryResult &qres : pres->queries) { + result->lims[qres.qno] += qres.nres; + } + } + result->do_allocation(); + for (int j = 0; j < npres; j++) { + if (!partial_results[j]) continue; + partial_results[j]->copy_result(true); + if (do_delete) { + delete partial_results[j]; + partial_results[j] = nullptr; + } + } + + // reset the limits + for (size_t i = nx; i > 0; i--) { + result->lims[i] = result->lims[i - 1]; + } + result->lims[0] = 0; + } + +} + diff --git a/src/3rd_party/faiss/utils/misc.h b/src/3rd_party/faiss/utils/misc.h new file mode 100644 index 000000000..86821ab19 --- /dev/null +++ b/src/3rd_party/faiss/utils/misc.h @@ -0,0 +1,144 @@ +#pragma once +#include +#include + +namespace faiss { + /// The metric space for vector comparison for Faiss indices and algorithms. +/// +/// Most algorithms support both inner product and L2, with the flat +/// (brute-force) indices supporting additional metric types for vector +/// comparison. + enum MetricType { + METRIC_INNER_PRODUCT = 0, ///< maximum inner product search + METRIC_L2 = 1, ///< squared L2 search + METRIC_L1, ///< L1 (aka cityblock) + METRIC_Linf, ///< infinity distance + METRIC_Lp, ///< L_p distance, p is given by a faiss::Index + /// metric_arg + + /// some additional metrics defined in scipy.spatial.distance + METRIC_Canberra = 20, + METRIC_BrayCurtis, + METRIC_JensenShannon, + }; + + template + struct ScopeDeleter { + const T * ptr; + explicit ScopeDeleter(const T* ptr = nullptr) : ptr(ptr) {} + void release() { ptr = nullptr; } + void set(const T * ptr_in) { ptr = ptr_in; } + void swap(ScopeDeleter &other) { std::swap(ptr, other.ptr); } + ~ScopeDeleter() { + delete[] ptr; + } + }; + + ////////////////////////////////////////// + using idx_t = int64_t; + + /** List of temporary buffers used to store results before they are + * copied to the RangeSearchResult object. */ + struct BufferList { + typedef faiss::idx_t idx_t; + + // buffer sizes in # entries + size_t buffer_size; + + struct Buffer { + idx_t *ids; + float *dis; + }; + + std::vector buffers; + size_t wp; ///< write pointer in the last buffer. + + explicit BufferList(size_t buffer_size); + + ~BufferList(); + + /// create a new buffer + void append_buffer(); + + /// add one result, possibly appending a new buffer if needed + void add(idx_t id, float dis); + + /// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to + /// tables dest_ids, dest_dis + void copy_range(size_t ofs, size_t n, + idx_t * dest_ids, float *dest_dis); + + }; + + /** The objective is to have a simple result structure while + * minimizing the number of mem copies in the result. The method + * do_allocation can be overloaded to allocate the result tables in + * the matrix type of a scripting language like Lua or Python. */ + struct RangeSearchResult { + size_t nq; ///< nb of queries + size_t *lims; ///< size (nq + 1) + + typedef faiss::idx_t idx_t; + + idx_t *labels; ///< result for query i is labels[lims[i]:lims[i+1]] + float *distances; ///< corresponding distances (not sorted) + + size_t buffer_size; ///< size of the result buffers used + + /// lims must be allocated on input to range_search. + explicit RangeSearchResult(idx_t nq, bool alloc_lims = true); + + /// called when lims contains the nb of elements result entries + /// for each query + + virtual void do_allocation(); + + virtual ~RangeSearchResult(); + }; + + struct RangeSearchPartialResult; + + /// result structure for a single query + struct RangeQueryResult { + using idx_t = faiss::idx_t; + idx_t qno; //< id of the query + size_t nres; //< nb of results for this query + RangeSearchPartialResult * pres; + + /// called by search function to report a new result + void add(float dis, idx_t id); + }; + + /// the entries in the buffers are split per query + struct RangeSearchPartialResult : BufferList { + RangeSearchResult * res; + + /// eventually the result will be stored in res_in + explicit RangeSearchPartialResult(RangeSearchResult * res_in); + + /// query ids + nb of results per query. + std::vector queries; + + /// begin a new result + RangeQueryResult & new_result(idx_t qno); + + /***************************************** + * functions used at the end of the search to merge the result + * lists */ + void finalize(); + + /// called by range_search before do_allocation + void set_lims(); + + /// called by range_search after do_allocation + void copy_result(bool incremental = false); + + /// merge a set of PartialResult's into one RangeSearchResult + /// on ouptut the partialresults are empty! + static void merge(std::vector & + partial_results, bool do_delete = true); + + }; + + +} // namespace diff --git a/src/3rd_party/faiss/utils/random.cpp b/src/3rd_party/faiss/utils/random.cpp new file mode 100644 index 000000000..ee4d6f559 --- /dev/null +++ b/src/3rd_party/faiss/utils/random.cpp @@ -0,0 +1,192 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +#include + +namespace faiss { + +/************************************************** + * Random data generation functions + **************************************************/ + +RandomGenerator::RandomGenerator (int64_t seed) + : mt((unsigned int)seed) {} + +int RandomGenerator::rand_int () +{ + return mt() & 0x7fffffff; +} + +int64_t RandomGenerator::rand_int64 () +{ + return int64_t(rand_int()) | int64_t(rand_int()) << 31; +} + +int RandomGenerator::rand_int (int max) +{ + return mt() % max; +} + +float RandomGenerator::rand_float () +{ + return mt() / float(mt.max()); +} + +double RandomGenerator::rand_double () +{ + return mt() / double(mt.max()); +} + + +/*********************************************************************** + * Random functions in this C file only exist because Torch + * counterparts are slow and not multi-threaded. Typical use is for + * more than 1-100 billion values. */ + + +/* Generate a set of random floating point values such that x[i] in [0,1] + multi-threading. For this reason, we rely on re-entreant functions. */ +void float_rand (float * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +//#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + + for (size_t i = istart; i < iend; i++) + x[i] = rng.rand_float (); + } +} + + +void float_randn (float * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +//#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + RandomGenerator rng (a0 + j * b0); + + double a = 0, b = 0, s = 0; + int state = 0; /* generate two number per "do-while" loop */ + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + + for (size_t i = istart; i < iend; i++) { + /* Marsaglia's method (see Knuth) */ + if (state == 0) { + do { + a = 2.0 * rng.rand_double () - 1; + b = 2.0 * rng.rand_double () - 1; + s = a * a + b * b; + } while (s >= 1.0); + x[i] = a * sqrt(-2.0 * log(s) / s); + } + else + x[i] = b * sqrt(-2.0 * log(s) / s); + state = 1 - state; + } + } +} + + +/* Integer versions */ +void int64_rand (int64_t * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +//#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + for (size_t i = istart; i < iend; i++) + x[i] = rng.rand_int64 (); + } +} + +void int64_rand_max (int64_t * x, size_t n, uint64_t max, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +//#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + for (size_t i = istart; i < iend; i++) + x[i] = rng.rand_int64 () % max; + } +} + + +void rand_perm (int *perm, size_t n, int64_t seed) +{ + for (size_t i = 0; i < n; i++) perm[i] = i; + + RandomGenerator rng (seed); + + for (size_t i = 0; i + 1 < n; i++) { + int i2 = i + rng.rand_int (n - i); + std::swap(perm[i], perm[i2]); + } +} + + + + +void byte_rand (uint8_t * x, size_t n, int64_t seed) +{ + // only try to parallelize on large enough arrays + const size_t nblock = n < 1024 ? 1 : 1024; + + RandomGenerator rng0 (seed); + int a0 = rng0.rand_int (), b0 = rng0.rand_int (); + +//#pragma omp parallel for + for (size_t j = 0; j < nblock; j++) { + + RandomGenerator rng (a0 + j * b0); + + const size_t istart = j * n / nblock; + const size_t iend = (j + 1) * n / nblock; + + size_t i; + for (i = istart; i < iend; i++) + x[i] = rng.rand_int64 (); + } +} + +} // namespace faiss diff --git a/src/3rd_party/faiss/utils/random.h b/src/3rd_party/faiss/utils/random.h new file mode 100644 index 000000000..e94ac068c --- /dev/null +++ b/src/3rd_party/faiss/utils/random.h @@ -0,0 +1,60 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +// -*- c++ -*- + +/* Random generators. Implemented here for speed and to make + * sequences reproducible. + */ + +#pragma once + +#include +#include + + +namespace faiss { + +/************************************************** + * Random data generation functions + **************************************************/ + +/// random generator that can be used in multithreaded contexts +struct RandomGenerator { + + std::mt19937 mt; + + /// random positive integer + int rand_int (); + + /// random int64_t + int64_t rand_int64 (); + + /// generate random integer between 0 and max-1 + int rand_int (int max); + + /// between 0 and 1 + float rand_float (); + + double rand_double (); + + explicit RandomGenerator (int64_t seed = 1234); +}; + +/* Generate an array of uniform random floats / multi-threaded implementation */ +void float_rand (float * x, size_t n, int64_t seed); +void float_randn (float * x, size_t n, int64_t seed); +void int64_rand (int64_t * x, size_t n, int64_t seed); +void byte_rand (uint8_t * x, size_t n, int64_t seed); +// max is actually the maximum value + 1 +void int64_rand_max (int64_t * x, size_t n, uint64_t max, int64_t seed); + +/* random permutation */ +void rand_perm (int * perm, size_t n, int64_t seed); + + +} // namespace faiss diff --git a/src/3rd_party/fbgemm b/src/3rd_party/fbgemm index 84e66a976..da28b0abb 160000 --- a/src/3rd_party/fbgemm +++ b/src/3rd_party/fbgemm @@ -1 +1 @@ -Subproject commit 84e66a976046180187724aff60a236c5378fde7c +Subproject commit da28b0abb0e54f4ed6e1444309e23879018689d7 diff --git a/src/3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.cpp b/src/3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.cpp new file mode 100755 index 000000000..8bcc371e2 --- /dev/null +++ b/src/3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.cpp @@ -0,0 +1,37 @@ +// protobuf-generated files don't compile clean. This compiles them with warnings +// disabled, without having to disable it for the entire project whole-sale. + +#ifdef USE_ONNX + +// Get protobuf this way: +// sudo apt-get install cmake pkg-config libprotobuf9v5 protobuf-compiler libprotobuf-dev libgoogle-perftools-dev + +// Since we don't develop the ONNX .proto file, I just hand-created the .pb. files. +// The automatic process that CMake would invoke fails because protobuf generates +// source code that is not warning-free. So let's use this manual process for now, +// and just version-control the resulting files. The command is simple enough: +// cd src/3rd_party/onnx/protobuf +// protoc -I=. --cpp_out=. onnx-ml.proto + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100 4125 4127 4244 4267 4512 4456 4510 4610 4800) +#endif +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wsuggest-override" +#endif + +#define AuxillaryParseTableField AuxiliaryParseTableField // in protobuf 3.12, the generated source has a spelling error + +#include "onnx-ml.pb.cc" // this is the actual file we compile + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif // USE_ONNX diff --git a/src/3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.h b/src/3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.h new file mode 100755 index 000000000..d92348a21 --- /dev/null +++ b/src/3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.h @@ -0,0 +1,23 @@ +// protobuf-generated files don't compile clean. This compiles them with warnings +// disabled, without having to disable it for the entire project whole-sale. + +#pragma once + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4800 4610 4512 4510 4267 4127 4125 4100 4456) +#endif +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wsuggest-override" +#endif + +#include "onnx-ml.pb.h" // this is the actual file we include + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif +#ifdef _MSC_VER +#pragma warning(pop) +#endif diff --git a/src/3rd_party/onnx/protobuf/onnx-ml.pb.cc b/src/3rd_party/onnx/protobuf/onnx-ml.pb.cc new file mode 100755 index 000000000..789d6c980 --- /dev/null +++ b/src/3rd_party/onnx/protobuf/onnx-ml.pb.cc @@ -0,0 +1,7244 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: onnx-ml.proto + +#include "onnx-ml.pb.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +// @@protoc_insertion_point(includes) +#include +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_AttributeProto_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_OperatorSetIdProto_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_StringStringEntryProto_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorProto_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorProto_Segment_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<3> scc_info_TypeProto_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TypeProto_Opaque_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_SparseTensor_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_Tensor_onnx_2dml_2eproto; +extern PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_ValueInfoProto_onnx_2dml_2eproto; +namespace onnx { +class AttributeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _AttributeProto_default_instance_; +class ValueInfoProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ValueInfoProto_default_instance_; +class NodeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _NodeProto_default_instance_; +class ModelProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _ModelProto_default_instance_; +class StringStringEntryProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _StringStringEntryProto_default_instance_; +class GraphProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _GraphProto_default_instance_; +class TensorProto_SegmentDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorProto_Segment_default_instance_; +class TensorProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorProto_default_instance_; +class TensorShapeProto_DimensionDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + ::PROTOBUF_NAMESPACE_ID::int64 dim_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr dim_param_; +} _TensorShapeProto_Dimension_default_instance_; +class TensorShapeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TensorShapeProto_default_instance_; +class TypeProto_TensorDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_Tensor_default_instance_; +class TypeProto_SequenceDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_Sequence_default_instance_; +class TypeProto_MapDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_Map_default_instance_; +class TypeProto_OpaqueDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_Opaque_default_instance_; +class TypeProto_SparseTensorDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _TypeProto_SparseTensor_default_instance_; +class TypeProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; + const ::onnx::TypeProto_Tensor* tensor_type_; + const ::onnx::TypeProto_Sequence* sequence_type_; + const ::onnx::TypeProto_Map* map_type_; + const ::onnx::TypeProto_Opaque* opaque_type_; + const ::onnx::TypeProto_SparseTensor* sparse_tensor_type_; +} _TypeProto_default_instance_; +class OperatorSetIdProtoDefaultTypeInternal { + public: + ::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed _instance; +} _OperatorSetIdProto_default_instance_; +} // namespace onnx +static void InitDefaultsscc_info_AttributeProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_AttributeProto_default_instance_; + new (ptr) ::onnx::AttributeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::onnx::_NodeProto_default_instance_; + new (ptr) ::onnx::NodeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::onnx::_GraphProto_default_instance_; + new (ptr) ::onnx::GraphProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::AttributeProto::InitAsDefaultInstance(); + ::onnx::NodeProto::InitAsDefaultInstance(); + ::onnx::GraphProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<2> scc_info_AttributeProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 2, 0, InitDefaultsscc_info_AttributeProto_onnx_2dml_2eproto}, { + &scc_info_TensorProto_onnx_2dml_2eproto.base, + &scc_info_ValueInfoProto_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_ModelProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_ModelProto_default_instance_; + new (ptr) ::onnx::ModelProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::ModelProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<3> scc_info_ModelProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 3, 0, InitDefaultsscc_info_ModelProto_onnx_2dml_2eproto}, { + &scc_info_OperatorSetIdProto_onnx_2dml_2eproto.base, + &scc_info_AttributeProto_onnx_2dml_2eproto.base, + &scc_info_StringStringEntryProto_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_OperatorSetIdProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_OperatorSetIdProto_default_instance_; + new (ptr) ::onnx::OperatorSetIdProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::OperatorSetIdProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_OperatorSetIdProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_OperatorSetIdProto_onnx_2dml_2eproto}, {}}; + +static void InitDefaultsscc_info_StringStringEntryProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_StringStringEntryProto_default_instance_; + new (ptr) ::onnx::StringStringEntryProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::StringStringEntryProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_StringStringEntryProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_StringStringEntryProto_onnx_2dml_2eproto}, {}}; + +static void InitDefaultsscc_info_TensorProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TensorProto_default_instance_; + new (ptr) ::onnx::TensorProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TensorProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TensorProto_onnx_2dml_2eproto}, { + &scc_info_TensorProto_Segment_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_TensorProto_Segment_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TensorProto_Segment_default_instance_; + new (ptr) ::onnx::TensorProto_Segment(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TensorProto_Segment::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorProto_Segment_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TensorProto_Segment_onnx_2dml_2eproto}, {}}; + +static void InitDefaultsscc_info_TensorShapeProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TensorShapeProto_default_instance_; + new (ptr) ::onnx::TensorShapeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TensorShapeProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TensorShapeProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TensorShapeProto_onnx_2dml_2eproto}, { + &scc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TensorShapeProto_Dimension_default_instance_; + new (ptr) ::onnx::TensorShapeProto_Dimension(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TensorShapeProto_Dimension::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto}, {}}; + +static void InitDefaultsscc_info_TypeProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TypeProto_Sequence_default_instance_; + new (ptr) ::onnx::TypeProto_Sequence(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::onnx::_TypeProto_Map_default_instance_; + new (ptr) ::onnx::TypeProto_Map(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + { + void* ptr = &::onnx::_TypeProto_default_instance_; + new (ptr) ::onnx::TypeProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TypeProto_Sequence::InitAsDefaultInstance(); + ::onnx::TypeProto_Map::InitAsDefaultInstance(); + ::onnx::TypeProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<3> scc_info_TypeProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 3, 0, InitDefaultsscc_info_TypeProto_onnx_2dml_2eproto}, { + &scc_info_TypeProto_Tensor_onnx_2dml_2eproto.base, + &scc_info_TypeProto_Opaque_onnx_2dml_2eproto.base, + &scc_info_TypeProto_SparseTensor_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_TypeProto_Opaque_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TypeProto_Opaque_default_instance_; + new (ptr) ::onnx::TypeProto_Opaque(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TypeProto_Opaque::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_TypeProto_Opaque_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_TypeProto_Opaque_onnx_2dml_2eproto}, {}}; + +static void InitDefaultsscc_info_TypeProto_SparseTensor_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TypeProto_SparseTensor_default_instance_; + new (ptr) ::onnx::TypeProto_SparseTensor(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TypeProto_SparseTensor::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_SparseTensor_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TypeProto_SparseTensor_onnx_2dml_2eproto}, { + &scc_info_TensorShapeProto_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_TypeProto_Tensor_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_TypeProto_Tensor_default_instance_; + new (ptr) ::onnx::TypeProto_Tensor(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::TypeProto_Tensor::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_TypeProto_Tensor_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_TypeProto_Tensor_onnx_2dml_2eproto}, { + &scc_info_TensorShapeProto_onnx_2dml_2eproto.base,}}; + +static void InitDefaultsscc_info_ValueInfoProto_onnx_2dml_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::onnx::_ValueInfoProto_default_instance_; + new (ptr) ::onnx::ValueInfoProto(); + ::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr); + } + ::onnx::ValueInfoProto::InitAsDefaultInstance(); +} + +::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_ValueInfoProto_onnx_2dml_2eproto = + {{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_ValueInfoProto_onnx_2dml_2eproto}, { + &scc_info_TypeProto_onnx_2dml_2eproto.base,}}; + +static ::PROTOBUF_NAMESPACE_ID::Metadata file_level_metadata_onnx_2dml_2eproto[17]; +static const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* file_level_enum_descriptors_onnx_2dml_2eproto[3]; +static constexpr ::PROTOBUF_NAMESPACE_ID::ServiceDescriptor const** file_level_service_descriptors_onnx_2dml_2eproto = nullptr; + +const ::PROTOBUF_NAMESPACE_ID::uint32 TableStruct_onnx_2dml_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, name_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, ref_attr_name_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, doc_string_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, type_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, f_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, i_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, s_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, t_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, g_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, floats_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, ints_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, strings_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, tensors_), + PROTOBUF_FIELD_OFFSET(::onnx::AttributeProto, graphs_), + 0, + 3, + 2, + 8, + 7, + 6, + 1, + 4, + 5, + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + PROTOBUF_FIELD_OFFSET(::onnx::ValueInfoProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::ValueInfoProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::ValueInfoProto, name_), + PROTOBUF_FIELD_OFFSET(::onnx::ValueInfoProto, type_), + PROTOBUF_FIELD_OFFSET(::onnx::ValueInfoProto, doc_string_), + 0, + 2, + 1, + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, input_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, output_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, name_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, op_type_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, domain_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, attribute_), + PROTOBUF_FIELD_OFFSET(::onnx::NodeProto, doc_string_), + ~0u, + ~0u, + 0, + 1, + 3, + ~0u, + 2, + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, ir_version_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, opset_import_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, producer_name_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, producer_version_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, domain_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, model_version_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, doc_string_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, graph_), + PROTOBUF_FIELD_OFFSET(::onnx::ModelProto, metadata_props_), + 5, + ~0u, + 0, + 1, + 2, + 6, + 3, + 4, + ~0u, + PROTOBUF_FIELD_OFFSET(::onnx::StringStringEntryProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::StringStringEntryProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::StringStringEntryProto, key_), + PROTOBUF_FIELD_OFFSET(::onnx::StringStringEntryProto, value_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, node_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, name_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, initializer_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, doc_string_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, input_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, output_), + PROTOBUF_FIELD_OFFSET(::onnx::GraphProto, value_info_), + ~0u, + 0, + ~0u, + 1, + ~0u, + ~0u, + ~0u, + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto_Segment, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto_Segment, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto_Segment, begin_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto_Segment, end_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, dims_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, data_type_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, segment_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, float_data_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, int32_data_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, string_data_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, int64_data_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, name_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, doc_string_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, raw_data_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, double_data_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorProto, uint64_data_), + ~0u, + 4, + 3, + ~0u, + ~0u, + ~0u, + ~0u, + 0, + 2, + 1, + ~0u, + ~0u, + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto_Dimension, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto_Dimension, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto_Dimension, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + offsetof(::onnx::TensorShapeProto_DimensionDefaultTypeInternal, dim_value_), + offsetof(::onnx::TensorShapeProto_DimensionDefaultTypeInternal, dim_param_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto_Dimension, denotation_), + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto_Dimension, value_), + ~0u, + ~0u, + 0, + ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TensorShapeProto, dim_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Tensor, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Tensor, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Tensor, elem_type_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Tensor, shape_), + 1, + 0, + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Sequence, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Sequence, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Sequence, elem_type_), + 0, + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Map, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Map, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Map, key_type_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Map, value_type_), + 1, + 0, + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Opaque, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Opaque, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Opaque, domain_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_Opaque, name_), + 0, + 1, + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_SparseTensor, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_SparseTensor, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_SparseTensor, elem_type_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto_SparseTensor, shape_), + 1, + 0, + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto, _internal_metadata_), + ~0u, // no _extensions_ + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto, _oneof_case_[0]), + ~0u, // no _weak_field_map_ + offsetof(::onnx::TypeProtoDefaultTypeInternal, tensor_type_), + offsetof(::onnx::TypeProtoDefaultTypeInternal, sequence_type_), + offsetof(::onnx::TypeProtoDefaultTypeInternal, map_type_), + offsetof(::onnx::TypeProtoDefaultTypeInternal, opaque_type_), + offsetof(::onnx::TypeProtoDefaultTypeInternal, sparse_tensor_type_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto, denotation_), + PROTOBUF_FIELD_OFFSET(::onnx::TypeProto, value_), + ~0u, + ~0u, + ~0u, + ~0u, + ~0u, + 0, + PROTOBUF_FIELD_OFFSET(::onnx::OperatorSetIdProto, _has_bits_), + PROTOBUF_FIELD_OFFSET(::onnx::OperatorSetIdProto, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + PROTOBUF_FIELD_OFFSET(::onnx::OperatorSetIdProto, domain_), + PROTOBUF_FIELD_OFFSET(::onnx::OperatorSetIdProto, version_), + 0, + 1, +}; +static const ::PROTOBUF_NAMESPACE_ID::internal::MigrationSchema schemas[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { + { 0, 19, sizeof(::onnx::AttributeProto)}, + { 33, 41, sizeof(::onnx::ValueInfoProto)}, + { 44, 56, sizeof(::onnx::NodeProto)}, + { 63, 77, sizeof(::onnx::ModelProto)}, + { 86, 93, sizeof(::onnx::StringStringEntryProto)}, + { 95, 107, sizeof(::onnx::GraphProto)}, + { 114, 121, sizeof(::onnx::TensorProto_Segment)}, + { 123, 140, sizeof(::onnx::TensorProto)}, + { 152, 161, sizeof(::onnx::TensorShapeProto_Dimension)}, + { 164, -1, sizeof(::onnx::TensorShapeProto)}, + { 170, 177, sizeof(::onnx::TypeProto_Tensor)}, + { 179, 185, sizeof(::onnx::TypeProto_Sequence)}, + { 186, 193, sizeof(::onnx::TypeProto_Map)}, + { 195, 202, sizeof(::onnx::TypeProto_Opaque)}, + { 204, 211, sizeof(::onnx::TypeProto_SparseTensor)}, + { 213, 225, sizeof(::onnx::TypeProto)}, + { 231, 238, sizeof(::onnx::OperatorSetIdProto)}, +}; + +static ::PROTOBUF_NAMESPACE_ID::Message const * const file_default_instances[] = { + reinterpret_cast(&::onnx::_AttributeProto_default_instance_), + reinterpret_cast(&::onnx::_ValueInfoProto_default_instance_), + reinterpret_cast(&::onnx::_NodeProto_default_instance_), + reinterpret_cast(&::onnx::_ModelProto_default_instance_), + reinterpret_cast(&::onnx::_StringStringEntryProto_default_instance_), + reinterpret_cast(&::onnx::_GraphProto_default_instance_), + reinterpret_cast(&::onnx::_TensorProto_Segment_default_instance_), + reinterpret_cast(&::onnx::_TensorProto_default_instance_), + reinterpret_cast(&::onnx::_TensorShapeProto_Dimension_default_instance_), + reinterpret_cast(&::onnx::_TensorShapeProto_default_instance_), + reinterpret_cast(&::onnx::_TypeProto_Tensor_default_instance_), + reinterpret_cast(&::onnx::_TypeProto_Sequence_default_instance_), + reinterpret_cast(&::onnx::_TypeProto_Map_default_instance_), + reinterpret_cast(&::onnx::_TypeProto_Opaque_default_instance_), + reinterpret_cast(&::onnx::_TypeProto_SparseTensor_default_instance_), + reinterpret_cast(&::onnx::_TypeProto_default_instance_), + reinterpret_cast(&::onnx::_OperatorSetIdProto_default_instance_), +}; + +const char descriptor_table_protodef_onnx_2dml_2eproto[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = + "\n\ronnx-ml.proto\022\004onnx\"\340\003\n\016AttributeProto" + "\022\014\n\004name\030\001 \001(\t\022\025\n\rref_attr_name\030\025 \001(\t\022\022\n" + "\ndoc_string\030\r \001(\t\0220\n\004type\030\024 \001(\0162\".onnx.A" + "ttributeProto.AttributeType\022\t\n\001f\030\002 \001(\002\022\t" + "\n\001i\030\003 \001(\003\022\t\n\001s\030\004 \001(\014\022\034\n\001t\030\005 \001(\0132\021.onnx.T" + "ensorProto\022\033\n\001g\030\006 \001(\0132\020.onnx.GraphProto\022" + "\016\n\006floats\030\007 \003(\002\022\014\n\004ints\030\010 \003(\003\022\017\n\007strings" + "\030\t \003(\014\022\"\n\007tensors\030\n \003(\0132\021.onnx.TensorPro" + "to\022 \n\006graphs\030\013 \003(\0132\020.onnx.GraphProto\"\221\001\n" + "\rAttributeType\022\r\n\tUNDEFINED\020\000\022\t\n\005FLOAT\020\001" + "\022\007\n\003INT\020\002\022\n\n\006STRING\020\003\022\n\n\006TENSOR\020\004\022\t\n\005GRA" + "PH\020\005\022\n\n\006FLOATS\020\006\022\010\n\004INTS\020\007\022\013\n\007STRINGS\020\010\022" + "\013\n\007TENSORS\020\t\022\n\n\006GRAPHS\020\n\"Q\n\016ValueInfoPro" + "to\022\014\n\004name\030\001 \001(\t\022\035\n\004type\030\002 \001(\0132\017.onnx.Ty" + "peProto\022\022\n\ndoc_string\030\003 \001(\t\"\226\001\n\tNodeProt" + "o\022\r\n\005input\030\001 \003(\t\022\016\n\006output\030\002 \003(\t\022\014\n\004name" + "\030\003 \001(\t\022\017\n\007op_type\030\004 \001(\t\022\016\n\006domain\030\007 \001(\t\022" + "\'\n\tattribute\030\005 \003(\0132\024.onnx.AttributeProto" + "\022\022\n\ndoc_string\030\006 \001(\t\"\223\002\n\nModelProto\022\022\n\ni" + "r_version\030\001 \001(\003\022.\n\014opset_import\030\010 \003(\0132\030." + "onnx.OperatorSetIdProto\022\025\n\rproducer_name" + "\030\002 \001(\t\022\030\n\020producer_version\030\003 \001(\t\022\016\n\006doma" + "in\030\004 \001(\t\022\025\n\rmodel_version\030\005 \001(\003\022\022\n\ndoc_s" + "tring\030\006 \001(\t\022\037\n\005graph\030\007 \001(\0132\020.onnx.GraphP" + "roto\0224\n\016metadata_props\030\016 \003(\0132\034.onnx.Stri" + "ngStringEntryProto\"4\n\026StringStringEntryP" + "roto\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t\"\352\001\n\nGra" + "phProto\022\035\n\004node\030\001 \003(\0132\017.onnx.NodeProto\022\014" + "\n\004name\030\002 \001(\t\022&\n\013initializer\030\005 \003(\0132\021.onnx" + ".TensorProto\022\022\n\ndoc_string\030\n \001(\t\022#\n\005inpu" + "t\030\013 \003(\0132\024.onnx.ValueInfoProto\022$\n\006output\030" + "\014 \003(\0132\024.onnx.ValueInfoProto\022(\n\nvalue_inf" + "o\030\r \003(\0132\024.onnx.ValueInfoProto\"\275\004\n\013Tensor" + "Proto\022\014\n\004dims\030\001 \003(\003\022-\n\tdata_type\030\002 \001(\0162\032" + ".onnx.TensorProto.DataType\022*\n\007segment\030\003 " + "\001(\0132\031.onnx.TensorProto.Segment\022\026\n\nfloat_" + "data\030\004 \003(\002B\002\020\001\022\026\n\nint32_data\030\005 \003(\005B\002\020\001\022\023" + "\n\013string_data\030\006 \003(\014\022\026\n\nint64_data\030\007 \003(\003B" + "\002\020\001\022\014\n\004name\030\010 \001(\t\022\022\n\ndoc_string\030\014 \001(\t\022\020\n" + "\010raw_data\030\t \001(\014\022\027\n\013double_data\030\n \003(\001B\002\020\001" + "\022\027\n\013uint64_data\030\013 \003(\004B\002\020\001\032%\n\007Segment\022\r\n\005" + "begin\030\001 \001(\003\022\013\n\003end\030\002 \001(\003\"\332\001\n\010DataType\022\r\n" + "\tUNDEFINED\020\000\022\t\n\005FLOAT\020\001\022\t\n\005UINT8\020\002\022\010\n\004IN" + "T8\020\003\022\n\n\006UINT16\020\004\022\t\n\005INT16\020\005\022\t\n\005INT32\020\006\022\t" + "\n\005INT64\020\007\022\n\n\006STRING\020\010\022\010\n\004BOOL\020\t\022\013\n\007FLOAT" + "16\020\n\022\n\n\006DOUBLE\020\013\022\n\n\006UINT32\020\014\022\n\n\006UINT64\020\r" + "\022\r\n\tCOMPLEX64\020\016\022\016\n\nCOMPLEX128\020\017\022\014\n\010BFLOA" + "T16\020\020\"\225\001\n\020TensorShapeProto\022-\n\003dim\030\001 \003(\0132" + " .onnx.TensorShapeProto.Dimension\032R\n\tDim" + "ension\022\023\n\tdim_value\030\001 \001(\003H\000\022\023\n\tdim_param" + "\030\002 \001(\tH\000\022\022\n\ndenotation\030\003 \001(\tB\007\n\005value\"\226\005" + "\n\tTypeProto\022-\n\013tensor_type\030\001 \001(\0132\026.onnx." + "TypeProto.TensorH\000\0221\n\rsequence_type\030\004 \001(" + "\0132\030.onnx.TypeProto.SequenceH\000\022\'\n\010map_typ" + "e\030\005 \001(\0132\023.onnx.TypeProto.MapH\000\022-\n\013opaque" + "_type\030\007 \001(\0132\026.onnx.TypeProto.OpaqueH\000\022:\n" + "\022sparse_tensor_type\030\010 \001(\0132\034.onnx.TypePro" + "to.SparseTensorH\000\022\022\n\ndenotation\030\006 \001(\t\032^\n" + "\006Tensor\022-\n\telem_type\030\001 \001(\0162\032.onnx.Tensor" + "Proto.DataType\022%\n\005shape\030\002 \001(\0132\026.onnx.Ten" + "sorShapeProto\032.\n\010Sequence\022\"\n\telem_type\030\001" + " \001(\0132\017.onnx.TypeProto\032X\n\003Map\022,\n\010key_type" + "\030\001 \001(\0162\032.onnx.TensorProto.DataType\022#\n\nva" + "lue_type\030\002 \001(\0132\017.onnx.TypeProto\032&\n\006Opaqu" + "e\022\016\n\006domain\030\001 \001(\t\022\014\n\004name\030\002 \001(\t\032d\n\014Spars" + "eTensor\022-\n\telem_type\030\001 \001(\0162\032.onnx.Tensor" + "Proto.DataType\022%\n\005shape\030\002 \001(\0132\026.onnx.Ten" + "sorShapeProtoB\007\n\005value\"5\n\022OperatorSetIdP" + "roto\022\016\n\006domain\030\001 \001(\t\022\017\n\007version\030\002 \001(\003*c\n" + "\007Version\022\022\n\016_START_VERSION\020\000\022\031\n\025IR_VERSI" + "ON_2017_10_10\020\001\022\031\n\025IR_VERSION_2017_10_30" + "\020\002\022\016\n\nIR_VERSION\020\003" + ; +static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_onnx_2dml_2eproto_deps[1] = { +}; +static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_onnx_2dml_2eproto_sccs[13] = { + &scc_info_AttributeProto_onnx_2dml_2eproto.base, + &scc_info_ModelProto_onnx_2dml_2eproto.base, + &scc_info_OperatorSetIdProto_onnx_2dml_2eproto.base, + &scc_info_StringStringEntryProto_onnx_2dml_2eproto.base, + &scc_info_TensorProto_onnx_2dml_2eproto.base, + &scc_info_TensorProto_Segment_onnx_2dml_2eproto.base, + &scc_info_TensorShapeProto_onnx_2dml_2eproto.base, + &scc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto.base, + &scc_info_TypeProto_onnx_2dml_2eproto.base, + &scc_info_TypeProto_Opaque_onnx_2dml_2eproto.base, + &scc_info_TypeProto_SparseTensor_onnx_2dml_2eproto.base, + &scc_info_TypeProto_Tensor_onnx_2dml_2eproto.base, + &scc_info_ValueInfoProto_onnx_2dml_2eproto.base, +}; +static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_onnx_2dml_2eproto_once; +const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_onnx_2dml_2eproto = { + false, false, descriptor_table_protodef_onnx_2dml_2eproto, "onnx-ml.proto", 2858, + &descriptor_table_onnx_2dml_2eproto_once, descriptor_table_onnx_2dml_2eproto_sccs, descriptor_table_onnx_2dml_2eproto_deps, 13, 0, + schemas, file_default_instances, TableStruct_onnx_2dml_2eproto::offsets, + file_level_metadata_onnx_2dml_2eproto, 17, file_level_enum_descriptors_onnx_2dml_2eproto, file_level_service_descriptors_onnx_2dml_2eproto, +}; + +// Force running AddDescriptors() at dynamic initialization time. +static bool dynamic_init_dummy_onnx_2dml_2eproto = (static_cast(::PROTOBUF_NAMESPACE_ID::internal::AddDescriptors(&descriptor_table_onnx_2dml_2eproto)), true); +namespace onnx { +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* AttributeProto_AttributeType_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_onnx_2dml_2eproto); + return file_level_enum_descriptors_onnx_2dml_2eproto[0]; +} +bool AttributeProto_AttributeType_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + case 9: + case 10: + return true; + default: + return false; + } +} + +#if (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +constexpr AttributeProto_AttributeType AttributeProto::UNDEFINED; +constexpr AttributeProto_AttributeType AttributeProto::FLOAT; +constexpr AttributeProto_AttributeType AttributeProto::INT; +constexpr AttributeProto_AttributeType AttributeProto::STRING; +constexpr AttributeProto_AttributeType AttributeProto::TENSOR; +constexpr AttributeProto_AttributeType AttributeProto::GRAPH; +constexpr AttributeProto_AttributeType AttributeProto::FLOATS; +constexpr AttributeProto_AttributeType AttributeProto::INTS; +constexpr AttributeProto_AttributeType AttributeProto::STRINGS; +constexpr AttributeProto_AttributeType AttributeProto::TENSORS; +constexpr AttributeProto_AttributeType AttributeProto::GRAPHS; +constexpr AttributeProto_AttributeType AttributeProto::AttributeType_MIN; +constexpr AttributeProto_AttributeType AttributeProto::AttributeType_MAX; +constexpr int AttributeProto::AttributeType_ARRAYSIZE; +#endif // (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* TensorProto_DataType_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_onnx_2dml_2eproto); + return file_level_enum_descriptors_onnx_2dml_2eproto[1]; +} +bool TensorProto_DataType_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + case 9: + case 10: + case 11: + case 12: + case 13: + case 14: + case 15: + case 16: + return true; + default: + return false; + } +} + +#if (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +constexpr TensorProto_DataType TensorProto::UNDEFINED; +constexpr TensorProto_DataType TensorProto::FLOAT; +constexpr TensorProto_DataType TensorProto::UINT8; +constexpr TensorProto_DataType TensorProto::INT8; +constexpr TensorProto_DataType TensorProto::UINT16; +constexpr TensorProto_DataType TensorProto::INT16; +constexpr TensorProto_DataType TensorProto::INT32; +constexpr TensorProto_DataType TensorProto::INT64; +constexpr TensorProto_DataType TensorProto::STRING; +constexpr TensorProto_DataType TensorProto::BOOL; +constexpr TensorProto_DataType TensorProto::FLOAT16; +constexpr TensorProto_DataType TensorProto::DOUBLE; +constexpr TensorProto_DataType TensorProto::UINT32; +constexpr TensorProto_DataType TensorProto::UINT64; +constexpr TensorProto_DataType TensorProto::COMPLEX64; +constexpr TensorProto_DataType TensorProto::COMPLEX128; +constexpr TensorProto_DataType TensorProto::BFLOAT16; +constexpr TensorProto_DataType TensorProto::DataType_MIN; +constexpr TensorProto_DataType TensorProto::DataType_MAX; +constexpr int TensorProto::DataType_ARRAYSIZE; +#endif // (__cplusplus < 201703) && (!defined(_MSC_VER) || _MSC_VER >= 1900) +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Version_descriptor() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&descriptor_table_onnx_2dml_2eproto); + return file_level_enum_descriptors_onnx_2dml_2eproto[2]; +} +bool Version_IsValid(int value) { + switch (value) { + case 0: + case 1: + case 2: + case 3: + return true; + default: + return false; + } +} + + +// =================================================================== + +void AttributeProto::InitAsDefaultInstance() { + ::onnx::_AttributeProto_default_instance_._instance.get_mutable()->t_ = const_cast< ::onnx::TensorProto*>( + ::onnx::TensorProto::internal_default_instance()); + ::onnx::_AttributeProto_default_instance_._instance.get_mutable()->g_ = const_cast< ::onnx::GraphProto*>( + ::onnx::GraphProto::internal_default_instance()); +} +class AttributeProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_ref_attr_name(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_doc_string(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_type(HasBits* has_bits) { + (*has_bits)[0] |= 256u; + } + static void set_has_f(HasBits* has_bits) { + (*has_bits)[0] |= 128u; + } + static void set_has_i(HasBits* has_bits) { + (*has_bits)[0] |= 64u; + } + static void set_has_s(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::onnx::TensorProto& t(const AttributeProto* msg); + static void set_has_t(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } + static const ::onnx::GraphProto& g(const AttributeProto* msg); + static void set_has_g(HasBits* has_bits) { + (*has_bits)[0] |= 32u; + } +}; + +const ::onnx::TensorProto& +AttributeProto::_Internal::t(const AttributeProto* msg) { + return *msg->t_; +} +const ::onnx::GraphProto& +AttributeProto::_Internal::g(const AttributeProto* msg) { + return *msg->g_; +} +AttributeProto::AttributeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + floats_(arena), + ints_(arena), + strings_(arena), + tensors_(arena), + graphs_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.AttributeProto) +} +AttributeProto::AttributeProto(const AttributeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + floats_(from.floats_), + ints_(from.ints_), + strings_(from.strings_), + tensors_(from.tensors_), + graphs_(from.graphs_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + s_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_s()) { + s_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_s(), + GetArena()); + } + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_doc_string()) { + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_doc_string(), + GetArena()); + } + ref_attr_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_ref_attr_name()) { + ref_attr_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_ref_attr_name(), + GetArena()); + } + if (from._internal_has_t()) { + t_ = new ::onnx::TensorProto(*from.t_); + } else { + t_ = nullptr; + } + if (from._internal_has_g()) { + g_ = new ::onnx::GraphProto(*from.g_); + } else { + g_ = nullptr; + } + ::memcpy(&i_, &from.i_, + static_cast(reinterpret_cast(&type_) - + reinterpret_cast(&i_)) + sizeof(type_)); + // @@protoc_insertion_point(copy_constructor:onnx.AttributeProto) +} + +void AttributeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_onnx_2dml_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + s_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ref_attr_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&t_, 0, static_cast( + reinterpret_cast(&type_) - + reinterpret_cast(&t_)) + sizeof(type_)); +} + +AttributeProto::~AttributeProto() { + // @@protoc_insertion_point(destructor:onnx.AttributeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void AttributeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + s_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ref_attr_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete t_; + if (this != internal_default_instance()) delete g_; +} + +void AttributeProto::ArenaDtor(void* object) { + AttributeProto* _this = reinterpret_cast< AttributeProto* >(object); + (void)_this; +} +void AttributeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void AttributeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const AttributeProto& AttributeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void AttributeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.AttributeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + floats_.Clear(); + ints_.Clear(); + strings_.Clear(); + tensors_.Clear(); + graphs_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000003fu) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + s_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + doc_string_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000008u) { + ref_attr_name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000010u) { + GOOGLE_DCHECK(t_ != nullptr); + t_->Clear(); + } + if (cached_has_bits & 0x00000020u) { + GOOGLE_DCHECK(g_ != nullptr); + g_->Clear(); + } + } + if (cached_has_bits & 0x000000c0u) { + ::memset(&i_, 0, static_cast( + reinterpret_cast(&f_) - + reinterpret_cast(&i_)) + sizeof(f_)); + } + type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* AttributeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.AttributeProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional float f = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 21)) { + _Internal::set_has_f(&has_bits); + f_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // optional int64 i = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 24)) { + _Internal::set_has_i(&has_bits); + i_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional bytes s = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_s(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .onnx.TensorProto t = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ctx->ParseMessage(_internal_mutable_t(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .onnx.GraphProto g = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr = ctx->ParseMessage(_internal_mutable_g(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float floats = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 61)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_floats(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<61>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_floats(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated int64 ints = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 64)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_ints(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<64>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_ints(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated bytes strings = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_strings(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<74>(ptr)); + } else goto handle_unusual; + continue; + // repeated .onnx.TensorProto tensors = 10; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_tensors(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<82>(ptr)); + } else goto handle_unusual; + continue; + // repeated .onnx.GraphProto graphs = 11; + case 11: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 90)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_graphs(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<90>(ptr)); + } else goto handle_unusual; + continue; + // optional string doc_string = 13; + case 13: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 106)) { + auto str = _internal_mutable_doc_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.AttributeProto.doc_string"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .onnx.AttributeProto.AttributeType type = 20; + case 20: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 160)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::onnx::AttributeProto_AttributeType_IsValid(val))) { + _internal_set_type(static_cast<::onnx::AttributeProto_AttributeType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(20, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional string ref_attr_name = 21; + case 21: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 170)) { + auto str = _internal_mutable_ref_attr_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.AttributeProto.ref_attr_name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* AttributeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.AttributeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.AttributeProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional float f = 2; + if (cached_has_bits & 0x00000080u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(2, this->_internal_f(), target); + } + + // optional int64 i = 3; + if (cached_has_bits & 0x00000040u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(3, this->_internal_i(), target); + } + + // optional bytes s = 4; + if (cached_has_bits & 0x00000002u) { + target = stream->WriteBytesMaybeAliased( + 4, this->_internal_s(), target); + } + + // optional .onnx.TensorProto t = 5; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 5, _Internal::t(this), target, stream); + } + + // optional .onnx.GraphProto g = 6; + if (cached_has_bits & 0x00000020u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 6, _Internal::g(this), target, stream); + } + + // repeated float floats = 7; + for (int i = 0, n = this->_internal_floats_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(7, this->_internal_floats(i), target); + } + + // repeated int64 ints = 8; + for (int i = 0, n = this->_internal_ints_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(8, this->_internal_ints(i), target); + } + + // repeated bytes strings = 9; + for (int i = 0, n = this->_internal_strings_size(); i < n; i++) { + const auto& s = this->_internal_strings(i); + target = stream->WriteBytes(9, s, target); + } + + // repeated .onnx.TensorProto tensors = 10; + for (unsigned int i = 0, + n = static_cast(this->_internal_tensors_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(10, this->_internal_tensors(i), target, stream); + } + + // repeated .onnx.GraphProto graphs = 11; + for (unsigned int i = 0, + n = static_cast(this->_internal_graphs_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(11, this->_internal_graphs(i), target, stream); + } + + // optional string doc_string = 13; + if (cached_has_bits & 0x00000004u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_doc_string().data(), static_cast(this->_internal_doc_string().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.AttributeProto.doc_string"); + target = stream->WriteStringMaybeAliased( + 13, this->_internal_doc_string(), target); + } + + // optional .onnx.AttributeProto.AttributeType type = 20; + if (cached_has_bits & 0x00000100u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 20, this->_internal_type(), target); + } + + // optional string ref_attr_name = 21; + if (cached_has_bits & 0x00000008u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_ref_attr_name().data(), static_cast(this->_internal_ref_attr_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.AttributeProto.ref_attr_name"); + target = stream->WriteStringMaybeAliased( + 21, this->_internal_ref_attr_name(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.AttributeProto) + return target; +} + +size_t AttributeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.AttributeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated float floats = 7; + { + unsigned int count = static_cast(this->_internal_floats_size()); + size_t data_size = 4UL * count; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_floats_size()); + total_size += data_size; + } + + // repeated int64 ints = 8; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->ints_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_ints_size()); + total_size += data_size; + } + + // repeated bytes strings = 9; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(strings_.size()); + for (int i = 0, n = strings_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + strings_.Get(i)); + } + + // repeated .onnx.TensorProto tensors = 10; + total_size += 1UL * this->_internal_tensors_size(); + for (const auto& msg : this->tensors_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .onnx.GraphProto graphs = 11; + total_size += 1UL * this->_internal_graphs_size(); + for (const auto& msg : this->graphs_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x000000ffu) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional bytes s = 4; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_s()); + } + + // optional string doc_string = 13; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_doc_string()); + } + + // optional string ref_attr_name = 21; + if (cached_has_bits & 0x00000008u) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_ref_attr_name()); + } + + // optional .onnx.TensorProto t = 5; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *t_); + } + + // optional .onnx.GraphProto g = 6; + if (cached_has_bits & 0x00000020u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *g_); + } + + // optional int64 i = 3; + if (cached_has_bits & 0x00000040u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_i()); + } + + // optional float f = 2; + if (cached_has_bits & 0x00000080u) { + total_size += 1 + 4; + } + + } + // optional .onnx.AttributeProto.AttributeType type = 20; + if (cached_has_bits & 0x00000100u) { + total_size += 2 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_type()); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void AttributeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.AttributeProto) + GOOGLE_DCHECK_NE(&from, this); + const AttributeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.AttributeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.AttributeProto) + MergeFrom(*source); + } +} + +void AttributeProto::MergeFrom(const AttributeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.AttributeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + floats_.MergeFrom(from.floats_); + ints_.MergeFrom(from.ints_); + strings_.MergeFrom(from.strings_); + tensors_.MergeFrom(from.tensors_); + graphs_.MergeFrom(from.graphs_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x000000ffu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_s(from._internal_s()); + } + if (cached_has_bits & 0x00000004u) { + _internal_set_doc_string(from._internal_doc_string()); + } + if (cached_has_bits & 0x00000008u) { + _internal_set_ref_attr_name(from._internal_ref_attr_name()); + } + if (cached_has_bits & 0x00000010u) { + _internal_mutable_t()->::onnx::TensorProto::MergeFrom(from._internal_t()); + } + if (cached_has_bits & 0x00000020u) { + _internal_mutable_g()->::onnx::GraphProto::MergeFrom(from._internal_g()); + } + if (cached_has_bits & 0x00000040u) { + i_ = from.i_; + } + if (cached_has_bits & 0x00000080u) { + f_ = from.f_; + } + _has_bits_[0] |= cached_has_bits; + } + if (cached_has_bits & 0x00000100u) { + _internal_set_type(from._internal_type()); + } +} + +void AttributeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.AttributeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void AttributeProto::CopyFrom(const AttributeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.AttributeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool AttributeProto::IsInitialized() const { + return true; +} + +void AttributeProto::InternalSwap(AttributeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + floats_.InternalSwap(&other->floats_); + ints_.InternalSwap(&other->ints_); + strings_.InternalSwap(&other->strings_); + tensors_.InternalSwap(&other->tensors_); + graphs_.InternalSwap(&other->graphs_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + s_.Swap(&other->s_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + doc_string_.Swap(&other->doc_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ref_attr_name_.Swap(&other->ref_attr_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(AttributeProto, type_) + + sizeof(AttributeProto::type_) + - PROTOBUF_FIELD_OFFSET(AttributeProto, t_)>( + reinterpret_cast(&t_), + reinterpret_cast(&other->t_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata AttributeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void ValueInfoProto::InitAsDefaultInstance() { + ::onnx::_ValueInfoProto_default_instance_._instance.get_mutable()->type_ = const_cast< ::onnx::TypeProto*>( + ::onnx::TypeProto::internal_default_instance()); +} +class ValueInfoProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static const ::onnx::TypeProto& type(const ValueInfoProto* msg); + static void set_has_type(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_doc_string(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +const ::onnx::TypeProto& +ValueInfoProto::_Internal::type(const ValueInfoProto* msg) { + return *msg->type_; +} +ValueInfoProto::ValueInfoProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.ValueInfoProto) +} +ValueInfoProto::ValueInfoProto(const ValueInfoProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_doc_string()) { + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_doc_string(), + GetArena()); + } + if (from._internal_has_type()) { + type_ = new ::onnx::TypeProto(*from.type_); + } else { + type_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:onnx.ValueInfoProto) +} + +void ValueInfoProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_ValueInfoProto_onnx_2dml_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + type_ = nullptr; +} + +ValueInfoProto::~ValueInfoProto() { + // @@protoc_insertion_point(destructor:onnx.ValueInfoProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ValueInfoProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete type_; +} + +void ValueInfoProto::ArenaDtor(void* object) { + ValueInfoProto* _this = reinterpret_cast< ValueInfoProto* >(object); + (void)_this; +} +void ValueInfoProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ValueInfoProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ValueInfoProto& ValueInfoProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_ValueInfoProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void ValueInfoProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.ValueInfoProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + doc_string_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + GOOGLE_DCHECK(type_ != nullptr); + type_->Clear(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ValueInfoProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string name = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.ValueInfoProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .onnx.TypeProto type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string doc_string = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_doc_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.ValueInfoProto.doc_string"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ValueInfoProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.ValueInfoProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.ValueInfoProto.name"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_name(), target); + } + + // optional .onnx.TypeProto type = 2; + if (cached_has_bits & 0x00000004u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::type(this), target, stream); + } + + // optional string doc_string = 3; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_doc_string().data(), static_cast(this->_internal_doc_string().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.ValueInfoProto.doc_string"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_doc_string(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.ValueInfoProto) + return target; +} + +size_t ValueInfoProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.ValueInfoProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + // optional string name = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional string doc_string = 3; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_doc_string()); + } + + // optional .onnx.TypeProto type = 2; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *type_); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ValueInfoProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.ValueInfoProto) + GOOGLE_DCHECK_NE(&from, this); + const ValueInfoProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.ValueInfoProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.ValueInfoProto) + MergeFrom(*source); + } +} + +void ValueInfoProto::MergeFrom(const ValueInfoProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.ValueInfoProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000007u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_doc_string(from._internal_doc_string()); + } + if (cached_has_bits & 0x00000004u) { + _internal_mutable_type()->::onnx::TypeProto::MergeFrom(from._internal_type()); + } + } +} + +void ValueInfoProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.ValueInfoProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ValueInfoProto::CopyFrom(const ValueInfoProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.ValueInfoProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ValueInfoProto::IsInitialized() const { + return true; +} + +void ValueInfoProto::InternalSwap(ValueInfoProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + doc_string_.Swap(&other->doc_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(type_, other->type_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ValueInfoProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void NodeProto::InitAsDefaultInstance() { +} +class NodeProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_op_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_domain(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_doc_string(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } +}; + +NodeProto::NodeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + input_(arena), + output_(arena), + attribute_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.NodeProto) +} +NodeProto::NodeProto(const NodeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + input_(from.input_), + output_(from.output_), + attribute_(from.attribute_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + op_type_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_op_type()) { + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_op_type(), + GetArena()); + } + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_doc_string()) { + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_doc_string(), + GetArena()); + } + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_domain()) { + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_domain(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:onnx.NodeProto) +} + +void NodeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_onnx_2dml_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + op_type_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +NodeProto::~NodeProto() { + // @@protoc_insertion_point(destructor:onnx.NodeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void NodeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + op_type_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + domain_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void NodeProto::ArenaDtor(void* object) { + NodeProto* _this = reinterpret_cast< NodeProto* >(object); + (void)_this; +} +void NodeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void NodeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const NodeProto& NodeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void NodeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.NodeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + input_.Clear(); + output_.Clear(); + attribute_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + op_type_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + doc_string_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000008u) { + domain_.ClearNonDefaultToEmpty(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* NodeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated string input = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_input(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.NodeProto.input"); + #endif // !NDEBUG + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // repeated string output = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_output(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.NodeProto.output"); + #endif // !NDEBUG + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<18>(ptr)); + } else goto handle_unusual; + continue; + // optional string name = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.NodeProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string op_type = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_op_type(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.NodeProto.op_type"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .onnx.AttributeProto attribute = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_attribute(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<42>(ptr)); + } else goto handle_unusual; + continue; + // optional string doc_string = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + auto str = _internal_mutable_doc_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.NodeProto.doc_string"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string domain = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + auto str = _internal_mutable_domain(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.NodeProto.domain"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* NodeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.NodeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated string input = 1; + for (int i = 0, n = this->_internal_input_size(); i < n; i++) { + const auto& s = this->_internal_input(i); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + s.data(), static_cast(s.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.NodeProto.input"); + target = stream->WriteString(1, s, target); + } + + // repeated string output = 2; + for (int i = 0, n = this->_internal_output_size(); i < n; i++) { + const auto& s = this->_internal_output(i); + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + s.data(), static_cast(s.length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.NodeProto.output"); + target = stream->WriteString(2, s, target); + } + + cached_has_bits = _has_bits_[0]; + // optional string name = 3; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.NodeProto.name"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_name(), target); + } + + // optional string op_type = 4; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_op_type().data(), static_cast(this->_internal_op_type().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.NodeProto.op_type"); + target = stream->WriteStringMaybeAliased( + 4, this->_internal_op_type(), target); + } + + // repeated .onnx.AttributeProto attribute = 5; + for (unsigned int i = 0, + n = static_cast(this->_internal_attribute_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(5, this->_internal_attribute(i), target, stream); + } + + // optional string doc_string = 6; + if (cached_has_bits & 0x00000004u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_doc_string().data(), static_cast(this->_internal_doc_string().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.NodeProto.doc_string"); + target = stream->WriteStringMaybeAliased( + 6, this->_internal_doc_string(), target); + } + + // optional string domain = 7; + if (cached_has_bits & 0x00000008u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_domain().data(), static_cast(this->_internal_domain().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.NodeProto.domain"); + target = stream->WriteStringMaybeAliased( + 7, this->_internal_domain(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.NodeProto) + return target; +} + +size_t NodeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.NodeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated string input = 1; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(input_.size()); + for (int i = 0, n = input_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + input_.Get(i)); + } + + // repeated string output = 2; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(output_.size()); + for (int i = 0, n = output_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + output_.Get(i)); + } + + // repeated .onnx.AttributeProto attribute = 5; + total_size += 1UL * this->_internal_attribute_size(); + for (const auto& msg : this->attribute_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + // optional string name = 3; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional string op_type = 4; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_op_type()); + } + + // optional string doc_string = 6; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_doc_string()); + } + + // optional string domain = 7; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_domain()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void NodeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.NodeProto) + GOOGLE_DCHECK_NE(&from, this); + const NodeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.NodeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.NodeProto) + MergeFrom(*source); + } +} + +void NodeProto::MergeFrom(const NodeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.NodeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + input_.MergeFrom(from.input_); + output_.MergeFrom(from.output_); + attribute_.MergeFrom(from.attribute_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_op_type(from._internal_op_type()); + } + if (cached_has_bits & 0x00000004u) { + _internal_set_doc_string(from._internal_doc_string()); + } + if (cached_has_bits & 0x00000008u) { + _internal_set_domain(from._internal_domain()); + } + } +} + +void NodeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.NodeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void NodeProto::CopyFrom(const NodeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.NodeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool NodeProto::IsInitialized() const { + return true; +} + +void NodeProto::InternalSwap(NodeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + input_.InternalSwap(&other->input_); + output_.InternalSwap(&other->output_); + attribute_.InternalSwap(&other->attribute_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + op_type_.Swap(&other->op_type_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + doc_string_.Swap(&other->doc_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + domain_.Swap(&other->domain_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata NodeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void ModelProto::InitAsDefaultInstance() { + ::onnx::_ModelProto_default_instance_._instance.get_mutable()->graph_ = const_cast< ::onnx::GraphProto*>( + ::onnx::GraphProto::internal_default_instance()); +} +class ModelProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_ir_version(HasBits* has_bits) { + (*has_bits)[0] |= 32u; + } + static void set_has_producer_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_producer_version(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static void set_has_domain(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_model_version(HasBits* has_bits) { + (*has_bits)[0] |= 64u; + } + static void set_has_doc_string(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static const ::onnx::GraphProto& graph(const ModelProto* msg); + static void set_has_graph(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } +}; + +const ::onnx::GraphProto& +ModelProto::_Internal::graph(const ModelProto* msg) { + return *msg->graph_; +} +ModelProto::ModelProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + opset_import_(arena), + metadata_props_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.ModelProto) +} +ModelProto::ModelProto(const ModelProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + opset_import_(from.opset_import_), + metadata_props_(from.metadata_props_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + producer_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_producer_name()) { + producer_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_producer_name(), + GetArena()); + } + producer_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_producer_version()) { + producer_version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_producer_version(), + GetArena()); + } + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_domain()) { + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_domain(), + GetArena()); + } + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_doc_string()) { + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_doc_string(), + GetArena()); + } + if (from._internal_has_graph()) { + graph_ = new ::onnx::GraphProto(*from.graph_); + } else { + graph_ = nullptr; + } + ::memcpy(&ir_version_, &from.ir_version_, + static_cast(reinterpret_cast(&model_version_) - + reinterpret_cast(&ir_version_)) + sizeof(model_version_)); + // @@protoc_insertion_point(copy_constructor:onnx.ModelProto) +} + +void ModelProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_ModelProto_onnx_2dml_2eproto.base); + producer_name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + producer_version_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&graph_, 0, static_cast( + reinterpret_cast(&model_version_) - + reinterpret_cast(&graph_)) + sizeof(model_version_)); +} + +ModelProto::~ModelProto() { + // @@protoc_insertion_point(destructor:onnx.ModelProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void ModelProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + producer_name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + producer_version_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + domain_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete graph_; +} + +void ModelProto::ArenaDtor(void* object) { + ModelProto* _this = reinterpret_cast< ModelProto* >(object); + (void)_this; +} +void ModelProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void ModelProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const ModelProto& ModelProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_ModelProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void ModelProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.ModelProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + opset_import_.Clear(); + metadata_props_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + if (cached_has_bits & 0x00000001u) { + producer_name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + producer_version_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + domain_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000008u) { + doc_string_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000010u) { + GOOGLE_DCHECK(graph_ != nullptr); + graph_->Clear(); + } + } + if (cached_has_bits & 0x00000060u) { + ::memset(&ir_version_, 0, static_cast( + reinterpret_cast(&model_version_) - + reinterpret_cast(&ir_version_)) + sizeof(model_version_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* ModelProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional int64 ir_version = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_ir_version(&has_bits); + ir_version_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string producer_name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_producer_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.ModelProto.producer_name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string producer_version = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_producer_version(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.ModelProto.producer_version"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string domain = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + auto str = _internal_mutable_domain(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.ModelProto.domain"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 model_version = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 40)) { + _Internal::set_has_model_version(&has_bits); + model_version_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string doc_string = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + auto str = _internal_mutable_doc_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.ModelProto.doc_string"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .onnx.GraphProto graph = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ctx->ParseMessage(_internal_mutable_graph(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .onnx.OperatorSetIdProto opset_import = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_opset_import(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<66>(ptr)); + } else goto handle_unusual; + continue; + // repeated .onnx.StringStringEntryProto metadata_props = 14; + case 14: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 114)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_metadata_props(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<114>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* ModelProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.ModelProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional int64 ir_version = 1; + if (cached_has_bits & 0x00000020u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_ir_version(), target); + } + + // optional string producer_name = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_producer_name().data(), static_cast(this->_internal_producer_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.ModelProto.producer_name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_producer_name(), target); + } + + // optional string producer_version = 3; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_producer_version().data(), static_cast(this->_internal_producer_version().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.ModelProto.producer_version"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_producer_version(), target); + } + + // optional string domain = 4; + if (cached_has_bits & 0x00000004u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_domain().data(), static_cast(this->_internal_domain().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.ModelProto.domain"); + target = stream->WriteStringMaybeAliased( + 4, this->_internal_domain(), target); + } + + // optional int64 model_version = 5; + if (cached_has_bits & 0x00000040u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(5, this->_internal_model_version(), target); + } + + // optional string doc_string = 6; + if (cached_has_bits & 0x00000008u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_doc_string().data(), static_cast(this->_internal_doc_string().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.ModelProto.doc_string"); + target = stream->WriteStringMaybeAliased( + 6, this->_internal_doc_string(), target); + } + + // optional .onnx.GraphProto graph = 7; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 7, _Internal::graph(this), target, stream); + } + + // repeated .onnx.OperatorSetIdProto opset_import = 8; + for (unsigned int i = 0, + n = static_cast(this->_internal_opset_import_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(8, this->_internal_opset_import(i), target, stream); + } + + // repeated .onnx.StringStringEntryProto metadata_props = 14; + for (unsigned int i = 0, + n = static_cast(this->_internal_metadata_props_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(14, this->_internal_metadata_props(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.ModelProto) + return target; +} + +size_t ModelProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.ModelProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .onnx.OperatorSetIdProto opset_import = 8; + total_size += 1UL * this->_internal_opset_import_size(); + for (const auto& msg : this->opset_import_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .onnx.StringStringEntryProto metadata_props = 14; + total_size += 1UL * this->_internal_metadata_props_size(); + for (const auto& msg : this->metadata_props_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + // optional string producer_name = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_producer_name()); + } + + // optional string producer_version = 3; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_producer_version()); + } + + // optional string domain = 4; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_domain()); + } + + // optional string doc_string = 6; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_doc_string()); + } + + // optional .onnx.GraphProto graph = 7; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *graph_); + } + + // optional int64 ir_version = 1; + if (cached_has_bits & 0x00000020u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_ir_version()); + } + + // optional int64 model_version = 5; + if (cached_has_bits & 0x00000040u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_model_version()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void ModelProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.ModelProto) + GOOGLE_DCHECK_NE(&from, this); + const ModelProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.ModelProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.ModelProto) + MergeFrom(*source); + } +} + +void ModelProto::MergeFrom(const ModelProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.ModelProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + opset_import_.MergeFrom(from.opset_import_); + metadata_props_.MergeFrom(from.metadata_props_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000007fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_producer_name(from._internal_producer_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_producer_version(from._internal_producer_version()); + } + if (cached_has_bits & 0x00000004u) { + _internal_set_domain(from._internal_domain()); + } + if (cached_has_bits & 0x00000008u) { + _internal_set_doc_string(from._internal_doc_string()); + } + if (cached_has_bits & 0x00000010u) { + _internal_mutable_graph()->::onnx::GraphProto::MergeFrom(from._internal_graph()); + } + if (cached_has_bits & 0x00000020u) { + ir_version_ = from.ir_version_; + } + if (cached_has_bits & 0x00000040u) { + model_version_ = from.model_version_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void ModelProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.ModelProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void ModelProto::CopyFrom(const ModelProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.ModelProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool ModelProto::IsInitialized() const { + return true; +} + +void ModelProto::InternalSwap(ModelProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + opset_import_.InternalSwap(&other->opset_import_); + metadata_props_.InternalSwap(&other->metadata_props_); + producer_name_.Swap(&other->producer_name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + producer_version_.Swap(&other->producer_version_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + domain_.Swap(&other->domain_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + doc_string_.Swap(&other->doc_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(ModelProto, model_version_) + + sizeof(ModelProto::model_version_) + - PROTOBUF_FIELD_OFFSET(ModelProto, graph_)>( + reinterpret_cast(&graph_), + reinterpret_cast(&other->graph_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata ModelProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void StringStringEntryProto::InitAsDefaultInstance() { +} +class StringStringEntryProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_key(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_value(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +StringStringEntryProto::StringStringEntryProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.StringStringEntryProto) +} +StringStringEntryProto::StringStringEntryProto(const StringStringEntryProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + key_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_key()) { + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_key(), + GetArena()); + } + value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_value()) { + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_value(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:onnx.StringStringEntryProto) +} + +void StringStringEntryProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_StringStringEntryProto_onnx_2dml_2eproto.base); + key_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + value_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +StringStringEntryProto::~StringStringEntryProto() { + // @@protoc_insertion_point(destructor:onnx.StringStringEntryProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void StringStringEntryProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + key_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + value_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void StringStringEntryProto::ArenaDtor(void* object) { + StringStringEntryProto* _this = reinterpret_cast< StringStringEntryProto* >(object); + (void)_this; +} +void StringStringEntryProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void StringStringEntryProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const StringStringEntryProto& StringStringEntryProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_StringStringEntryProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void StringStringEntryProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.StringStringEntryProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + key_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + value_.ClearNonDefaultToEmpty(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* StringStringEntryProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string key = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_key(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.StringStringEntryProto.key"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string value = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_value(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.StringStringEntryProto.value"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* StringStringEntryProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.StringStringEntryProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string key = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_key().data(), static_cast(this->_internal_key().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.StringStringEntryProto.key"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_key(), target); + } + + // optional string value = 2; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_value().data(), static_cast(this->_internal_value().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.StringStringEntryProto.value"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_value(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.StringStringEntryProto) + return target; +} + +size_t StringStringEntryProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.StringStringEntryProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string key = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_key()); + } + + // optional string value = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_value()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void StringStringEntryProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.StringStringEntryProto) + GOOGLE_DCHECK_NE(&from, this); + const StringStringEntryProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.StringStringEntryProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.StringStringEntryProto) + MergeFrom(*source); + } +} + +void StringStringEntryProto::MergeFrom(const StringStringEntryProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.StringStringEntryProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_key(from._internal_key()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_value(from._internal_value()); + } + } +} + +void StringStringEntryProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.StringStringEntryProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void StringStringEntryProto::CopyFrom(const StringStringEntryProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.StringStringEntryProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool StringStringEntryProto::IsInitialized() const { + return true; +} + +void StringStringEntryProto::InternalSwap(StringStringEntryProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + key_.Swap(&other->key_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + value_.Swap(&other->value_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata StringStringEntryProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void GraphProto::InitAsDefaultInstance() { +} +class GraphProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_doc_string(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +GraphProto::GraphProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + node_(arena), + initializer_(arena), + input_(arena), + output_(arena), + value_info_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.GraphProto) +} +GraphProto::GraphProto(const GraphProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + node_(from.node_), + initializer_(from.initializer_), + input_(from.input_), + output_(from.output_), + value_info_(from.value_info_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_doc_string()) { + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_doc_string(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:onnx.GraphProto) +} + +void GraphProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_AttributeProto_onnx_2dml_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +GraphProto::~GraphProto() { + // @@protoc_insertion_point(destructor:onnx.GraphProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void GraphProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void GraphProto::ArenaDtor(void* object) { + GraphProto* _this = reinterpret_cast< GraphProto* >(object); + (void)_this; +} +void GraphProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void GraphProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const GraphProto& GraphProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_AttributeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void GraphProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.GraphProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + node_.Clear(); + initializer_.Clear(); + input_.Clear(); + output_.Clear(); + value_info_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + doc_string_.ClearNonDefaultToEmpty(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* GraphProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .onnx.NodeProto node = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_node(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + // optional string name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.GraphProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .onnx.TensorProto initializer = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_initializer(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<42>(ptr)); + } else goto handle_unusual; + continue; + // optional string doc_string = 10; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + auto str = _internal_mutable_doc_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.GraphProto.doc_string"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated .onnx.ValueInfoProto input = 11; + case 11: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 90)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_input(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<90>(ptr)); + } else goto handle_unusual; + continue; + // repeated .onnx.ValueInfoProto output = 12; + case 12: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 98)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_output(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<98>(ptr)); + } else goto handle_unusual; + continue; + // repeated .onnx.ValueInfoProto value_info = 13; + case 13: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 106)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_value_info(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<106>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* GraphProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.GraphProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .onnx.NodeProto node = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_node_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_node(i), target, stream); + } + + cached_has_bits = _has_bits_[0]; + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.GraphProto.name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_name(), target); + } + + // repeated .onnx.TensorProto initializer = 5; + for (unsigned int i = 0, + n = static_cast(this->_internal_initializer_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(5, this->_internal_initializer(i), target, stream); + } + + // optional string doc_string = 10; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_doc_string().data(), static_cast(this->_internal_doc_string().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.GraphProto.doc_string"); + target = stream->WriteStringMaybeAliased( + 10, this->_internal_doc_string(), target); + } + + // repeated .onnx.ValueInfoProto input = 11; + for (unsigned int i = 0, + n = static_cast(this->_internal_input_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(11, this->_internal_input(i), target, stream); + } + + // repeated .onnx.ValueInfoProto output = 12; + for (unsigned int i = 0, + n = static_cast(this->_internal_output_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(12, this->_internal_output(i), target, stream); + } + + // repeated .onnx.ValueInfoProto value_info = 13; + for (unsigned int i = 0, + n = static_cast(this->_internal_value_info_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(13, this->_internal_value_info(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.GraphProto) + return target; +} + +size_t GraphProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.GraphProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .onnx.NodeProto node = 1; + total_size += 1UL * this->_internal_node_size(); + for (const auto& msg : this->node_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .onnx.TensorProto initializer = 5; + total_size += 1UL * this->_internal_initializer_size(); + for (const auto& msg : this->initializer_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .onnx.ValueInfoProto input = 11; + total_size += 1UL * this->_internal_input_size(); + for (const auto& msg : this->input_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .onnx.ValueInfoProto output = 12; + total_size += 1UL * this->_internal_output_size(); + for (const auto& msg : this->output_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + // repeated .onnx.ValueInfoProto value_info = 13; + total_size += 1UL * this->_internal_value_info_size(); + for (const auto& msg : this->value_info_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string name = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional string doc_string = 10; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_doc_string()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void GraphProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.GraphProto) + GOOGLE_DCHECK_NE(&from, this); + const GraphProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.GraphProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.GraphProto) + MergeFrom(*source); + } +} + +void GraphProto::MergeFrom(const GraphProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.GraphProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + node_.MergeFrom(from.node_); + initializer_.MergeFrom(from.initializer_); + input_.MergeFrom(from.input_); + output_.MergeFrom(from.output_); + value_info_.MergeFrom(from.value_info_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_doc_string(from._internal_doc_string()); + } + } +} + +void GraphProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.GraphProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void GraphProto::CopyFrom(const GraphProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.GraphProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool GraphProto::IsInitialized() const { + return true; +} + +void GraphProto::InternalSwap(GraphProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + node_.InternalSwap(&other->node_); + initializer_.InternalSwap(&other->initializer_); + input_.InternalSwap(&other->input_); + output_.InternalSwap(&other->output_); + value_info_.InternalSwap(&other->value_info_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + doc_string_.Swap(&other->doc_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata GraphProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorProto_Segment::InitAsDefaultInstance() { +} +class TensorProto_Segment::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_begin(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_end(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +TensorProto_Segment::TensorProto_Segment(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TensorProto.Segment) +} +TensorProto_Segment::TensorProto_Segment(const TensorProto_Segment& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::memcpy(&begin_, &from.begin_, + static_cast(reinterpret_cast(&end_) - + reinterpret_cast(&begin_)) + sizeof(end_)); + // @@protoc_insertion_point(copy_constructor:onnx.TensorProto.Segment) +} + +void TensorProto_Segment::SharedCtor() { + ::memset(&begin_, 0, static_cast( + reinterpret_cast(&end_) - + reinterpret_cast(&begin_)) + sizeof(end_)); +} + +TensorProto_Segment::~TensorProto_Segment() { + // @@protoc_insertion_point(destructor:onnx.TensorProto.Segment) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorProto_Segment::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void TensorProto_Segment::ArenaDtor(void* object) { + TensorProto_Segment* _this = reinterpret_cast< TensorProto_Segment* >(object); + (void)_this; +} +void TensorProto_Segment::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorProto_Segment::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorProto_Segment& TensorProto_Segment::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorProto_Segment_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TensorProto_Segment::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TensorProto.Segment) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + ::memset(&begin_, 0, static_cast( + reinterpret_cast(&end_) - + reinterpret_cast(&begin_)) + sizeof(end_)); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorProto_Segment::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional int64 begin = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _Internal::set_has_begin(&has_bits); + begin_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 end = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_end(&has_bits); + end_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorProto_Segment::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TensorProto.Segment) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional int64 begin = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_begin(), target); + } + + // optional int64 end = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(2, this->_internal_end(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TensorProto.Segment) + return target; +} + +size_t TensorProto_Segment::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TensorProto.Segment) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional int64 begin = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_begin()); + } + + // optional int64 end = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_end()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorProto_Segment::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TensorProto.Segment) + GOOGLE_DCHECK_NE(&from, this); + const TensorProto_Segment* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TensorProto.Segment) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TensorProto.Segment) + MergeFrom(*source); + } +} + +void TensorProto_Segment::MergeFrom(const TensorProto_Segment& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TensorProto.Segment) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + begin_ = from.begin_; + } + if (cached_has_bits & 0x00000002u) { + end_ = from.end_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TensorProto_Segment::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TensorProto.Segment) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorProto_Segment::CopyFrom(const TensorProto_Segment& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TensorProto.Segment) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorProto_Segment::IsInitialized() const { + return true; +} + +void TensorProto_Segment::InternalSwap(TensorProto_Segment* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TensorProto_Segment, end_) + + sizeof(TensorProto_Segment::end_) + - PROTOBUF_FIELD_OFFSET(TensorProto_Segment, begin_)>( + reinterpret_cast(&begin_), + reinterpret_cast(&other->begin_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorProto_Segment::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorProto::InitAsDefaultInstance() { + ::onnx::_TensorProto_default_instance_._instance.get_mutable()->segment_ = const_cast< ::onnx::TensorProto_Segment*>( + ::onnx::TensorProto_Segment::internal_default_instance()); +} +class TensorProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_data_type(HasBits* has_bits) { + (*has_bits)[0] |= 16u; + } + static const ::onnx::TensorProto_Segment& segment(const TensorProto* msg); + static void set_has_segment(HasBits* has_bits) { + (*has_bits)[0] |= 8u; + } + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_doc_string(HasBits* has_bits) { + (*has_bits)[0] |= 4u; + } + static void set_has_raw_data(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +const ::onnx::TensorProto_Segment& +TensorProto::_Internal::segment(const TensorProto* msg) { + return *msg->segment_; +} +TensorProto::TensorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + dims_(arena), + float_data_(arena), + int32_data_(arena), + string_data_(arena), + int64_data_(arena), + double_data_(arena), + uint64_data_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TensorProto) +} +TensorProto::TensorProto(const TensorProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_), + dims_(from.dims_), + float_data_(from.float_data_), + int32_data_(from.int32_data_), + string_data_(from.string_data_), + int64_data_(from.int64_data_), + double_data_(from.double_data_), + uint64_data_(from.uint64_data_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + raw_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_raw_data()) { + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_raw_data(), + GetArena()); + } + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_doc_string()) { + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_doc_string(), + GetArena()); + } + if (from._internal_has_segment()) { + segment_ = new ::onnx::TensorProto_Segment(*from.segment_); + } else { + segment_ = nullptr; + } + data_type_ = from.data_type_; + // @@protoc_insertion_point(copy_constructor:onnx.TensorProto) +} + +void TensorProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorProto_onnx_2dml_2eproto.base); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + raw_data_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + ::memset(&segment_, 0, static_cast( + reinterpret_cast(&data_type_) - + reinterpret_cast(&segment_)) + sizeof(data_type_)); +} + +TensorProto::~TensorProto() { + // @@protoc_insertion_point(destructor:onnx.TensorProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + raw_data_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + doc_string_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (this != internal_default_instance()) delete segment_; +} + +void TensorProto::ArenaDtor(void* object) { + TensorProto* _this = reinterpret_cast< TensorProto* >(object); + (void)_this; +} +void TensorProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorProto& TensorProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TensorProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TensorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + dims_.Clear(); + float_data_.Clear(); + int32_data_.Clear(); + string_data_.Clear(); + int64_data_.Clear(); + double_data_.Clear(); + uint64_data_.Clear(); + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000000fu) { + if (cached_has_bits & 0x00000001u) { + name_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + raw_data_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000004u) { + doc_string_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000008u) { + GOOGLE_DCHECK(segment_ != nullptr); + segment_->Clear(); + } + } + data_type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated int64 dims = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ptr -= 1; + do { + ptr += 1; + _internal_add_dims(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<8>(ptr)); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_dims(), ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional .onnx.TensorProto.DataType data_type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::onnx::TensorProto_DataType_IsValid(val))) { + _internal_set_data_type(static_cast<::onnx::TensorProto_DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(2, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional .onnx.TensorProto.Segment segment = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + ptr = ctx->ParseMessage(_internal_mutable_segment(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated float float_data = 4 [packed = true]; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedFloatParser(_internal_mutable_float_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 37) { + _internal_add_float_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(float); + } else goto handle_unusual; + continue; + // repeated int32 int32_data = 5 [packed = true]; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt32Parser(_internal_mutable_int32_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 40) { + _internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated bytes string_data = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + ptr -= 1; + do { + ptr += 1; + auto str = _internal_add_string_data(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<50>(ptr)); + } else goto handle_unusual; + continue; + // repeated int64 int64_data = 7 [packed = true]; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedInt64Parser(_internal_mutable_int64_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 56) { + _internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string name = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TensorProto.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional bytes raw_data = 9; + case 9: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 74)) { + auto str = _internal_mutable_raw_data(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + CHK_(ptr); + } else goto handle_unusual; + continue; + // repeated double double_data = 10 [packed = true]; + case 10: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 82)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedDoubleParser(_internal_mutable_double_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 81) { + _internal_add_double_data(::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad(ptr)); + ptr += sizeof(double); + } else goto handle_unusual; + continue; + // repeated uint64 uint64_data = 11 [packed = true]; + case 11: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 90)) { + ptr = ::PROTOBUF_NAMESPACE_ID::internal::PackedUInt64Parser(_internal_mutable_uint64_data(), ptr, ctx); + CHK_(ptr); + } else if (static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 88) { + _internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string doc_string = 12; + case 12: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 98)) { + auto str = _internal_mutable_doc_string(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TensorProto.doc_string"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TensorProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated int64 dims = 1; + for (int i = 0, n = this->_internal_dims_size(); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_dims(i), target); + } + + cached_has_bits = _has_bits_[0]; + // optional .onnx.TensorProto.DataType data_type = 2; + if (cached_has_bits & 0x00000010u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 2, this->_internal_data_type(), target); + } + + // optional .onnx.TensorProto.Segment segment = 3; + if (cached_has_bits & 0x00000008u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 3, _Internal::segment(this), target, stream); + } + + // repeated float float_data = 4 [packed = true]; + if (this->_internal_float_data_size() > 0) { + target = stream->WriteFixedPacked(4, _internal_float_data(), target); + } + + // repeated int32 int32_data = 5 [packed = true]; + { + int byte_size = _int32_data_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt32Packed( + 5, _internal_int32_data(), byte_size, target); + } + } + + // repeated bytes string_data = 6; + for (int i = 0, n = this->_internal_string_data_size(); i < n; i++) { + const auto& s = this->_internal_string_data(i); + target = stream->WriteBytes(6, s, target); + } + + // repeated int64 int64_data = 7 [packed = true]; + { + int byte_size = _int64_data_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteInt64Packed( + 7, _internal_int64_data(), byte_size, target); + } + } + + // optional string name = 8; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TensorProto.name"); + target = stream->WriteStringMaybeAliased( + 8, this->_internal_name(), target); + } + + // optional bytes raw_data = 9; + if (cached_has_bits & 0x00000002u) { + target = stream->WriteBytesMaybeAliased( + 9, this->_internal_raw_data(), target); + } + + // repeated double double_data = 10 [packed = true]; + if (this->_internal_double_data_size() > 0) { + target = stream->WriteFixedPacked(10, _internal_double_data(), target); + } + + // repeated uint64 uint64_data = 11 [packed = true]; + { + int byte_size = _uint64_data_cached_byte_size_.load(std::memory_order_relaxed); + if (byte_size > 0) { + target = stream->WriteUInt64Packed( + 11, _internal_uint64_data(), byte_size, target); + } + } + + // optional string doc_string = 12; + if (cached_has_bits & 0x00000004u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_doc_string().data(), static_cast(this->_internal_doc_string().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TensorProto.doc_string"); + target = stream->WriteStringMaybeAliased( + 12, this->_internal_doc_string(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TensorProto) + return target; +} + +size_t TensorProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TensorProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated int64 dims = 1; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->dims_); + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(this->_internal_dims_size()); + total_size += data_size; + } + + // repeated float float_data = 4 [packed = true]; + { + unsigned int count = static_cast(this->_internal_float_data_size()); + size_t data_size = 4UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _float_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated int32 int32_data = 5 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int32Size(this->int32_data_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _int32_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated bytes string_data = 6; + total_size += 1 * + ::PROTOBUF_NAMESPACE_ID::internal::FromIntSize(string_data_.size()); + for (int i = 0, n = string_data_.size(); i < n; i++) { + total_size += ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + string_data_.Get(i)); + } + + // repeated int64 int64_data = 7 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + Int64Size(this->int64_data_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _int64_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated double double_data = 10 [packed = true]; + { + unsigned int count = static_cast(this->_internal_double_data_size()); + size_t data_size = 8UL * count; + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _double_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + // repeated uint64 uint64_data = 11 [packed = true]; + { + size_t data_size = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + UInt64Size(this->uint64_data_); + if (data_size > 0) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int32Size( + static_cast<::PROTOBUF_NAMESPACE_ID::int32>(data_size)); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(data_size); + _uint64_data_cached_byte_size_.store(cached_size, + std::memory_order_relaxed); + total_size += data_size; + } + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + // optional string name = 8; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + // optional bytes raw_data = 9; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::BytesSize( + this->_internal_raw_data()); + } + + // optional string doc_string = 12; + if (cached_has_bits & 0x00000004u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_doc_string()); + } + + // optional .onnx.TensorProto.Segment segment = 3; + if (cached_has_bits & 0x00000008u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *segment_); + } + + // optional .onnx.TensorProto.DataType data_type = 2; + if (cached_has_bits & 0x00000010u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_data_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TensorProto) + GOOGLE_DCHECK_NE(&from, this); + const TensorProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TensorProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TensorProto) + MergeFrom(*source); + } +} + +void TensorProto::MergeFrom(const TensorProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TensorProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + dims_.MergeFrom(from.dims_); + float_data_.MergeFrom(from.float_data_); + int32_data_.MergeFrom(from.int32_data_); + string_data_.MergeFrom(from.string_data_); + int64_data_.MergeFrom(from.int64_data_); + double_data_.MergeFrom(from.double_data_); + uint64_data_.MergeFrom(from.uint64_data_); + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x0000001fu) { + if (cached_has_bits & 0x00000001u) { + _internal_set_name(from._internal_name()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_raw_data(from._internal_raw_data()); + } + if (cached_has_bits & 0x00000004u) { + _internal_set_doc_string(from._internal_doc_string()); + } + if (cached_has_bits & 0x00000008u) { + _internal_mutable_segment()->::onnx::TensorProto_Segment::MergeFrom(from._internal_segment()); + } + if (cached_has_bits & 0x00000010u) { + data_type_ = from.data_type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TensorProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TensorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorProto::CopyFrom(const TensorProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TensorProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorProto::IsInitialized() const { + return true; +} + +void TensorProto::InternalSwap(TensorProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + dims_.InternalSwap(&other->dims_); + float_data_.InternalSwap(&other->float_data_); + int32_data_.InternalSwap(&other->int32_data_); + string_data_.InternalSwap(&other->string_data_); + int64_data_.InternalSwap(&other->int64_data_); + double_data_.InternalSwap(&other->double_data_); + uint64_data_.InternalSwap(&other->uint64_data_); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + raw_data_.Swap(&other->raw_data_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + doc_string_.Swap(&other->doc_string_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TensorProto, data_type_) + + sizeof(TensorProto::data_type_) + - PROTOBUF_FIELD_OFFSET(TensorProto, segment_)>( + reinterpret_cast(&segment_), + reinterpret_cast(&other->segment_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorShapeProto_Dimension::InitAsDefaultInstance() { + ::onnx::_TensorShapeProto_Dimension_default_instance_.dim_value_ = PROTOBUF_LONGLONG(0); + ::onnx::_TensorShapeProto_Dimension_default_instance_.dim_param_.UnsafeSetDefault( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} +class TensorShapeProto_Dimension::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_denotation(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +TensorShapeProto_Dimension::TensorShapeProto_Dimension(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TensorShapeProto.Dimension) +} +TensorShapeProto_Dimension::TensorShapeProto_Dimension(const TensorShapeProto_Dimension& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + denotation_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_denotation()) { + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_denotation(), + GetArena()); + } + clear_has_value(); + switch (from.value_case()) { + case kDimValue: { + _internal_set_dim_value(from._internal_dim_value()); + break; + } + case kDimParam: { + _internal_set_dim_param(from._internal_dim_param()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:onnx.TensorShapeProto.Dimension) +} + +void TensorShapeProto_Dimension::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto.base); + denotation_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + clear_has_value(); +} + +TensorShapeProto_Dimension::~TensorShapeProto_Dimension() { + // @@protoc_insertion_point(destructor:onnx.TensorShapeProto.Dimension) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorShapeProto_Dimension::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + denotation_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (has_value()) { + clear_value(); + } +} + +void TensorShapeProto_Dimension::ArenaDtor(void* object) { + TensorShapeProto_Dimension* _this = reinterpret_cast< TensorShapeProto_Dimension* >(object); + (void)_this; +} +void TensorShapeProto_Dimension::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorShapeProto_Dimension::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorShapeProto_Dimension& TensorShapeProto_Dimension::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorShapeProto_Dimension_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TensorShapeProto_Dimension::clear_value() { +// @@protoc_insertion_point(one_of_clear_start:onnx.TensorShapeProto.Dimension) + switch (value_case()) { + case kDimValue: { + // No need to clear + break; + } + case kDimParam: { + value_.dim_param_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + _oneof_case_[0] = VALUE_NOT_SET; +} + + +void TensorShapeProto_Dimension::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TensorShapeProto.Dimension) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + denotation_.ClearNonDefaultToEmpty(); + } + clear_value(); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorShapeProto_Dimension::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // int64 dim_value = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + _internal_set_dim_value(::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr)); + CHK_(ptr); + } else goto handle_unusual; + continue; + // string dim_param = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_dim_param(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TensorShapeProto.Dimension.dim_param"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string denotation = 3; + case 3: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) { + auto str = _internal_mutable_denotation(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TensorShapeProto.Dimension.denotation"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorShapeProto_Dimension::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TensorShapeProto.Dimension) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + switch (value_case()) { + case kDimValue: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(1, this->_internal_dim_value(), target); + break; + } + case kDimParam: { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_dim_param().data(), static_cast(this->_internal_dim_param().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TensorShapeProto.Dimension.dim_param"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_dim_param(), target); + break; + } + default: ; + } + cached_has_bits = _has_bits_[0]; + // optional string denotation = 3; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_denotation().data(), static_cast(this->_internal_denotation().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TensorShapeProto.Dimension.denotation"); + target = stream->WriteStringMaybeAliased( + 3, this->_internal_denotation(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TensorShapeProto.Dimension) + return target; +} + +size_t TensorShapeProto_Dimension::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TensorShapeProto.Dimension) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // optional string denotation = 3; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_denotation()); + } + + switch (value_case()) { + // int64 dim_value = 1; + case kDimValue: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_dim_value()); + break; + } + // string dim_param = 2; + case kDimParam: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_dim_param()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorShapeProto_Dimension::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TensorShapeProto.Dimension) + GOOGLE_DCHECK_NE(&from, this); + const TensorShapeProto_Dimension* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TensorShapeProto.Dimension) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TensorShapeProto.Dimension) + MergeFrom(*source); + } +} + +void TensorShapeProto_Dimension::MergeFrom(const TensorShapeProto_Dimension& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TensorShapeProto.Dimension) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from._internal_has_denotation()) { + _internal_set_denotation(from._internal_denotation()); + } + switch (from.value_case()) { + case kDimValue: { + _internal_set_dim_value(from._internal_dim_value()); + break; + } + case kDimParam: { + _internal_set_dim_param(from._internal_dim_param()); + break; + } + case VALUE_NOT_SET: { + break; + } + } +} + +void TensorShapeProto_Dimension::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TensorShapeProto.Dimension) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorShapeProto_Dimension::CopyFrom(const TensorShapeProto_Dimension& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TensorShapeProto.Dimension) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorShapeProto_Dimension::IsInitialized() const { + return true; +} + +void TensorShapeProto_Dimension::InternalSwap(TensorShapeProto_Dimension* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + denotation_.Swap(&other->denotation_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(value_, other->value_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorShapeProto_Dimension::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TensorShapeProto::InitAsDefaultInstance() { +} +class TensorShapeProto::_Internal { + public: +}; + +TensorShapeProto::TensorShapeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena), + dim_(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TensorShapeProto) +} +TensorShapeProto::TensorShapeProto(const TensorShapeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + dim_(from.dim_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:onnx.TensorShapeProto) +} + +void TensorShapeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TensorShapeProto_onnx_2dml_2eproto.base); +} + +TensorShapeProto::~TensorShapeProto() { + // @@protoc_insertion_point(destructor:onnx.TensorShapeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TensorShapeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); +} + +void TensorShapeProto::ArenaDtor(void* object) { + TensorShapeProto* _this = reinterpret_cast< TensorShapeProto* >(object); + (void)_this; +} +void TensorShapeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TensorShapeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TensorShapeProto& TensorShapeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TensorShapeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TensorShapeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + dim_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TensorShapeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // repeated .onnx.TensorShapeProto.Dimension dim = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr -= 1; + do { + ptr += 1; + ptr = ctx->ParseMessage(_internal_add_dim(), ptr); + CHK_(ptr); + if (!ctx->DataAvailable(ptr)) break; + } while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr)); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TensorShapeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + // repeated .onnx.TensorShapeProto.Dimension dim = 1; + for (unsigned int i = 0, + n = static_cast(this->_internal_dim_size()); i < n; i++) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage(1, this->_internal_dim(i), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TensorShapeProto) + return target; +} + +size_t TensorShapeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TensorShapeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // repeated .onnx.TensorShapeProto.Dimension dim = 1; + total_size += 1UL * this->_internal_dim_size(); + for (const auto& msg : this->dim_) { + total_size += + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TensorShapeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TensorShapeProto) + GOOGLE_DCHECK_NE(&from, this); + const TensorShapeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TensorShapeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TensorShapeProto) + MergeFrom(*source); + } +} + +void TensorShapeProto::MergeFrom(const TensorShapeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TensorShapeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + dim_.MergeFrom(from.dim_); +} + +void TensorShapeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TensorShapeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TensorShapeProto::CopyFrom(const TensorShapeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TensorShapeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TensorShapeProto::IsInitialized() const { + return true; +} + +void TensorShapeProto::InternalSwap(TensorShapeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + dim_.InternalSwap(&other->dim_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TensorShapeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_Tensor::InitAsDefaultInstance() { + ::onnx::_TypeProto_Tensor_default_instance_._instance.get_mutable()->shape_ = const_cast< ::onnx::TensorShapeProto*>( + ::onnx::TensorShapeProto::internal_default_instance()); +} +class TypeProto_Tensor::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_elem_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::onnx::TensorShapeProto& shape(const TypeProto_Tensor* msg); + static void set_has_shape(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::onnx::TensorShapeProto& +TypeProto_Tensor::_Internal::shape(const TypeProto_Tensor* msg) { + return *msg->shape_; +} +TypeProto_Tensor::TypeProto_Tensor(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TypeProto.Tensor) +} +TypeProto_Tensor::TypeProto_Tensor(const TypeProto_Tensor& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_shape()) { + shape_ = new ::onnx::TensorShapeProto(*from.shape_); + } else { + shape_ = nullptr; + } + elem_type_ = from.elem_type_; + // @@protoc_insertion_point(copy_constructor:onnx.TypeProto.Tensor) +} + +void TypeProto_Tensor::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_Tensor_onnx_2dml_2eproto.base); + ::memset(&shape_, 0, static_cast( + reinterpret_cast(&elem_type_) - + reinterpret_cast(&shape_)) + sizeof(elem_type_)); +} + +TypeProto_Tensor::~TypeProto_Tensor() { + // @@protoc_insertion_point(destructor:onnx.TypeProto.Tensor) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_Tensor::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete shape_; +} + +void TypeProto_Tensor::ArenaDtor(void* object) { + TypeProto_Tensor* _this = reinterpret_cast< TypeProto_Tensor* >(object); + (void)_this; +} +void TypeProto_Tensor::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_Tensor::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_Tensor& TypeProto_Tensor::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_Tensor_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_Tensor::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TypeProto.Tensor) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(shape_ != nullptr); + shape_->Clear(); + } + elem_type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_Tensor::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .onnx.TensorProto.DataType elem_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::onnx::TensorProto_DataType_IsValid(val))) { + _internal_set_elem_type(static_cast<::onnx::TensorProto_DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(1, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional .onnx.TensorShapeProto shape = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_shape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_Tensor::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TypeProto.Tensor) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .onnx.TensorProto.DataType elem_type = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_elem_type(), target); + } + + // optional .onnx.TensorShapeProto shape = 2; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::shape(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TypeProto.Tensor) + return target; +} + +size_t TypeProto_Tensor::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TypeProto.Tensor) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional .onnx.TensorShapeProto shape = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *shape_); + } + + // optional .onnx.TensorProto.DataType elem_type = 1; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_elem_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_Tensor::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TypeProto.Tensor) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_Tensor* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TypeProto.Tensor) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TypeProto.Tensor) + MergeFrom(*source); + } +} + +void TypeProto_Tensor::MergeFrom(const TypeProto_Tensor& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TypeProto.Tensor) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_mutable_shape()->::onnx::TensorShapeProto::MergeFrom(from._internal_shape()); + } + if (cached_has_bits & 0x00000002u) { + elem_type_ = from.elem_type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TypeProto_Tensor::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TypeProto.Tensor) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_Tensor::CopyFrom(const TypeProto_Tensor& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TypeProto.Tensor) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_Tensor::IsInitialized() const { + return true; +} + +void TypeProto_Tensor::InternalSwap(TypeProto_Tensor* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TypeProto_Tensor, elem_type_) + + sizeof(TypeProto_Tensor::elem_type_) + - PROTOBUF_FIELD_OFFSET(TypeProto_Tensor, shape_)>( + reinterpret_cast(&shape_), + reinterpret_cast(&other->shape_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_Tensor::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_Sequence::InitAsDefaultInstance() { + ::onnx::_TypeProto_Sequence_default_instance_._instance.get_mutable()->elem_type_ = const_cast< ::onnx::TypeProto*>( + ::onnx::TypeProto::internal_default_instance()); +} +class TypeProto_Sequence::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static const ::onnx::TypeProto& elem_type(const TypeProto_Sequence* msg); + static void set_has_elem_type(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::onnx::TypeProto& +TypeProto_Sequence::_Internal::elem_type(const TypeProto_Sequence* msg) { + return *msg->elem_type_; +} +TypeProto_Sequence::TypeProto_Sequence(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TypeProto.Sequence) +} +TypeProto_Sequence::TypeProto_Sequence(const TypeProto_Sequence& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_elem_type()) { + elem_type_ = new ::onnx::TypeProto(*from.elem_type_); + } else { + elem_type_ = nullptr; + } + // @@protoc_insertion_point(copy_constructor:onnx.TypeProto.Sequence) +} + +void TypeProto_Sequence::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_onnx_2dml_2eproto.base); + elem_type_ = nullptr; +} + +TypeProto_Sequence::~TypeProto_Sequence() { + // @@protoc_insertion_point(destructor:onnx.TypeProto.Sequence) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_Sequence::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete elem_type_; +} + +void TypeProto_Sequence::ArenaDtor(void* object) { + TypeProto_Sequence* _this = reinterpret_cast< TypeProto_Sequence* >(object); + (void)_this; +} +void TypeProto_Sequence::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_Sequence::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_Sequence& TypeProto_Sequence::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_Sequence::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TypeProto.Sequence) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(elem_type_ != nullptr); + elem_type_->Clear(); + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_Sequence::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .onnx.TypeProto elem_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr = ctx->ParseMessage(_internal_mutable_elem_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_Sequence::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TypeProto.Sequence) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .onnx.TypeProto elem_type = 1; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 1, _Internal::elem_type(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TypeProto.Sequence) + return target; +} + +size_t TypeProto_Sequence::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TypeProto.Sequence) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // optional .onnx.TypeProto elem_type = 1; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *elem_type_); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_Sequence::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TypeProto.Sequence) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_Sequence* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TypeProto.Sequence) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TypeProto.Sequence) + MergeFrom(*source); + } +} + +void TypeProto_Sequence::MergeFrom(const TypeProto_Sequence& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TypeProto.Sequence) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from._internal_has_elem_type()) { + _internal_mutable_elem_type()->::onnx::TypeProto::MergeFrom(from._internal_elem_type()); + } +} + +void TypeProto_Sequence::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TypeProto.Sequence) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_Sequence::CopyFrom(const TypeProto_Sequence& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TypeProto.Sequence) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_Sequence::IsInitialized() const { + return true; +} + +void TypeProto_Sequence::InternalSwap(TypeProto_Sequence* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + swap(elem_type_, other->elem_type_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_Sequence::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_Map::InitAsDefaultInstance() { + ::onnx::_TypeProto_Map_default_instance_._instance.get_mutable()->value_type_ = const_cast< ::onnx::TypeProto*>( + ::onnx::TypeProto::internal_default_instance()); +} +class TypeProto_Map::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_key_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::onnx::TypeProto& value_type(const TypeProto_Map* msg); + static void set_has_value_type(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::onnx::TypeProto& +TypeProto_Map::_Internal::value_type(const TypeProto_Map* msg) { + return *msg->value_type_; +} +TypeProto_Map::TypeProto_Map(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TypeProto.Map) +} +TypeProto_Map::TypeProto_Map(const TypeProto_Map& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_value_type()) { + value_type_ = new ::onnx::TypeProto(*from.value_type_); + } else { + value_type_ = nullptr; + } + key_type_ = from.key_type_; + // @@protoc_insertion_point(copy_constructor:onnx.TypeProto.Map) +} + +void TypeProto_Map::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_onnx_2dml_2eproto.base); + ::memset(&value_type_, 0, static_cast( + reinterpret_cast(&key_type_) - + reinterpret_cast(&value_type_)) + sizeof(key_type_)); +} + +TypeProto_Map::~TypeProto_Map() { + // @@protoc_insertion_point(destructor:onnx.TypeProto.Map) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_Map::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete value_type_; +} + +void TypeProto_Map::ArenaDtor(void* object) { + TypeProto_Map* _this = reinterpret_cast< TypeProto_Map* >(object); + (void)_this; +} +void TypeProto_Map::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_Map::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_Map& TypeProto_Map::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_Map::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TypeProto.Map) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(value_type_ != nullptr); + value_type_->Clear(); + } + key_type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_Map::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .onnx.TensorProto.DataType key_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::onnx::TensorProto_DataType_IsValid(val))) { + _internal_set_key_type(static_cast<::onnx::TensorProto_DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(1, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional .onnx.TypeProto value_type = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_value_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_Map::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TypeProto.Map) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .onnx.TensorProto.DataType key_type = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_key_type(), target); + } + + // optional .onnx.TypeProto value_type = 2; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::value_type(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TypeProto.Map) + return target; +} + +size_t TypeProto_Map::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TypeProto.Map) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional .onnx.TypeProto value_type = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_type_); + } + + // optional .onnx.TensorProto.DataType key_type = 1; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_key_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_Map::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TypeProto.Map) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_Map* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TypeProto.Map) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TypeProto.Map) + MergeFrom(*source); + } +} + +void TypeProto_Map::MergeFrom(const TypeProto_Map& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TypeProto.Map) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_mutable_value_type()->::onnx::TypeProto::MergeFrom(from._internal_value_type()); + } + if (cached_has_bits & 0x00000002u) { + key_type_ = from.key_type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TypeProto_Map::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TypeProto.Map) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_Map::CopyFrom(const TypeProto_Map& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TypeProto.Map) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_Map::IsInitialized() const { + return true; +} + +void TypeProto_Map::InternalSwap(TypeProto_Map* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TypeProto_Map, key_type_) + + sizeof(TypeProto_Map::key_type_) + - PROTOBUF_FIELD_OFFSET(TypeProto_Map, value_type_)>( + reinterpret_cast(&value_type_), + reinterpret_cast(&other->value_type_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_Map::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_Opaque::InitAsDefaultInstance() { +} +class TypeProto_Opaque::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_domain(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_name(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +TypeProto_Opaque::TypeProto_Opaque(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TypeProto.Opaque) +} +TypeProto_Opaque::TypeProto_Opaque(const TypeProto_Opaque& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_domain()) { + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_domain(), + GetArena()); + } + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_name()) { + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_name(), + GetArena()); + } + // @@protoc_insertion_point(copy_constructor:onnx.TypeProto.Opaque) +} + +void TypeProto_Opaque::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_Opaque_onnx_2dml_2eproto.base); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + name_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +TypeProto_Opaque::~TypeProto_Opaque() { + // @@protoc_insertion_point(destructor:onnx.TypeProto.Opaque) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_Opaque::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + domain_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + name_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void TypeProto_Opaque::ArenaDtor(void* object) { + TypeProto_Opaque* _this = reinterpret_cast< TypeProto_Opaque* >(object); + (void)_this; +} +void TypeProto_Opaque::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_Opaque::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_Opaque& TypeProto_Opaque::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_Opaque_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_Opaque::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TypeProto.Opaque) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + domain_.ClearNonDefaultToEmpty(); + } + if (cached_has_bits & 0x00000002u) { + name_.ClearNonDefaultToEmpty(); + } + } + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_Opaque::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string domain = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_domain(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TypeProto.Opaque.domain"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string name = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + auto str = _internal_mutable_name(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TypeProto.Opaque.name"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_Opaque::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TypeProto.Opaque) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string domain = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_domain().data(), static_cast(this->_internal_domain().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TypeProto.Opaque.domain"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_domain(), target); + } + + // optional string name = 2; + if (cached_has_bits & 0x00000002u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_name().data(), static_cast(this->_internal_name().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TypeProto.Opaque.name"); + target = stream->WriteStringMaybeAliased( + 2, this->_internal_name(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TypeProto.Opaque) + return target; +} + +size_t TypeProto_Opaque::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TypeProto.Opaque) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string domain = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_domain()); + } + + // optional string name = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_name()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_Opaque::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TypeProto.Opaque) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_Opaque* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TypeProto.Opaque) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TypeProto.Opaque) + MergeFrom(*source); + } +} + +void TypeProto_Opaque::MergeFrom(const TypeProto_Opaque& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TypeProto.Opaque) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_domain(from._internal_domain()); + } + if (cached_has_bits & 0x00000002u) { + _internal_set_name(from._internal_name()); + } + } +} + +void TypeProto_Opaque::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TypeProto.Opaque) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_Opaque::CopyFrom(const TypeProto_Opaque& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TypeProto.Opaque) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_Opaque::IsInitialized() const { + return true; +} + +void TypeProto_Opaque::InternalSwap(TypeProto_Opaque* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + domain_.Swap(&other->domain_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + name_.Swap(&other->name_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_Opaque::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto_SparseTensor::InitAsDefaultInstance() { + ::onnx::_TypeProto_SparseTensor_default_instance_._instance.get_mutable()->shape_ = const_cast< ::onnx::TensorShapeProto*>( + ::onnx::TensorShapeProto::internal_default_instance()); +} +class TypeProto_SparseTensor::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_elem_type(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } + static const ::onnx::TensorShapeProto& shape(const TypeProto_SparseTensor* msg); + static void set_has_shape(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::onnx::TensorShapeProto& +TypeProto_SparseTensor::_Internal::shape(const TypeProto_SparseTensor* msg) { + return *msg->shape_; +} +TypeProto_SparseTensor::TypeProto_SparseTensor(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TypeProto.SparseTensor) +} +TypeProto_SparseTensor::TypeProto_SparseTensor(const TypeProto_SparseTensor& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + if (from._internal_has_shape()) { + shape_ = new ::onnx::TensorShapeProto(*from.shape_); + } else { + shape_ = nullptr; + } + elem_type_ = from.elem_type_; + // @@protoc_insertion_point(copy_constructor:onnx.TypeProto.SparseTensor) +} + +void TypeProto_SparseTensor::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_SparseTensor_onnx_2dml_2eproto.base); + ::memset(&shape_, 0, static_cast( + reinterpret_cast(&elem_type_) - + reinterpret_cast(&shape_)) + sizeof(elem_type_)); +} + +TypeProto_SparseTensor::~TypeProto_SparseTensor() { + // @@protoc_insertion_point(destructor:onnx.TypeProto.SparseTensor) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto_SparseTensor::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + if (this != internal_default_instance()) delete shape_; +} + +void TypeProto_SparseTensor::ArenaDtor(void* object) { + TypeProto_SparseTensor* _this = reinterpret_cast< TypeProto_SparseTensor* >(object); + (void)_this; +} +void TypeProto_SparseTensor::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto_SparseTensor::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto_SparseTensor& TypeProto_SparseTensor::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_SparseTensor_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto_SparseTensor::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TypeProto.SparseTensor) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + GOOGLE_DCHECK(shape_ != nullptr); + shape_->Clear(); + } + elem_type_ = 0; + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto_SparseTensor::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional .onnx.TensorProto.DataType elem_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 8)) { + ::PROTOBUF_NAMESPACE_ID::uint64 val = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + if (PROTOBUF_PREDICT_TRUE(::onnx::TensorProto_DataType_IsValid(val))) { + _internal_set_elem_type(static_cast<::onnx::TensorProto_DataType>(val)); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::WriteVarint(1, val, mutable_unknown_fields()); + } + } else goto handle_unusual; + continue; + // optional .onnx.TensorShapeProto shape = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) { + ptr = ctx->ParseMessage(_internal_mutable_shape(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto_SparseTensor::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TypeProto.SparseTensor) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional .onnx.TensorProto.DataType elem_type = 1; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteEnumToArray( + 1, this->_internal_elem_type(), target); + } + + // optional .onnx.TensorShapeProto shape = 2; + if (cached_has_bits & 0x00000001u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 2, _Internal::shape(this), target, stream); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TypeProto.SparseTensor) + return target; +} + +size_t TypeProto_SparseTensor::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TypeProto.SparseTensor) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional .onnx.TensorShapeProto shape = 2; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *shape_); + } + + // optional .onnx.TensorProto.DataType elem_type = 1; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::EnumSize(this->_internal_elem_type()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto_SparseTensor::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TypeProto.SparseTensor) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto_SparseTensor* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TypeProto.SparseTensor) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TypeProto.SparseTensor) + MergeFrom(*source); + } +} + +void TypeProto_SparseTensor::MergeFrom(const TypeProto_SparseTensor& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TypeProto.SparseTensor) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_mutable_shape()->::onnx::TensorShapeProto::MergeFrom(from._internal_shape()); + } + if (cached_has_bits & 0x00000002u) { + elem_type_ = from.elem_type_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void TypeProto_SparseTensor::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TypeProto.SparseTensor) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto_SparseTensor::CopyFrom(const TypeProto_SparseTensor& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TypeProto.SparseTensor) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto_SparseTensor::IsInitialized() const { + return true; +} + +void TypeProto_SparseTensor::InternalSwap(TypeProto_SparseTensor* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + ::PROTOBUF_NAMESPACE_ID::internal::memswap< + PROTOBUF_FIELD_OFFSET(TypeProto_SparseTensor, elem_type_) + + sizeof(TypeProto_SparseTensor::elem_type_) + - PROTOBUF_FIELD_OFFSET(TypeProto_SparseTensor, shape_)>( + reinterpret_cast(&shape_), + reinterpret_cast(&other->shape_)); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto_SparseTensor::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void TypeProto::InitAsDefaultInstance() { + ::onnx::_TypeProto_default_instance_.tensor_type_ = const_cast< ::onnx::TypeProto_Tensor*>( + ::onnx::TypeProto_Tensor::internal_default_instance()); + ::onnx::_TypeProto_default_instance_.sequence_type_ = const_cast< ::onnx::TypeProto_Sequence*>( + ::onnx::TypeProto_Sequence::internal_default_instance()); + ::onnx::_TypeProto_default_instance_.map_type_ = const_cast< ::onnx::TypeProto_Map*>( + ::onnx::TypeProto_Map::internal_default_instance()); + ::onnx::_TypeProto_default_instance_.opaque_type_ = const_cast< ::onnx::TypeProto_Opaque*>( + ::onnx::TypeProto_Opaque::internal_default_instance()); + ::onnx::_TypeProto_default_instance_.sparse_tensor_type_ = const_cast< ::onnx::TypeProto_SparseTensor*>( + ::onnx::TypeProto_SparseTensor::internal_default_instance()); +} +class TypeProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static const ::onnx::TypeProto_Tensor& tensor_type(const TypeProto* msg); + static const ::onnx::TypeProto_Sequence& sequence_type(const TypeProto* msg); + static const ::onnx::TypeProto_Map& map_type(const TypeProto* msg); + static const ::onnx::TypeProto_Opaque& opaque_type(const TypeProto* msg); + static const ::onnx::TypeProto_SparseTensor& sparse_tensor_type(const TypeProto* msg); + static void set_has_denotation(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } +}; + +const ::onnx::TypeProto_Tensor& +TypeProto::_Internal::tensor_type(const TypeProto* msg) { + return *msg->value_.tensor_type_; +} +const ::onnx::TypeProto_Sequence& +TypeProto::_Internal::sequence_type(const TypeProto* msg) { + return *msg->value_.sequence_type_; +} +const ::onnx::TypeProto_Map& +TypeProto::_Internal::map_type(const TypeProto* msg) { + return *msg->value_.map_type_; +} +const ::onnx::TypeProto_Opaque& +TypeProto::_Internal::opaque_type(const TypeProto* msg) { + return *msg->value_.opaque_type_; +} +const ::onnx::TypeProto_SparseTensor& +TypeProto::_Internal::sparse_tensor_type(const TypeProto* msg) { + return *msg->value_.sparse_tensor_type_; +} +void TypeProto::set_allocated_tensor_type(::onnx::TypeProto_Tensor* tensor_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (tensor_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(tensor_type); + if (message_arena != submessage_arena) { + tensor_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, tensor_type, submessage_arena); + } + set_has_tensor_type(); + value_.tensor_type_ = tensor_type; + } + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.tensor_type) +} +void TypeProto::set_allocated_sequence_type(::onnx::TypeProto_Sequence* sequence_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (sequence_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(sequence_type); + if (message_arena != submessage_arena) { + sequence_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, sequence_type, submessage_arena); + } + set_has_sequence_type(); + value_.sequence_type_ = sequence_type; + } + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.sequence_type) +} +void TypeProto::set_allocated_map_type(::onnx::TypeProto_Map* map_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (map_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(map_type); + if (message_arena != submessage_arena) { + map_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, map_type, submessage_arena); + } + set_has_map_type(); + value_.map_type_ = map_type; + } + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.map_type) +} +void TypeProto::set_allocated_opaque_type(::onnx::TypeProto_Opaque* opaque_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (opaque_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(opaque_type); + if (message_arena != submessage_arena) { + opaque_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, opaque_type, submessage_arena); + } + set_has_opaque_type(); + value_.opaque_type_ = opaque_type; + } + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.opaque_type) +} +void TypeProto::set_allocated_sparse_tensor_type(::onnx::TypeProto_SparseTensor* sparse_tensor_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + clear_value(); + if (sparse_tensor_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(sparse_tensor_type); + if (message_arena != submessage_arena) { + sparse_tensor_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, sparse_tensor_type, submessage_arena); + } + set_has_sparse_tensor_type(); + value_.sparse_tensor_type_ = sparse_tensor_type; + } + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.sparse_tensor_type) +} +TypeProto::TypeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.TypeProto) +} +TypeProto::TypeProto(const TypeProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + denotation_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_denotation()) { + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_denotation(), + GetArena()); + } + clear_has_value(); + switch (from.value_case()) { + case kTensorType: { + _internal_mutable_tensor_type()->::onnx::TypeProto_Tensor::MergeFrom(from._internal_tensor_type()); + break; + } + case kSequenceType: { + _internal_mutable_sequence_type()->::onnx::TypeProto_Sequence::MergeFrom(from._internal_sequence_type()); + break; + } + case kMapType: { + _internal_mutable_map_type()->::onnx::TypeProto_Map::MergeFrom(from._internal_map_type()); + break; + } + case kOpaqueType: { + _internal_mutable_opaque_type()->::onnx::TypeProto_Opaque::MergeFrom(from._internal_opaque_type()); + break; + } + case kSparseTensorType: { + _internal_mutable_sparse_tensor_type()->::onnx::TypeProto_SparseTensor::MergeFrom(from._internal_sparse_tensor_type()); + break; + } + case VALUE_NOT_SET: { + break; + } + } + // @@protoc_insertion_point(copy_constructor:onnx.TypeProto) +} + +void TypeProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_TypeProto_onnx_2dml_2eproto.base); + denotation_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + clear_has_value(); +} + +TypeProto::~TypeProto() { + // @@protoc_insertion_point(destructor:onnx.TypeProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void TypeProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + denotation_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (has_value()) { + clear_value(); + } +} + +void TypeProto::ArenaDtor(void* object) { + TypeProto* _this = reinterpret_cast< TypeProto* >(object); + (void)_this; +} +void TypeProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void TypeProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const TypeProto& TypeProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_TypeProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void TypeProto::clear_value() { +// @@protoc_insertion_point(one_of_clear_start:onnx.TypeProto) + switch (value_case()) { + case kTensorType: { + if (GetArena() == nullptr) { + delete value_.tensor_type_; + } + break; + } + case kSequenceType: { + if (GetArena() == nullptr) { + delete value_.sequence_type_; + } + break; + } + case kMapType: { + if (GetArena() == nullptr) { + delete value_.map_type_; + } + break; + } + case kOpaqueType: { + if (GetArena() == nullptr) { + delete value_.opaque_type_; + } + break; + } + case kSparseTensorType: { + if (GetArena() == nullptr) { + delete value_.sparse_tensor_type_; + } + break; + } + case VALUE_NOT_SET: { + break; + } + } + _oneof_case_[0] = VALUE_NOT_SET; +} + + +void TypeProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.TypeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + denotation_.ClearNonDefaultToEmpty(); + } + clear_value(); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* TypeProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // .onnx.TypeProto.Tensor tensor_type = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + ptr = ctx->ParseMessage(_internal_mutable_tensor_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .onnx.TypeProto.Sequence sequence_type = 4; + case 4: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 34)) { + ptr = ctx->ParseMessage(_internal_mutable_sequence_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .onnx.TypeProto.Map map_type = 5; + case 5: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 42)) { + ptr = ctx->ParseMessage(_internal_mutable_map_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional string denotation = 6; + case 6: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 50)) { + auto str = _internal_mutable_denotation(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.TypeProto.denotation"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // .onnx.TypeProto.Opaque opaque_type = 7; + case 7: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 58)) { + ptr = ctx->ParseMessage(_internal_mutable_opaque_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + // .onnx.TypeProto.SparseTensor sparse_tensor_type = 8; + case 8: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 66)) { + ptr = ctx->ParseMessage(_internal_mutable_sparse_tensor_type(), ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* TypeProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.TypeProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + switch (value_case()) { + case kTensorType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 1, _Internal::tensor_type(this), target, stream); + break; + } + case kSequenceType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 4, _Internal::sequence_type(this), target, stream); + break; + } + case kMapType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 5, _Internal::map_type(this), target, stream); + break; + } + default: ; + } + cached_has_bits = _has_bits_[0]; + // optional string denotation = 6; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_denotation().data(), static_cast(this->_internal_denotation().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.TypeProto.denotation"); + target = stream->WriteStringMaybeAliased( + 6, this->_internal_denotation(), target); + } + + switch (value_case()) { + case kOpaqueType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 7, _Internal::opaque_type(this), target, stream); + break; + } + case kSparseTensorType: { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite:: + InternalWriteMessage( + 8, _Internal::sparse_tensor_type(this), target, stream); + break; + } + default: ; + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.TypeProto) + return target; +} + +size_t TypeProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.TypeProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + // optional string denotation = 6; + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_denotation()); + } + + switch (value_case()) { + // .onnx.TypeProto.Tensor tensor_type = 1; + case kTensorType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.tensor_type_); + break; + } + // .onnx.TypeProto.Sequence sequence_type = 4; + case kSequenceType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.sequence_type_); + break; + } + // .onnx.TypeProto.Map map_type = 5; + case kMapType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.map_type_); + break; + } + // .onnx.TypeProto.Opaque opaque_type = 7; + case kOpaqueType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.opaque_type_); + break; + } + // .onnx.TypeProto.SparseTensor sparse_tensor_type = 8; + case kSparseTensorType: { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize( + *value_.sparse_tensor_type_); + break; + } + case VALUE_NOT_SET: { + break; + } + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void TypeProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.TypeProto) + GOOGLE_DCHECK_NE(&from, this); + const TypeProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.TypeProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.TypeProto) + MergeFrom(*source); + } +} + +void TypeProto::MergeFrom(const TypeProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.TypeProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (from._internal_has_denotation()) { + _internal_set_denotation(from._internal_denotation()); + } + switch (from.value_case()) { + case kTensorType: { + _internal_mutable_tensor_type()->::onnx::TypeProto_Tensor::MergeFrom(from._internal_tensor_type()); + break; + } + case kSequenceType: { + _internal_mutable_sequence_type()->::onnx::TypeProto_Sequence::MergeFrom(from._internal_sequence_type()); + break; + } + case kMapType: { + _internal_mutable_map_type()->::onnx::TypeProto_Map::MergeFrom(from._internal_map_type()); + break; + } + case kOpaqueType: { + _internal_mutable_opaque_type()->::onnx::TypeProto_Opaque::MergeFrom(from._internal_opaque_type()); + break; + } + case kSparseTensorType: { + _internal_mutable_sparse_tensor_type()->::onnx::TypeProto_SparseTensor::MergeFrom(from._internal_sparse_tensor_type()); + break; + } + case VALUE_NOT_SET: { + break; + } + } +} + +void TypeProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.TypeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void TypeProto::CopyFrom(const TypeProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.TypeProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool TypeProto::IsInitialized() const { + return true; +} + +void TypeProto::InternalSwap(TypeProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + denotation_.Swap(&other->denotation_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(value_, other->value_); + swap(_oneof_case_[0], other->_oneof_case_[0]); +} + +::PROTOBUF_NAMESPACE_ID::Metadata TypeProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// =================================================================== + +void OperatorSetIdProto::InitAsDefaultInstance() { +} +class OperatorSetIdProto::_Internal { + public: + using HasBits = decltype(std::declval()._has_bits_); + static void set_has_domain(HasBits* has_bits) { + (*has_bits)[0] |= 1u; + } + static void set_has_version(HasBits* has_bits) { + (*has_bits)[0] |= 2u; + } +}; + +OperatorSetIdProto::OperatorSetIdProto(::PROTOBUF_NAMESPACE_ID::Arena* arena) + : ::PROTOBUF_NAMESPACE_ID::Message(arena) { + SharedCtor(); + RegisterArenaDtor(arena); + // @@protoc_insertion_point(arena_constructor:onnx.OperatorSetIdProto) +} +OperatorSetIdProto::OperatorSetIdProto(const OperatorSetIdProto& from) + : ::PROTOBUF_NAMESPACE_ID::Message(), + _has_bits_(from._has_bits_) { + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + if (from._internal_has_domain()) { + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), from._internal_domain(), + GetArena()); + } + version_ = from.version_; + // @@protoc_insertion_point(copy_constructor:onnx.OperatorSetIdProto) +} + +void OperatorSetIdProto::SharedCtor() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_OperatorSetIdProto_onnx_2dml_2eproto.base); + domain_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + version_ = PROTOBUF_LONGLONG(0); +} + +OperatorSetIdProto::~OperatorSetIdProto() { + // @@protoc_insertion_point(destructor:onnx.OperatorSetIdProto) + SharedDtor(); + _internal_metadata_.Delete<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +void OperatorSetIdProto::SharedDtor() { + GOOGLE_DCHECK(GetArena() == nullptr); + domain_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); +} + +void OperatorSetIdProto::ArenaDtor(void* object) { + OperatorSetIdProto* _this = reinterpret_cast< OperatorSetIdProto* >(object); + (void)_this; +} +void OperatorSetIdProto::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) { +} +void OperatorSetIdProto::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const OperatorSetIdProto& OperatorSetIdProto::default_instance() { + ::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_OperatorSetIdProto_onnx_2dml_2eproto.base); + return *internal_default_instance(); +} + + +void OperatorSetIdProto::Clear() { +// @@protoc_insertion_point(message_clear_start:onnx.OperatorSetIdProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000001u) { + domain_.ClearNonDefaultToEmpty(); + } + version_ = PROTOBUF_LONGLONG(0); + _has_bits_.Clear(); + _internal_metadata_.Clear<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); +} + +const char* OperatorSetIdProto::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) { +#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure + _Internal::HasBits has_bits{}; + ::PROTOBUF_NAMESPACE_ID::Arena* arena = GetArena(); (void)arena; + while (!ctx->Done(&ptr)) { + ::PROTOBUF_NAMESPACE_ID::uint32 tag; + ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag); + CHK_(ptr); + switch (tag >> 3) { + // optional string domain = 1; + case 1: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) { + auto str = _internal_mutable_domain(); + ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx); + #ifndef NDEBUG + ::PROTOBUF_NAMESPACE_ID::internal::VerifyUTF8(str, "onnx.OperatorSetIdProto.domain"); + #endif // !NDEBUG + CHK_(ptr); + } else goto handle_unusual; + continue; + // optional int64 version = 2; + case 2: + if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) { + _Internal::set_has_version(&has_bits); + version_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint64(&ptr); + CHK_(ptr); + } else goto handle_unusual; + continue; + default: { + handle_unusual: + if ((tag & 7) == 4 || tag == 0) { + ctx->SetLastTag(tag); + goto success; + } + ptr = UnknownFieldParse(tag, + _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(), + ptr, ctx); + CHK_(ptr != nullptr); + continue; + } + } // switch + } // while +success: + _has_bits_.Or(has_bits); + return ptr; +failure: + ptr = nullptr; + goto success; +#undef CHK_ +} + +::PROTOBUF_NAMESPACE_ID::uint8* OperatorSetIdProto::_InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const { + // @@protoc_insertion_point(serialize_to_array_start:onnx.OperatorSetIdProto) + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + // optional string domain = 1; + if (cached_has_bits & 0x00000001u) { + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::VerifyUTF8StringNamedField( + this->_internal_domain().data(), static_cast(this->_internal_domain().length()), + ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::SERIALIZE, + "onnx.OperatorSetIdProto.domain"); + target = stream->WriteStringMaybeAliased( + 1, this->_internal_domain(), target); + } + + // optional int64 version = 2; + if (cached_has_bits & 0x00000002u) { + target = stream->EnsureSpace(target); + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteInt64ToArray(2, this->_internal_version(), target); + } + + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormat::InternalSerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance), target, stream); + } + // @@protoc_insertion_point(serialize_to_array_end:onnx.OperatorSetIdProto) + return target; +} + +size_t OperatorSetIdProto::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:onnx.OperatorSetIdProto) + size_t total_size = 0; + + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + cached_has_bits = _has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + // optional string domain = 1; + if (cached_has_bits & 0x00000001u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize( + this->_internal_domain()); + } + + // optional int64 version = 2; + if (cached_has_bits & 0x00000002u) { + total_size += 1 + + ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::Int64Size( + this->_internal_version()); + } + + } + if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) { + return ::PROTOBUF_NAMESPACE_ID::internal::ComputeUnknownFieldsSize( + _internal_metadata_, total_size, &_cached_size_); + } + int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void OperatorSetIdProto::MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:onnx.OperatorSetIdProto) + GOOGLE_DCHECK_NE(&from, this); + const OperatorSetIdProto* source = + ::PROTOBUF_NAMESPACE_ID::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:onnx.OperatorSetIdProto) + ::PROTOBUF_NAMESPACE_ID::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:onnx.OperatorSetIdProto) + MergeFrom(*source); + } +} + +void OperatorSetIdProto::MergeFrom(const OperatorSetIdProto& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:onnx.OperatorSetIdProto) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(from._internal_metadata_); + ::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + cached_has_bits = from._has_bits_[0]; + if (cached_has_bits & 0x00000003u) { + if (cached_has_bits & 0x00000001u) { + _internal_set_domain(from._internal_domain()); + } + if (cached_has_bits & 0x00000002u) { + version_ = from.version_; + } + _has_bits_[0] |= cached_has_bits; + } +} + +void OperatorSetIdProto::CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:onnx.OperatorSetIdProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void OperatorSetIdProto::CopyFrom(const OperatorSetIdProto& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:onnx.OperatorSetIdProto) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool OperatorSetIdProto::IsInitialized() const { + return true; +} + +void OperatorSetIdProto::InternalSwap(OperatorSetIdProto* other) { + using std::swap; + _internal_metadata_.Swap<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(&other->_internal_metadata_); + swap(_has_bits_[0], other->_has_bits_[0]); + domain_.Swap(&other->domain_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + swap(version_, other->version_); +} + +::PROTOBUF_NAMESPACE_ID::Metadata OperatorSetIdProto::GetMetadata() const { + return GetMetadataStatic(); +} + + +// @@protoc_insertion_point(namespace_scope) +} // namespace onnx +PROTOBUF_NAMESPACE_OPEN +template<> PROTOBUF_NOINLINE ::onnx::AttributeProto* Arena::CreateMaybeMessage< ::onnx::AttributeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::AttributeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::ValueInfoProto* Arena::CreateMaybeMessage< ::onnx::ValueInfoProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::ValueInfoProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::NodeProto* Arena::CreateMaybeMessage< ::onnx::NodeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::NodeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::ModelProto* Arena::CreateMaybeMessage< ::onnx::ModelProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::ModelProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::StringStringEntryProto* Arena::CreateMaybeMessage< ::onnx::StringStringEntryProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::StringStringEntryProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::GraphProto* Arena::CreateMaybeMessage< ::onnx::GraphProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::GraphProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TensorProto_Segment* Arena::CreateMaybeMessage< ::onnx::TensorProto_Segment >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TensorProto_Segment >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TensorProto* Arena::CreateMaybeMessage< ::onnx::TensorProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TensorProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TensorShapeProto_Dimension* Arena::CreateMaybeMessage< ::onnx::TensorShapeProto_Dimension >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TensorShapeProto_Dimension >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TensorShapeProto* Arena::CreateMaybeMessage< ::onnx::TensorShapeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TensorShapeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TypeProto_Tensor* Arena::CreateMaybeMessage< ::onnx::TypeProto_Tensor >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TypeProto_Tensor >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TypeProto_Sequence* Arena::CreateMaybeMessage< ::onnx::TypeProto_Sequence >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TypeProto_Sequence >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TypeProto_Map* Arena::CreateMaybeMessage< ::onnx::TypeProto_Map >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TypeProto_Map >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TypeProto_Opaque* Arena::CreateMaybeMessage< ::onnx::TypeProto_Opaque >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TypeProto_Opaque >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TypeProto_SparseTensor* Arena::CreateMaybeMessage< ::onnx::TypeProto_SparseTensor >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TypeProto_SparseTensor >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::TypeProto* Arena::CreateMaybeMessage< ::onnx::TypeProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::TypeProto >(arena); +} +template<> PROTOBUF_NOINLINE ::onnx::OperatorSetIdProto* Arena::CreateMaybeMessage< ::onnx::OperatorSetIdProto >(Arena* arena) { + return Arena::CreateMessageInternal< ::onnx::OperatorSetIdProto >(arena); +} +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) +#include diff --git a/src/3rd_party/onnx/protobuf/onnx-ml.pb.h b/src/3rd_party/onnx/protobuf/onnx-ml.pb.h new file mode 100755 index 000000000..974631c29 --- /dev/null +++ b/src/3rd_party/onnx/protobuf/onnx-ml.pb.h @@ -0,0 +1,9773 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: onnx-ml.proto + +#ifndef GOOGLE_PROTOBUF_INCLUDED_onnx_2dml_2eproto +#define GOOGLE_PROTOBUF_INCLUDED_onnx_2dml_2eproto + +#include +#include + +#include +#if PROTOBUF_VERSION < 3012000 +#error This file was generated by a newer version of protoc which is +#error incompatible with your Protocol Buffer headers. Please update +#error your headers. +#endif +#if 3012000 < PROTOBUF_MIN_PROTOC_VERSION +#error This file was generated by an older version of protoc which is +#error incompatible with your Protocol Buffer headers. Please +#error regenerate this file with a newer version of protoc. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // IWYU pragma: export +#include // IWYU pragma: export +#include +#include +// @@protoc_insertion_point(includes) +#include +#define PROTOBUF_INTERNAL_EXPORT_onnx_2dml_2eproto +PROTOBUF_NAMESPACE_OPEN +namespace internal { +class AnyMetadata; +} // namespace internal +PROTOBUF_NAMESPACE_CLOSE + +// Internal implementation detail -- do not use these members. +struct TableStruct_onnx_2dml_2eproto { + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::AuxillaryParseTableField aux[] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[17] + PROTOBUF_SECTION_VARIABLE(protodesc_cold); + static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[]; + static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[]; + static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[]; +}; +extern const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_onnx_2dml_2eproto; +namespace onnx { +class AttributeProto; +class AttributeProtoDefaultTypeInternal; +extern AttributeProtoDefaultTypeInternal _AttributeProto_default_instance_; +class GraphProto; +class GraphProtoDefaultTypeInternal; +extern GraphProtoDefaultTypeInternal _GraphProto_default_instance_; +class ModelProto; +class ModelProtoDefaultTypeInternal; +extern ModelProtoDefaultTypeInternal _ModelProto_default_instance_; +class NodeProto; +class NodeProtoDefaultTypeInternal; +extern NodeProtoDefaultTypeInternal _NodeProto_default_instance_; +class OperatorSetIdProto; +class OperatorSetIdProtoDefaultTypeInternal; +extern OperatorSetIdProtoDefaultTypeInternal _OperatorSetIdProto_default_instance_; +class StringStringEntryProto; +class StringStringEntryProtoDefaultTypeInternal; +extern StringStringEntryProtoDefaultTypeInternal _StringStringEntryProto_default_instance_; +class TensorProto; +class TensorProtoDefaultTypeInternal; +extern TensorProtoDefaultTypeInternal _TensorProto_default_instance_; +class TensorProto_Segment; +class TensorProto_SegmentDefaultTypeInternal; +extern TensorProto_SegmentDefaultTypeInternal _TensorProto_Segment_default_instance_; +class TensorShapeProto; +class TensorShapeProtoDefaultTypeInternal; +extern TensorShapeProtoDefaultTypeInternal _TensorShapeProto_default_instance_; +class TensorShapeProto_Dimension; +class TensorShapeProto_DimensionDefaultTypeInternal; +extern TensorShapeProto_DimensionDefaultTypeInternal _TensorShapeProto_Dimension_default_instance_; +class TypeProto; +class TypeProtoDefaultTypeInternal; +extern TypeProtoDefaultTypeInternal _TypeProto_default_instance_; +class TypeProto_Map; +class TypeProto_MapDefaultTypeInternal; +extern TypeProto_MapDefaultTypeInternal _TypeProto_Map_default_instance_; +class TypeProto_Opaque; +class TypeProto_OpaqueDefaultTypeInternal; +extern TypeProto_OpaqueDefaultTypeInternal _TypeProto_Opaque_default_instance_; +class TypeProto_Sequence; +class TypeProto_SequenceDefaultTypeInternal; +extern TypeProto_SequenceDefaultTypeInternal _TypeProto_Sequence_default_instance_; +class TypeProto_SparseTensor; +class TypeProto_SparseTensorDefaultTypeInternal; +extern TypeProto_SparseTensorDefaultTypeInternal _TypeProto_SparseTensor_default_instance_; +class TypeProto_Tensor; +class TypeProto_TensorDefaultTypeInternal; +extern TypeProto_TensorDefaultTypeInternal _TypeProto_Tensor_default_instance_; +class ValueInfoProto; +class ValueInfoProtoDefaultTypeInternal; +extern ValueInfoProtoDefaultTypeInternal _ValueInfoProto_default_instance_; +} // namespace onnx +PROTOBUF_NAMESPACE_OPEN +template<> ::onnx::AttributeProto* Arena::CreateMaybeMessage<::onnx::AttributeProto>(Arena*); +template<> ::onnx::GraphProto* Arena::CreateMaybeMessage<::onnx::GraphProto>(Arena*); +template<> ::onnx::ModelProto* Arena::CreateMaybeMessage<::onnx::ModelProto>(Arena*); +template<> ::onnx::NodeProto* Arena::CreateMaybeMessage<::onnx::NodeProto>(Arena*); +template<> ::onnx::OperatorSetIdProto* Arena::CreateMaybeMessage<::onnx::OperatorSetIdProto>(Arena*); +template<> ::onnx::StringStringEntryProto* Arena::CreateMaybeMessage<::onnx::StringStringEntryProto>(Arena*); +template<> ::onnx::TensorProto* Arena::CreateMaybeMessage<::onnx::TensorProto>(Arena*); +template<> ::onnx::TensorProto_Segment* Arena::CreateMaybeMessage<::onnx::TensorProto_Segment>(Arena*); +template<> ::onnx::TensorShapeProto* Arena::CreateMaybeMessage<::onnx::TensorShapeProto>(Arena*); +template<> ::onnx::TensorShapeProto_Dimension* Arena::CreateMaybeMessage<::onnx::TensorShapeProto_Dimension>(Arena*); +template<> ::onnx::TypeProto* Arena::CreateMaybeMessage<::onnx::TypeProto>(Arena*); +template<> ::onnx::TypeProto_Map* Arena::CreateMaybeMessage<::onnx::TypeProto_Map>(Arena*); +template<> ::onnx::TypeProto_Opaque* Arena::CreateMaybeMessage<::onnx::TypeProto_Opaque>(Arena*); +template<> ::onnx::TypeProto_Sequence* Arena::CreateMaybeMessage<::onnx::TypeProto_Sequence>(Arena*); +template<> ::onnx::TypeProto_SparseTensor* Arena::CreateMaybeMessage<::onnx::TypeProto_SparseTensor>(Arena*); +template<> ::onnx::TypeProto_Tensor* Arena::CreateMaybeMessage<::onnx::TypeProto_Tensor>(Arena*); +template<> ::onnx::ValueInfoProto* Arena::CreateMaybeMessage<::onnx::ValueInfoProto>(Arena*); +PROTOBUF_NAMESPACE_CLOSE +namespace onnx { + +enum AttributeProto_AttributeType : int { + AttributeProto_AttributeType_UNDEFINED = 0, + AttributeProto_AttributeType_FLOAT = 1, + AttributeProto_AttributeType_INT = 2, + AttributeProto_AttributeType_STRING = 3, + AttributeProto_AttributeType_TENSOR = 4, + AttributeProto_AttributeType_GRAPH = 5, + AttributeProto_AttributeType_FLOATS = 6, + AttributeProto_AttributeType_INTS = 7, + AttributeProto_AttributeType_STRINGS = 8, + AttributeProto_AttributeType_TENSORS = 9, + AttributeProto_AttributeType_GRAPHS = 10 +}; +bool AttributeProto_AttributeType_IsValid(int value); +constexpr AttributeProto_AttributeType AttributeProto_AttributeType_AttributeType_MIN = AttributeProto_AttributeType_UNDEFINED; +constexpr AttributeProto_AttributeType AttributeProto_AttributeType_AttributeType_MAX = AttributeProto_AttributeType_GRAPHS; +constexpr int AttributeProto_AttributeType_AttributeType_ARRAYSIZE = AttributeProto_AttributeType_AttributeType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* AttributeProto_AttributeType_descriptor(); +template +inline const std::string& AttributeProto_AttributeType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function AttributeProto_AttributeType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + AttributeProto_AttributeType_descriptor(), enum_t_value); +} +inline bool AttributeProto_AttributeType_Parse( + const std::string& name, AttributeProto_AttributeType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + AttributeProto_AttributeType_descriptor(), name, value); +} +enum TensorProto_DataType : int { + TensorProto_DataType_UNDEFINED = 0, + TensorProto_DataType_FLOAT = 1, + TensorProto_DataType_UINT8 = 2, + TensorProto_DataType_INT8 = 3, + TensorProto_DataType_UINT16 = 4, + TensorProto_DataType_INT16 = 5, + TensorProto_DataType_INT32 = 6, + TensorProto_DataType_INT64 = 7, + TensorProto_DataType_STRING = 8, + TensorProto_DataType_BOOL = 9, + TensorProto_DataType_FLOAT16 = 10, + TensorProto_DataType_DOUBLE = 11, + TensorProto_DataType_UINT32 = 12, + TensorProto_DataType_UINT64 = 13, + TensorProto_DataType_COMPLEX64 = 14, + TensorProto_DataType_COMPLEX128 = 15, + TensorProto_DataType_BFLOAT16 = 16 +}; +bool TensorProto_DataType_IsValid(int value); +constexpr TensorProto_DataType TensorProto_DataType_DataType_MIN = TensorProto_DataType_UNDEFINED; +constexpr TensorProto_DataType TensorProto_DataType_DataType_MAX = TensorProto_DataType_BFLOAT16; +constexpr int TensorProto_DataType_DataType_ARRAYSIZE = TensorProto_DataType_DataType_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* TensorProto_DataType_descriptor(); +template +inline const std::string& TensorProto_DataType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function TensorProto_DataType_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + TensorProto_DataType_descriptor(), enum_t_value); +} +inline bool TensorProto_DataType_Parse( + const std::string& name, TensorProto_DataType* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + TensorProto_DataType_descriptor(), name, value); +} +enum Version : int { + _START_VERSION = 0, + IR_VERSION_2017_10_10 = 1, + IR_VERSION_2017_10_30 = 2, + IR_VERSION = 3 +}; +bool Version_IsValid(int value); +constexpr Version Version_MIN = _START_VERSION; +constexpr Version Version_MAX = IR_VERSION; +constexpr int Version_ARRAYSIZE = Version_MAX + 1; + +const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* Version_descriptor(); +template +inline const std::string& Version_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function Version_Name."); + return ::PROTOBUF_NAMESPACE_ID::internal::NameOfEnum( + Version_descriptor(), enum_t_value); +} +inline bool Version_Parse( + const std::string& name, Version* value) { + return ::PROTOBUF_NAMESPACE_ID::internal::ParseNamedEnum( + Version_descriptor(), name, value); +} +// =================================================================== + +class AttributeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.AttributeProto) */ { + public: + inline AttributeProto() : AttributeProto(nullptr) {}; + virtual ~AttributeProto(); + + AttributeProto(const AttributeProto& from); + AttributeProto(AttributeProto&& from) noexcept + : AttributeProto() { + *this = ::std::move(from); + } + + inline AttributeProto& operator=(const AttributeProto& from) { + CopyFrom(from); + return *this; + } + inline AttributeProto& operator=(AttributeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const AttributeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const AttributeProto* internal_default_instance() { + return reinterpret_cast( + &_AttributeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 0; + + friend void swap(AttributeProto& a, AttributeProto& b) { + a.Swap(&b); + } + inline void Swap(AttributeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(AttributeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline AttributeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + AttributeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const AttributeProto& from); + void MergeFrom(const AttributeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(AttributeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.AttributeProto"; + } + protected: + explicit AttributeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef AttributeProto_AttributeType AttributeType; + static constexpr AttributeType UNDEFINED = + AttributeProto_AttributeType_UNDEFINED; + static constexpr AttributeType FLOAT = + AttributeProto_AttributeType_FLOAT; + static constexpr AttributeType INT = + AttributeProto_AttributeType_INT; + static constexpr AttributeType STRING = + AttributeProto_AttributeType_STRING; + static constexpr AttributeType TENSOR = + AttributeProto_AttributeType_TENSOR; + static constexpr AttributeType GRAPH = + AttributeProto_AttributeType_GRAPH; + static constexpr AttributeType FLOATS = + AttributeProto_AttributeType_FLOATS; + static constexpr AttributeType INTS = + AttributeProto_AttributeType_INTS; + static constexpr AttributeType STRINGS = + AttributeProto_AttributeType_STRINGS; + static constexpr AttributeType TENSORS = + AttributeProto_AttributeType_TENSORS; + static constexpr AttributeType GRAPHS = + AttributeProto_AttributeType_GRAPHS; + static inline bool AttributeType_IsValid(int value) { + return AttributeProto_AttributeType_IsValid(value); + } + static constexpr AttributeType AttributeType_MIN = + AttributeProto_AttributeType_AttributeType_MIN; + static constexpr AttributeType AttributeType_MAX = + AttributeProto_AttributeType_AttributeType_MAX; + static constexpr int AttributeType_ARRAYSIZE = + AttributeProto_AttributeType_AttributeType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + AttributeType_descriptor() { + return AttributeProto_AttributeType_descriptor(); + } + template + static inline const std::string& AttributeType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function AttributeType_Name."); + return AttributeProto_AttributeType_Name(enum_t_value); + } + static inline bool AttributeType_Parse(const std::string& name, + AttributeType* value) { + return AttributeProto_AttributeType_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kFloatsFieldNumber = 7, + kIntsFieldNumber = 8, + kStringsFieldNumber = 9, + kTensorsFieldNumber = 10, + kGraphsFieldNumber = 11, + kNameFieldNumber = 1, + kSFieldNumber = 4, + kDocStringFieldNumber = 13, + kRefAttrNameFieldNumber = 21, + kTFieldNumber = 5, + kGFieldNumber = 6, + kIFieldNumber = 3, + kFFieldNumber = 2, + kTypeFieldNumber = 20, + }; + // repeated float floats = 7; + int floats_size() const; + private: + int _internal_floats_size() const; + public: + void clear_floats(); + private: + float _internal_floats(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_floats() const; + void _internal_add_floats(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_floats(); + public: + float floats(int index) const; + void set_floats(int index, float value); + void add_floats(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + floats() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_floats(); + + // repeated int64 ints = 8; + int ints_size() const; + private: + int _internal_ints_size() const; + public: + void clear_ints(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_ints(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_ints() const; + void _internal_add_ints(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_ints(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 ints(int index) const; + void set_ints(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_ints(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + ints() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_ints(); + + // repeated bytes strings = 9; + int strings_size() const; + private: + int _internal_strings_size() const; + public: + void clear_strings(); + const std::string& strings(int index) const; + std::string* mutable_strings(int index); + void set_strings(int index, const std::string& value); + void set_strings(int index, std::string&& value); + void set_strings(int index, const char* value); + void set_strings(int index, const void* value, size_t size); + std::string* add_strings(); + void add_strings(const std::string& value); + void add_strings(std::string&& value); + void add_strings(const char* value); + void add_strings(const void* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& strings() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_strings(); + private: + const std::string& _internal_strings(int index) const; + std::string* _internal_add_strings(); + public: + + // repeated .onnx.TensorProto tensors = 10; + int tensors_size() const; + private: + int _internal_tensors_size() const; + public: + void clear_tensors(); + ::onnx::TensorProto* mutable_tensors(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* + mutable_tensors(); + private: + const ::onnx::TensorProto& _internal_tensors(int index) const; + ::onnx::TensorProto* _internal_add_tensors(); + public: + const ::onnx::TensorProto& tensors(int index) const; + ::onnx::TensorProto* add_tensors(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& + tensors() const; + + // repeated .onnx.GraphProto graphs = 11; + int graphs_size() const; + private: + int _internal_graphs_size() const; + public: + void clear_graphs(); + ::onnx::GraphProto* mutable_graphs(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >* + mutable_graphs(); + private: + const ::onnx::GraphProto& _internal_graphs(int index) const; + ::onnx::GraphProto* _internal_add_graphs(); + public: + const ::onnx::GraphProto& graphs(int index) const; + ::onnx::GraphProto* add_graphs(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >& + graphs() const; + + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_name( + std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional bytes s = 4; + bool has_s() const; + private: + bool _internal_has_s() const; + public: + void clear_s(); + const std::string& s() const; + void set_s(const std::string& value); + void set_s(std::string&& value); + void set_s(const char* value); + void set_s(const void* value, size_t size); + std::string* mutable_s(); + std::string* release_s(); + void set_allocated_s(std::string* s); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_s(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_s( + std::string* s); + private: + const std::string& _internal_s() const; + void _internal_set_s(const std::string& value); + std::string* _internal_mutable_s(); + public: + + // optional string doc_string = 13; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_doc_string(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_doc_string( + std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional string ref_attr_name = 21; + bool has_ref_attr_name() const; + private: + bool _internal_has_ref_attr_name() const; + public: + void clear_ref_attr_name(); + const std::string& ref_attr_name() const; + void set_ref_attr_name(const std::string& value); + void set_ref_attr_name(std::string&& value); + void set_ref_attr_name(const char* value); + void set_ref_attr_name(const char* value, size_t size); + std::string* mutable_ref_attr_name(); + std::string* release_ref_attr_name(); + void set_allocated_ref_attr_name(std::string* ref_attr_name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_ref_attr_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_ref_attr_name( + std::string* ref_attr_name); + private: + const std::string& _internal_ref_attr_name() const; + void _internal_set_ref_attr_name(const std::string& value); + std::string* _internal_mutable_ref_attr_name(); + public: + + // optional .onnx.TensorProto t = 5; + bool has_t() const; + private: + bool _internal_has_t() const; + public: + void clear_t(); + const ::onnx::TensorProto& t() const; + ::onnx::TensorProto* release_t(); + ::onnx::TensorProto* mutable_t(); + void set_allocated_t(::onnx::TensorProto* t); + private: + const ::onnx::TensorProto& _internal_t() const; + ::onnx::TensorProto* _internal_mutable_t(); + public: + void unsafe_arena_set_allocated_t( + ::onnx::TensorProto* t); + ::onnx::TensorProto* unsafe_arena_release_t(); + + // optional .onnx.GraphProto g = 6; + bool has_g() const; + private: + bool _internal_has_g() const; + public: + void clear_g(); + const ::onnx::GraphProto& g() const; + ::onnx::GraphProto* release_g(); + ::onnx::GraphProto* mutable_g(); + void set_allocated_g(::onnx::GraphProto* g); + private: + const ::onnx::GraphProto& _internal_g() const; + ::onnx::GraphProto* _internal_mutable_g(); + public: + void unsafe_arena_set_allocated_g( + ::onnx::GraphProto* g); + ::onnx::GraphProto* unsafe_arena_release_g(); + + // optional int64 i = 3; + bool has_i() const; + private: + bool _internal_has_i() const; + public: + void clear_i(); + ::PROTOBUF_NAMESPACE_ID::int64 i() const; + void set_i(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_i() const; + void _internal_set_i(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional float f = 2; + bool has_f() const; + private: + bool _internal_has_f() const; + public: + void clear_f(); + float f() const; + void set_f(float value); + private: + float _internal_f() const; + void _internal_set_f(float value); + public: + + // optional .onnx.AttributeProto.AttributeType type = 20; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + ::onnx::AttributeProto_AttributeType type() const; + void set_type(::onnx::AttributeProto_AttributeType value); + private: + ::onnx::AttributeProto_AttributeType _internal_type() const; + void _internal_set_type(::onnx::AttributeProto_AttributeType value); + public: + + // @@protoc_insertion_point(class_scope:onnx.AttributeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > floats_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > ints_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField strings_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto > tensors_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto > graphs_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr s_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr ref_attr_name_; + ::onnx::TensorProto* t_; + ::onnx::GraphProto* g_; + ::PROTOBUF_NAMESPACE_ID::int64 i_; + float f_; + int type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class ValueInfoProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.ValueInfoProto) */ { + public: + inline ValueInfoProto() : ValueInfoProto(nullptr) {}; + virtual ~ValueInfoProto(); + + ValueInfoProto(const ValueInfoProto& from); + ValueInfoProto(ValueInfoProto&& from) noexcept + : ValueInfoProto() { + *this = ::std::move(from); + } + + inline ValueInfoProto& operator=(const ValueInfoProto& from) { + CopyFrom(from); + return *this; + } + inline ValueInfoProto& operator=(ValueInfoProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ValueInfoProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ValueInfoProto* internal_default_instance() { + return reinterpret_cast( + &_ValueInfoProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 1; + + friend void swap(ValueInfoProto& a, ValueInfoProto& b) { + a.Swap(&b); + } + inline void Swap(ValueInfoProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ValueInfoProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ValueInfoProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ValueInfoProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ValueInfoProto& from); + void MergeFrom(const ValueInfoProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ValueInfoProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.ValueInfoProto"; + } + protected: + explicit ValueInfoProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNameFieldNumber = 1, + kDocStringFieldNumber = 3, + kTypeFieldNumber = 2, + }; + // optional string name = 1; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_name( + std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string doc_string = 3; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_doc_string(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_doc_string( + std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional .onnx.TypeProto type = 2; + bool has_type() const; + private: + bool _internal_has_type() const; + public: + void clear_type(); + const ::onnx::TypeProto& type() const; + ::onnx::TypeProto* release_type(); + ::onnx::TypeProto* mutable_type(); + void set_allocated_type(::onnx::TypeProto* type); + private: + const ::onnx::TypeProto& _internal_type() const; + ::onnx::TypeProto* _internal_mutable_type(); + public: + void unsafe_arena_set_allocated_type( + ::onnx::TypeProto* type); + ::onnx::TypeProto* unsafe_arena_release_type(); + + // @@protoc_insertion_point(class_scope:onnx.ValueInfoProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::onnx::TypeProto* type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class NodeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.NodeProto) */ { + public: + inline NodeProto() : NodeProto(nullptr) {}; + virtual ~NodeProto(); + + NodeProto(const NodeProto& from); + NodeProto(NodeProto&& from) noexcept + : NodeProto() { + *this = ::std::move(from); + } + + inline NodeProto& operator=(const NodeProto& from) { + CopyFrom(from); + return *this; + } + inline NodeProto& operator=(NodeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const NodeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const NodeProto* internal_default_instance() { + return reinterpret_cast( + &_NodeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 2; + + friend void swap(NodeProto& a, NodeProto& b) { + a.Swap(&b); + } + inline void Swap(NodeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(NodeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline NodeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + NodeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const NodeProto& from); + void MergeFrom(const NodeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(NodeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.NodeProto"; + } + protected: + explicit NodeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kInputFieldNumber = 1, + kOutputFieldNumber = 2, + kAttributeFieldNumber = 5, + kNameFieldNumber = 3, + kOpTypeFieldNumber = 4, + kDocStringFieldNumber = 6, + kDomainFieldNumber = 7, + }; + // repeated string input = 1; + int input_size() const; + private: + int _internal_input_size() const; + public: + void clear_input(); + const std::string& input(int index) const; + std::string* mutable_input(int index); + void set_input(int index, const std::string& value); + void set_input(int index, std::string&& value); + void set_input(int index, const char* value); + void set_input(int index, const char* value, size_t size); + std::string* add_input(); + void add_input(const std::string& value); + void add_input(std::string&& value); + void add_input(const char* value); + void add_input(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& input() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_input(); + private: + const std::string& _internal_input(int index) const; + std::string* _internal_add_input(); + public: + + // repeated string output = 2; + int output_size() const; + private: + int _internal_output_size() const; + public: + void clear_output(); + const std::string& output(int index) const; + std::string* mutable_output(int index); + void set_output(int index, const std::string& value); + void set_output(int index, std::string&& value); + void set_output(int index, const char* value); + void set_output(int index, const char* value, size_t size); + std::string* add_output(); + void add_output(const std::string& value); + void add_output(std::string&& value); + void add_output(const char* value); + void add_output(const char* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& output() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_output(); + private: + const std::string& _internal_output(int index) const; + std::string* _internal_add_output(); + public: + + // repeated .onnx.AttributeProto attribute = 5; + int attribute_size() const; + private: + int _internal_attribute_size() const; + public: + void clear_attribute(); + ::onnx::AttributeProto* mutable_attribute(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >* + mutable_attribute(); + private: + const ::onnx::AttributeProto& _internal_attribute(int index) const; + ::onnx::AttributeProto* _internal_add_attribute(); + public: + const ::onnx::AttributeProto& attribute(int index) const; + ::onnx::AttributeProto* add_attribute(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >& + attribute() const; + + // optional string name = 3; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_name( + std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string op_type = 4; + bool has_op_type() const; + private: + bool _internal_has_op_type() const; + public: + void clear_op_type(); + const std::string& op_type() const; + void set_op_type(const std::string& value); + void set_op_type(std::string&& value); + void set_op_type(const char* value); + void set_op_type(const char* value, size_t size); + std::string* mutable_op_type(); + std::string* release_op_type(); + void set_allocated_op_type(std::string* op_type); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_op_type(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_op_type( + std::string* op_type); + private: + const std::string& _internal_op_type() const; + void _internal_set_op_type(const std::string& value); + std::string* _internal_mutable_op_type(); + public: + + // optional string doc_string = 6; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_doc_string(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_doc_string( + std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional string domain = 7; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_domain(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_domain( + std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // @@protoc_insertion_point(class_scope:onnx.NodeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField input_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField output_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto > attribute_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr op_type_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class ModelProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.ModelProto) */ { + public: + inline ModelProto() : ModelProto(nullptr) {}; + virtual ~ModelProto(); + + ModelProto(const ModelProto& from); + ModelProto(ModelProto&& from) noexcept + : ModelProto() { + *this = ::std::move(from); + } + + inline ModelProto& operator=(const ModelProto& from) { + CopyFrom(from); + return *this; + } + inline ModelProto& operator=(ModelProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const ModelProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const ModelProto* internal_default_instance() { + return reinterpret_cast( + &_ModelProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 3; + + friend void swap(ModelProto& a, ModelProto& b) { + a.Swap(&b); + } + inline void Swap(ModelProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(ModelProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline ModelProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + ModelProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const ModelProto& from); + void MergeFrom(const ModelProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(ModelProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.ModelProto"; + } + protected: + explicit ModelProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kOpsetImportFieldNumber = 8, + kMetadataPropsFieldNumber = 14, + kProducerNameFieldNumber = 2, + kProducerVersionFieldNumber = 3, + kDomainFieldNumber = 4, + kDocStringFieldNumber = 6, + kGraphFieldNumber = 7, + kIrVersionFieldNumber = 1, + kModelVersionFieldNumber = 5, + }; + // repeated .onnx.OperatorSetIdProto opset_import = 8; + int opset_import_size() const; + private: + int _internal_opset_import_size() const; + public: + void clear_opset_import(); + ::onnx::OperatorSetIdProto* mutable_opset_import(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >* + mutable_opset_import(); + private: + const ::onnx::OperatorSetIdProto& _internal_opset_import(int index) const; + ::onnx::OperatorSetIdProto* _internal_add_opset_import(); + public: + const ::onnx::OperatorSetIdProto& opset_import(int index) const; + ::onnx::OperatorSetIdProto* add_opset_import(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >& + opset_import() const; + + // repeated .onnx.StringStringEntryProto metadata_props = 14; + int metadata_props_size() const; + private: + int _internal_metadata_props_size() const; + public: + void clear_metadata_props(); + ::onnx::StringStringEntryProto* mutable_metadata_props(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* + mutable_metadata_props(); + private: + const ::onnx::StringStringEntryProto& _internal_metadata_props(int index) const; + ::onnx::StringStringEntryProto* _internal_add_metadata_props(); + public: + const ::onnx::StringStringEntryProto& metadata_props(int index) const; + ::onnx::StringStringEntryProto* add_metadata_props(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& + metadata_props() const; + + // optional string producer_name = 2; + bool has_producer_name() const; + private: + bool _internal_has_producer_name() const; + public: + void clear_producer_name(); + const std::string& producer_name() const; + void set_producer_name(const std::string& value); + void set_producer_name(std::string&& value); + void set_producer_name(const char* value); + void set_producer_name(const char* value, size_t size); + std::string* mutable_producer_name(); + std::string* release_producer_name(); + void set_allocated_producer_name(std::string* producer_name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_producer_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_producer_name( + std::string* producer_name); + private: + const std::string& _internal_producer_name() const; + void _internal_set_producer_name(const std::string& value); + std::string* _internal_mutable_producer_name(); + public: + + // optional string producer_version = 3; + bool has_producer_version() const; + private: + bool _internal_has_producer_version() const; + public: + void clear_producer_version(); + const std::string& producer_version() const; + void set_producer_version(const std::string& value); + void set_producer_version(std::string&& value); + void set_producer_version(const char* value); + void set_producer_version(const char* value, size_t size); + std::string* mutable_producer_version(); + std::string* release_producer_version(); + void set_allocated_producer_version(std::string* producer_version); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_producer_version(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_producer_version( + std::string* producer_version); + private: + const std::string& _internal_producer_version() const; + void _internal_set_producer_version(const std::string& value); + std::string* _internal_mutable_producer_version(); + public: + + // optional string domain = 4; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_domain(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_domain( + std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // optional string doc_string = 6; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_doc_string(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_doc_string( + std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional .onnx.GraphProto graph = 7; + bool has_graph() const; + private: + bool _internal_has_graph() const; + public: + void clear_graph(); + const ::onnx::GraphProto& graph() const; + ::onnx::GraphProto* release_graph(); + ::onnx::GraphProto* mutable_graph(); + void set_allocated_graph(::onnx::GraphProto* graph); + private: + const ::onnx::GraphProto& _internal_graph() const; + ::onnx::GraphProto* _internal_mutable_graph(); + public: + void unsafe_arena_set_allocated_graph( + ::onnx::GraphProto* graph); + ::onnx::GraphProto* unsafe_arena_release_graph(); + + // optional int64 ir_version = 1; + bool has_ir_version() const; + private: + bool _internal_has_ir_version() const; + public: + void clear_ir_version(); + ::PROTOBUF_NAMESPACE_ID::int64 ir_version() const; + void set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_ir_version() const; + void _internal_set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 model_version = 5; + bool has_model_version() const; + private: + bool _internal_has_model_version() const; + public: + void clear_model_version(); + ::PROTOBUF_NAMESPACE_ID::int64 model_version() const; + void set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_model_version() const; + void _internal_set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.ModelProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto > opset_import_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto > metadata_props_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr producer_name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr producer_version_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::onnx::GraphProto* graph_; + ::PROTOBUF_NAMESPACE_ID::int64 ir_version_; + ::PROTOBUF_NAMESPACE_ID::int64 model_version_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class StringStringEntryProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.StringStringEntryProto) */ { + public: + inline StringStringEntryProto() : StringStringEntryProto(nullptr) {}; + virtual ~StringStringEntryProto(); + + StringStringEntryProto(const StringStringEntryProto& from); + StringStringEntryProto(StringStringEntryProto&& from) noexcept + : StringStringEntryProto() { + *this = ::std::move(from); + } + + inline StringStringEntryProto& operator=(const StringStringEntryProto& from) { + CopyFrom(from); + return *this; + } + inline StringStringEntryProto& operator=(StringStringEntryProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const StringStringEntryProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const StringStringEntryProto* internal_default_instance() { + return reinterpret_cast( + &_StringStringEntryProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 4; + + friend void swap(StringStringEntryProto& a, StringStringEntryProto& b) { + a.Swap(&b); + } + inline void Swap(StringStringEntryProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(StringStringEntryProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline StringStringEntryProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + StringStringEntryProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const StringStringEntryProto& from); + void MergeFrom(const StringStringEntryProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(StringStringEntryProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.StringStringEntryProto"; + } + protected: + explicit StringStringEntryProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kKeyFieldNumber = 1, + kValueFieldNumber = 2, + }; + // optional string key = 1; + bool has_key() const; + private: + bool _internal_has_key() const; + public: + void clear_key(); + const std::string& key() const; + void set_key(const std::string& value); + void set_key(std::string&& value); + void set_key(const char* value); + void set_key(const char* value, size_t size); + std::string* mutable_key(); + std::string* release_key(); + void set_allocated_key(std::string* key); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_key(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_key( + std::string* key); + private: + const std::string& _internal_key() const; + void _internal_set_key(const std::string& value); + std::string* _internal_mutable_key(); + public: + + // optional string value = 2; + bool has_value() const; + private: + bool _internal_has_value() const; + public: + void clear_value(); + const std::string& value() const; + void set_value(const std::string& value); + void set_value(std::string&& value); + void set_value(const char* value); + void set_value(const char* value, size_t size); + std::string* mutable_value(); + std::string* release_value(); + void set_allocated_value(std::string* value); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_value(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_value( + std::string* value); + private: + const std::string& _internal_value() const; + void _internal_set_value(const std::string& value); + std::string* _internal_mutable_value(); + public: + + // @@protoc_insertion_point(class_scope:onnx.StringStringEntryProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr key_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr value_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class GraphProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.GraphProto) */ { + public: + inline GraphProto() : GraphProto(nullptr) {}; + virtual ~GraphProto(); + + GraphProto(const GraphProto& from); + GraphProto(GraphProto&& from) noexcept + : GraphProto() { + *this = ::std::move(from); + } + + inline GraphProto& operator=(const GraphProto& from) { + CopyFrom(from); + return *this; + } + inline GraphProto& operator=(GraphProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const GraphProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const GraphProto* internal_default_instance() { + return reinterpret_cast( + &_GraphProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + friend void swap(GraphProto& a, GraphProto& b) { + a.Swap(&b); + } + inline void Swap(GraphProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(GraphProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline GraphProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + GraphProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const GraphProto& from); + void MergeFrom(const GraphProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(GraphProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.GraphProto"; + } + protected: + explicit GraphProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kNodeFieldNumber = 1, + kInitializerFieldNumber = 5, + kInputFieldNumber = 11, + kOutputFieldNumber = 12, + kValueInfoFieldNumber = 13, + kNameFieldNumber = 2, + kDocStringFieldNumber = 10, + }; + // repeated .onnx.NodeProto node = 1; + int node_size() const; + private: + int _internal_node_size() const; + public: + void clear_node(); + ::onnx::NodeProto* mutable_node(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >* + mutable_node(); + private: + const ::onnx::NodeProto& _internal_node(int index) const; + ::onnx::NodeProto* _internal_add_node(); + public: + const ::onnx::NodeProto& node(int index) const; + ::onnx::NodeProto* add_node(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >& + node() const; + + // repeated .onnx.TensorProto initializer = 5; + int initializer_size() const; + private: + int _internal_initializer_size() const; + public: + void clear_initializer(); + ::onnx::TensorProto* mutable_initializer(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* + mutable_initializer(); + private: + const ::onnx::TensorProto& _internal_initializer(int index) const; + ::onnx::TensorProto* _internal_add_initializer(); + public: + const ::onnx::TensorProto& initializer(int index) const; + ::onnx::TensorProto* add_initializer(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& + initializer() const; + + // repeated .onnx.ValueInfoProto input = 11; + int input_size() const; + private: + int _internal_input_size() const; + public: + void clear_input(); + ::onnx::ValueInfoProto* mutable_input(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* + mutable_input(); + private: + const ::onnx::ValueInfoProto& _internal_input(int index) const; + ::onnx::ValueInfoProto* _internal_add_input(); + public: + const ::onnx::ValueInfoProto& input(int index) const; + ::onnx::ValueInfoProto* add_input(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& + input() const; + + // repeated .onnx.ValueInfoProto output = 12; + int output_size() const; + private: + int _internal_output_size() const; + public: + void clear_output(); + ::onnx::ValueInfoProto* mutable_output(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* + mutable_output(); + private: + const ::onnx::ValueInfoProto& _internal_output(int index) const; + ::onnx::ValueInfoProto* _internal_add_output(); + public: + const ::onnx::ValueInfoProto& output(int index) const; + ::onnx::ValueInfoProto* add_output(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& + output() const; + + // repeated .onnx.ValueInfoProto value_info = 13; + int value_info_size() const; + private: + int _internal_value_info_size() const; + public: + void clear_value_info(); + ::onnx::ValueInfoProto* mutable_value_info(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* + mutable_value_info(); + private: + const ::onnx::ValueInfoProto& _internal_value_info(int index) const; + ::onnx::ValueInfoProto* _internal_add_value_info(); + public: + const ::onnx::ValueInfoProto& value_info(int index) const; + ::onnx::ValueInfoProto* add_value_info(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& + value_info() const; + + // optional string name = 2; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_name( + std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional string doc_string = 10; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_doc_string(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_doc_string( + std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // @@protoc_insertion_point(class_scope:onnx.GraphProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto > node_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto > initializer_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto > input_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto > output_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto > value_info_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorProto_Segment PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorProto.Segment) */ { + public: + inline TensorProto_Segment() : TensorProto_Segment(nullptr) {}; + virtual ~TensorProto_Segment(); + + TensorProto_Segment(const TensorProto_Segment& from); + TensorProto_Segment(TensorProto_Segment&& from) noexcept + : TensorProto_Segment() { + *this = ::std::move(from); + } + + inline TensorProto_Segment& operator=(const TensorProto_Segment& from) { + CopyFrom(from); + return *this; + } + inline TensorProto_Segment& operator=(TensorProto_Segment&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorProto_Segment& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorProto_Segment* internal_default_instance() { + return reinterpret_cast( + &_TensorProto_Segment_default_instance_); + } + static constexpr int kIndexInFileMessages = + 6; + + friend void swap(TensorProto_Segment& a, TensorProto_Segment& b) { + a.Swap(&b); + } + inline void Swap(TensorProto_Segment* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorProto_Segment* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorProto_Segment* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorProto_Segment* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorProto_Segment& from); + void MergeFrom(const TensorProto_Segment& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorProto_Segment* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorProto.Segment"; + } + protected: + explicit TensorProto_Segment(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kBeginFieldNumber = 1, + kEndFieldNumber = 2, + }; + // optional int64 begin = 1; + bool has_begin() const; + private: + bool _internal_has_begin() const; + public: + void clear_begin(); + ::PROTOBUF_NAMESPACE_ID::int64 begin() const; + void set_begin(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_begin() const; + void _internal_set_begin(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // optional int64 end = 2; + bool has_end() const; + private: + bool _internal_has_end() const; + public: + void clear_end(); + ::PROTOBUF_NAMESPACE_ID::int64 end() const; + void set_end(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_end() const; + void _internal_set_end(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TensorProto.Segment) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::int64 begin_; + ::PROTOBUF_NAMESPACE_ID::int64 end_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorProto) */ { + public: + inline TensorProto() : TensorProto(nullptr) {}; + virtual ~TensorProto(); + + TensorProto(const TensorProto& from); + TensorProto(TensorProto&& from) noexcept + : TensorProto() { + *this = ::std::move(from); + } + + inline TensorProto& operator=(const TensorProto& from) { + CopyFrom(from); + return *this; + } + inline TensorProto& operator=(TensorProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorProto* internal_default_instance() { + return reinterpret_cast( + &_TensorProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 7; + + friend void swap(TensorProto& a, TensorProto& b) { + a.Swap(&b); + } + inline void Swap(TensorProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorProto& from); + void MergeFrom(const TensorProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorProto"; + } + protected: + explicit TensorProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TensorProto_Segment Segment; + + typedef TensorProto_DataType DataType; + static constexpr DataType UNDEFINED = + TensorProto_DataType_UNDEFINED; + static constexpr DataType FLOAT = + TensorProto_DataType_FLOAT; + static constexpr DataType UINT8 = + TensorProto_DataType_UINT8; + static constexpr DataType INT8 = + TensorProto_DataType_INT8; + static constexpr DataType UINT16 = + TensorProto_DataType_UINT16; + static constexpr DataType INT16 = + TensorProto_DataType_INT16; + static constexpr DataType INT32 = + TensorProto_DataType_INT32; + static constexpr DataType INT64 = + TensorProto_DataType_INT64; + static constexpr DataType STRING = + TensorProto_DataType_STRING; + static constexpr DataType BOOL = + TensorProto_DataType_BOOL; + static constexpr DataType FLOAT16 = + TensorProto_DataType_FLOAT16; + static constexpr DataType DOUBLE = + TensorProto_DataType_DOUBLE; + static constexpr DataType UINT32 = + TensorProto_DataType_UINT32; + static constexpr DataType UINT64 = + TensorProto_DataType_UINT64; + static constexpr DataType COMPLEX64 = + TensorProto_DataType_COMPLEX64; + static constexpr DataType COMPLEX128 = + TensorProto_DataType_COMPLEX128; + static constexpr DataType BFLOAT16 = + TensorProto_DataType_BFLOAT16; + static inline bool DataType_IsValid(int value) { + return TensorProto_DataType_IsValid(value); + } + static constexpr DataType DataType_MIN = + TensorProto_DataType_DataType_MIN; + static constexpr DataType DataType_MAX = + TensorProto_DataType_DataType_MAX; + static constexpr int DataType_ARRAYSIZE = + TensorProto_DataType_DataType_ARRAYSIZE; + static inline const ::PROTOBUF_NAMESPACE_ID::EnumDescriptor* + DataType_descriptor() { + return TensorProto_DataType_descriptor(); + } + template + static inline const std::string& DataType_Name(T enum_t_value) { + static_assert(::std::is_same::value || + ::std::is_integral::value, + "Incorrect type passed to function DataType_Name."); + return TensorProto_DataType_Name(enum_t_value); + } + static inline bool DataType_Parse(const std::string& name, + DataType* value) { + return TensorProto_DataType_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + enum : int { + kDimsFieldNumber = 1, + kFloatDataFieldNumber = 4, + kInt32DataFieldNumber = 5, + kStringDataFieldNumber = 6, + kInt64DataFieldNumber = 7, + kDoubleDataFieldNumber = 10, + kUint64DataFieldNumber = 11, + kNameFieldNumber = 8, + kRawDataFieldNumber = 9, + kDocStringFieldNumber = 12, + kSegmentFieldNumber = 3, + kDataTypeFieldNumber = 2, + }; + // repeated int64 dims = 1; + int dims_size() const; + private: + int _internal_dims_size() const; + public: + void clear_dims(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_dims(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_dims() const; + void _internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_dims(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 dims(int index) const; + void set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_dims(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + dims() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_dims(); + + // repeated float float_data = 4 [packed = true]; + int float_data_size() const; + private: + int _internal_float_data_size() const; + public: + void clear_float_data(); + private: + float _internal_float_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + _internal_float_data() const; + void _internal_add_float_data(float value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + _internal_mutable_float_data(); + public: + float float_data(int index) const; + void set_float_data(int index, float value); + void add_float_data(float value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& + float_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* + mutable_float_data(); + + // repeated int32 int32_data = 5 [packed = true]; + int int32_data_size() const; + private: + int _internal_int32_data_size() const; + public: + void clear_int32_data(); + private: + ::PROTOBUF_NAMESPACE_ID::int32 _internal_int32_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + _internal_int32_data() const; + void _internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + _internal_mutable_int32_data(); + public: + ::PROTOBUF_NAMESPACE_ID::int32 int32_data(int index) const; + void set_int32_data(int index, ::PROTOBUF_NAMESPACE_ID::int32 value); + void add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& + int32_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* + mutable_int32_data(); + + // repeated bytes string_data = 6; + int string_data_size() const; + private: + int _internal_string_data_size() const; + public: + void clear_string_data(); + const std::string& string_data(int index) const; + std::string* mutable_string_data(int index); + void set_string_data(int index, const std::string& value); + void set_string_data(int index, std::string&& value); + void set_string_data(int index, const char* value); + void set_string_data(int index, const void* value, size_t size); + std::string* add_string_data(); + void add_string_data(const std::string& value); + void add_string_data(std::string&& value); + void add_string_data(const char* value); + void add_string_data(const void* value, size_t size); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& string_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* mutable_string_data(); + private: + const std::string& _internal_string_data(int index) const; + std::string* _internal_add_string_data(); + public: + + // repeated int64 int64_data = 7 [packed = true]; + int int64_data_size() const; + private: + int _internal_int64_data_size() const; + public: + void clear_int64_data(); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_int64_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + _internal_int64_data() const; + void _internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + _internal_mutable_int64_data(); + public: + ::PROTOBUF_NAMESPACE_ID::int64 int64_data(int index) const; + void set_int64_data(int index, ::PROTOBUF_NAMESPACE_ID::int64 value); + void add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& + int64_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* + mutable_int64_data(); + + // repeated double double_data = 10 [packed = true]; + int double_data_size() const; + private: + int _internal_double_data_size() const; + public: + void clear_double_data(); + private: + double _internal_double_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + _internal_double_data() const; + void _internal_add_double_data(double value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + _internal_mutable_double_data(); + public: + double double_data(int index) const; + void set_double_data(int index, double value); + void add_double_data(double value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& + double_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* + mutable_double_data(); + + // repeated uint64 uint64_data = 11 [packed = true]; + int uint64_data_size() const; + private: + int _internal_uint64_data_size() const; + public: + void clear_uint64_data(); + private: + ::PROTOBUF_NAMESPACE_ID::uint64 _internal_uint64_data(int index) const; + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + _internal_uint64_data() const; + void _internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value); + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + _internal_mutable_uint64_data(); + public: + ::PROTOBUF_NAMESPACE_ID::uint64 uint64_data(int index) const; + void set_uint64_data(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value); + void add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value); + const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& + uint64_data() const; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* + mutable_uint64_data(); + + // optional string name = 8; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_name( + std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // optional bytes raw_data = 9; + bool has_raw_data() const; + private: + bool _internal_has_raw_data() const; + public: + void clear_raw_data(); + const std::string& raw_data() const; + void set_raw_data(const std::string& value); + void set_raw_data(std::string&& value); + void set_raw_data(const char* value); + void set_raw_data(const void* value, size_t size); + std::string* mutable_raw_data(); + std::string* release_raw_data(); + void set_allocated_raw_data(std::string* raw_data); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_raw_data(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_raw_data( + std::string* raw_data); + private: + const std::string& _internal_raw_data() const; + void _internal_set_raw_data(const std::string& value); + std::string* _internal_mutable_raw_data(); + public: + + // optional string doc_string = 12; + bool has_doc_string() const; + private: + bool _internal_has_doc_string() const; + public: + void clear_doc_string(); + const std::string& doc_string() const; + void set_doc_string(const std::string& value); + void set_doc_string(std::string&& value); + void set_doc_string(const char* value); + void set_doc_string(const char* value, size_t size); + std::string* mutable_doc_string(); + std::string* release_doc_string(); + void set_allocated_doc_string(std::string* doc_string); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_doc_string(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_doc_string( + std::string* doc_string); + private: + const std::string& _internal_doc_string() const; + void _internal_set_doc_string(const std::string& value); + std::string* _internal_mutable_doc_string(); + public: + + // optional .onnx.TensorProto.Segment segment = 3; + bool has_segment() const; + private: + bool _internal_has_segment() const; + public: + void clear_segment(); + const ::onnx::TensorProto_Segment& segment() const; + ::onnx::TensorProto_Segment* release_segment(); + ::onnx::TensorProto_Segment* mutable_segment(); + void set_allocated_segment(::onnx::TensorProto_Segment* segment); + private: + const ::onnx::TensorProto_Segment& _internal_segment() const; + ::onnx::TensorProto_Segment* _internal_mutable_segment(); + public: + void unsafe_arena_set_allocated_segment( + ::onnx::TensorProto_Segment* segment); + ::onnx::TensorProto_Segment* unsafe_arena_release_segment(); + + // optional .onnx.TensorProto.DataType data_type = 2; + bool has_data_type() const; + private: + bool _internal_has_data_type() const; + public: + void clear_data_type(); + ::onnx::TensorProto_DataType data_type() const; + void set_data_type(::onnx::TensorProto_DataType value); + private: + ::onnx::TensorProto_DataType _internal_data_type() const; + void _internal_set_data_type(::onnx::TensorProto_DataType value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TensorProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > dims_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< float > float_data_; + mutable std::atomic _float_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 > int32_data_; + mutable std::atomic _int32_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField string_data_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > int64_data_; + mutable std::atomic _int64_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< double > double_data_; + mutable std::atomic _double_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 > uint64_data_; + mutable std::atomic _uint64_data_cached_byte_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr raw_data_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr doc_string_; + ::onnx::TensorProto_Segment* segment_; + int data_type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto_Dimension PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorShapeProto.Dimension) */ { + public: + inline TensorShapeProto_Dimension() : TensorShapeProto_Dimension(nullptr) {}; + virtual ~TensorShapeProto_Dimension(); + + TensorShapeProto_Dimension(const TensorShapeProto_Dimension& from); + TensorShapeProto_Dimension(TensorShapeProto_Dimension&& from) noexcept + : TensorShapeProto_Dimension() { + *this = ::std::move(from); + } + + inline TensorShapeProto_Dimension& operator=(const TensorShapeProto_Dimension& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto_Dimension& operator=(TensorShapeProto_Dimension&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto_Dimension& default_instance(); + + enum ValueCase { + kDimValue = 1, + kDimParam = 2, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto_Dimension* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_Dimension_default_instance_); + } + static constexpr int kIndexInFileMessages = + 8; + + friend void swap(TensorShapeProto_Dimension& a, TensorShapeProto_Dimension& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto_Dimension* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorShapeProto_Dimension* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto_Dimension* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto_Dimension* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto_Dimension& from); + void MergeFrom(const TensorShapeProto_Dimension& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto_Dimension* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorShapeProto.Dimension"; + } + protected: + explicit TensorShapeProto_Dimension(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDenotationFieldNumber = 3, + kDimValueFieldNumber = 1, + kDimParamFieldNumber = 2, + }; + // optional string denotation = 3; + bool has_denotation() const; + private: + bool _internal_has_denotation() const; + public: + void clear_denotation(); + const std::string& denotation() const; + void set_denotation(const std::string& value); + void set_denotation(std::string&& value); + void set_denotation(const char* value); + void set_denotation(const char* value, size_t size); + std::string* mutable_denotation(); + std::string* release_denotation(); + void set_allocated_denotation(std::string* denotation); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_denotation(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_denotation( + std::string* denotation); + private: + const std::string& _internal_denotation() const; + void _internal_set_denotation(const std::string& value); + std::string* _internal_mutable_denotation(); + public: + + // int64 dim_value = 1; + bool has_dim_value() const; + private: + bool _internal_has_dim_value() const; + public: + void clear_dim_value(); + ::PROTOBUF_NAMESPACE_ID::int64 dim_value() const; + void set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_dim_value() const; + void _internal_set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // string dim_param = 2; + bool has_dim_param() const; + private: + bool _internal_has_dim_param() const; + public: + void clear_dim_param(); + const std::string& dim_param() const; + void set_dim_param(const std::string& value); + void set_dim_param(std::string&& value); + void set_dim_param(const char* value); + void set_dim_param(const char* value, size_t size); + std::string* mutable_dim_param(); + std::string* release_dim_param(); + void set_allocated_dim_param(std::string* dim_param); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_dim_param(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_dim_param( + std::string* dim_param); + private: + const std::string& _internal_dim_param() const; + void _internal_set_dim_param(const std::string& value); + std::string* _internal_mutable_dim_param(); + public: + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:onnx.TensorShapeProto.Dimension) + private: + class _Internal; + void set_has_dim_value(); + void set_has_dim_param(); + + inline bool has_value() const; + inline void clear_has_value(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr denotation_; + union ValueUnion { + ValueUnion() {} + ::PROTOBUF_NAMESPACE_ID::int64 dim_value_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr dim_param_; + } value_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TensorShapeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TensorShapeProto) */ { + public: + inline TensorShapeProto() : TensorShapeProto(nullptr) {}; + virtual ~TensorShapeProto(); + + TensorShapeProto(const TensorShapeProto& from); + TensorShapeProto(TensorShapeProto&& from) noexcept + : TensorShapeProto() { + *this = ::std::move(from); + } + + inline TensorShapeProto& operator=(const TensorShapeProto& from) { + CopyFrom(from); + return *this; + } + inline TensorShapeProto& operator=(TensorShapeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TensorShapeProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TensorShapeProto* internal_default_instance() { + return reinterpret_cast( + &_TensorShapeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 9; + + friend void swap(TensorShapeProto& a, TensorShapeProto& b) { + a.Swap(&b); + } + inline void Swap(TensorShapeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TensorShapeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TensorShapeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TensorShapeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TensorShapeProto& from); + void MergeFrom(const TensorShapeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TensorShapeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TensorShapeProto"; + } + protected: + explicit TensorShapeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TensorShapeProto_Dimension Dimension; + + // accessors ------------------------------------------------------- + + enum : int { + kDimFieldNumber = 1, + }; + // repeated .onnx.TensorShapeProto.Dimension dim = 1; + int dim_size() const; + private: + int _internal_dim_size() const; + public: + void clear_dim(); + ::onnx::TensorShapeProto_Dimension* mutable_dim(int index); + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >* + mutable_dim(); + private: + const ::onnx::TensorShapeProto_Dimension& _internal_dim(int index) const; + ::onnx::TensorShapeProto_Dimension* _internal_add_dim(); + public: + const ::onnx::TensorShapeProto_Dimension& dim(int index) const; + ::onnx::TensorShapeProto_Dimension* add_dim(); + const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >& + dim() const; + + // @@protoc_insertion_point(class_scope:onnx.TensorShapeProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension > dim_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Tensor PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Tensor) */ { + public: + inline TypeProto_Tensor() : TypeProto_Tensor(nullptr) {}; + virtual ~TypeProto_Tensor(); + + TypeProto_Tensor(const TypeProto_Tensor& from); + TypeProto_Tensor(TypeProto_Tensor&& from) noexcept + : TypeProto_Tensor() { + *this = ::std::move(from); + } + + inline TypeProto_Tensor& operator=(const TypeProto_Tensor& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Tensor& operator=(TypeProto_Tensor&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Tensor& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Tensor* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Tensor_default_instance_); + } + static constexpr int kIndexInFileMessages = + 10; + + friend void swap(TypeProto_Tensor& a, TypeProto_Tensor& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Tensor* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_Tensor* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Tensor* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Tensor* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Tensor& from); + void MergeFrom(const TypeProto_Tensor& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Tensor* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Tensor"; + } + protected: + explicit TypeProto_Tensor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kShapeFieldNumber = 2, + kElemTypeFieldNumber = 1, + }; + // optional .onnx.TensorShapeProto shape = 2; + bool has_shape() const; + private: + bool _internal_has_shape() const; + public: + void clear_shape(); + const ::onnx::TensorShapeProto& shape() const; + ::onnx::TensorShapeProto* release_shape(); + ::onnx::TensorShapeProto* mutable_shape(); + void set_allocated_shape(::onnx::TensorShapeProto* shape); + private: + const ::onnx::TensorShapeProto& _internal_shape() const; + ::onnx::TensorShapeProto* _internal_mutable_shape(); + public: + void unsafe_arena_set_allocated_shape( + ::onnx::TensorShapeProto* shape); + ::onnx::TensorShapeProto* unsafe_arena_release_shape(); + + // optional .onnx.TensorProto.DataType elem_type = 1; + bool has_elem_type() const; + private: + bool _internal_has_elem_type() const; + public: + void clear_elem_type(); + ::onnx::TensorProto_DataType elem_type() const; + void set_elem_type(::onnx::TensorProto_DataType value); + private: + ::onnx::TensorProto_DataType _internal_elem_type() const; + void _internal_set_elem_type(::onnx::TensorProto_DataType value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Tensor) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TensorShapeProto* shape_; + int elem_type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Sequence PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Sequence) */ { + public: + inline TypeProto_Sequence() : TypeProto_Sequence(nullptr) {}; + virtual ~TypeProto_Sequence(); + + TypeProto_Sequence(const TypeProto_Sequence& from); + TypeProto_Sequence(TypeProto_Sequence&& from) noexcept + : TypeProto_Sequence() { + *this = ::std::move(from); + } + + inline TypeProto_Sequence& operator=(const TypeProto_Sequence& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Sequence& operator=(TypeProto_Sequence&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Sequence& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Sequence* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Sequence_default_instance_); + } + static constexpr int kIndexInFileMessages = + 11; + + friend void swap(TypeProto_Sequence& a, TypeProto_Sequence& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Sequence* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_Sequence* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Sequence* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Sequence* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Sequence& from); + void MergeFrom(const TypeProto_Sequence& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Sequence* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Sequence"; + } + protected: + explicit TypeProto_Sequence(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kElemTypeFieldNumber = 1, + }; + // optional .onnx.TypeProto elem_type = 1; + bool has_elem_type() const; + private: + bool _internal_has_elem_type() const; + public: + void clear_elem_type(); + const ::onnx::TypeProto& elem_type() const; + ::onnx::TypeProto* release_elem_type(); + ::onnx::TypeProto* mutable_elem_type(); + void set_allocated_elem_type(::onnx::TypeProto* elem_type); + private: + const ::onnx::TypeProto& _internal_elem_type() const; + ::onnx::TypeProto* _internal_mutable_elem_type(); + public: + void unsafe_arena_set_allocated_elem_type( + ::onnx::TypeProto* elem_type); + ::onnx::TypeProto* unsafe_arena_release_elem_type(); + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Sequence) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TypeProto* elem_type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Map PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Map) */ { + public: + inline TypeProto_Map() : TypeProto_Map(nullptr) {}; + virtual ~TypeProto_Map(); + + TypeProto_Map(const TypeProto_Map& from); + TypeProto_Map(TypeProto_Map&& from) noexcept + : TypeProto_Map() { + *this = ::std::move(from); + } + + inline TypeProto_Map& operator=(const TypeProto_Map& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Map& operator=(TypeProto_Map&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Map& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Map* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Map_default_instance_); + } + static constexpr int kIndexInFileMessages = + 12; + + friend void swap(TypeProto_Map& a, TypeProto_Map& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Map* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_Map* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Map* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Map* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Map& from); + void MergeFrom(const TypeProto_Map& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Map* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Map"; + } + protected: + explicit TypeProto_Map(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kValueTypeFieldNumber = 2, + kKeyTypeFieldNumber = 1, + }; + // optional .onnx.TypeProto value_type = 2; + bool has_value_type() const; + private: + bool _internal_has_value_type() const; + public: + void clear_value_type(); + const ::onnx::TypeProto& value_type() const; + ::onnx::TypeProto* release_value_type(); + ::onnx::TypeProto* mutable_value_type(); + void set_allocated_value_type(::onnx::TypeProto* value_type); + private: + const ::onnx::TypeProto& _internal_value_type() const; + ::onnx::TypeProto* _internal_mutable_value_type(); + public: + void unsafe_arena_set_allocated_value_type( + ::onnx::TypeProto* value_type); + ::onnx::TypeProto* unsafe_arena_release_value_type(); + + // optional .onnx.TensorProto.DataType key_type = 1; + bool has_key_type() const; + private: + bool _internal_has_key_type() const; + public: + void clear_key_type(); + ::onnx::TensorProto_DataType key_type() const; + void set_key_type(::onnx::TensorProto_DataType value); + private: + ::onnx::TensorProto_DataType _internal_key_type() const; + void _internal_set_key_type(::onnx::TensorProto_DataType value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Map) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TypeProto* value_type_; + int key_type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_Opaque PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.Opaque) */ { + public: + inline TypeProto_Opaque() : TypeProto_Opaque(nullptr) {}; + virtual ~TypeProto_Opaque(); + + TypeProto_Opaque(const TypeProto_Opaque& from); + TypeProto_Opaque(TypeProto_Opaque&& from) noexcept + : TypeProto_Opaque() { + *this = ::std::move(from); + } + + inline TypeProto_Opaque& operator=(const TypeProto_Opaque& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_Opaque& operator=(TypeProto_Opaque&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_Opaque& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_Opaque* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_Opaque_default_instance_); + } + static constexpr int kIndexInFileMessages = + 13; + + friend void swap(TypeProto_Opaque& a, TypeProto_Opaque& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_Opaque* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_Opaque* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_Opaque* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_Opaque* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_Opaque& from); + void MergeFrom(const TypeProto_Opaque& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_Opaque* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.Opaque"; + } + protected: + explicit TypeProto_Opaque(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDomainFieldNumber = 1, + kNameFieldNumber = 2, + }; + // optional string domain = 1; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_domain(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_domain( + std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // optional string name = 2; + bool has_name() const; + private: + bool _internal_has_name() const; + public: + void clear_name(); + const std::string& name() const; + void set_name(const std::string& value); + void set_name(std::string&& value); + void set_name(const char* value); + void set_name(const char* value, size_t size); + std::string* mutable_name(); + std::string* release_name(); + void set_allocated_name(std::string* name); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_name(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_name( + std::string* name); + private: + const std::string& _internal_name() const; + void _internal_set_name(const std::string& value); + std::string* _internal_mutable_name(); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.Opaque) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr name_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto_SparseTensor PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto.SparseTensor) */ { + public: + inline TypeProto_SparseTensor() : TypeProto_SparseTensor(nullptr) {}; + virtual ~TypeProto_SparseTensor(); + + TypeProto_SparseTensor(const TypeProto_SparseTensor& from); + TypeProto_SparseTensor(TypeProto_SparseTensor&& from) noexcept + : TypeProto_SparseTensor() { + *this = ::std::move(from); + } + + inline TypeProto_SparseTensor& operator=(const TypeProto_SparseTensor& from) { + CopyFrom(from); + return *this; + } + inline TypeProto_SparseTensor& operator=(TypeProto_SparseTensor&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto_SparseTensor& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto_SparseTensor* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_SparseTensor_default_instance_); + } + static constexpr int kIndexInFileMessages = + 14; + + friend void swap(TypeProto_SparseTensor& a, TypeProto_SparseTensor& b) { + a.Swap(&b); + } + inline void Swap(TypeProto_SparseTensor* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto_SparseTensor* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto_SparseTensor* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto_SparseTensor* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto_SparseTensor& from); + void MergeFrom(const TypeProto_SparseTensor& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto_SparseTensor* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto.SparseTensor"; + } + protected: + explicit TypeProto_SparseTensor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kShapeFieldNumber = 2, + kElemTypeFieldNumber = 1, + }; + // optional .onnx.TensorShapeProto shape = 2; + bool has_shape() const; + private: + bool _internal_has_shape() const; + public: + void clear_shape(); + const ::onnx::TensorShapeProto& shape() const; + ::onnx::TensorShapeProto* release_shape(); + ::onnx::TensorShapeProto* mutable_shape(); + void set_allocated_shape(::onnx::TensorShapeProto* shape); + private: + const ::onnx::TensorShapeProto& _internal_shape() const; + ::onnx::TensorShapeProto* _internal_mutable_shape(); + public: + void unsafe_arena_set_allocated_shape( + ::onnx::TensorShapeProto* shape); + ::onnx::TensorShapeProto* unsafe_arena_release_shape(); + + // optional .onnx.TensorProto.DataType elem_type = 1; + bool has_elem_type() const; + private: + bool _internal_has_elem_type() const; + public: + void clear_elem_type(); + ::onnx::TensorProto_DataType elem_type() const; + void set_elem_type(::onnx::TensorProto_DataType value); + private: + ::onnx::TensorProto_DataType _internal_elem_type() const; + void _internal_set_elem_type(::onnx::TensorProto_DataType value); + public: + + // @@protoc_insertion_point(class_scope:onnx.TypeProto.SparseTensor) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::onnx::TensorShapeProto* shape_; + int elem_type_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class TypeProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.TypeProto) */ { + public: + inline TypeProto() : TypeProto(nullptr) {}; + virtual ~TypeProto(); + + TypeProto(const TypeProto& from); + TypeProto(TypeProto&& from) noexcept + : TypeProto() { + *this = ::std::move(from); + } + + inline TypeProto& operator=(const TypeProto& from) { + CopyFrom(from); + return *this; + } + inline TypeProto& operator=(TypeProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const TypeProto& default_instance(); + + enum ValueCase { + kTensorType = 1, + kSequenceType = 4, + kMapType = 5, + kOpaqueType = 7, + kSparseTensorType = 8, + VALUE_NOT_SET = 0, + }; + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const TypeProto* internal_default_instance() { + return reinterpret_cast( + &_TypeProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 15; + + friend void swap(TypeProto& a, TypeProto& b) { + a.Swap(&b); + } + inline void Swap(TypeProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(TypeProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline TypeProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + TypeProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const TypeProto& from); + void MergeFrom(const TypeProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(TypeProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.TypeProto"; + } + protected: + explicit TypeProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + typedef TypeProto_Tensor Tensor; + typedef TypeProto_Sequence Sequence; + typedef TypeProto_Map Map; + typedef TypeProto_Opaque Opaque; + typedef TypeProto_SparseTensor SparseTensor; + + // accessors ------------------------------------------------------- + + enum : int { + kDenotationFieldNumber = 6, + kTensorTypeFieldNumber = 1, + kSequenceTypeFieldNumber = 4, + kMapTypeFieldNumber = 5, + kOpaqueTypeFieldNumber = 7, + kSparseTensorTypeFieldNumber = 8, + }; + // optional string denotation = 6; + bool has_denotation() const; + private: + bool _internal_has_denotation() const; + public: + void clear_denotation(); + const std::string& denotation() const; + void set_denotation(const std::string& value); + void set_denotation(std::string&& value); + void set_denotation(const char* value); + void set_denotation(const char* value, size_t size); + std::string* mutable_denotation(); + std::string* release_denotation(); + void set_allocated_denotation(std::string* denotation); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_denotation(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_denotation( + std::string* denotation); + private: + const std::string& _internal_denotation() const; + void _internal_set_denotation(const std::string& value); + std::string* _internal_mutable_denotation(); + public: + + // .onnx.TypeProto.Tensor tensor_type = 1; + bool has_tensor_type() const; + private: + bool _internal_has_tensor_type() const; + public: + void clear_tensor_type(); + const ::onnx::TypeProto_Tensor& tensor_type() const; + ::onnx::TypeProto_Tensor* release_tensor_type(); + ::onnx::TypeProto_Tensor* mutable_tensor_type(); + void set_allocated_tensor_type(::onnx::TypeProto_Tensor* tensor_type); + private: + const ::onnx::TypeProto_Tensor& _internal_tensor_type() const; + ::onnx::TypeProto_Tensor* _internal_mutable_tensor_type(); + public: + void unsafe_arena_set_allocated_tensor_type( + ::onnx::TypeProto_Tensor* tensor_type); + ::onnx::TypeProto_Tensor* unsafe_arena_release_tensor_type(); + + // .onnx.TypeProto.Sequence sequence_type = 4; + bool has_sequence_type() const; + private: + bool _internal_has_sequence_type() const; + public: + void clear_sequence_type(); + const ::onnx::TypeProto_Sequence& sequence_type() const; + ::onnx::TypeProto_Sequence* release_sequence_type(); + ::onnx::TypeProto_Sequence* mutable_sequence_type(); + void set_allocated_sequence_type(::onnx::TypeProto_Sequence* sequence_type); + private: + const ::onnx::TypeProto_Sequence& _internal_sequence_type() const; + ::onnx::TypeProto_Sequence* _internal_mutable_sequence_type(); + public: + void unsafe_arena_set_allocated_sequence_type( + ::onnx::TypeProto_Sequence* sequence_type); + ::onnx::TypeProto_Sequence* unsafe_arena_release_sequence_type(); + + // .onnx.TypeProto.Map map_type = 5; + bool has_map_type() const; + private: + bool _internal_has_map_type() const; + public: + void clear_map_type(); + const ::onnx::TypeProto_Map& map_type() const; + ::onnx::TypeProto_Map* release_map_type(); + ::onnx::TypeProto_Map* mutable_map_type(); + void set_allocated_map_type(::onnx::TypeProto_Map* map_type); + private: + const ::onnx::TypeProto_Map& _internal_map_type() const; + ::onnx::TypeProto_Map* _internal_mutable_map_type(); + public: + void unsafe_arena_set_allocated_map_type( + ::onnx::TypeProto_Map* map_type); + ::onnx::TypeProto_Map* unsafe_arena_release_map_type(); + + // .onnx.TypeProto.Opaque opaque_type = 7; + bool has_opaque_type() const; + private: + bool _internal_has_opaque_type() const; + public: + void clear_opaque_type(); + const ::onnx::TypeProto_Opaque& opaque_type() const; + ::onnx::TypeProto_Opaque* release_opaque_type(); + ::onnx::TypeProto_Opaque* mutable_opaque_type(); + void set_allocated_opaque_type(::onnx::TypeProto_Opaque* opaque_type); + private: + const ::onnx::TypeProto_Opaque& _internal_opaque_type() const; + ::onnx::TypeProto_Opaque* _internal_mutable_opaque_type(); + public: + void unsafe_arena_set_allocated_opaque_type( + ::onnx::TypeProto_Opaque* opaque_type); + ::onnx::TypeProto_Opaque* unsafe_arena_release_opaque_type(); + + // .onnx.TypeProto.SparseTensor sparse_tensor_type = 8; + bool has_sparse_tensor_type() const; + private: + bool _internal_has_sparse_tensor_type() const; + public: + void clear_sparse_tensor_type(); + const ::onnx::TypeProto_SparseTensor& sparse_tensor_type() const; + ::onnx::TypeProto_SparseTensor* release_sparse_tensor_type(); + ::onnx::TypeProto_SparseTensor* mutable_sparse_tensor_type(); + void set_allocated_sparse_tensor_type(::onnx::TypeProto_SparseTensor* sparse_tensor_type); + private: + const ::onnx::TypeProto_SparseTensor& _internal_sparse_tensor_type() const; + ::onnx::TypeProto_SparseTensor* _internal_mutable_sparse_tensor_type(); + public: + void unsafe_arena_set_allocated_sparse_tensor_type( + ::onnx::TypeProto_SparseTensor* sparse_tensor_type); + ::onnx::TypeProto_SparseTensor* unsafe_arena_release_sparse_tensor_type(); + + void clear_value(); + ValueCase value_case() const; + // @@protoc_insertion_point(class_scope:onnx.TypeProto) + private: + class _Internal; + void set_has_tensor_type(); + void set_has_sequence_type(); + void set_has_map_type(); + void set_has_opaque_type(); + void set_has_sparse_tensor_type(); + + inline bool has_value() const; + inline void clear_has_value(); + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr denotation_; + union ValueUnion { + ValueUnion() {} + ::onnx::TypeProto_Tensor* tensor_type_; + ::onnx::TypeProto_Sequence* sequence_type_; + ::onnx::TypeProto_Map* map_type_; + ::onnx::TypeProto_Opaque* opaque_type_; + ::onnx::TypeProto_SparseTensor* sparse_tensor_type_; + } value_; + ::PROTOBUF_NAMESPACE_ID::uint32 _oneof_case_[1]; + + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// ------------------------------------------------------------------- + +class OperatorSetIdProto PROTOBUF_FINAL : + public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:onnx.OperatorSetIdProto) */ { + public: + inline OperatorSetIdProto() : OperatorSetIdProto(nullptr) {}; + virtual ~OperatorSetIdProto(); + + OperatorSetIdProto(const OperatorSetIdProto& from); + OperatorSetIdProto(OperatorSetIdProto&& from) noexcept + : OperatorSetIdProto() { + *this = ::std::move(from); + } + + inline OperatorSetIdProto& operator=(const OperatorSetIdProto& from) { + CopyFrom(from); + return *this; + } + inline OperatorSetIdProto& operator=(OperatorSetIdProto&& from) noexcept { + if (GetArena() == from.GetArena()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + + inline const ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet& unknown_fields() const { + return _internal_metadata_.unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(::PROTOBUF_NAMESPACE_ID::UnknownFieldSet::default_instance); + } + inline ::PROTOBUF_NAMESPACE_ID::UnknownFieldSet* mutable_unknown_fields() { + return _internal_metadata_.mutable_unknown_fields<::PROTOBUF_NAMESPACE_ID::UnknownFieldSet>(); + } + + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() { + return GetDescriptor(); + } + static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() { + return GetMetadataStatic().descriptor; + } + static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() { + return GetMetadataStatic().reflection; + } + static const OperatorSetIdProto& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const OperatorSetIdProto* internal_default_instance() { + return reinterpret_cast( + &_OperatorSetIdProto_default_instance_); + } + static constexpr int kIndexInFileMessages = + 16; + + friend void swap(OperatorSetIdProto& a, OperatorSetIdProto& b) { + a.Swap(&b); + } + inline void Swap(OperatorSetIdProto* other) { + if (other == this) return; + if (GetArena() == other->GetArena()) { + InternalSwap(other); + } else { + ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other); + } + } + void UnsafeArenaSwap(OperatorSetIdProto* other) { + if (other == this) return; + GOOGLE_DCHECK(GetArena() == other->GetArena()); + InternalSwap(other); + } + + // implements Message ---------------------------------------------- + + inline OperatorSetIdProto* New() const final { + return CreateMaybeMessage(nullptr); + } + + OperatorSetIdProto* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final; + void CopyFrom(const OperatorSetIdProto& from); + void MergeFrom(const OperatorSetIdProto& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final; + ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize( + ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + inline void SharedCtor(); + inline void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(OperatorSetIdProto* other); + friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata; + static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() { + return "onnx.OperatorSetIdProto"; + } + protected: + explicit OperatorSetIdProto(::PROTOBUF_NAMESPACE_ID::Arena* arena); + private: + static void ArenaDtor(void* object); + inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena); + public: + + ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final; + private: + static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() { + ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_onnx_2dml_2eproto); + return ::descriptor_table_onnx_2dml_2eproto.file_level_metadata[kIndexInFileMessages]; + } + + public: + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + enum : int { + kDomainFieldNumber = 1, + kVersionFieldNumber = 2, + }; + // optional string domain = 1; + bool has_domain() const; + private: + bool _internal_has_domain() const; + public: + void clear_domain(); + const std::string& domain() const; + void set_domain(const std::string& value); + void set_domain(std::string&& value); + void set_domain(const char* value); + void set_domain(const char* value, size_t size); + std::string* mutable_domain(); + std::string* release_domain(); + void set_allocated_domain(std::string* domain); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + std::string* unsafe_arena_release_domain(); + GOOGLE_PROTOBUF_RUNTIME_DEPRECATED("The unsafe_arena_ accessors for" + " string fields are deprecated and will be removed in a" + " future release.") + void unsafe_arena_set_allocated_domain( + std::string* domain); + private: + const std::string& _internal_domain() const; + void _internal_set_domain(const std::string& value); + std::string* _internal_mutable_domain(); + public: + + // optional int64 version = 2; + bool has_version() const; + private: + bool _internal_has_version() const; + public: + void clear_version(); + ::PROTOBUF_NAMESPACE_ID::int64 version() const; + void set_version(::PROTOBUF_NAMESPACE_ID::int64 value); + private: + ::PROTOBUF_NAMESPACE_ID::int64 _internal_version() const; + void _internal_set_version(::PROTOBUF_NAMESPACE_ID::int64 value); + public: + + // @@protoc_insertion_point(class_scope:onnx.OperatorSetIdProto) + private: + class _Internal; + + template friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper; + typedef void InternalArenaConstructable_; + typedef void DestructorSkippable_; + ::PROTOBUF_NAMESPACE_ID::internal::HasBits<1> _has_bits_; + mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; + ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr domain_; + ::PROTOBUF_NAMESPACE_ID::int64 version_; + friend struct ::TableStruct_onnx_2dml_2eproto; +}; +// =================================================================== + + +// =================================================================== + +#ifdef __GNUC__ + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ +// AttributeProto + +// optional string name = 1; +inline bool AttributeProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool AttributeProto::has_name() const { + return _internal_has_name(); +} +inline void AttributeProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& AttributeProto::name() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.name) + return _internal_name(); +} +inline void AttributeProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.name) +} +inline std::string* AttributeProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.name) + return _internal_mutable_name(); +} +inline const std::string& AttributeProto::_internal_name() const { + return name_.Get(); +} +inline void AttributeProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void AttributeProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.name) +} +inline void AttributeProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.name) +} +inline void AttributeProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.name) +} +inline std::string* AttributeProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* AttributeProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void AttributeProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.name) +} +inline std::string* AttributeProto::unsafe_arena_release_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.AttributeProto.name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void AttributeProto::unsafe_arena_set_allocated_name( + std::string* name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.AttributeProto.name) +} + +// optional string ref_attr_name = 21; +inline bool AttributeProto::_internal_has_ref_attr_name() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool AttributeProto::has_ref_attr_name() const { + return _internal_has_ref_attr_name(); +} +inline void AttributeProto::clear_ref_attr_name() { + ref_attr_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& AttributeProto::ref_attr_name() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.ref_attr_name) + return _internal_ref_attr_name(); +} +inline void AttributeProto::set_ref_attr_name(const std::string& value) { + _internal_set_ref_attr_name(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.ref_attr_name) +} +inline std::string* AttributeProto::mutable_ref_attr_name() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.ref_attr_name) + return _internal_mutable_ref_attr_name(); +} +inline const std::string& AttributeProto::_internal_ref_attr_name() const { + return ref_attr_name_.Get(); +} +inline void AttributeProto::_internal_set_ref_attr_name(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void AttributeProto::set_ref_attr_name(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.ref_attr_name) +} +inline void AttributeProto::set_ref_attr_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.ref_attr_name) +} +inline void AttributeProto::set_ref_attr_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000008u; + ref_attr_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.ref_attr_name) +} +inline std::string* AttributeProto::_internal_mutable_ref_attr_name() { + _has_bits_[0] |= 0x00000008u; + return ref_attr_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* AttributeProto::release_ref_attr_name() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.ref_attr_name) + if (!_internal_has_ref_attr_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return ref_attr_name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void AttributeProto::set_allocated_ref_attr_name(std::string* ref_attr_name) { + if (ref_attr_name != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + ref_attr_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ref_attr_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.ref_attr_name) +} +inline std::string* AttributeProto::unsafe_arena_release_ref_attr_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.AttributeProto.ref_attr_name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000008u; + return ref_attr_name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void AttributeProto::unsafe_arena_set_allocated_ref_attr_name( + std::string* ref_attr_name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (ref_attr_name != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + ref_attr_name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ref_attr_name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.AttributeProto.ref_attr_name) +} + +// optional string doc_string = 13; +inline bool AttributeProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool AttributeProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void AttributeProto::clear_doc_string() { + doc_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& AttributeProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.doc_string) + return _internal_doc_string(); +} +inline void AttributeProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.doc_string) +} +inline std::string* AttributeProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& AttributeProto::_internal_doc_string() const { + return doc_string_.Get(); +} +inline void AttributeProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void AttributeProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.doc_string) +} +inline void AttributeProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.doc_string) +} +inline void AttributeProto::set_doc_string(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.doc_string) +} +inline std::string* AttributeProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000004u; + return doc_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* AttributeProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return doc_string_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void AttributeProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.doc_string) +} +inline std::string* AttributeProto::unsafe_arena_release_doc_string() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.AttributeProto.doc_string) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000004u; + return doc_string_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void AttributeProto::unsafe_arena_set_allocated_doc_string( + std::string* doc_string) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + doc_string, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.AttributeProto.doc_string) +} + +// optional .onnx.AttributeProto.AttributeType type = 20; +inline bool AttributeProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000100u) != 0; + return value; +} +inline bool AttributeProto::has_type() const { + return _internal_has_type(); +} +inline void AttributeProto::clear_type() { + type_ = 0; + _has_bits_[0] &= ~0x00000100u; +} +inline ::onnx::AttributeProto_AttributeType AttributeProto::_internal_type() const { + return static_cast< ::onnx::AttributeProto_AttributeType >(type_); +} +inline ::onnx::AttributeProto_AttributeType AttributeProto::type() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.type) + return _internal_type(); +} +inline void AttributeProto::_internal_set_type(::onnx::AttributeProto_AttributeType value) { + assert(::onnx::AttributeProto_AttributeType_IsValid(value)); + _has_bits_[0] |= 0x00000100u; + type_ = value; +} +inline void AttributeProto::set_type(::onnx::AttributeProto_AttributeType value) { + _internal_set_type(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.type) +} + +// optional float f = 2; +inline bool AttributeProto::_internal_has_f() const { + bool value = (_has_bits_[0] & 0x00000080u) != 0; + return value; +} +inline bool AttributeProto::has_f() const { + return _internal_has_f(); +} +inline void AttributeProto::clear_f() { + f_ = 0; + _has_bits_[0] &= ~0x00000080u; +} +inline float AttributeProto::_internal_f() const { + return f_; +} +inline float AttributeProto::f() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.f) + return _internal_f(); +} +inline void AttributeProto::_internal_set_f(float value) { + _has_bits_[0] |= 0x00000080u; + f_ = value; +} +inline void AttributeProto::set_f(float value) { + _internal_set_f(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.f) +} + +// optional int64 i = 3; +inline bool AttributeProto::_internal_has_i() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool AttributeProto::has_i() const { + return _internal_has_i(); +} +inline void AttributeProto::clear_i() { + i_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::_internal_i() const { + return i_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::i() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.i) + return _internal_i(); +} +inline void AttributeProto::_internal_set_i(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000040u; + i_ = value; +} +inline void AttributeProto::set_i(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_i(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.i) +} + +// optional bytes s = 4; +inline bool AttributeProto::_internal_has_s() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool AttributeProto::has_s() const { + return _internal_has_s(); +} +inline void AttributeProto::clear_s() { + s_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& AttributeProto::s() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.s) + return _internal_s(); +} +inline void AttributeProto::set_s(const std::string& value) { + _internal_set_s(value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.s) +} +inline std::string* AttributeProto::mutable_s() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.s) + return _internal_mutable_s(); +} +inline const std::string& AttributeProto::_internal_s() const { + return s_.Get(); +} +inline void AttributeProto::_internal_set_s(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + s_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void AttributeProto::set_s(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + s_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.AttributeProto.s) +} +inline void AttributeProto::set_s(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + s_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.s) +} +inline void AttributeProto::set_s(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + s_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.s) +} +inline std::string* AttributeProto::_internal_mutable_s() { + _has_bits_[0] |= 0x00000002u; + return s_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* AttributeProto::release_s() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.s) + if (!_internal_has_s()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return s_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void AttributeProto::set_allocated_s(std::string* s) { + if (s != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + s_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), s, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.s) +} +inline std::string* AttributeProto::unsafe_arena_release_s() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.AttributeProto.s) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return s_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void AttributeProto::unsafe_arena_set_allocated_s( + std::string* s) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (s != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + s_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + s, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.AttributeProto.s) +} + +// optional .onnx.TensorProto t = 5; +inline bool AttributeProto::_internal_has_t() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + PROTOBUF_ASSUME(!value || t_ != nullptr); + return value; +} +inline bool AttributeProto::has_t() const { + return _internal_has_t(); +} +inline void AttributeProto::clear_t() { + if (t_ != nullptr) t_->Clear(); + _has_bits_[0] &= ~0x00000010u; +} +inline const ::onnx::TensorProto& AttributeProto::_internal_t() const { + const ::onnx::TensorProto* p = t_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorProto_default_instance_); +} +inline const ::onnx::TensorProto& AttributeProto::t() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.t) + return _internal_t(); +} +inline void AttributeProto::unsafe_arena_set_allocated_t( + ::onnx::TensorProto* t) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(t_); + } + t_ = t; + if (t) { + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.AttributeProto.t) +} +inline ::onnx::TensorProto* AttributeProto::release_t() { + auto temp = unsafe_arena_release_t(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TensorProto* AttributeProto::unsafe_arena_release_t() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.t) + _has_bits_[0] &= ~0x00000010u; + ::onnx::TensorProto* temp = t_; + t_ = nullptr; + return temp; +} +inline ::onnx::TensorProto* AttributeProto::_internal_mutable_t() { + _has_bits_[0] |= 0x00000010u; + if (t_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorProto>(GetArena()); + t_ = p; + } + return t_; +} +inline ::onnx::TensorProto* AttributeProto::mutable_t() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.t) + return _internal_mutable_t(); +} +inline void AttributeProto::set_allocated_t(::onnx::TensorProto* t) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete t_; + } + if (t) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(t); + if (message_arena != submessage_arena) { + t = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, t, submessage_arena); + } + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + t_ = t; + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.t) +} + +// optional .onnx.GraphProto g = 6; +inline bool AttributeProto::_internal_has_g() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + PROTOBUF_ASSUME(!value || g_ != nullptr); + return value; +} +inline bool AttributeProto::has_g() const { + return _internal_has_g(); +} +inline void AttributeProto::clear_g() { + if (g_ != nullptr) g_->Clear(); + _has_bits_[0] &= ~0x00000020u; +} +inline const ::onnx::GraphProto& AttributeProto::_internal_g() const { + const ::onnx::GraphProto* p = g_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_GraphProto_default_instance_); +} +inline const ::onnx::GraphProto& AttributeProto::g() const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.g) + return _internal_g(); +} +inline void AttributeProto::unsafe_arena_set_allocated_g( + ::onnx::GraphProto* g) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(g_); + } + g_ = g; + if (g) { + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.AttributeProto.g) +} +inline ::onnx::GraphProto* AttributeProto::release_g() { + auto temp = unsafe_arena_release_g(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::GraphProto* AttributeProto::unsafe_arena_release_g() { + // @@protoc_insertion_point(field_release:onnx.AttributeProto.g) + _has_bits_[0] &= ~0x00000020u; + ::onnx::GraphProto* temp = g_; + g_ = nullptr; + return temp; +} +inline ::onnx::GraphProto* AttributeProto::_internal_mutable_g() { + _has_bits_[0] |= 0x00000020u; + if (g_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::GraphProto>(GetArena()); + g_ = p; + } + return g_; +} +inline ::onnx::GraphProto* AttributeProto::mutable_g() { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.g) + return _internal_mutable_g(); +} +inline void AttributeProto::set_allocated_g(::onnx::GraphProto* g) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete g_; + } + if (g) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(g); + if (message_arena != submessage_arena) { + g = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, g, submessage_arena); + } + _has_bits_[0] |= 0x00000020u; + } else { + _has_bits_[0] &= ~0x00000020u; + } + g_ = g; + // @@protoc_insertion_point(field_set_allocated:onnx.AttributeProto.g) +} + +// repeated float floats = 7; +inline int AttributeProto::_internal_floats_size() const { + return floats_.size(); +} +inline int AttributeProto::floats_size() const { + return _internal_floats_size(); +} +inline void AttributeProto::clear_floats() { + floats_.Clear(); +} +inline float AttributeProto::_internal_floats(int index) const { + return floats_.Get(index); +} +inline float AttributeProto::floats(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.floats) + return _internal_floats(index); +} +inline void AttributeProto::set_floats(int index, float value) { + floats_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.floats) +} +inline void AttributeProto::_internal_add_floats(float value) { + floats_.Add(value); +} +inline void AttributeProto::add_floats(float value) { + _internal_add_floats(value); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.floats) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +AttributeProto::_internal_floats() const { + return floats_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +AttributeProto::floats() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.floats) + return _internal_floats(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +AttributeProto::_internal_mutable_floats() { + return &floats_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +AttributeProto::mutable_floats() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.floats) + return _internal_mutable_floats(); +} + +// repeated int64 ints = 8; +inline int AttributeProto::_internal_ints_size() const { + return ints_.size(); +} +inline int AttributeProto::ints_size() const { + return _internal_ints_size(); +} +inline void AttributeProto::clear_ints() { + ints_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::_internal_ints(int index) const { + return ints_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 AttributeProto::ints(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.ints) + return _internal_ints(index); +} +inline void AttributeProto::set_ints(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + ints_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.AttributeProto.ints) +} +inline void AttributeProto::_internal_add_ints(::PROTOBUF_NAMESPACE_ID::int64 value) { + ints_.Add(value); +} +inline void AttributeProto::add_ints(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_ints(value); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.ints) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +AttributeProto::_internal_ints() const { + return ints_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +AttributeProto::ints() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.ints) + return _internal_ints(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +AttributeProto::_internal_mutable_ints() { + return &ints_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +AttributeProto::mutable_ints() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.ints) + return _internal_mutable_ints(); +} + +// repeated bytes strings = 9; +inline int AttributeProto::_internal_strings_size() const { + return strings_.size(); +} +inline int AttributeProto::strings_size() const { + return _internal_strings_size(); +} +inline void AttributeProto::clear_strings() { + strings_.Clear(); +} +inline std::string* AttributeProto::add_strings() { + // @@protoc_insertion_point(field_add_mutable:onnx.AttributeProto.strings) + return _internal_add_strings(); +} +inline const std::string& AttributeProto::_internal_strings(int index) const { + return strings_.Get(index); +} +inline const std::string& AttributeProto::strings(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.strings) + return _internal_strings(index); +} +inline std::string* AttributeProto::mutable_strings(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.strings) + return strings_.Mutable(index); +} +inline void AttributeProto::set_strings(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.AttributeProto.strings) + strings_.Mutable(index)->assign(value); +} +inline void AttributeProto::set_strings(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.AttributeProto.strings) + strings_.Mutable(index)->assign(std::move(value)); +} +inline void AttributeProto::set_strings(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + strings_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.AttributeProto.strings) +} +inline void AttributeProto::set_strings(int index, const void* value, size_t size) { + strings_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.AttributeProto.strings) +} +inline std::string* AttributeProto::_internal_add_strings() { + return strings_.Add(); +} +inline void AttributeProto::add_strings(const std::string& value) { + strings_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.strings) +} +inline void AttributeProto::add_strings(std::string&& value) { + strings_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.AttributeProto.strings) +} +inline void AttributeProto::add_strings(const char* value) { + GOOGLE_DCHECK(value != nullptr); + strings_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.AttributeProto.strings) +} +inline void AttributeProto::add_strings(const void* value, size_t size) { + strings_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.AttributeProto.strings) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +AttributeProto::strings() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.strings) + return strings_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +AttributeProto::mutable_strings() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.strings) + return &strings_; +} + +// repeated .onnx.TensorProto tensors = 10; +inline int AttributeProto::_internal_tensors_size() const { + return tensors_.size(); +} +inline int AttributeProto::tensors_size() const { + return _internal_tensors_size(); +} +inline void AttributeProto::clear_tensors() { + tensors_.Clear(); +} +inline ::onnx::TensorProto* AttributeProto::mutable_tensors(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.tensors) + return tensors_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* +AttributeProto::mutable_tensors() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.tensors) + return &tensors_; +} +inline const ::onnx::TensorProto& AttributeProto::_internal_tensors(int index) const { + return tensors_.Get(index); +} +inline const ::onnx::TensorProto& AttributeProto::tensors(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.tensors) + return _internal_tensors(index); +} +inline ::onnx::TensorProto* AttributeProto::_internal_add_tensors() { + return tensors_.Add(); +} +inline ::onnx::TensorProto* AttributeProto::add_tensors() { + // @@protoc_insertion_point(field_add:onnx.AttributeProto.tensors) + return _internal_add_tensors(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& +AttributeProto::tensors() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.tensors) + return tensors_; +} + +// repeated .onnx.GraphProto graphs = 11; +inline int AttributeProto::_internal_graphs_size() const { + return graphs_.size(); +} +inline int AttributeProto::graphs_size() const { + return _internal_graphs_size(); +} +inline void AttributeProto::clear_graphs() { + graphs_.Clear(); +} +inline ::onnx::GraphProto* AttributeProto::mutable_graphs(int index) { + // @@protoc_insertion_point(field_mutable:onnx.AttributeProto.graphs) + return graphs_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >* +AttributeProto::mutable_graphs() { + // @@protoc_insertion_point(field_mutable_list:onnx.AttributeProto.graphs) + return &graphs_; +} +inline const ::onnx::GraphProto& AttributeProto::_internal_graphs(int index) const { + return graphs_.Get(index); +} +inline const ::onnx::GraphProto& AttributeProto::graphs(int index) const { + // @@protoc_insertion_point(field_get:onnx.AttributeProto.graphs) + return _internal_graphs(index); +} +inline ::onnx::GraphProto* AttributeProto::_internal_add_graphs() { + return graphs_.Add(); +} +inline ::onnx::GraphProto* AttributeProto::add_graphs() { + // @@protoc_insertion_point(field_add:onnx.AttributeProto.graphs) + return _internal_add_graphs(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::GraphProto >& +AttributeProto::graphs() const { + // @@protoc_insertion_point(field_list:onnx.AttributeProto.graphs) + return graphs_; +} + +// ------------------------------------------------------------------- + +// ValueInfoProto + +// optional string name = 1; +inline bool ValueInfoProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ValueInfoProto::has_name() const { + return _internal_has_name(); +} +inline void ValueInfoProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ValueInfoProto::name() const { + // @@protoc_insertion_point(field_get:onnx.ValueInfoProto.name) + return _internal_name(); +} +inline void ValueInfoProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.ValueInfoProto.name) +} +inline std::string* ValueInfoProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.ValueInfoProto.name) + return _internal_mutable_name(); +} +inline const std::string& ValueInfoProto::_internal_name() const { + return name_.Get(); +} +inline void ValueInfoProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ValueInfoProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.ValueInfoProto.name) +} +inline void ValueInfoProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.ValueInfoProto.name) +} +inline void ValueInfoProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.ValueInfoProto.name) +} +inline std::string* ValueInfoProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ValueInfoProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.ValueInfoProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ValueInfoProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.ValueInfoProto.name) +} +inline std::string* ValueInfoProto::unsafe_arena_release_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.ValueInfoProto.name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void ValueInfoProto::unsafe_arena_set_allocated_name( + std::string* name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ValueInfoProto.name) +} + +// optional .onnx.TypeProto type = 2; +inline bool ValueInfoProto::_internal_has_type() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + PROTOBUF_ASSUME(!value || type_ != nullptr); + return value; +} +inline bool ValueInfoProto::has_type() const { + return _internal_has_type(); +} +inline void ValueInfoProto::clear_type() { + if (type_ != nullptr) type_->Clear(); + _has_bits_[0] &= ~0x00000004u; +} +inline const ::onnx::TypeProto& ValueInfoProto::_internal_type() const { + const ::onnx::TypeProto* p = type_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TypeProto_default_instance_); +} +inline const ::onnx::TypeProto& ValueInfoProto::type() const { + // @@protoc_insertion_point(field_get:onnx.ValueInfoProto.type) + return _internal_type(); +} +inline void ValueInfoProto::unsafe_arena_set_allocated_type( + ::onnx::TypeProto* type) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(type_); + } + type_ = type; + if (type) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ValueInfoProto.type) +} +inline ::onnx::TypeProto* ValueInfoProto::release_type() { + auto temp = unsafe_arena_release_type(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TypeProto* ValueInfoProto::unsafe_arena_release_type() { + // @@protoc_insertion_point(field_release:onnx.ValueInfoProto.type) + _has_bits_[0] &= ~0x00000004u; + ::onnx::TypeProto* temp = type_; + type_ = nullptr; + return temp; +} +inline ::onnx::TypeProto* ValueInfoProto::_internal_mutable_type() { + _has_bits_[0] |= 0x00000004u; + if (type_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TypeProto>(GetArena()); + type_ = p; + } + return type_; +} +inline ::onnx::TypeProto* ValueInfoProto::mutable_type() { + // @@protoc_insertion_point(field_mutable:onnx.ValueInfoProto.type) + return _internal_mutable_type(); +} +inline void ValueInfoProto::set_allocated_type(::onnx::TypeProto* type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete type_; + } + if (type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(type); + if (message_arena != submessage_arena) { + type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, type, submessage_arena); + } + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + type_ = type; + // @@protoc_insertion_point(field_set_allocated:onnx.ValueInfoProto.type) +} + +// optional string doc_string = 3; +inline bool ValueInfoProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool ValueInfoProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void ValueInfoProto::clear_doc_string() { + doc_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& ValueInfoProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.ValueInfoProto.doc_string) + return _internal_doc_string(); +} +inline void ValueInfoProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.ValueInfoProto.doc_string) +} +inline std::string* ValueInfoProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.ValueInfoProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& ValueInfoProto::_internal_doc_string() const { + return doc_string_.Get(); +} +inline void ValueInfoProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ValueInfoProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.ValueInfoProto.doc_string) +} +inline void ValueInfoProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.ValueInfoProto.doc_string) +} +inline void ValueInfoProto::set_doc_string(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.ValueInfoProto.doc_string) +} +inline std::string* ValueInfoProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000002u; + return doc_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ValueInfoProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.ValueInfoProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return doc_string_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ValueInfoProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + doc_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.ValueInfoProto.doc_string) +} +inline std::string* ValueInfoProto::unsafe_arena_release_doc_string() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.ValueInfoProto.doc_string) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return doc_string_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void ValueInfoProto::unsafe_arena_set_allocated_doc_string( + std::string* doc_string) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + doc_string_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + doc_string, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ValueInfoProto.doc_string) +} + +// ------------------------------------------------------------------- + +// NodeProto + +// repeated string input = 1; +inline int NodeProto::_internal_input_size() const { + return input_.size(); +} +inline int NodeProto::input_size() const { + return _internal_input_size(); +} +inline void NodeProto::clear_input() { + input_.Clear(); +} +inline std::string* NodeProto::add_input() { + // @@protoc_insertion_point(field_add_mutable:onnx.NodeProto.input) + return _internal_add_input(); +} +inline const std::string& NodeProto::_internal_input(int index) const { + return input_.Get(index); +} +inline const std::string& NodeProto::input(int index) const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.input) + return _internal_input(index); +} +inline std::string* NodeProto::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.input) + return input_.Mutable(index); +} +inline void NodeProto::set_input(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.input) + input_.Mutable(index)->assign(value); +} +inline void NodeProto::set_input(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.input) + input_.Mutable(index)->assign(std::move(value)); +} +inline void NodeProto::set_input(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + input_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.input) +} +inline void NodeProto::set_input(int index, const char* value, size_t size) { + input_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.input) +} +inline std::string* NodeProto::_internal_add_input() { + return input_.Add(); +} +inline void NodeProto::add_input(const std::string& value) { + input_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.NodeProto.input) +} +inline void NodeProto::add_input(std::string&& value) { + input_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.NodeProto.input) +} +inline void NodeProto::add_input(const char* value) { + GOOGLE_DCHECK(value != nullptr); + input_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.NodeProto.input) +} +inline void NodeProto::add_input(const char* value, size_t size) { + input_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.NodeProto.input) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +NodeProto::input() const { + // @@protoc_insertion_point(field_list:onnx.NodeProto.input) + return input_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +NodeProto::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:onnx.NodeProto.input) + return &input_; +} + +// repeated string output = 2; +inline int NodeProto::_internal_output_size() const { + return output_.size(); +} +inline int NodeProto::output_size() const { + return _internal_output_size(); +} +inline void NodeProto::clear_output() { + output_.Clear(); +} +inline std::string* NodeProto::add_output() { + // @@protoc_insertion_point(field_add_mutable:onnx.NodeProto.output) + return _internal_add_output(); +} +inline const std::string& NodeProto::_internal_output(int index) const { + return output_.Get(index); +} +inline const std::string& NodeProto::output(int index) const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.output) + return _internal_output(index); +} +inline std::string* NodeProto::mutable_output(int index) { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.output) + return output_.Mutable(index); +} +inline void NodeProto::set_output(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.output) + output_.Mutable(index)->assign(value); +} +inline void NodeProto::set_output(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.NodeProto.output) + output_.Mutable(index)->assign(std::move(value)); +} +inline void NodeProto::set_output(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + output_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.output) +} +inline void NodeProto::set_output(int index, const char* value, size_t size) { + output_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.output) +} +inline std::string* NodeProto::_internal_add_output() { + return output_.Add(); +} +inline void NodeProto::add_output(const std::string& value) { + output_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.NodeProto.output) +} +inline void NodeProto::add_output(std::string&& value) { + output_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.NodeProto.output) +} +inline void NodeProto::add_output(const char* value) { + GOOGLE_DCHECK(value != nullptr); + output_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.NodeProto.output) +} +inline void NodeProto::add_output(const char* value, size_t size) { + output_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.NodeProto.output) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +NodeProto::output() const { + // @@protoc_insertion_point(field_list:onnx.NodeProto.output) + return output_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +NodeProto::mutable_output() { + // @@protoc_insertion_point(field_mutable_list:onnx.NodeProto.output) + return &output_; +} + +// optional string name = 3; +inline bool NodeProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool NodeProto::has_name() const { + return _internal_has_name(); +} +inline void NodeProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& NodeProto::name() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.name) + return _internal_name(); +} +inline void NodeProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.name) +} +inline std::string* NodeProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.name) + return _internal_mutable_name(); +} +inline const std::string& NodeProto::_internal_name() const { + return name_.Get(); +} +inline void NodeProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.name) +} +inline void NodeProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.name) +} +inline void NodeProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.name) +} +inline std::string* NodeProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.name) +} +inline std::string* NodeProto::unsafe_arena_release_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.NodeProto.name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void NodeProto::unsafe_arena_set_allocated_name( + std::string* name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.NodeProto.name) +} + +// optional string op_type = 4; +inline bool NodeProto::_internal_has_op_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool NodeProto::has_op_type() const { + return _internal_has_op_type(); +} +inline void NodeProto::clear_op_type() { + op_type_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& NodeProto::op_type() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.op_type) + return _internal_op_type(); +} +inline void NodeProto::set_op_type(const std::string& value) { + _internal_set_op_type(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.op_type) +} +inline std::string* NodeProto::mutable_op_type() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.op_type) + return _internal_mutable_op_type(); +} +inline const std::string& NodeProto::_internal_op_type() const { + return op_type_.Get(); +} +inline void NodeProto::_internal_set_op_type(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_op_type(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + op_type_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.op_type) +} +inline void NodeProto::set_op_type(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.op_type) +} +inline void NodeProto::set_op_type(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + op_type_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.op_type) +} +inline std::string* NodeProto::_internal_mutable_op_type() { + _has_bits_[0] |= 0x00000002u; + return op_type_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_op_type() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.op_type) + if (!_internal_has_op_type()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return op_type_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_op_type(std::string* op_type) { + if (op_type != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + op_type_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), op_type, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.op_type) +} +inline std::string* NodeProto::unsafe_arena_release_op_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.NodeProto.op_type) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return op_type_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void NodeProto::unsafe_arena_set_allocated_op_type( + std::string* op_type) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (op_type != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + op_type_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + op_type, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.NodeProto.op_type) +} + +// optional string domain = 7; +inline bool NodeProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool NodeProto::has_domain() const { + return _internal_has_domain(); +} +inline void NodeProto::clear_domain() { + domain_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& NodeProto::domain() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.domain) + return _internal_domain(); +} +inline void NodeProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.domain) +} +inline std::string* NodeProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& NodeProto::_internal_domain() const { + return domain_.Get(); +} +inline void NodeProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + domain_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.domain) +} +inline void NodeProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.domain) +} +inline void NodeProto::set_domain(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000008u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.domain) +} +inline std::string* NodeProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000008u; + return domain_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_domain() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return domain_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + domain_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.domain) +} +inline std::string* NodeProto::unsafe_arena_release_domain() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.NodeProto.domain) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000008u; + return domain_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void NodeProto::unsafe_arena_set_allocated_domain( + std::string* domain) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (domain != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + domain_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + domain, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.NodeProto.domain) +} + +// repeated .onnx.AttributeProto attribute = 5; +inline int NodeProto::_internal_attribute_size() const { + return attribute_.size(); +} +inline int NodeProto::attribute_size() const { + return _internal_attribute_size(); +} +inline void NodeProto::clear_attribute() { + attribute_.Clear(); +} +inline ::onnx::AttributeProto* NodeProto::mutable_attribute(int index) { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.attribute) + return attribute_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >* +NodeProto::mutable_attribute() { + // @@protoc_insertion_point(field_mutable_list:onnx.NodeProto.attribute) + return &attribute_; +} +inline const ::onnx::AttributeProto& NodeProto::_internal_attribute(int index) const { + return attribute_.Get(index); +} +inline const ::onnx::AttributeProto& NodeProto::attribute(int index) const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.attribute) + return _internal_attribute(index); +} +inline ::onnx::AttributeProto* NodeProto::_internal_add_attribute() { + return attribute_.Add(); +} +inline ::onnx::AttributeProto* NodeProto::add_attribute() { + // @@protoc_insertion_point(field_add:onnx.NodeProto.attribute) + return _internal_add_attribute(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::AttributeProto >& +NodeProto::attribute() const { + // @@protoc_insertion_point(field_list:onnx.NodeProto.attribute) + return attribute_; +} + +// optional string doc_string = 6; +inline bool NodeProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool NodeProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void NodeProto::clear_doc_string() { + doc_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& NodeProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.NodeProto.doc_string) + return _internal_doc_string(); +} +inline void NodeProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.NodeProto.doc_string) +} +inline std::string* NodeProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.NodeProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& NodeProto::_internal_doc_string() const { + return doc_string_.Get(); +} +inline void NodeProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void NodeProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.NodeProto.doc_string) +} +inline void NodeProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.NodeProto.doc_string) +} +inline void NodeProto::set_doc_string(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.NodeProto.doc_string) +} +inline std::string* NodeProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000004u; + return doc_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* NodeProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.NodeProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return doc_string_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void NodeProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.NodeProto.doc_string) +} +inline std::string* NodeProto::unsafe_arena_release_doc_string() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.NodeProto.doc_string) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000004u; + return doc_string_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void NodeProto::unsafe_arena_set_allocated_doc_string( + std::string* doc_string) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + doc_string, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.NodeProto.doc_string) +} + +// ------------------------------------------------------------------- + +// ModelProto + +// optional int64 ir_version = 1; +inline bool ModelProto::_internal_has_ir_version() const { + bool value = (_has_bits_[0] & 0x00000020u) != 0; + return value; +} +inline bool ModelProto::has_ir_version() const { + return _internal_has_ir_version(); +} +inline void ModelProto::clear_ir_version() { + ir_version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000020u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::_internal_ir_version() const { + return ir_version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::ir_version() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.ir_version) + return _internal_ir_version(); +} +inline void ModelProto::_internal_set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000020u; + ir_version_ = value; +} +inline void ModelProto::set_ir_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_ir_version(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.ir_version) +} + +// repeated .onnx.OperatorSetIdProto opset_import = 8; +inline int ModelProto::_internal_opset_import_size() const { + return opset_import_.size(); +} +inline int ModelProto::opset_import_size() const { + return _internal_opset_import_size(); +} +inline void ModelProto::clear_opset_import() { + opset_import_.Clear(); +} +inline ::onnx::OperatorSetIdProto* ModelProto::mutable_opset_import(int index) { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.opset_import) + return opset_import_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >* +ModelProto::mutable_opset_import() { + // @@protoc_insertion_point(field_mutable_list:onnx.ModelProto.opset_import) + return &opset_import_; +} +inline const ::onnx::OperatorSetIdProto& ModelProto::_internal_opset_import(int index) const { + return opset_import_.Get(index); +} +inline const ::onnx::OperatorSetIdProto& ModelProto::opset_import(int index) const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.opset_import) + return _internal_opset_import(index); +} +inline ::onnx::OperatorSetIdProto* ModelProto::_internal_add_opset_import() { + return opset_import_.Add(); +} +inline ::onnx::OperatorSetIdProto* ModelProto::add_opset_import() { + // @@protoc_insertion_point(field_add:onnx.ModelProto.opset_import) + return _internal_add_opset_import(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::OperatorSetIdProto >& +ModelProto::opset_import() const { + // @@protoc_insertion_point(field_list:onnx.ModelProto.opset_import) + return opset_import_; +} + +// optional string producer_name = 2; +inline bool ModelProto::_internal_has_producer_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool ModelProto::has_producer_name() const { + return _internal_has_producer_name(); +} +inline void ModelProto::clear_producer_name() { + producer_name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& ModelProto::producer_name() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.producer_name) + return _internal_producer_name(); +} +inline void ModelProto::set_producer_name(const std::string& value) { + _internal_set_producer_name(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.producer_name) +} +inline std::string* ModelProto::mutable_producer_name() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.producer_name) + return _internal_mutable_producer_name(); +} +inline const std::string& ModelProto::_internal_producer_name() const { + return producer_name_.Get(); +} +inline void ModelProto::_internal_set_producer_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + producer_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ModelProto::set_producer_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + producer_name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.producer_name) +} +inline void ModelProto::set_producer_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + producer_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.producer_name) +} +inline void ModelProto::set_producer_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + producer_name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.producer_name) +} +inline std::string* ModelProto::_internal_mutable_producer_name() { + _has_bits_[0] |= 0x00000001u; + return producer_name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ModelProto::release_producer_name() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.producer_name) + if (!_internal_has_producer_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return producer_name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ModelProto::set_allocated_producer_name(std::string* producer_name) { + if (producer_name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + producer_name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), producer_name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.producer_name) +} +inline std::string* ModelProto::unsafe_arena_release_producer_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.ModelProto.producer_name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return producer_name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void ModelProto::unsafe_arena_set_allocated_producer_name( + std::string* producer_name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (producer_name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + producer_name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + producer_name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ModelProto.producer_name) +} + +// optional string producer_version = 3; +inline bool ModelProto::_internal_has_producer_version() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool ModelProto::has_producer_version() const { + return _internal_has_producer_version(); +} +inline void ModelProto::clear_producer_version() { + producer_version_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& ModelProto::producer_version() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.producer_version) + return _internal_producer_version(); +} +inline void ModelProto::set_producer_version(const std::string& value) { + _internal_set_producer_version(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.producer_version) +} +inline std::string* ModelProto::mutable_producer_version() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.producer_version) + return _internal_mutable_producer_version(); +} +inline const std::string& ModelProto::_internal_producer_version() const { + return producer_version_.Get(); +} +inline void ModelProto::_internal_set_producer_version(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + producer_version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ModelProto::set_producer_version(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + producer_version_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.producer_version) +} +inline void ModelProto::set_producer_version(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + producer_version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.producer_version) +} +inline void ModelProto::set_producer_version(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + producer_version_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.producer_version) +} +inline std::string* ModelProto::_internal_mutable_producer_version() { + _has_bits_[0] |= 0x00000002u; + return producer_version_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ModelProto::release_producer_version() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.producer_version) + if (!_internal_has_producer_version()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return producer_version_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ModelProto::set_allocated_producer_version(std::string* producer_version) { + if (producer_version != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + producer_version_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), producer_version, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.producer_version) +} +inline std::string* ModelProto::unsafe_arena_release_producer_version() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.ModelProto.producer_version) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return producer_version_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void ModelProto::unsafe_arena_set_allocated_producer_version( + std::string* producer_version) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (producer_version != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + producer_version_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + producer_version, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ModelProto.producer_version) +} + +// optional string domain = 4; +inline bool ModelProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool ModelProto::has_domain() const { + return _internal_has_domain(); +} +inline void ModelProto::clear_domain() { + domain_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& ModelProto::domain() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.domain) + return _internal_domain(); +} +inline void ModelProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.domain) +} +inline std::string* ModelProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& ModelProto::_internal_domain() const { + return domain_.Get(); +} +inline void ModelProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ModelProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + domain_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.domain) +} +inline void ModelProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.domain) +} +inline void ModelProto::set_domain(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.domain) +} +inline std::string* ModelProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000004u; + return domain_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ModelProto::release_domain() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return domain_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ModelProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + domain_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.domain) +} +inline std::string* ModelProto::unsafe_arena_release_domain() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.ModelProto.domain) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000004u; + return domain_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void ModelProto::unsafe_arena_set_allocated_domain( + std::string* domain) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (domain != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + domain_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + domain, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ModelProto.domain) +} + +// optional int64 model_version = 5; +inline bool ModelProto::_internal_has_model_version() const { + bool value = (_has_bits_[0] & 0x00000040u) != 0; + return value; +} +inline bool ModelProto::has_model_version() const { + return _internal_has_model_version(); +} +inline void ModelProto::clear_model_version() { + model_version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000040u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::_internal_model_version() const { + return model_version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 ModelProto::model_version() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.model_version) + return _internal_model_version(); +} +inline void ModelProto::_internal_set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000040u; + model_version_ = value; +} +inline void ModelProto::set_model_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_model_version(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.model_version) +} + +// optional string doc_string = 6; +inline bool ModelProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + return value; +} +inline bool ModelProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void ModelProto::clear_doc_string() { + doc_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000008u; +} +inline const std::string& ModelProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.doc_string) + return _internal_doc_string(); +} +inline void ModelProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.ModelProto.doc_string) +} +inline std::string* ModelProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& ModelProto::_internal_doc_string() const { + return doc_string_.Get(); +} +inline void ModelProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000008u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void ModelProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000008u; + doc_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.ModelProto.doc_string) +} +inline void ModelProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000008u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.ModelProto.doc_string) +} +inline void ModelProto::set_doc_string(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000008u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.ModelProto.doc_string) +} +inline std::string* ModelProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000008u; + return doc_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* ModelProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000008u; + return doc_string_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void ModelProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + doc_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.doc_string) +} +inline std::string* ModelProto::unsafe_arena_release_doc_string() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.ModelProto.doc_string) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000008u; + return doc_string_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void ModelProto::unsafe_arena_set_allocated_doc_string( + std::string* doc_string) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + doc_string_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + doc_string, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ModelProto.doc_string) +} + +// optional .onnx.GraphProto graph = 7; +inline bool ModelProto::_internal_has_graph() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + PROTOBUF_ASSUME(!value || graph_ != nullptr); + return value; +} +inline bool ModelProto::has_graph() const { + return _internal_has_graph(); +} +inline void ModelProto::clear_graph() { + if (graph_ != nullptr) graph_->Clear(); + _has_bits_[0] &= ~0x00000010u; +} +inline const ::onnx::GraphProto& ModelProto::_internal_graph() const { + const ::onnx::GraphProto* p = graph_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_GraphProto_default_instance_); +} +inline const ::onnx::GraphProto& ModelProto::graph() const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.graph) + return _internal_graph(); +} +inline void ModelProto::unsafe_arena_set_allocated_graph( + ::onnx::GraphProto* graph) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(graph_); + } + graph_ = graph; + if (graph) { + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.ModelProto.graph) +} +inline ::onnx::GraphProto* ModelProto::release_graph() { + auto temp = unsafe_arena_release_graph(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::GraphProto* ModelProto::unsafe_arena_release_graph() { + // @@protoc_insertion_point(field_release:onnx.ModelProto.graph) + _has_bits_[0] &= ~0x00000010u; + ::onnx::GraphProto* temp = graph_; + graph_ = nullptr; + return temp; +} +inline ::onnx::GraphProto* ModelProto::_internal_mutable_graph() { + _has_bits_[0] |= 0x00000010u; + if (graph_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::GraphProto>(GetArena()); + graph_ = p; + } + return graph_; +} +inline ::onnx::GraphProto* ModelProto::mutable_graph() { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.graph) + return _internal_mutable_graph(); +} +inline void ModelProto::set_allocated_graph(::onnx::GraphProto* graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete graph_; + } + if (graph) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(graph); + if (message_arena != submessage_arena) { + graph = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, graph, submessage_arena); + } + _has_bits_[0] |= 0x00000010u; + } else { + _has_bits_[0] &= ~0x00000010u; + } + graph_ = graph; + // @@protoc_insertion_point(field_set_allocated:onnx.ModelProto.graph) +} + +// repeated .onnx.StringStringEntryProto metadata_props = 14; +inline int ModelProto::_internal_metadata_props_size() const { + return metadata_props_.size(); +} +inline int ModelProto::metadata_props_size() const { + return _internal_metadata_props_size(); +} +inline void ModelProto::clear_metadata_props() { + metadata_props_.Clear(); +} +inline ::onnx::StringStringEntryProto* ModelProto::mutable_metadata_props(int index) { + // @@protoc_insertion_point(field_mutable:onnx.ModelProto.metadata_props) + return metadata_props_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >* +ModelProto::mutable_metadata_props() { + // @@protoc_insertion_point(field_mutable_list:onnx.ModelProto.metadata_props) + return &metadata_props_; +} +inline const ::onnx::StringStringEntryProto& ModelProto::_internal_metadata_props(int index) const { + return metadata_props_.Get(index); +} +inline const ::onnx::StringStringEntryProto& ModelProto::metadata_props(int index) const { + // @@protoc_insertion_point(field_get:onnx.ModelProto.metadata_props) + return _internal_metadata_props(index); +} +inline ::onnx::StringStringEntryProto* ModelProto::_internal_add_metadata_props() { + return metadata_props_.Add(); +} +inline ::onnx::StringStringEntryProto* ModelProto::add_metadata_props() { + // @@protoc_insertion_point(field_add:onnx.ModelProto.metadata_props) + return _internal_add_metadata_props(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::StringStringEntryProto >& +ModelProto::metadata_props() const { + // @@protoc_insertion_point(field_list:onnx.ModelProto.metadata_props) + return metadata_props_; +} + +// ------------------------------------------------------------------- + +// StringStringEntryProto + +// optional string key = 1; +inline bool StringStringEntryProto::_internal_has_key() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool StringStringEntryProto::has_key() const { + return _internal_has_key(); +} +inline void StringStringEntryProto::clear_key() { + key_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& StringStringEntryProto::key() const { + // @@protoc_insertion_point(field_get:onnx.StringStringEntryProto.key) + return _internal_key(); +} +inline void StringStringEntryProto::set_key(const std::string& value) { + _internal_set_key(value); + // @@protoc_insertion_point(field_set:onnx.StringStringEntryProto.key) +} +inline std::string* StringStringEntryProto::mutable_key() { + // @@protoc_insertion_point(field_mutable:onnx.StringStringEntryProto.key) + return _internal_mutable_key(); +} +inline const std::string& StringStringEntryProto::_internal_key() const { + return key_.Get(); +} +inline void StringStringEntryProto::_internal_set_key(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void StringStringEntryProto::set_key(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + key_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.StringStringEntryProto.key) +} +inline void StringStringEntryProto::set_key(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.StringStringEntryProto.key) +} +inline void StringStringEntryProto::set_key(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + key_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.StringStringEntryProto.key) +} +inline std::string* StringStringEntryProto::_internal_mutable_key() { + _has_bits_[0] |= 0x00000001u; + return key_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* StringStringEntryProto::release_key() { + // @@protoc_insertion_point(field_release:onnx.StringStringEntryProto.key) + if (!_internal_has_key()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return key_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void StringStringEntryProto::set_allocated_key(std::string* key) { + if (key != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + key_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), key, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.StringStringEntryProto.key) +} +inline std::string* StringStringEntryProto::unsafe_arena_release_key() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.StringStringEntryProto.key) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return key_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void StringStringEntryProto::unsafe_arena_set_allocated_key( + std::string* key) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (key != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + key_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + key, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.StringStringEntryProto.key) +} + +// optional string value = 2; +inline bool StringStringEntryProto::_internal_has_value() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool StringStringEntryProto::has_value() const { + return _internal_has_value(); +} +inline void StringStringEntryProto::clear_value() { + value_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& StringStringEntryProto::value() const { + // @@protoc_insertion_point(field_get:onnx.StringStringEntryProto.value) + return _internal_value(); +} +inline void StringStringEntryProto::set_value(const std::string& value) { + _internal_set_value(value); + // @@protoc_insertion_point(field_set:onnx.StringStringEntryProto.value) +} +inline std::string* StringStringEntryProto::mutable_value() { + // @@protoc_insertion_point(field_mutable:onnx.StringStringEntryProto.value) + return _internal_mutable_value(); +} +inline const std::string& StringStringEntryProto::_internal_value() const { + return value_.Get(); +} +inline void StringStringEntryProto::_internal_set_value(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void StringStringEntryProto::set_value(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + value_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.StringStringEntryProto.value) +} +inline void StringStringEntryProto::set_value(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.StringStringEntryProto.value) +} +inline void StringStringEntryProto::set_value(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + value_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.StringStringEntryProto.value) +} +inline std::string* StringStringEntryProto::_internal_mutable_value() { + _has_bits_[0] |= 0x00000002u; + return value_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* StringStringEntryProto::release_value() { + // @@protoc_insertion_point(field_release:onnx.StringStringEntryProto.value) + if (!_internal_has_value()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return value_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void StringStringEntryProto::set_allocated_value(std::string* value) { + if (value != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + value_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.StringStringEntryProto.value) +} +inline std::string* StringStringEntryProto::unsafe_arena_release_value() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.StringStringEntryProto.value) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return value_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void StringStringEntryProto::unsafe_arena_set_allocated_value( + std::string* value) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (value != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + value_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + value, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.StringStringEntryProto.value) +} + +// ------------------------------------------------------------------- + +// GraphProto + +// repeated .onnx.NodeProto node = 1; +inline int GraphProto::_internal_node_size() const { + return node_.size(); +} +inline int GraphProto::node_size() const { + return _internal_node_size(); +} +inline void GraphProto::clear_node() { + node_.Clear(); +} +inline ::onnx::NodeProto* GraphProto::mutable_node(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.node) + return node_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >* +GraphProto::mutable_node() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.node) + return &node_; +} +inline const ::onnx::NodeProto& GraphProto::_internal_node(int index) const { + return node_.Get(index); +} +inline const ::onnx::NodeProto& GraphProto::node(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.node) + return _internal_node(index); +} +inline ::onnx::NodeProto* GraphProto::_internal_add_node() { + return node_.Add(); +} +inline ::onnx::NodeProto* GraphProto::add_node() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.node) + return _internal_add_node(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::NodeProto >& +GraphProto::node() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.node) + return node_; +} + +// optional string name = 2; +inline bool GraphProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool GraphProto::has_name() const { + return _internal_has_name(); +} +inline void GraphProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& GraphProto::name() const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.name) + return _internal_name(); +} +inline void GraphProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.GraphProto.name) +} +inline std::string* GraphProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.name) + return _internal_mutable_name(); +} +inline const std::string& GraphProto::_internal_name() const { + return name_.Get(); +} +inline void GraphProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void GraphProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.GraphProto.name) +} +inline void GraphProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.GraphProto.name) +} +inline void GraphProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.GraphProto.name) +} +inline std::string* GraphProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* GraphProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.GraphProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void GraphProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.GraphProto.name) +} +inline std::string* GraphProto::unsafe_arena_release_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.GraphProto.name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void GraphProto::unsafe_arena_set_allocated_name( + std::string* name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.GraphProto.name) +} + +// repeated .onnx.TensorProto initializer = 5; +inline int GraphProto::_internal_initializer_size() const { + return initializer_.size(); +} +inline int GraphProto::initializer_size() const { + return _internal_initializer_size(); +} +inline void GraphProto::clear_initializer() { + initializer_.Clear(); +} +inline ::onnx::TensorProto* GraphProto::mutable_initializer(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.initializer) + return initializer_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >* +GraphProto::mutable_initializer() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.initializer) + return &initializer_; +} +inline const ::onnx::TensorProto& GraphProto::_internal_initializer(int index) const { + return initializer_.Get(index); +} +inline const ::onnx::TensorProto& GraphProto::initializer(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.initializer) + return _internal_initializer(index); +} +inline ::onnx::TensorProto* GraphProto::_internal_add_initializer() { + return initializer_.Add(); +} +inline ::onnx::TensorProto* GraphProto::add_initializer() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.initializer) + return _internal_add_initializer(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorProto >& +GraphProto::initializer() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.initializer) + return initializer_; +} + +// optional string doc_string = 10; +inline bool GraphProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool GraphProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void GraphProto::clear_doc_string() { + doc_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& GraphProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.doc_string) + return _internal_doc_string(); +} +inline void GraphProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.GraphProto.doc_string) +} +inline std::string* GraphProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& GraphProto::_internal_doc_string() const { + return doc_string_.Get(); +} +inline void GraphProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void GraphProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + doc_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.GraphProto.doc_string) +} +inline void GraphProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.GraphProto.doc_string) +} +inline void GraphProto::set_doc_string(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.GraphProto.doc_string) +} +inline std::string* GraphProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000002u; + return doc_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* GraphProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.GraphProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return doc_string_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void GraphProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + doc_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.GraphProto.doc_string) +} +inline std::string* GraphProto::unsafe_arena_release_doc_string() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.GraphProto.doc_string) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return doc_string_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void GraphProto::unsafe_arena_set_allocated_doc_string( + std::string* doc_string) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + doc_string_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + doc_string, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.GraphProto.doc_string) +} + +// repeated .onnx.ValueInfoProto input = 11; +inline int GraphProto::_internal_input_size() const { + return input_.size(); +} +inline int GraphProto::input_size() const { + return _internal_input_size(); +} +inline void GraphProto::clear_input() { + input_.Clear(); +} +inline ::onnx::ValueInfoProto* GraphProto::mutable_input(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.input) + return input_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* +GraphProto::mutable_input() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.input) + return &input_; +} +inline const ::onnx::ValueInfoProto& GraphProto::_internal_input(int index) const { + return input_.Get(index); +} +inline const ::onnx::ValueInfoProto& GraphProto::input(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.input) + return _internal_input(index); +} +inline ::onnx::ValueInfoProto* GraphProto::_internal_add_input() { + return input_.Add(); +} +inline ::onnx::ValueInfoProto* GraphProto::add_input() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.input) + return _internal_add_input(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& +GraphProto::input() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.input) + return input_; +} + +// repeated .onnx.ValueInfoProto output = 12; +inline int GraphProto::_internal_output_size() const { + return output_.size(); +} +inline int GraphProto::output_size() const { + return _internal_output_size(); +} +inline void GraphProto::clear_output() { + output_.Clear(); +} +inline ::onnx::ValueInfoProto* GraphProto::mutable_output(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.output) + return output_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* +GraphProto::mutable_output() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.output) + return &output_; +} +inline const ::onnx::ValueInfoProto& GraphProto::_internal_output(int index) const { + return output_.Get(index); +} +inline const ::onnx::ValueInfoProto& GraphProto::output(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.output) + return _internal_output(index); +} +inline ::onnx::ValueInfoProto* GraphProto::_internal_add_output() { + return output_.Add(); +} +inline ::onnx::ValueInfoProto* GraphProto::add_output() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.output) + return _internal_add_output(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& +GraphProto::output() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.output) + return output_; +} + +// repeated .onnx.ValueInfoProto value_info = 13; +inline int GraphProto::_internal_value_info_size() const { + return value_info_.size(); +} +inline int GraphProto::value_info_size() const { + return _internal_value_info_size(); +} +inline void GraphProto::clear_value_info() { + value_info_.Clear(); +} +inline ::onnx::ValueInfoProto* GraphProto::mutable_value_info(int index) { + // @@protoc_insertion_point(field_mutable:onnx.GraphProto.value_info) + return value_info_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >* +GraphProto::mutable_value_info() { + // @@protoc_insertion_point(field_mutable_list:onnx.GraphProto.value_info) + return &value_info_; +} +inline const ::onnx::ValueInfoProto& GraphProto::_internal_value_info(int index) const { + return value_info_.Get(index); +} +inline const ::onnx::ValueInfoProto& GraphProto::value_info(int index) const { + // @@protoc_insertion_point(field_get:onnx.GraphProto.value_info) + return _internal_value_info(index); +} +inline ::onnx::ValueInfoProto* GraphProto::_internal_add_value_info() { + return value_info_.Add(); +} +inline ::onnx::ValueInfoProto* GraphProto::add_value_info() { + // @@protoc_insertion_point(field_add:onnx.GraphProto.value_info) + return _internal_add_value_info(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::ValueInfoProto >& +GraphProto::value_info() const { + // @@protoc_insertion_point(field_list:onnx.GraphProto.value_info) + return value_info_; +} + +// ------------------------------------------------------------------- + +// TensorProto_Segment + +// optional int64 begin = 1; +inline bool TensorProto_Segment::_internal_has_begin() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorProto_Segment::has_begin() const { + return _internal_has_begin(); +} +inline void TensorProto_Segment::clear_begin() { + begin_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000001u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::_internal_begin() const { + return begin_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::begin() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.Segment.begin) + return _internal_begin(); +} +inline void TensorProto_Segment::_internal_set_begin(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000001u; + begin_ = value; +} +inline void TensorProto_Segment::set_begin(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_begin(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.Segment.begin) +} + +// optional int64 end = 2; +inline bool TensorProto_Segment::_internal_has_end() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TensorProto_Segment::has_end() const { + return _internal_has_end(); +} +inline void TensorProto_Segment::clear_end() { + end_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::_internal_end() const { + return end_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto_Segment::end() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.Segment.end) + return _internal_end(); +} +inline void TensorProto_Segment::_internal_set_end(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + end_ = value; +} +inline void TensorProto_Segment::set_end(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_end(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.Segment.end) +} + +// ------------------------------------------------------------------- + +// TensorProto + +// repeated int64 dims = 1; +inline int TensorProto::_internal_dims_size() const { + return dims_.size(); +} +inline int TensorProto::dims_size() const { + return _internal_dims_size(); +} +inline void TensorProto::clear_dims() { + dims_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_dims(int index) const { + return dims_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::dims(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.dims) + return _internal_dims(index); +} +inline void TensorProto::set_dims(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.dims) +} +inline void TensorProto::_internal_add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + dims_.Add(value); +} +inline void TensorProto::add_dims(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_dims(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.dims) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_dims() const { + return dims_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::dims() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.dims) + return _internal_dims(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_dims() { + return &dims_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_dims() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.dims) + return _internal_mutable_dims(); +} + +// optional .onnx.TensorProto.DataType data_type = 2; +inline bool TensorProto::_internal_has_data_type() const { + bool value = (_has_bits_[0] & 0x00000010u) != 0; + return value; +} +inline bool TensorProto::has_data_type() const { + return _internal_has_data_type(); +} +inline void TensorProto::clear_data_type() { + data_type_ = 0; + _has_bits_[0] &= ~0x00000010u; +} +inline ::onnx::TensorProto_DataType TensorProto::_internal_data_type() const { + return static_cast< ::onnx::TensorProto_DataType >(data_type_); +} +inline ::onnx::TensorProto_DataType TensorProto::data_type() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.data_type) + return _internal_data_type(); +} +inline void TensorProto::_internal_set_data_type(::onnx::TensorProto_DataType value) { + assert(::onnx::TensorProto_DataType_IsValid(value)); + _has_bits_[0] |= 0x00000010u; + data_type_ = value; +} +inline void TensorProto::set_data_type(::onnx::TensorProto_DataType value) { + _internal_set_data_type(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.data_type) +} + +// optional .onnx.TensorProto.Segment segment = 3; +inline bool TensorProto::_internal_has_segment() const { + bool value = (_has_bits_[0] & 0x00000008u) != 0; + PROTOBUF_ASSUME(!value || segment_ != nullptr); + return value; +} +inline bool TensorProto::has_segment() const { + return _internal_has_segment(); +} +inline void TensorProto::clear_segment() { + if (segment_ != nullptr) segment_->Clear(); + _has_bits_[0] &= ~0x00000008u; +} +inline const ::onnx::TensorProto_Segment& TensorProto::_internal_segment() const { + const ::onnx::TensorProto_Segment* p = segment_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorProto_Segment_default_instance_); +} +inline const ::onnx::TensorProto_Segment& TensorProto::segment() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.segment) + return _internal_segment(); +} +inline void TensorProto::unsafe_arena_set_allocated_segment( + ::onnx::TensorProto_Segment* segment) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(segment_); + } + segment_ = segment; + if (segment) { + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TensorProto.segment) +} +inline ::onnx::TensorProto_Segment* TensorProto::release_segment() { + auto temp = unsafe_arena_release_segment(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TensorProto_Segment* TensorProto::unsafe_arena_release_segment() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.segment) + _has_bits_[0] &= ~0x00000008u; + ::onnx::TensorProto_Segment* temp = segment_; + segment_ = nullptr; + return temp; +} +inline ::onnx::TensorProto_Segment* TensorProto::_internal_mutable_segment() { + _has_bits_[0] |= 0x00000008u; + if (segment_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorProto_Segment>(GetArena()); + segment_ = p; + } + return segment_; +} +inline ::onnx::TensorProto_Segment* TensorProto::mutable_segment() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.segment) + return _internal_mutable_segment(); +} +inline void TensorProto::set_allocated_segment(::onnx::TensorProto_Segment* segment) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete segment_; + } + if (segment) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(segment); + if (message_arena != submessage_arena) { + segment = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, segment, submessage_arena); + } + _has_bits_[0] |= 0x00000008u; + } else { + _has_bits_[0] &= ~0x00000008u; + } + segment_ = segment; + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.segment) +} + +// repeated float float_data = 4 [packed = true]; +inline int TensorProto::_internal_float_data_size() const { + return float_data_.size(); +} +inline int TensorProto::float_data_size() const { + return _internal_float_data_size(); +} +inline void TensorProto::clear_float_data() { + float_data_.Clear(); +} +inline float TensorProto::_internal_float_data(int index) const { + return float_data_.Get(index); +} +inline float TensorProto::float_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.float_data) + return _internal_float_data(index); +} +inline void TensorProto::set_float_data(int index, float value) { + float_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.float_data) +} +inline void TensorProto::_internal_add_float_data(float value) { + float_data_.Add(value); +} +inline void TensorProto::add_float_data(float value) { + _internal_add_float_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.float_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::_internal_float_data() const { + return float_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >& +TensorProto::float_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.float_data) + return _internal_float_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::_internal_mutable_float_data() { + return &float_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< float >* +TensorProto::mutable_float_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.float_data) + return _internal_mutable_float_data(); +} + +// repeated int32 int32_data = 5 [packed = true]; +inline int TensorProto::_internal_int32_data_size() const { + return int32_data_.size(); +} +inline int TensorProto::int32_data_size() const { + return _internal_int32_data_size(); +} +inline void TensorProto::clear_int32_data() { + int32_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::_internal_int32_data(int index) const { + return int32_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int32 TensorProto::int32_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.int32_data) + return _internal_int32_data(index); +} +inline void TensorProto::set_int32_data(int index, ::PROTOBUF_NAMESPACE_ID::int32 value) { + int32_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.int32_data) +} +inline void TensorProto::_internal_add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value) { + int32_data_.Add(value); +} +inline void TensorProto::add_int32_data(::PROTOBUF_NAMESPACE_ID::int32 value) { + _internal_add_int32_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.int32_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::_internal_int32_data() const { + return int32_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >& +TensorProto::int32_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.int32_data) + return _internal_int32_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::_internal_mutable_int32_data() { + return &int32_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int32 >* +TensorProto::mutable_int32_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.int32_data) + return _internal_mutable_int32_data(); +} + +// repeated bytes string_data = 6; +inline int TensorProto::_internal_string_data_size() const { + return string_data_.size(); +} +inline int TensorProto::string_data_size() const { + return _internal_string_data_size(); +} +inline void TensorProto::clear_string_data() { + string_data_.Clear(); +} +inline std::string* TensorProto::add_string_data() { + // @@protoc_insertion_point(field_add_mutable:onnx.TensorProto.string_data) + return _internal_add_string_data(); +} +inline const std::string& TensorProto::_internal_string_data(int index) const { + return string_data_.Get(index); +} +inline const std::string& TensorProto::string_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.string_data) + return _internal_string_data(index); +} +inline std::string* TensorProto::mutable_string_data(int index) { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.string_data) + return string_data_.Mutable(index); +} +inline void TensorProto::set_string_data(int index, const std::string& value) { + // @@protoc_insertion_point(field_set:onnx.TensorProto.string_data) + string_data_.Mutable(index)->assign(value); +} +inline void TensorProto::set_string_data(int index, std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.TensorProto.string_data) + string_data_.Mutable(index)->assign(std::move(value)); +} +inline void TensorProto::set_string_data(int index, const char* value) { + GOOGLE_DCHECK(value != nullptr); + string_data_.Mutable(index)->assign(value); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.string_data) +} +inline void TensorProto::set_string_data(int index, const void* value, size_t size) { + string_data_.Mutable(index)->assign( + reinterpret_cast(value), size); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.string_data) +} +inline std::string* TensorProto::_internal_add_string_data() { + return string_data_.Add(); +} +inline void TensorProto::add_string_data(const std::string& value) { + string_data_.Add()->assign(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.string_data) +} +inline void TensorProto::add_string_data(std::string&& value) { + string_data_.Add(std::move(value)); + // @@protoc_insertion_point(field_add:onnx.TensorProto.string_data) +} +inline void TensorProto::add_string_data(const char* value) { + GOOGLE_DCHECK(value != nullptr); + string_data_.Add()->assign(value); + // @@protoc_insertion_point(field_add_char:onnx.TensorProto.string_data) +} +inline void TensorProto::add_string_data(const void* value, size_t size) { + string_data_.Add()->assign(reinterpret_cast(value), size); + // @@protoc_insertion_point(field_add_pointer:onnx.TensorProto.string_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField& +TensorProto::string_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.string_data) + return string_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField* +TensorProto::mutable_string_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.string_data) + return &string_data_; +} + +// repeated int64 int64_data = 7 [packed = true]; +inline int TensorProto::_internal_int64_data_size() const { + return int64_data_.size(); +} +inline int TensorProto::int64_data_size() const { + return _internal_int64_data_size(); +} +inline void TensorProto::clear_int64_data() { + int64_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::_internal_int64_data(int index) const { + return int64_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorProto::int64_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.int64_data) + return _internal_int64_data(index); +} +inline void TensorProto::set_int64_data(int index, ::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.int64_data) +} +inline void TensorProto::_internal_add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value) { + int64_data_.Add(value); +} +inline void TensorProto::add_int64_data(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_add_int64_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.int64_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::_internal_int64_data() const { + return int64_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >& +TensorProto::int64_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.int64_data) + return _internal_int64_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::_internal_mutable_int64_data() { + return &int64_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* +TensorProto::mutable_int64_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.int64_data) + return _internal_mutable_int64_data(); +} + +// optional string name = 8; +inline bool TensorProto::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorProto::has_name() const { + return _internal_has_name(); +} +inline void TensorProto::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorProto::name() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.name) + return _internal_name(); +} +inline void TensorProto::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.name) +} +inline std::string* TensorProto::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.name) + return _internal_mutable_name(); +} +inline const std::string& TensorProto::_internal_name() const { + return name_.Get(); +} +inline void TensorProto::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorProto::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorProto.name) +} +inline void TensorProto::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.name) +} +inline void TensorProto::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.name) +} +inline std::string* TensorProto::_internal_mutable_name() { + _has_bits_[0] |= 0x00000001u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorProto::release_name() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorProto::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.name) +} +inline std::string* TensorProto::unsafe_arena_release_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TensorProto.name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TensorProto::unsafe_arena_set_allocated_name( + std::string* name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (name != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TensorProto.name) +} + +// optional string doc_string = 12; +inline bool TensorProto::_internal_has_doc_string() const { + bool value = (_has_bits_[0] & 0x00000004u) != 0; + return value; +} +inline bool TensorProto::has_doc_string() const { + return _internal_has_doc_string(); +} +inline void TensorProto::clear_doc_string() { + doc_string_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000004u; +} +inline const std::string& TensorProto::doc_string() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.doc_string) + return _internal_doc_string(); +} +inline void TensorProto::set_doc_string(const std::string& value) { + _internal_set_doc_string(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.doc_string) +} +inline std::string* TensorProto::mutable_doc_string() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.doc_string) + return _internal_mutable_doc_string(); +} +inline const std::string& TensorProto::_internal_doc_string() const { + return doc_string_.Get(); +} +inline void TensorProto::_internal_set_doc_string(const std::string& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorProto::set_doc_string(std::string&& value) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorProto.doc_string) +} +inline void TensorProto::set_doc_string(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.doc_string) +} +inline void TensorProto::set_doc_string(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000004u; + doc_string_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.doc_string) +} +inline std::string* TensorProto::_internal_mutable_doc_string() { + _has_bits_[0] |= 0x00000004u; + return doc_string_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorProto::release_doc_string() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.doc_string) + if (!_internal_has_doc_string()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000004u; + return doc_string_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorProto::set_allocated_doc_string(std::string* doc_string) { + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), doc_string, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.doc_string) +} +inline std::string* TensorProto::unsafe_arena_release_doc_string() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TensorProto.doc_string) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000004u; + return doc_string_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TensorProto::unsafe_arena_set_allocated_doc_string( + std::string* doc_string) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (doc_string != nullptr) { + _has_bits_[0] |= 0x00000004u; + } else { + _has_bits_[0] &= ~0x00000004u; + } + doc_string_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + doc_string, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TensorProto.doc_string) +} + +// optional bytes raw_data = 9; +inline bool TensorProto::_internal_has_raw_data() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TensorProto::has_raw_data() const { + return _internal_has_raw_data(); +} +inline void TensorProto::clear_raw_data() { + raw_data_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& TensorProto::raw_data() const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.raw_data) + return _internal_raw_data(); +} +inline void TensorProto::set_raw_data(const std::string& value) { + _internal_set_raw_data(value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.raw_data) +} +inline std::string* TensorProto::mutable_raw_data() { + // @@protoc_insertion_point(field_mutable:onnx.TensorProto.raw_data) + return _internal_mutable_raw_data(); +} +inline const std::string& TensorProto::_internal_raw_data() const { + return raw_data_.Get(); +} +inline void TensorProto::_internal_set_raw_data(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorProto::set_raw_data(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + raw_data_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorProto.raw_data) +} +inline void TensorProto::set_raw_data(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TensorProto.raw_data) +} +inline void TensorProto::set_raw_data(const void* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + raw_data_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorProto.raw_data) +} +inline std::string* TensorProto::_internal_mutable_raw_data() { + _has_bits_[0] |= 0x00000002u; + return raw_data_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorProto::release_raw_data() { + // @@protoc_insertion_point(field_release:onnx.TensorProto.raw_data) + if (!_internal_has_raw_data()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return raw_data_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorProto::set_allocated_raw_data(std::string* raw_data) { + if (raw_data != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + raw_data_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), raw_data, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorProto.raw_data) +} +inline std::string* TensorProto::unsafe_arena_release_raw_data() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TensorProto.raw_data) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return raw_data_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TensorProto::unsafe_arena_set_allocated_raw_data( + std::string* raw_data) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (raw_data != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + raw_data_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + raw_data, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TensorProto.raw_data) +} + +// repeated double double_data = 10 [packed = true]; +inline int TensorProto::_internal_double_data_size() const { + return double_data_.size(); +} +inline int TensorProto::double_data_size() const { + return _internal_double_data_size(); +} +inline void TensorProto::clear_double_data() { + double_data_.Clear(); +} +inline double TensorProto::_internal_double_data(int index) const { + return double_data_.Get(index); +} +inline double TensorProto::double_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.double_data) + return _internal_double_data(index); +} +inline void TensorProto::set_double_data(int index, double value) { + double_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.double_data) +} +inline void TensorProto::_internal_add_double_data(double value) { + double_data_.Add(value); +} +inline void TensorProto::add_double_data(double value) { + _internal_add_double_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.double_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::_internal_double_data() const { + return double_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >& +TensorProto::double_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.double_data) + return _internal_double_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::_internal_mutable_double_data() { + return &double_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< double >* +TensorProto::mutable_double_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.double_data) + return _internal_mutable_double_data(); +} + +// repeated uint64 uint64_data = 11 [packed = true]; +inline int TensorProto::_internal_uint64_data_size() const { + return uint64_data_.size(); +} +inline int TensorProto::uint64_data_size() const { + return _internal_uint64_data_size(); +} +inline void TensorProto::clear_uint64_data() { + uint64_data_.Clear(); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::_internal_uint64_data(int index) const { + return uint64_data_.Get(index); +} +inline ::PROTOBUF_NAMESPACE_ID::uint64 TensorProto::uint64_data(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorProto.uint64_data) + return _internal_uint64_data(index); +} +inline void TensorProto::set_uint64_data(int index, ::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_data_.Set(index, value); + // @@protoc_insertion_point(field_set:onnx.TensorProto.uint64_data) +} +inline void TensorProto::_internal_add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value) { + uint64_data_.Add(value); +} +inline void TensorProto::add_uint64_data(::PROTOBUF_NAMESPACE_ID::uint64 value) { + _internal_add_uint64_data(value); + // @@protoc_insertion_point(field_add:onnx.TensorProto.uint64_data) +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::_internal_uint64_data() const { + return uint64_data_; +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >& +TensorProto::uint64_data() const { + // @@protoc_insertion_point(field_list:onnx.TensorProto.uint64_data) + return _internal_uint64_data(); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::_internal_mutable_uint64_data() { + return &uint64_data_; +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::uint64 >* +TensorProto::mutable_uint64_data() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorProto.uint64_data) + return _internal_mutable_uint64_data(); +} + +// ------------------------------------------------------------------- + +// TensorShapeProto_Dimension + +// int64 dim_value = 1; +inline bool TensorShapeProto_Dimension::_internal_has_dim_value() const { + return value_case() == kDimValue; +} +inline bool TensorShapeProto_Dimension::has_dim_value() const { + return _internal_has_dim_value(); +} +inline void TensorShapeProto_Dimension::set_has_dim_value() { + _oneof_case_[0] = kDimValue; +} +inline void TensorShapeProto_Dimension::clear_dim_value() { + if (_internal_has_dim_value()) { + value_.dim_value_ = PROTOBUF_LONGLONG(0); + clear_has_value(); + } +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dimension::_internal_dim_value() const { + if (_internal_has_dim_value()) { + return value_.dim_value_; + } + return PROTOBUF_LONGLONG(0); +} +inline void TensorShapeProto_Dimension::_internal_set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value) { + if (!_internal_has_dim_value()) { + clear_value(); + set_has_dim_value(); + } + value_.dim_value_ = value; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 TensorShapeProto_Dimension::dim_value() const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.Dimension.dim_value) + return _internal_dim_value(); +} +inline void TensorShapeProto_Dimension::set_dim_value(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_dim_value(value); + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.dim_value) +} + +// string dim_param = 2; +inline bool TensorShapeProto_Dimension::_internal_has_dim_param() const { + return value_case() == kDimParam; +} +inline bool TensorShapeProto_Dimension::has_dim_param() const { + return _internal_has_dim_param(); +} +inline void TensorShapeProto_Dimension::set_has_dim_param() { + _oneof_case_[0] = kDimParam; +} +inline void TensorShapeProto_Dimension::clear_dim_param() { + if (_internal_has_dim_param()) { + value_.dim_param_.Destroy(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + clear_has_value(); + } +} +inline const std::string& TensorShapeProto_Dimension::dim_param() const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.Dimension.dim_param) + return _internal_dim_param(); +} +inline void TensorShapeProto_Dimension::set_dim_param(const std::string& value) { + _internal_set_dim_param(value); + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.dim_param) +} +inline std::string* TensorShapeProto_Dimension::mutable_dim_param() { + // @@protoc_insertion_point(field_mutable:onnx.TensorShapeProto.Dimension.dim_param) + return _internal_mutable_dim_param(); +} +inline const std::string& TensorShapeProto_Dimension::_internal_dim_param() const { + if (_internal_has_dim_param()) { + return value_.dim_param_.Get(); + } + return *&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(); +} +inline void TensorShapeProto_Dimension::_internal_set_dim_param(const std::string& value) { + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorShapeProto_Dimension::set_dim_param(std::string&& value) { + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.dim_param) + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorShapeProto.Dimension.dim_param) +} +inline void TensorShapeProto_Dimension::set_dim_param(const char* value) { + GOOGLE_DCHECK(value != nullptr); + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + ::std::string(value), GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TensorShapeProto.Dimension.dim_param) +} +inline void TensorShapeProto_Dimension::set_dim_param(const char* value, + size_t size) { + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + value_.dim_param_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), + GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorShapeProto.Dimension.dim_param) +} +inline std::string* TensorShapeProto_Dimension::_internal_mutable_dim_param() { + if (!_internal_has_dim_param()) { + clear_value(); + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + return value_.dim_param_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorShapeProto_Dimension::release_dim_param() { + // @@protoc_insertion_point(field_release:onnx.TensorShapeProto.Dimension.dim_param) + if (_internal_has_dim_param()) { + clear_has_value(); + return value_.dim_param_.Release(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void TensorShapeProto_Dimension::set_allocated_dim_param(std::string* dim_param) { + if (has_value()) { + clear_value(); + } + if (dim_param != nullptr) { + set_has_dim_param(); + value_.dim_param_.UnsafeSetDefault(dim_param); + } + // @@protoc_insertion_point(field_set_allocated:onnx.TensorShapeProto.Dimension.dim_param) +} +inline std::string* TensorShapeProto_Dimension::unsafe_arena_release_dim_param() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TensorShapeProto.Dimension.dim_param) + GOOGLE_DCHECK(GetArena() != nullptr); + if (_internal_has_dim_param()) { + clear_has_value(); + return value_.dim_param_.UnsafeArenaRelease( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + } else { + return nullptr; + } +} +inline void TensorShapeProto_Dimension::unsafe_arena_set_allocated_dim_param(std::string* dim_param) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (!_internal_has_dim_param()) { + value_.dim_param_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited()); + } + clear_value(); + if (dim_param) { + set_has_dim_param(); + value_.dim_param_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), dim_param, GetArena()); + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TensorShapeProto.Dimension.dim_param) +} + +// optional string denotation = 3; +inline bool TensorShapeProto_Dimension::_internal_has_denotation() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TensorShapeProto_Dimension::has_denotation() const { + return _internal_has_denotation(); +} +inline void TensorShapeProto_Dimension::clear_denotation() { + denotation_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TensorShapeProto_Dimension::denotation() const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.Dimension.denotation) + return _internal_denotation(); +} +inline void TensorShapeProto_Dimension::set_denotation(const std::string& value) { + _internal_set_denotation(value); + // @@protoc_insertion_point(field_set:onnx.TensorShapeProto.Dimension.denotation) +} +inline std::string* TensorShapeProto_Dimension::mutable_denotation() { + // @@protoc_insertion_point(field_mutable:onnx.TensorShapeProto.Dimension.denotation) + return _internal_mutable_denotation(); +} +inline const std::string& TensorShapeProto_Dimension::_internal_denotation() const { + return denotation_.Get(); +} +inline void TensorShapeProto_Dimension::_internal_set_denotation(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TensorShapeProto_Dimension::set_denotation(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TensorShapeProto.Dimension.denotation) +} +inline void TensorShapeProto_Dimension::set_denotation(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TensorShapeProto.Dimension.denotation) +} +inline void TensorShapeProto_Dimension::set_denotation(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TensorShapeProto.Dimension.denotation) +} +inline std::string* TensorShapeProto_Dimension::_internal_mutable_denotation() { + _has_bits_[0] |= 0x00000001u; + return denotation_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TensorShapeProto_Dimension::release_denotation() { + // @@protoc_insertion_point(field_release:onnx.TensorShapeProto.Dimension.denotation) + if (!_internal_has_denotation()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return denotation_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TensorShapeProto_Dimension::set_allocated_denotation(std::string* denotation) { + if (denotation != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + denotation_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), denotation, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TensorShapeProto.Dimension.denotation) +} +inline std::string* TensorShapeProto_Dimension::unsafe_arena_release_denotation() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TensorShapeProto.Dimension.denotation) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return denotation_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TensorShapeProto_Dimension::unsafe_arena_set_allocated_denotation( + std::string* denotation) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (denotation != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + denotation_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + denotation, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TensorShapeProto.Dimension.denotation) +} + +inline bool TensorShapeProto_Dimension::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void TensorShapeProto_Dimension::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline TensorShapeProto_Dimension::ValueCase TensorShapeProto_Dimension::value_case() const { + return TensorShapeProto_Dimension::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// TensorShapeProto + +// repeated .onnx.TensorShapeProto.Dimension dim = 1; +inline int TensorShapeProto::_internal_dim_size() const { + return dim_.size(); +} +inline int TensorShapeProto::dim_size() const { + return _internal_dim_size(); +} +inline void TensorShapeProto::clear_dim() { + dim_.Clear(); +} +inline ::onnx::TensorShapeProto_Dimension* TensorShapeProto::mutable_dim(int index) { + // @@protoc_insertion_point(field_mutable:onnx.TensorShapeProto.dim) + return dim_.Mutable(index); +} +inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >* +TensorShapeProto::mutable_dim() { + // @@protoc_insertion_point(field_mutable_list:onnx.TensorShapeProto.dim) + return &dim_; +} +inline const ::onnx::TensorShapeProto_Dimension& TensorShapeProto::_internal_dim(int index) const { + return dim_.Get(index); +} +inline const ::onnx::TensorShapeProto_Dimension& TensorShapeProto::dim(int index) const { + // @@protoc_insertion_point(field_get:onnx.TensorShapeProto.dim) + return _internal_dim(index); +} +inline ::onnx::TensorShapeProto_Dimension* TensorShapeProto::_internal_add_dim() { + return dim_.Add(); +} +inline ::onnx::TensorShapeProto_Dimension* TensorShapeProto::add_dim() { + // @@protoc_insertion_point(field_add:onnx.TensorShapeProto.dim) + return _internal_add_dim(); +} +inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField< ::onnx::TensorShapeProto_Dimension >& +TensorShapeProto::dim() const { + // @@protoc_insertion_point(field_list:onnx.TensorShapeProto.dim) + return dim_; +} + +// ------------------------------------------------------------------- + +// TypeProto_Tensor + +// optional .onnx.TensorProto.DataType elem_type = 1; +inline bool TypeProto_Tensor::_internal_has_elem_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_Tensor::has_elem_type() const { + return _internal_has_elem_type(); +} +inline void TypeProto_Tensor::clear_elem_type() { + elem_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::onnx::TensorProto_DataType TypeProto_Tensor::_internal_elem_type() const { + return static_cast< ::onnx::TensorProto_DataType >(elem_type_); +} +inline ::onnx::TensorProto_DataType TypeProto_Tensor::elem_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Tensor.elem_type) + return _internal_elem_type(); +} +inline void TypeProto_Tensor::_internal_set_elem_type(::onnx::TensorProto_DataType value) { + assert(::onnx::TensorProto_DataType_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + elem_type_ = value; +} +inline void TypeProto_Tensor::set_elem_type(::onnx::TensorProto_DataType value) { + _internal_set_elem_type(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.Tensor.elem_type) +} + +// optional .onnx.TensorShapeProto shape = 2; +inline bool TypeProto_Tensor::_internal_has_shape() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || shape_ != nullptr); + return value; +} +inline bool TypeProto_Tensor::has_shape() const { + return _internal_has_shape(); +} +inline void TypeProto_Tensor::clear_shape() { + if (shape_ != nullptr) shape_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TensorShapeProto& TypeProto_Tensor::_internal_shape() const { + const ::onnx::TensorShapeProto* p = shape_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorShapeProto_default_instance_); +} +inline const ::onnx::TensorShapeProto& TypeProto_Tensor::shape() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Tensor.shape) + return _internal_shape(); +} +inline void TypeProto_Tensor::unsafe_arena_set_allocated_shape( + ::onnx::TensorShapeProto* shape) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(shape_); + } + shape_ = shape; + if (shape) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.Tensor.shape) +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::release_shape() { + auto temp = unsafe_arena_release_shape(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::unsafe_arena_release_shape() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Tensor.shape) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TensorShapeProto* temp = shape_; + shape_ = nullptr; + return temp; +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::_internal_mutable_shape() { + _has_bits_[0] |= 0x00000001u; + if (shape_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorShapeProto>(GetArena()); + shape_ = p; + } + return shape_; +} +inline ::onnx::TensorShapeProto* TypeProto_Tensor::mutable_shape() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Tensor.shape) + return _internal_mutable_shape(); +} +inline void TypeProto_Tensor::set_allocated_shape(::onnx::TensorShapeProto* shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete shape_; + } + if (shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(shape); + if (message_arena != submessage_arena) { + shape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, shape, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + shape_ = shape; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Tensor.shape) +} + +// ------------------------------------------------------------------- + +// TypeProto_Sequence + +// optional .onnx.TypeProto elem_type = 1; +inline bool TypeProto_Sequence::_internal_has_elem_type() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || elem_type_ != nullptr); + return value; +} +inline bool TypeProto_Sequence::has_elem_type() const { + return _internal_has_elem_type(); +} +inline void TypeProto_Sequence::clear_elem_type() { + if (elem_type_ != nullptr) elem_type_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TypeProto& TypeProto_Sequence::_internal_elem_type() const { + const ::onnx::TypeProto* p = elem_type_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TypeProto_default_instance_); +} +inline const ::onnx::TypeProto& TypeProto_Sequence::elem_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Sequence.elem_type) + return _internal_elem_type(); +} +inline void TypeProto_Sequence::unsafe_arena_set_allocated_elem_type( + ::onnx::TypeProto* elem_type) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(elem_type_); + } + elem_type_ = elem_type; + if (elem_type) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.Sequence.elem_type) +} +inline ::onnx::TypeProto* TypeProto_Sequence::release_elem_type() { + auto temp = unsafe_arena_release_elem_type(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TypeProto* TypeProto_Sequence::unsafe_arena_release_elem_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Sequence.elem_type) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TypeProto* temp = elem_type_; + elem_type_ = nullptr; + return temp; +} +inline ::onnx::TypeProto* TypeProto_Sequence::_internal_mutable_elem_type() { + _has_bits_[0] |= 0x00000001u; + if (elem_type_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TypeProto>(GetArena()); + elem_type_ = p; + } + return elem_type_; +} +inline ::onnx::TypeProto* TypeProto_Sequence::mutable_elem_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Sequence.elem_type) + return _internal_mutable_elem_type(); +} +inline void TypeProto_Sequence::set_allocated_elem_type(::onnx::TypeProto* elem_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete elem_type_; + } + if (elem_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(elem_type); + if (message_arena != submessage_arena) { + elem_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, elem_type, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + elem_type_ = elem_type; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Sequence.elem_type) +} + +// ------------------------------------------------------------------- + +// TypeProto_Map + +// optional .onnx.TensorProto.DataType key_type = 1; +inline bool TypeProto_Map::_internal_has_key_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_Map::has_key_type() const { + return _internal_has_key_type(); +} +inline void TypeProto_Map::clear_key_type() { + key_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::onnx::TensorProto_DataType TypeProto_Map::_internal_key_type() const { + return static_cast< ::onnx::TensorProto_DataType >(key_type_); +} +inline ::onnx::TensorProto_DataType TypeProto_Map::key_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Map.key_type) + return _internal_key_type(); +} +inline void TypeProto_Map::_internal_set_key_type(::onnx::TensorProto_DataType value) { + assert(::onnx::TensorProto_DataType_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + key_type_ = value; +} +inline void TypeProto_Map::set_key_type(::onnx::TensorProto_DataType value) { + _internal_set_key_type(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.Map.key_type) +} + +// optional .onnx.TypeProto value_type = 2; +inline bool TypeProto_Map::_internal_has_value_type() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || value_type_ != nullptr); + return value; +} +inline bool TypeProto_Map::has_value_type() const { + return _internal_has_value_type(); +} +inline void TypeProto_Map::clear_value_type() { + if (value_type_ != nullptr) value_type_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TypeProto& TypeProto_Map::_internal_value_type() const { + const ::onnx::TypeProto* p = value_type_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TypeProto_default_instance_); +} +inline const ::onnx::TypeProto& TypeProto_Map::value_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Map.value_type) + return _internal_value_type(); +} +inline void TypeProto_Map::unsafe_arena_set_allocated_value_type( + ::onnx::TypeProto* value_type) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(value_type_); + } + value_type_ = value_type; + if (value_type) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.Map.value_type) +} +inline ::onnx::TypeProto* TypeProto_Map::release_value_type() { + auto temp = unsafe_arena_release_value_type(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TypeProto* TypeProto_Map::unsafe_arena_release_value_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Map.value_type) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TypeProto* temp = value_type_; + value_type_ = nullptr; + return temp; +} +inline ::onnx::TypeProto* TypeProto_Map::_internal_mutable_value_type() { + _has_bits_[0] |= 0x00000001u; + if (value_type_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TypeProto>(GetArena()); + value_type_ = p; + } + return value_type_; +} +inline ::onnx::TypeProto* TypeProto_Map::mutable_value_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Map.value_type) + return _internal_mutable_value_type(); +} +inline void TypeProto_Map::set_allocated_value_type(::onnx::TypeProto* value_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete value_type_; + } + if (value_type) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(value_type); + if (message_arena != submessage_arena) { + value_type = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, value_type, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + value_type_ = value_type; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Map.value_type) +} + +// ------------------------------------------------------------------- + +// TypeProto_Opaque + +// optional string domain = 1; +inline bool TypeProto_Opaque::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TypeProto_Opaque::has_domain() const { + return _internal_has_domain(); +} +inline void TypeProto_Opaque::clear_domain() { + domain_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TypeProto_Opaque::domain() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Opaque.domain) + return _internal_domain(); +} +inline void TypeProto_Opaque::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.Opaque.domain) +} +inline std::string* TypeProto_Opaque::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Opaque.domain) + return _internal_mutable_domain(); +} +inline const std::string& TypeProto_Opaque::_internal_domain() const { + return domain_.Get(); +} +inline void TypeProto_Opaque::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TypeProto_Opaque::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + domain_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TypeProto.Opaque.domain) +} +inline void TypeProto_Opaque::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TypeProto.Opaque.domain) +} +inline void TypeProto_Opaque::set_domain(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TypeProto.Opaque.domain) +} +inline std::string* TypeProto_Opaque::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000001u; + return domain_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TypeProto_Opaque::release_domain() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Opaque.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return domain_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TypeProto_Opaque::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + domain_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Opaque.domain) +} +inline std::string* TypeProto_Opaque::unsafe_arena_release_domain() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.Opaque.domain) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return domain_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TypeProto_Opaque::unsafe_arena_set_allocated_domain( + std::string* domain) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (domain != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + domain_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + domain, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.Opaque.domain) +} + +// optional string name = 2; +inline bool TypeProto_Opaque::_internal_has_name() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_Opaque::has_name() const { + return _internal_has_name(); +} +inline void TypeProto_Opaque::clear_name() { + name_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000002u; +} +inline const std::string& TypeProto_Opaque::name() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.Opaque.name) + return _internal_name(); +} +inline void TypeProto_Opaque::set_name(const std::string& value) { + _internal_set_name(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.Opaque.name) +} +inline std::string* TypeProto_Opaque::mutable_name() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.Opaque.name) + return _internal_mutable_name(); +} +inline const std::string& TypeProto_Opaque::_internal_name() const { + return name_.Get(); +} +inline void TypeProto_Opaque::_internal_set_name(const std::string& value) { + _has_bits_[0] |= 0x00000002u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TypeProto_Opaque::set_name(std::string&& value) { + _has_bits_[0] |= 0x00000002u; + name_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TypeProto.Opaque.name) +} +inline void TypeProto_Opaque::set_name(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000002u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TypeProto.Opaque.name) +} +inline void TypeProto_Opaque::set_name(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000002u; + name_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TypeProto.Opaque.name) +} +inline std::string* TypeProto_Opaque::_internal_mutable_name() { + _has_bits_[0] |= 0x00000002u; + return name_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TypeProto_Opaque::release_name() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.Opaque.name) + if (!_internal_has_name()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000002u; + return name_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TypeProto_Opaque::set_allocated_name(std::string* name) { + if (name != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + name_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), name, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.Opaque.name) +} +inline std::string* TypeProto_Opaque::unsafe_arena_release_name() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.Opaque.name) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000002u; + return name_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TypeProto_Opaque::unsafe_arena_set_allocated_name( + std::string* name) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (name != nullptr) { + _has_bits_[0] |= 0x00000002u; + } else { + _has_bits_[0] &= ~0x00000002u; + } + name_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + name, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.Opaque.name) +} + +// ------------------------------------------------------------------- + +// TypeProto_SparseTensor + +// optional .onnx.TensorProto.DataType elem_type = 1; +inline bool TypeProto_SparseTensor::_internal_has_elem_type() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool TypeProto_SparseTensor::has_elem_type() const { + return _internal_has_elem_type(); +} +inline void TypeProto_SparseTensor::clear_elem_type() { + elem_type_ = 0; + _has_bits_[0] &= ~0x00000002u; +} +inline ::onnx::TensorProto_DataType TypeProto_SparseTensor::_internal_elem_type() const { + return static_cast< ::onnx::TensorProto_DataType >(elem_type_); +} +inline ::onnx::TensorProto_DataType TypeProto_SparseTensor::elem_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.SparseTensor.elem_type) + return _internal_elem_type(); +} +inline void TypeProto_SparseTensor::_internal_set_elem_type(::onnx::TensorProto_DataType value) { + assert(::onnx::TensorProto_DataType_IsValid(value)); + _has_bits_[0] |= 0x00000002u; + elem_type_ = value; +} +inline void TypeProto_SparseTensor::set_elem_type(::onnx::TensorProto_DataType value) { + _internal_set_elem_type(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.SparseTensor.elem_type) +} + +// optional .onnx.TensorShapeProto shape = 2; +inline bool TypeProto_SparseTensor::_internal_has_shape() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + PROTOBUF_ASSUME(!value || shape_ != nullptr); + return value; +} +inline bool TypeProto_SparseTensor::has_shape() const { + return _internal_has_shape(); +} +inline void TypeProto_SparseTensor::clear_shape() { + if (shape_ != nullptr) shape_->Clear(); + _has_bits_[0] &= ~0x00000001u; +} +inline const ::onnx::TensorShapeProto& TypeProto_SparseTensor::_internal_shape() const { + const ::onnx::TensorShapeProto* p = shape_; + return p != nullptr ? *p : *reinterpret_cast( + &::onnx::_TensorShapeProto_default_instance_); +} +inline const ::onnx::TensorShapeProto& TypeProto_SparseTensor::shape() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.SparseTensor.shape) + return _internal_shape(); +} +inline void TypeProto_SparseTensor::unsafe_arena_set_allocated_shape( + ::onnx::TensorShapeProto* shape) { + if (GetArena() == nullptr) { + delete reinterpret_cast<::PROTOBUF_NAMESPACE_ID::MessageLite*>(shape_); + } + shape_ = shape; + if (shape) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.SparseTensor.shape) +} +inline ::onnx::TensorShapeProto* TypeProto_SparseTensor::release_shape() { + auto temp = unsafe_arena_release_shape(); + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + return temp; +} +inline ::onnx::TensorShapeProto* TypeProto_SparseTensor::unsafe_arena_release_shape() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.SparseTensor.shape) + _has_bits_[0] &= ~0x00000001u; + ::onnx::TensorShapeProto* temp = shape_; + shape_ = nullptr; + return temp; +} +inline ::onnx::TensorShapeProto* TypeProto_SparseTensor::_internal_mutable_shape() { + _has_bits_[0] |= 0x00000001u; + if (shape_ == nullptr) { + auto* p = CreateMaybeMessage<::onnx::TensorShapeProto>(GetArena()); + shape_ = p; + } + return shape_; +} +inline ::onnx::TensorShapeProto* TypeProto_SparseTensor::mutable_shape() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.SparseTensor.shape) + return _internal_mutable_shape(); +} +inline void TypeProto_SparseTensor::set_allocated_shape(::onnx::TensorShapeProto* shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* message_arena = GetArena(); + if (message_arena == nullptr) { + delete shape_; + } + if (shape) { + ::PROTOBUF_NAMESPACE_ID::Arena* submessage_arena = + ::PROTOBUF_NAMESPACE_ID::Arena::GetArena(shape); + if (message_arena != submessage_arena) { + shape = ::PROTOBUF_NAMESPACE_ID::internal::GetOwnedMessage( + message_arena, shape, submessage_arena); + } + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + shape_ = shape; + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.SparseTensor.shape) +} + +// ------------------------------------------------------------------- + +// TypeProto + +// .onnx.TypeProto.Tensor tensor_type = 1; +inline bool TypeProto::_internal_has_tensor_type() const { + return value_case() == kTensorType; +} +inline bool TypeProto::has_tensor_type() const { + return _internal_has_tensor_type(); +} +inline void TypeProto::set_has_tensor_type() { + _oneof_case_[0] = kTensorType; +} +inline void TypeProto::clear_tensor_type() { + if (_internal_has_tensor_type()) { + if (GetArena() == nullptr) { + delete value_.tensor_type_; + } + clear_has_value(); + } +} +inline ::onnx::TypeProto_Tensor* TypeProto::release_tensor_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.tensor_type) + if (_internal_has_tensor_type()) { + clear_has_value(); + ::onnx::TypeProto_Tensor* temp = value_.tensor_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Tensor& TypeProto::_internal_tensor_type() const { + return _internal_has_tensor_type() + ? *value_.tensor_type_ + : *reinterpret_cast< ::onnx::TypeProto_Tensor*>(&::onnx::_TypeProto_Tensor_default_instance_); +} +inline const ::onnx::TypeProto_Tensor& TypeProto::tensor_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.tensor_type) + return _internal_tensor_type(); +} +inline ::onnx::TypeProto_Tensor* TypeProto::unsafe_arena_release_tensor_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.tensor_type) + if (_internal_has_tensor_type()) { + clear_has_value(); + ::onnx::TypeProto_Tensor* temp = value_.tensor_type_; + value_.tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_tensor_type(::onnx::TypeProto_Tensor* tensor_type) { + clear_value(); + if (tensor_type) { + set_has_tensor_type(); + value_.tensor_type_ = tensor_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.tensor_type) +} +inline ::onnx::TypeProto_Tensor* TypeProto::_internal_mutable_tensor_type() { + if (!_internal_has_tensor_type()) { + clear_value(); + set_has_tensor_type(); + value_.tensor_type_ = CreateMaybeMessage< ::onnx::TypeProto_Tensor >(GetArena()); + } + return value_.tensor_type_; +} +inline ::onnx::TypeProto_Tensor* TypeProto::mutable_tensor_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.tensor_type) + return _internal_mutable_tensor_type(); +} + +// .onnx.TypeProto.Sequence sequence_type = 4; +inline bool TypeProto::_internal_has_sequence_type() const { + return value_case() == kSequenceType; +} +inline bool TypeProto::has_sequence_type() const { + return _internal_has_sequence_type(); +} +inline void TypeProto::set_has_sequence_type() { + _oneof_case_[0] = kSequenceType; +} +inline void TypeProto::clear_sequence_type() { + if (_internal_has_sequence_type()) { + if (GetArena() == nullptr) { + delete value_.sequence_type_; + } + clear_has_value(); + } +} +inline ::onnx::TypeProto_Sequence* TypeProto::release_sequence_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.sequence_type) + if (_internal_has_sequence_type()) { + clear_has_value(); + ::onnx::TypeProto_Sequence* temp = value_.sequence_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.sequence_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Sequence& TypeProto::_internal_sequence_type() const { + return _internal_has_sequence_type() + ? *value_.sequence_type_ + : *reinterpret_cast< ::onnx::TypeProto_Sequence*>(&::onnx::_TypeProto_Sequence_default_instance_); +} +inline const ::onnx::TypeProto_Sequence& TypeProto::sequence_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.sequence_type) + return _internal_sequence_type(); +} +inline ::onnx::TypeProto_Sequence* TypeProto::unsafe_arena_release_sequence_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.sequence_type) + if (_internal_has_sequence_type()) { + clear_has_value(); + ::onnx::TypeProto_Sequence* temp = value_.sequence_type_; + value_.sequence_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_sequence_type(::onnx::TypeProto_Sequence* sequence_type) { + clear_value(); + if (sequence_type) { + set_has_sequence_type(); + value_.sequence_type_ = sequence_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.sequence_type) +} +inline ::onnx::TypeProto_Sequence* TypeProto::_internal_mutable_sequence_type() { + if (!_internal_has_sequence_type()) { + clear_value(); + set_has_sequence_type(); + value_.sequence_type_ = CreateMaybeMessage< ::onnx::TypeProto_Sequence >(GetArena()); + } + return value_.sequence_type_; +} +inline ::onnx::TypeProto_Sequence* TypeProto::mutable_sequence_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.sequence_type) + return _internal_mutable_sequence_type(); +} + +// .onnx.TypeProto.Map map_type = 5; +inline bool TypeProto::_internal_has_map_type() const { + return value_case() == kMapType; +} +inline bool TypeProto::has_map_type() const { + return _internal_has_map_type(); +} +inline void TypeProto::set_has_map_type() { + _oneof_case_[0] = kMapType; +} +inline void TypeProto::clear_map_type() { + if (_internal_has_map_type()) { + if (GetArena() == nullptr) { + delete value_.map_type_; + } + clear_has_value(); + } +} +inline ::onnx::TypeProto_Map* TypeProto::release_map_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.map_type) + if (_internal_has_map_type()) { + clear_has_value(); + ::onnx::TypeProto_Map* temp = value_.map_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.map_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Map& TypeProto::_internal_map_type() const { + return _internal_has_map_type() + ? *value_.map_type_ + : *reinterpret_cast< ::onnx::TypeProto_Map*>(&::onnx::_TypeProto_Map_default_instance_); +} +inline const ::onnx::TypeProto_Map& TypeProto::map_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.map_type) + return _internal_map_type(); +} +inline ::onnx::TypeProto_Map* TypeProto::unsafe_arena_release_map_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.map_type) + if (_internal_has_map_type()) { + clear_has_value(); + ::onnx::TypeProto_Map* temp = value_.map_type_; + value_.map_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_map_type(::onnx::TypeProto_Map* map_type) { + clear_value(); + if (map_type) { + set_has_map_type(); + value_.map_type_ = map_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.map_type) +} +inline ::onnx::TypeProto_Map* TypeProto::_internal_mutable_map_type() { + if (!_internal_has_map_type()) { + clear_value(); + set_has_map_type(); + value_.map_type_ = CreateMaybeMessage< ::onnx::TypeProto_Map >(GetArena()); + } + return value_.map_type_; +} +inline ::onnx::TypeProto_Map* TypeProto::mutable_map_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.map_type) + return _internal_mutable_map_type(); +} + +// .onnx.TypeProto.Opaque opaque_type = 7; +inline bool TypeProto::_internal_has_opaque_type() const { + return value_case() == kOpaqueType; +} +inline bool TypeProto::has_opaque_type() const { + return _internal_has_opaque_type(); +} +inline void TypeProto::set_has_opaque_type() { + _oneof_case_[0] = kOpaqueType; +} +inline void TypeProto::clear_opaque_type() { + if (_internal_has_opaque_type()) { + if (GetArena() == nullptr) { + delete value_.opaque_type_; + } + clear_has_value(); + } +} +inline ::onnx::TypeProto_Opaque* TypeProto::release_opaque_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.opaque_type) + if (_internal_has_opaque_type()) { + clear_has_value(); + ::onnx::TypeProto_Opaque* temp = value_.opaque_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.opaque_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_Opaque& TypeProto::_internal_opaque_type() const { + return _internal_has_opaque_type() + ? *value_.opaque_type_ + : *reinterpret_cast< ::onnx::TypeProto_Opaque*>(&::onnx::_TypeProto_Opaque_default_instance_); +} +inline const ::onnx::TypeProto_Opaque& TypeProto::opaque_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.opaque_type) + return _internal_opaque_type(); +} +inline ::onnx::TypeProto_Opaque* TypeProto::unsafe_arena_release_opaque_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.opaque_type) + if (_internal_has_opaque_type()) { + clear_has_value(); + ::onnx::TypeProto_Opaque* temp = value_.opaque_type_; + value_.opaque_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_opaque_type(::onnx::TypeProto_Opaque* opaque_type) { + clear_value(); + if (opaque_type) { + set_has_opaque_type(); + value_.opaque_type_ = opaque_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.opaque_type) +} +inline ::onnx::TypeProto_Opaque* TypeProto::_internal_mutable_opaque_type() { + if (!_internal_has_opaque_type()) { + clear_value(); + set_has_opaque_type(); + value_.opaque_type_ = CreateMaybeMessage< ::onnx::TypeProto_Opaque >(GetArena()); + } + return value_.opaque_type_; +} +inline ::onnx::TypeProto_Opaque* TypeProto::mutable_opaque_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.opaque_type) + return _internal_mutable_opaque_type(); +} + +// .onnx.TypeProto.SparseTensor sparse_tensor_type = 8; +inline bool TypeProto::_internal_has_sparse_tensor_type() const { + return value_case() == kSparseTensorType; +} +inline bool TypeProto::has_sparse_tensor_type() const { + return _internal_has_sparse_tensor_type(); +} +inline void TypeProto::set_has_sparse_tensor_type() { + _oneof_case_[0] = kSparseTensorType; +} +inline void TypeProto::clear_sparse_tensor_type() { + if (_internal_has_sparse_tensor_type()) { + if (GetArena() == nullptr) { + delete value_.sparse_tensor_type_; + } + clear_has_value(); + } +} +inline ::onnx::TypeProto_SparseTensor* TypeProto::release_sparse_tensor_type() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.sparse_tensor_type) + if (_internal_has_sparse_tensor_type()) { + clear_has_value(); + ::onnx::TypeProto_SparseTensor* temp = value_.sparse_tensor_type_; + if (GetArena() != nullptr) { + temp = ::PROTOBUF_NAMESPACE_ID::internal::DuplicateIfNonNull(temp); + } + value_.sparse_tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline const ::onnx::TypeProto_SparseTensor& TypeProto::_internal_sparse_tensor_type() const { + return _internal_has_sparse_tensor_type() + ? *value_.sparse_tensor_type_ + : *reinterpret_cast< ::onnx::TypeProto_SparseTensor*>(&::onnx::_TypeProto_SparseTensor_default_instance_); +} +inline const ::onnx::TypeProto_SparseTensor& TypeProto::sparse_tensor_type() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.sparse_tensor_type) + return _internal_sparse_tensor_type(); +} +inline ::onnx::TypeProto_SparseTensor* TypeProto::unsafe_arena_release_sparse_tensor_type() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.sparse_tensor_type) + if (_internal_has_sparse_tensor_type()) { + clear_has_value(); + ::onnx::TypeProto_SparseTensor* temp = value_.sparse_tensor_type_; + value_.sparse_tensor_type_ = nullptr; + return temp; + } else { + return nullptr; + } +} +inline void TypeProto::unsafe_arena_set_allocated_sparse_tensor_type(::onnx::TypeProto_SparseTensor* sparse_tensor_type) { + clear_value(); + if (sparse_tensor_type) { + set_has_sparse_tensor_type(); + value_.sparse_tensor_type_ = sparse_tensor_type; + } + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.sparse_tensor_type) +} +inline ::onnx::TypeProto_SparseTensor* TypeProto::_internal_mutable_sparse_tensor_type() { + if (!_internal_has_sparse_tensor_type()) { + clear_value(); + set_has_sparse_tensor_type(); + value_.sparse_tensor_type_ = CreateMaybeMessage< ::onnx::TypeProto_SparseTensor >(GetArena()); + } + return value_.sparse_tensor_type_; +} +inline ::onnx::TypeProto_SparseTensor* TypeProto::mutable_sparse_tensor_type() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.sparse_tensor_type) + return _internal_mutable_sparse_tensor_type(); +} + +// optional string denotation = 6; +inline bool TypeProto::_internal_has_denotation() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool TypeProto::has_denotation() const { + return _internal_has_denotation(); +} +inline void TypeProto::clear_denotation() { + denotation_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& TypeProto::denotation() const { + // @@protoc_insertion_point(field_get:onnx.TypeProto.denotation) + return _internal_denotation(); +} +inline void TypeProto::set_denotation(const std::string& value) { + _internal_set_denotation(value); + // @@protoc_insertion_point(field_set:onnx.TypeProto.denotation) +} +inline std::string* TypeProto::mutable_denotation() { + // @@protoc_insertion_point(field_mutable:onnx.TypeProto.denotation) + return _internal_mutable_denotation(); +} +inline const std::string& TypeProto::_internal_denotation() const { + return denotation_.Get(); +} +inline void TypeProto::_internal_set_denotation(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void TypeProto::set_denotation(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + denotation_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.TypeProto.denotation) +} +inline void TypeProto::set_denotation(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.TypeProto.denotation) +} +inline void TypeProto::set_denotation(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + denotation_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.TypeProto.denotation) +} +inline std::string* TypeProto::_internal_mutable_denotation() { + _has_bits_[0] |= 0x00000001u; + return denotation_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* TypeProto::release_denotation() { + // @@protoc_insertion_point(field_release:onnx.TypeProto.denotation) + if (!_internal_has_denotation()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return denotation_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void TypeProto::set_allocated_denotation(std::string* denotation) { + if (denotation != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + denotation_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), denotation, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.TypeProto.denotation) +} +inline std::string* TypeProto::unsafe_arena_release_denotation() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.TypeProto.denotation) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return denotation_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void TypeProto::unsafe_arena_set_allocated_denotation( + std::string* denotation) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (denotation != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + denotation_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + denotation, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.TypeProto.denotation) +} + +inline bool TypeProto::has_value() const { + return value_case() != VALUE_NOT_SET; +} +inline void TypeProto::clear_has_value() { + _oneof_case_[0] = VALUE_NOT_SET; +} +inline TypeProto::ValueCase TypeProto::value_case() const { + return TypeProto::ValueCase(_oneof_case_[0]); +} +// ------------------------------------------------------------------- + +// OperatorSetIdProto + +// optional string domain = 1; +inline bool OperatorSetIdProto::_internal_has_domain() const { + bool value = (_has_bits_[0] & 0x00000001u) != 0; + return value; +} +inline bool OperatorSetIdProto::has_domain() const { + return _internal_has_domain(); +} +inline void OperatorSetIdProto::clear_domain() { + domain_.ClearToEmpty(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); + _has_bits_[0] &= ~0x00000001u; +} +inline const std::string& OperatorSetIdProto::domain() const { + // @@protoc_insertion_point(field_get:onnx.OperatorSetIdProto.domain) + return _internal_domain(); +} +inline void OperatorSetIdProto::set_domain(const std::string& value) { + _internal_set_domain(value); + // @@protoc_insertion_point(field_set:onnx.OperatorSetIdProto.domain) +} +inline std::string* OperatorSetIdProto::mutable_domain() { + // @@protoc_insertion_point(field_mutable:onnx.OperatorSetIdProto.domain) + return _internal_mutable_domain(); +} +inline const std::string& OperatorSetIdProto::_internal_domain() const { + return domain_.Get(); +} +inline void OperatorSetIdProto::_internal_set_domain(const std::string& value) { + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value, GetArena()); +} +inline void OperatorSetIdProto::set_domain(std::string&& value) { + _has_bits_[0] |= 0x00000001u; + domain_.Set( + &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value), GetArena()); + // @@protoc_insertion_point(field_set_rvalue:onnx.OperatorSetIdProto.domain) +} +inline void OperatorSetIdProto::set_domain(const char* value) { + GOOGLE_DCHECK(value != nullptr); + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value), + GetArena()); + // @@protoc_insertion_point(field_set_char:onnx.OperatorSetIdProto.domain) +} +inline void OperatorSetIdProto::set_domain(const char* value, + size_t size) { + _has_bits_[0] |= 0x00000001u; + domain_.Set(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string( + reinterpret_cast(value), size), GetArena()); + // @@protoc_insertion_point(field_set_pointer:onnx.OperatorSetIdProto.domain) +} +inline std::string* OperatorSetIdProto::_internal_mutable_domain() { + _has_bits_[0] |= 0x00000001u; + return domain_.Mutable(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline std::string* OperatorSetIdProto::release_domain() { + // @@protoc_insertion_point(field_release:onnx.OperatorSetIdProto.domain) + if (!_internal_has_domain()) { + return nullptr; + } + _has_bits_[0] &= ~0x00000001u; + return domain_.ReleaseNonDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena()); +} +inline void OperatorSetIdProto::set_allocated_domain(std::string* domain) { + if (domain != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + domain_.SetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), domain, + GetArena()); + // @@protoc_insertion_point(field_set_allocated:onnx.OperatorSetIdProto.domain) +} +inline std::string* OperatorSetIdProto::unsafe_arena_release_domain() { + // @@protoc_insertion_point(field_unsafe_arena_release:onnx.OperatorSetIdProto.domain) + GOOGLE_DCHECK(GetArena() != nullptr); + _has_bits_[0] &= ~0x00000001u; + return domain_.UnsafeArenaRelease(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + GetArena()); +} +inline void OperatorSetIdProto::unsafe_arena_set_allocated_domain( + std::string* domain) { + GOOGLE_DCHECK(GetArena() != nullptr); + if (domain != nullptr) { + _has_bits_[0] |= 0x00000001u; + } else { + _has_bits_[0] &= ~0x00000001u; + } + domain_.UnsafeArenaSetAllocated(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), + domain, GetArena()); + // @@protoc_insertion_point(field_unsafe_arena_set_allocated:onnx.OperatorSetIdProto.domain) +} + +// optional int64 version = 2; +inline bool OperatorSetIdProto::_internal_has_version() const { + bool value = (_has_bits_[0] & 0x00000002u) != 0; + return value; +} +inline bool OperatorSetIdProto::has_version() const { + return _internal_has_version(); +} +inline void OperatorSetIdProto::clear_version() { + version_ = PROTOBUF_LONGLONG(0); + _has_bits_[0] &= ~0x00000002u; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 OperatorSetIdProto::_internal_version() const { + return version_; +} +inline ::PROTOBUF_NAMESPACE_ID::int64 OperatorSetIdProto::version() const { + // @@protoc_insertion_point(field_get:onnx.OperatorSetIdProto.version) + return _internal_version(); +} +inline void OperatorSetIdProto::_internal_set_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _has_bits_[0] |= 0x00000002u; + version_ = value; +} +inline void OperatorSetIdProto::set_version(::PROTOBUF_NAMESPACE_ID::int64 value) { + _internal_set_version(value); + // @@protoc_insertion_point(field_set:onnx.OperatorSetIdProto.version) +} + +#ifdef __GNUC__ + #pragma GCC diagnostic pop +#endif // __GNUC__ +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + + +// @@protoc_insertion_point(namespace_scope) + +} // namespace onnx + +PROTOBUF_NAMESPACE_OPEN + +template <> struct is_proto_enum< ::onnx::AttributeProto_AttributeType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::AttributeProto_AttributeType>() { + return ::onnx::AttributeProto_AttributeType_descriptor(); +} +template <> struct is_proto_enum< ::onnx::TensorProto_DataType> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::TensorProto_DataType>() { + return ::onnx::TensorProto_DataType_descriptor(); +} +template <> struct is_proto_enum< ::onnx::Version> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::onnx::Version>() { + return ::onnx::Version_descriptor(); +} + +PROTOBUF_NAMESPACE_CLOSE + +// @@protoc_insertion_point(global_scope) + +#include +#endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_onnx_2dml_2eproto diff --git a/src/3rd_party/onnx/protobuf/onnx-ml.proto b/src/3rd_party/onnx/protobuf/onnx-ml.proto new file mode 100644 index 000000000..8fe09da73 --- /dev/null +++ b/src/3rd_party/onnx/protobuf/onnx-ml.proto @@ -0,0 +1,506 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) Facebook Inc. and Microsoft Corporation. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. +// Definitions of the built-in classical machine learning operators may be found in +// docs/Operators-ml.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION = 0x0000000000000003; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // also appears in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + optional DataType data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional TensorProto.DataType key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + + message Opaque { + // When missing, the domain is the same as the model's. + optional string domain = 1; + // The name is optional but significant when provided. + optional string name = 2; + // parameters that help defining the type + // DEPRECATED do not use. + // repeated TypeProto parameters = 3; + } + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + Opaque opaque_type = 7; + + SparseTensor sparse_tensor_type = 8; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +} diff --git a/src/3rd_party/phf/phf.cc b/src/3rd_party/phf/phf.cc index d00d5c5e2..7855565dc 100644 --- a/src/3rd_party/phf/phf.cc +++ b/src/3rd_party/phf/phf.cc @@ -249,8 +249,8 @@ namespace PHF { template<> int cmp(const phf_string_t *a, const phf_string_t *b) { - int cmp; - if ((cmp = memcmp(a->p, b->p, PHF_MIN(a->n, b->n)))) + int cmp = memcmp(a->p, b->p, PHF_MIN(a->n, b->n)); + if (cmp) return cmp; if (a->n > b->n) return -1; @@ -415,8 +415,8 @@ static inline uint32_t phf_f(uint32_t d, T k, uint32_t seed) { static inline uint32_t phf_g(uint64_t k, uint32_t seed) { uint32_t h1 = seed; - h1 = phf_round32(k, h1); - h1 = phf_round32(k >> 32, h1); + h1 = phf_round32(static_cast(k), h1); + h1 = phf_round32(static_cast(k >> 32), h1); return phf_mix32(h1); } /* phf_g() */ @@ -519,7 +519,8 @@ PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const uint32_t d_max = 0; /* maximum displacement value */ int error; - if ((phf->nodiv = nodiv)) { + phf->nodiv = nodiv; + if (phf->nodiv) { /* round to power-of-2 so we can use bit masks instead of modulo division */ r = phf_powerup(n1 / PHF_MIN(l1, n1)); m = phf_powerup((n1 * 100) / a1); @@ -532,24 +533,28 @@ PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const if (r == 0 || m == 0) return ERANGE; - if (!(B_k = static_cast *>(calloc(n1, sizeof *B_k)))) + B_k = static_cast *>(calloc(n1, sizeof *B_k)); + if (!B_k) goto syerr; - if (!(B_z = static_cast(calloc(r, sizeof *B_z)))) + + B_z = static_cast(calloc(r, sizeof *B_z)); + if (!B_z) goto syerr; for (size_t i = 0; i < n; i++) { - phf_hash_t g = phf_g_mod_r(k[i], seed, r); + phf_hash_t gt = phf_g_mod_r(k[i], seed, r); B_k[i].k = k[i]; - B_k[i].g = g; - B_k[i].n = &B_z[g]; + B_k[i].g = gt; + B_k[i].n = &B_z[gt]; ++*B_k[i].n; } qsort(B_k, n1, sizeof *B_k, reinterpret_cast(&phf_keycmp)); T_n = PHF_HOWMANY(m, PHF_BITS(*T)); - if (!(T = static_cast(calloc(T_n * 2, sizeof *T)))) + T = static_cast(calloc(T_n * 2, sizeof *T)); + if (!T) goto syerr; T_b = &T[T_n]; /* share single allocation */ @@ -563,7 +568,8 @@ PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const * end of the outer loop. */ - if (!(g = static_cast(calloc(r, sizeof *g)))) + g = static_cast(calloc(r, sizeof *g)); + if (!g) goto syerr; B_p = B_k; @@ -579,12 +585,12 @@ PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const Bi_pe = B_p + *B_p->n; for (; Bi_p < Bi_pe; Bi_p++) { - f = phf_f_mod_m(d, Bi_p->k, seed, m); + f = phf_f_mod_m((uint32_t)d, Bi_p->k, (uint32_t)seed, m); if (phf_isset(T, f) || phf_isset(T_b, f)) { /* reset T_b[] */ for (Bi_p = B_p; Bi_p < Bi_pe; Bi_p++) { - f = phf_f_mod_m(d, Bi_p->k, seed, m); + f = phf_f_mod_m((uint32_t)d, Bi_p->k, (uint32_t)seed, m); phf_clrbit(T_b, f); } @@ -596,13 +602,13 @@ PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const /* commit to T[] */ for (Bi_p = B_p; Bi_p < Bi_pe; Bi_p++) { - f = phf_f_mod_m(d, Bi_p->k, seed, m); + f = phf_f_mod_m((uint32_t)d, Bi_p->k, (uint32_t)seed, m); phf_setbit(T, f); } /* commit to g[] */ - g[B_p->g] = d; - d_max = PHF_MAX(d, d_max); + g[B_p->g] = (uint32_t)d; + d_max = PHF_MAX((uint32_t)d, d_max); } phf->seed = seed; @@ -643,7 +649,7 @@ PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const template static inline void phf_memmove(dst_t *dst, src_t *src, size_t n) { for (size_t i = 0; i < n; i++) { - dst_t tmp = src[i]; + dst_t tmp = (dst_t)src[i]; dst[i] = tmp; } } /* phf_memmove() */ @@ -673,7 +679,8 @@ PHF_PUBLIC void PHF::compact(struct phf *phf) { } /* simply keep old array if realloc fails */ - if ((tmp = realloc(phf->g, phf->r * size))) + tmp = realloc(phf->g, phf->r * size); + if (tmp != 0) phf->g = static_cast(tmp); } /* PHF::compact() */ @@ -750,7 +757,7 @@ PHF_PUBLIC phf_hash_t PHF::hash(struct phf *phf, T k) { return phf_hash_(reinterpret_cast(phf->g), k, phf->seed, phf->r, phf->m); default: abort(); - return 0; + // return 0; } #endif } /* PHF::hash() */ diff --git a/src/3rd_party/zlib/CMakeLists.txt b/src/3rd_party/zlib/CMakeLists.txt index 503e0f1a0..63e0b08ea 100644 --- a/src/3rd_party/zlib/CMakeLists.txt +++ b/src/3rd_party/zlib/CMakeLists.txt @@ -2,11 +2,11 @@ file(GLOB ZLIB_SRC *.c) file(GLOB ZLIB_INC *.h) -# add sources of the wrapper as a "SQLiteCpp" static library +# add sources of the wrapper as a "zlib" static library add_library(zlib OBJECT ${ZLIB_SRC} ${ZLIB_INC}) if(MSVC) - target_compile_options(zlib PUBLIC /wd"4996" /wd"4267") + target_compile_options(zlib PUBLIC /wd4996 /wd4267) else() target_compile_options(zlib PUBLIC -Wno-implicit-function-declaration) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b78431d6a..34db923c8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,7 +20,7 @@ add_library(marian STATIC common/config_validator.cpp common/options.cpp common/binary.cpp - common/build_info.cpp + ${CMAKE_CURRENT_BINARY_DIR}/common/build_info.cpp common/io.cpp common/filesystem.cpp common/file_stream.cpp @@ -41,6 +41,7 @@ add_library(marian STATIC 3rd_party/cnpy/cnpy.cpp 3rd_party/ExceptionWithCallStack.cpp + 3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.cpp 3rd_party/phf/phf.cc @@ -49,6 +50,7 @@ add_library(marian STATIC tensors/tensor.cpp tensors/cpu/device.cpp tensors/cpu/prod.cpp + tensors/cpu/topk.cpp tensors/cpu/tensor_operators.cpp tensors/cpu/sharp/int_gemm.cpp @@ -62,6 +64,9 @@ add_library(marian STATIC graph/node_operators.cpp graph/node_initializers.cpp + onnx/expression_graph_onnx_exporter.cpp + onnx/expression_graph_onnx_serialization.cpp + layers/convolution.cpp layers/generic.cpp layers/loss.cpp @@ -101,6 +106,11 @@ add_library(marian STATIC $ $ ) + +if(BLAS_FOUND) + target_sources(marian PRIVATE ${CMAKE_CURRENT_LIST_DIR}/layers/lsh.cpp $) +endif() + target_compile_options(marian PUBLIC ${ALL_WARNINGS}) # Generate git_revision.h to reflect current git revision information @@ -108,21 +118,24 @@ target_compile_options(marian PUBLIC ${ALL_WARNINGS}) # Git updates .git/logs/HEAD file whenever you pull or commit something. # If Marian is checked out as a submodule in another repository, -# there's no .git directory in ${CMAKE_SOURCE_DIR}. Instead .git is a -# file that specifies the relative path from ${CMAKE_SOURCE_DIR} to -# ./git/modules/ in the root of the repository that -# contains Marian as a submodule. We set MARIAN_GIT_DIR to the appropriate -# path, depending on whether ${CMAKE_SOURCE_DIR}/.git is a directory or file. -if(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) # not a submodule - set(MARIAN_GIT_DIR ${CMAKE_SOURCE_DIR}/.git) -else(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) - file(READ ${CMAKE_SOURCE_DIR}/.git MARIAN_GIT_DIR) +# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is not a directory but a file +# that specifies the relative path from ${CMAKE_CURRENT_SOURCE_DIR}/.. +# to ./git/modules/ in the root of the check_out of +# the project that contains Marian as a submodule. +# +# We set MARIAN_GIT_DIR to the appropriate path, depending on whether +# ${CMAKE_CURRENT_SOURCE_DIR}/../.git is a directory or file. +set(MARIAN_GIT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../.git) +if(NOT IS_DIRECTORY ${MARIAN_GIT_DIR}) # i.e., it's a submodule + file(READ ${MARIAN_GIT_DIR} MARIAN_GIT_DIR) string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR}) - get_filename_component(MARIAN_GIT_DIR "${CMAKE_SOURCE_DIR}/${MARIAN_GIT_DIR}" ABSOLUTE) -endif(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) + if(NOT IS_ABSOLUTE ${MARIAN_GIT_DIR}) + get_filename_component(MARIAN_GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${MARIAN_GIT_DIR}" ABSOLUTE) + endif() +endif(NOT IS_DIRECTORY ${MARIAN_GIT_DIR}) add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD VERBATIM @@ -137,6 +150,7 @@ cuda_add_library(marian_cuda tensors/gpu/device.cu tensors/gpu/algorithm.cu tensors/gpu/prod.cpp + tensors/gpu/topk.cu tensors/gpu/element.cu tensors/gpu/add.cu tensors/gpu/add_all.cu @@ -209,7 +223,13 @@ endif(USE_STATIC_LIBS) if(COMPILE_SERVER) add_executable(marian_server command/marian_server.cpp) set_target_properties(marian_server PROPERTIES OUTPUT_NAME marian-server) - target_compile_options(marian_server PUBLIC ${ALL_WARNINGS}) + if(MSVC) + # Disable warnings from the SimpleWebSocketServer library needed for compilation of marian-server + target_compile_options(marian_server PUBLIC ${ALL_WARNINGS}) + else(MSVC) + # -Wno-suggest-override disables warnings from Boost 1.69+ + target_compile_options(marian_server PUBLIC ${ALL_WARNINGS} -Wno-suggest-override) + endif(MSVC) set(EXECUTABLES ${EXECUTABLES} marian_server) endif(COMPILE_SERVER) diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp index 97f83acf1..2e2c6c1fa 100644 --- a/src/command/marian_conv.cpp +++ b/src/command/marian_conv.cpp @@ -5,6 +5,7 @@ #include #include "tensors/cpu/fbgemm/expression_graph_packable.h" +#include "onnx/expression_graph_onnx_exporter.h" int main(int argc, char** argv) { using namespace marian; @@ -22,12 +23,17 @@ int main(int argc, char** argv) { " ./marian-conv -f model.npz -t model.bin --gemm-type packed16"); cli->add("--from,-f", "Input model", "model.npz"); cli->add("--to,-t", "Output model", "model.bin"); + cli->add("--export-as", "Kind of conversion: marian-bin or onnx-{encode,decoder-step,decoder-init,decoder-stop}", "marian-bin"); cli->add("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512", "float32"); + cli->add>("--vocabs,-V", "Vocabulary file, required for ONNX export"); cli->parse(argc, argv); options->merge(config); } auto modelFrom = options->get("from"); auto modelTo = options->get("to"); + + auto exportAs = options->get("export-as"); + auto vocabPaths = options->get>("vocabs");// , std::vector()); auto saveGemmTypeStr = options->get("gemm-type", "float32"); Type saveGemmType; @@ -43,21 +49,38 @@ int main(int argc, char** argv) { ABORT("Unknown gemm-type: {}", saveGemmTypeStr); } - LOG(info, "Outputting {}", modelTo); + LOG(info, "Outputting {}, precision: {}", modelTo, saveGemmType); YAML::Node config; std::stringstream configStr; marian::io::getYamlFromModel(config, "special:model.yml", modelFrom); configStr << config; - auto graph = New(); - graph->setDevice(CPU0); - graph->getBackend()->setOptimized(false); + auto load = [&](Ptr graph) { + graph->setDevice(CPU0); + graph->getBackend()->setOptimized(false); + + graph->load(modelFrom); + graph->forward(); // run the initializers + }; - graph->load(modelFrom); - graph->forward(); - // added a flag if the weights needs to be packed or not - graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32); + if (exportAs == "marian-bin") { + auto graph = New(); + load(graph); + // added a flag if the weights needs to be packed or not + graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32); + } +#ifdef USE_ONNX + else if (exportAs == "onnx-encode") { + auto graph = New(); + load(graph); + auto modelOptions = New(config)->with("vocabs", vocabPaths, "inference", true); + + graph->exportToONNX(modelTo, modelOptions, vocabPaths); + } +#endif // USE_ONNX + else + ABORT("Unknown --export-as value: {}", exportAs); // graph->saveBinary(vm["bin"].as()); diff --git a/src/command/marian_server.cpp b/src/command/marian_server.cpp index 2c3649407..467b564f4 100644 --- a/src/command/marian_server.cpp +++ b/src/command/marian_server.cpp @@ -14,6 +14,7 @@ int main(int argc, char **argv) { // Initialize translation task auto options = parseOptions(argc, argv, cli::mode::server, true); auto task = New>(options); + auto quiet = options->get("quiet-translation"); // Initialize web server WSServer server; @@ -21,8 +22,8 @@ int main(int argc, char **argv) { auto &translate = server.endpoint["^/translate/?$"]; - translate.on_message = [&task](Ptr connection, - Ptr message) { + translate.on_message = [&task, quiet](Ptr connection, + Ptr message) { // Get input text auto inputText = message->string(); auto sendStream = std::make_shared(); @@ -30,9 +31,9 @@ int main(int argc, char **argv) { // Translate timer::Timer timer; auto outputText = task->run(inputText); - LOG(info, "Best translation: {}", outputText); *sendStream << outputText << std::endl; - LOG(info, "Translation took: {:.5f}s", timer.elapsed()); + if(!quiet) + LOG(info, "Translation took: {:.5f}s", timer.elapsed()); // Send translation back connection->send(sendStream, [](const SimpleWeb::error_code &ec) { diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 2f56d8870..2338cb45a 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -646,6 +646,9 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { cli.add("--output-sampling", "Noise output layer with gumbel noise", false); + cli.add>("--output-approx-knn", + "Use approximate knn search in output layer (currently only in transformer)") + ->implicit_val("100 1024"); #if 0 // @TODO: Ask Hany if there are any decoding-time options // add ULR settings diff --git a/src/common/definitions.h b/src/common/definitions.h index 96eb6ed19..0bc14ba18 100755 --- a/src/common/definitions.h +++ b/src/common/definitions.h @@ -10,21 +10,6 @@ #include #include -// The macro MAYBE_UNUSED is used to selectively disable -// unused-variable warnings. C++17 defines the attribute -// [[maybe_unused]], but I don't think we're at C++17 yet. We can add it when we reach C++17. -// The compilers gcc and clang (and maybe others) define -// __has_attribute and support __attribute__(unused) in C++11, -#if defined __has_attribute -# if __has_attribute(unused) -# define MAYBE_UNUSED __attribute__((unused)) -# else -# define MAYBE_UNUSED -# endif -#else -# define MAYBE_UNUSED -#endif - #define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur. #define NodeOp(op) [=]() { op; } @@ -37,6 +22,61 @@ #define DONT_OPTIMIZE // silently ignore on Visual Studio, where this is less of a problem #endif +// Use these macros to enable faster floating-point math. Put them around one +// or more functions. +// +// Usage: +// MARIAN_FFAST_MATH_BEGIN +// void LayerNormalization(float *arg) { *arg += 1.0; } +// void SomethingElse() {} +// MARIAN_FFAST_MATH_END +// +// ffast-math allows the compiler to assume associative arithmetic and finite +// values. +// +// Associative arithmetic is particularly important to vectorize i.e. a sum: +// for (const float f : range) sum += f; +// Without ffast-math, the sum will be done one value at a time. On x86 it +// still uses vector math, but only uses the first slot and wastes the rest. +// +// With ffast-math, the compiler can sum in batches of 4, 8, or 16 floats. +// Also, it can run multiple adds in parallel e.g. vaddps has latency 4 and +// throughput 0.5 on Skylake so multiple vector adds can run at once. +// +// On average, a vectorized sum is more numerically stable because it sums in +// batches. Vectorized floats can still produce NaNs and infs (remember even +// scalar operations are implemented with vector instructions). +// +// Allowing the compiler to assume finite values means functions like isnan or +// isinf do not work as expected. Do not enable this for a function that +// depends upon fully standard float behavior. +// +// It can also change the sign of zeros. +// +// Fast math also makes results more architecture dependent because different +// register widths mean different results. They also depend on the compiler +// and compiler version more. For example, clang <= 10 does not support the +// float_control pragma below so it will still be conservative. +// +// There is a more conservative option for just associativity: +// llvm introduced "#pragma clang fp reassociate" that goes inside a function. +// However, llvm <11 considers that pragma an error so we'd need some ugly +// version test (which they don't recommend) or a compilation test. Moreover, +// it has to be in the function to keep scope. +// gcc supports "-fassociative-math" that has to be outside a function. +// I didn't find a MSVC equivalent. +#if defined(_MSC_VER) +#define MARIAN_FFAST_MATH_BEGIN __pragma(float_control(precise, off, push)) +#define MARIAN_FFAST_MATH_END __pragma(float_control(pop)) +#elif defined(__clang__) +#define MARIAN_FFAST_MATH_BEGIN _Pragma("float_control(precise, off, push)") +#define MARIAN_FFAST_MATH_END _Pragma("float_control(pop)") +#elif defined(__GNUC__) +// Also available as __attribute__((optimize("-ffast-math"))) but done as pragmas for consistency +#define MARIAN_FFAST_MATH_BEGIN _Pragma("GCC push_options") _Pragma("GCC optimize(\"-ffast-math\")") +#define MARIAN_FFAST_MATH_END _Pragma("GCC pop_options") +#endif + namespace marian { // Type to be used for all index types, e.g. for integer tensors for rows operator. diff --git a/src/common/filesystem.cpp b/src/common/filesystem.cpp index 1abaeae43..42e183e79 100644 --- a/src/common/filesystem.cpp +++ b/src/common/filesystem.cpp @@ -14,7 +14,9 @@ namespace filesystem { // Pretend that Windows knows no named pipes. It does, by the way, but // they seem to be different from pipes on Unix / Linux. See // https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes -bool is_fifo(char const*) { return false; } +bool is_fifo(char const* /*path*/) { + return false; +} #else bool is_fifo(char const* path) { struct stat buf; diff --git a/src/common/io.cpp b/src/common/io.cpp index decc4acac..4e5f10fa9 100644 --- a/src/common/io.cpp +++ b/src/common/io.cpp @@ -136,6 +136,10 @@ void saveItemsNpz(const std::string& fileName, const std::vector& items) { type = cnpy::map_type(typeid(double)); else if(item.type == Type::int8) type = cnpy::map_type(typeid(char)); + else if(item.type == Type::int32) + type = cnpy::map_type(typeid(int32_t)); + else if (item.type == Type::uint32) + type = cnpy::map_type(typeid(uint32_t)); else ABORT("Other types not supported yet"); diff --git a/src/common/logging.h b/src/common/logging.h index 6f292f619..19920abbf 100644 --- a/src/common/logging.h +++ b/src/common/logging.h @@ -120,6 +120,13 @@ namespace marian { } \ } while(0) +#define ABORT_UNLESS(condition, ...) \ + do { \ + if(!(bool)(condition)) { \ + ABORT(__VA_ARGS__); \ + } \ + } while(0) + typedef std::shared_ptr Logger; Logger createStderrLogger(const std::string&, const std::string&, diff --git a/src/common/options.cpp b/src/common/options.cpp index e39b55222..59e8420a4 100644 --- a/src/common/options.cpp +++ b/src/common/options.cpp @@ -2,9 +2,9 @@ namespace marian { -Options::Options() +Options::Options() #if FASTOPT - : fastOptions_(options_) + : fastOptions_(options_) #endif {} diff --git a/src/common/options.h b/src/common/options.h index d288f0a3d..08c6a3ca9 100755 --- a/src/common/options.h +++ b/src/common/options.h @@ -68,6 +68,10 @@ class Options { Options(const std::string& key, T value, Args&&... moreArgs) : Options() { set(key, value, std::forward(moreArgs)...); } + + Options(const YAML::Node& node) : Options() { + merge(node); + } // constructor that clones and zero or more updates // options->with("var1", val1, "var2", val2, ...) diff --git a/src/common/shape.h b/src/common/shape.h index fd86ef512..0b633e4d6 100644 --- a/src/common/shape.h +++ b/src/common/shape.h @@ -64,15 +64,13 @@ struct Shape { inline int& dim(int i) { if(i >= 0) { ABORT_IF(i >= (int)size(), - "Index {} is out of bounds, shape has {} dimension", - i, - size()); + "Index {} is out of bounds, shape {} has {} dimension", + i, std::string(*this), size()); return shape_[i]; } else { ABORT_IF((int)size() + i < 0, - "Negative index {} is out of bounds, shape has {} dimension", - i, - size()); + "Negative index {} is out of bounds, shape {} has {} dimension", + i, std::string(*this), size()); return shape_[size() + i]; } } diff --git a/src/common/timer.h b/src/common/timer.h index ac83b3634..d03c0cc75 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace marian { namespace timer { diff --git a/src/common/utils.cpp b/src/common/utils.cpp index aded13c5d..442682591 100755 --- a/src/common/utils.cpp +++ b/src/common/utils.cpp @@ -127,7 +127,7 @@ std::string join(const std::vector& words, const std::string& del / std::string join(const std::vector& nums, const std::string& del /*= " "*/) { std::vector words(nums.size()); - std::transform(nums.begin(), nums.end(), words.begin(), [](int i) { return std::to_string(i); }); + std::transform(nums.begin(), nums.end(), words.begin(), [](size_t i) { return std::to_string(i); }); return join(words, del); } @@ -414,7 +414,8 @@ double parseNumber(std::string param) { } // we allow users to place commas in numbers (note: we are not actually verifying that they are in // the right place) - std::remove_if(param.begin(), param.end(), [](char c) { return c == ','; }); + auto it = std::remove_if(param.begin(), param.end(), [](char c) { return c == ','; }); // use return value for future-proofing against nodiscard warning + param.erase(it, param.end()); // since we have that iterator now, we might as well shrink to fit return factor * parseDouble(param); } diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h index 1a26baa04..7801073f4 100644 --- a/src/data/batch_generator.h +++ b/src/data/batch_generator.h @@ -272,7 +272,7 @@ class BatchGenerator : public RNGEngine { Ptr options, Ptr stats = nullptr) : data_(data), options_(options), stats_(stats), threadPool_(1) { - auto shuffle = options_->get("shuffle"); + auto shuffle = options_->get("shuffle", "none"); shuffleData_ = shuffle == "data"; shuffleBatches_ = shuffleData_ || shuffle == "batches"; } diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h old mode 100755 new mode 100644 index 5efd2211b..ba56929ec --- a/src/data/corpus_base.h +++ b/src/data/corpus_base.h @@ -299,7 +299,7 @@ class CorpusBatch : public Batch { size_t wordsTrg() const override { return subBatches_.back()->batchWords(); }; /** - * @brief The width of the target mini-batch. Num words + padded? + * @brief The target width (=max length) of the mini-batch. */ size_t widthTrg() const override { return subBatches_.back()->batchWidth(); }; diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 177272df8..78cbccea2 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -3,9 +3,12 @@ #include "common/config.h" #include "common/definitions.h" #include "common/file_stream.h" +#include "data/corpus_base.h" +#include "data/types.h" #include #include +#include #include #include #include diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp index 5d68d9fb4..7c9df149e 100644 --- a/src/data/text_input.cpp +++ b/src/data/text_input.cpp @@ -23,8 +23,11 @@ const SentenceTuple& TextIterator::dereference() const { TextInput::TextInput(std::vector inputs, std::vector> vocabs, Ptr options) - : DatasetBase(inputs, options), vocabs_(vocabs) { - // note: inputs are automatically stored in the inherited variable named paths_, but these are + : DatasetBase(inputs, options), + vocabs_(vocabs), + maxLength_(options_->get("max-length")), + maxLengthCrop_(options_->get("max-length-crop")) { + // Note: inputs are automatically stored in the inherited variable named paths_, but these are // texts not paths! for(const auto& text : paths_) files_.emplace_back(new std::istringstream(text)); @@ -42,6 +45,10 @@ SentenceTuple TextInput::next() { std::string line; if(io::getline(*files_[i], line)) { Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_); + if(this->maxLengthCrop_ && words.size() > this->maxLength_) { + words.resize(maxLength_); + words.back() = vocabs_.back()->getEosId(); // note: this will not work with class-labels + } if(words.empty()) words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right? tup.push_back(words); diff --git a/src/data/text_input.h b/src/data/text_input.h index db99ef6ae..b08a4fdcc 100644 --- a/src/data/text_input.h +++ b/src/data/text_input.h @@ -33,6 +33,9 @@ class TextInput : public DatasetBase { size_t pos_{0}; + size_t maxLength_{0}; + bool maxLengthCrop_{false}; + public: typedef SentenceTuple Sample; diff --git a/src/functional/operators.h b/src/functional/operators.h index be6bb2e5c..25982009d 100755 --- a/src/functional/operators.h +++ b/src/functional/operators.h @@ -20,6 +20,7 @@ struct Ops { static HOST_DEVICE_INLINE T log(const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T exp(const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T abs(const T&) { ABORT("Unknown type"); } + static HOST_DEVICE_INLINE T sqr(const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T sqrt(const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T neg(const T&) { ABORT("Unknown type"); } static HOST_DEVICE_INLINE T sgn(const T&) { ABORT("Unknown type"); } @@ -73,6 +74,7 @@ struct Ops { static HOST_DEVICE_INLINE float log(const float& x) { return logf(x); } static HOST_DEVICE_INLINE float exp(const float& x) { return expf(x); } static HOST_DEVICE_INLINE float abs(const float& x) { return fabs(x); } + static HOST_DEVICE_INLINE float sqr(const float& x) { return x * x; } static HOST_DEVICE_INLINE float sqrt(const float& x) { return sqrtf(x); } static HOST_DEVICE_INLINE float neg(const float& x) { return -x; } static HOST_DEVICE_INLINE float sgn(const float& x) { return (float)((0 < x) - (x < 0)); } @@ -137,6 +139,7 @@ struct Ops { static HOST_DEVICE_INLINE double log(const double& x) { return std::log(x); } static HOST_DEVICE_INLINE double exp(const double& x) { return std::exp(x); } static HOST_DEVICE_INLINE double abs(const double& x) { return std::abs(x); } + static HOST_DEVICE_INLINE double sqr(const double& x) { return x * x; } static HOST_DEVICE_INLINE double sqrt(const double& x) { return std::sqrt(x); } static HOST_DEVICE_INLINE double neg(const double& x) { return -x; } static HOST_DEVICE_INLINE double sgn(const double& x) { return (0 < x) - (x < 0); } @@ -244,6 +247,7 @@ struct Ops { // @TODO: get rid of loop4 with proper intrisics static inline float32x4 abs(const float32x4& x) { return loop4(Ops::abs, x); } + static inline float32x4 sqr(const float32x4& x) { return _mm_mul_ps(x, x); } static inline float32x4 sqrt(const float32x4& x) { return _mm_sqrt_ps(x); } static inline float32x4 neg(const float32x4& x) { return sub(0.f, x); } @@ -369,6 +373,7 @@ struct Ops { // @TODO: get rid of loop8 with proper intrisics static inline float32x8 abs(const float32x8& x) { return loop8(Ops::abs, x); } + static inline float32x8 sqr(const float32x8& x) { return _mm256_mul_ps(x, x); } static inline float32x8 sqrt(const float32x8& x) { return _mm256_sqrt_ps(x); } static inline float32x8 neg(const float32x8& x) { return sub(0.f, x); } @@ -461,6 +466,7 @@ struct Ops { static DEVICE_INLINE half tan(const half& x) { return hsin(x) / hcos(x); } static DEVICE_INLINE half log(const half& x) { return hlog(x); } static DEVICE_INLINE half exp(const half& x) { return hexp(x); } + static DEVICE_INLINE half sqr(const half& x) { return x * x; } static DEVICE_INLINE half sqrt(const half& x) { return hsqrt(x); } static DEVICE_INLINE half neg(const half& x) { return -x; } @@ -567,6 +573,7 @@ UNARY(Tan, tan, Ops::tan(x)); UNARY(Log, log, Ops::log(x)); UNARY(Exp, exp, Ops::exp(x)); UNARY(Abs, abs, Ops::abs(x)); +UNARY(Sqr, sqr, Ops::sqr(x)); UNARY(Sqrt, sqrt, Ops::sqrt(x)); UNARY(Neg, operator-, Ops::neg(x)); UNARY(Sgn, sgn, Ops::sgn(x)); diff --git a/src/graph/chainable.h b/src/graph/chainable.h index b78eb485a..8d18c8c1a 100644 --- a/src/graph/chainable.h +++ b/src/graph/chainable.h @@ -63,7 +63,7 @@ class Chainable { virtual NodeOps forwardOps() = 0; virtual NodeOps backwardOps() = 0; - virtual size_t allocate() = 0; + virtual void allocate() = 0; virtual void free() = 0; virtual void init() = 0; virtual void init_dependent() {} diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp old mode 100755 new mode 100644 diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h old mode 100755 new mode 100644 index f7a781265..b4f0c1e29 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -65,9 +65,9 @@ class Tensors { void free(const Tensor& tensor) { tensors_->free(tensor); } - // @TODO: get rid of this, not really used or can be done better - Ptr allocator() { return tensors_->allocator(); } - + Ptr getAllocator() { return tensors_->allocator(); } + Ptr getTensorAllocator() { return tensors_; } + Expr findOrRemember(Expr node) { size_t hash = node->hash(); // memoize constant nodes that are not parameters @@ -115,14 +115,16 @@ typedef std::map> ElementTypeParamsMap; // keep it sorted, class ExpressionGraph : public std::enable_shared_from_this { size_t count_{0}; + std::unordered_set topNodes_; // current set of roots. In the end, all but one must have been consumed. + +protected: // (these are protected, not private, for ONNX exporting) std::list nodesForward_; std::list nodesBackward_; - std::unordered_set topNodes_; // current set of roots. In the end, all but one must have been consumed. - // Holds memory and expressions that correspond to temporary expressions. // This gets cleared before a new graph is built. Ptr tensors_; +private: std::unordered_map> memoized_; @@ -463,8 +465,11 @@ class ExpressionGraph : public std::enable_shared_from_this { tensors_->free(tensor); } - // @TODO: get rid of this, not really used or can be done better - Ptr allocator() { return tensors_->allocator(); } + // Returns the memory allocator of the graph workspace, allocates row unstructured memory (but 256-byte aligned) + Ptr allocator() { return tensors_->getAllocator(); } // @TODO: rename this to getAllocator(); + + // Returns the tensor allocator of the graph workspace, different from above as proper tensor objects are allocated + Ptr getTensorAllocator() { return tensors_->getTensorAllocator(); } void clear() { // clear everything apart from parameters and memoized nodes diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp old mode 100755 new mode 100644 index f858d7309..d51914074 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -4,6 +4,7 @@ #include "graph/node_operators.h" #include "graph/node_operators_binary.h" #include "graph/node_operators_unary.h" +#include "graph/node_operators_tuple.h" #include "graph/auto_tuner.h" #include "tensors/cpu/int16.h" @@ -25,6 +26,16 @@ Expr checkpoint(Expr a) { return a; } +Expr lambda(const std::vector& nodes, Shape shape, Type type, + LambdaNodeFunctor fwd) { + return Expression(nodes, shape, type, fwd); +} + +Expr lambda(const std::vector& nodes, Shape shape, Type type, + LambdaNodeFunctor fwd, LambdaNodeFunctor bwd) { + return Expression(nodes, shape, type, fwd, bwd); +} + // logistic function. Note: scipy name is expit() Expr sigmoid(Expr a) { return Expression(a); @@ -57,6 +68,10 @@ Expr exp(Expr a) { return Expression(a); }; +Expr sin(Expr a) { + return Expression(a); +}; + Expr swish(Expr a) { return Expression(a); } @@ -118,14 +133,54 @@ Expr logaddexp(Expr a, Expr b) { return Expression(a, b); } +Expr2 topk(Expr a, int k, int axis, bool descending) { + // only supports topk along last dimension, hence transpose if required + a = swapAxes(a, axis, -1); // non-op if axes are the same + auto topkVal = Expression(a, k, -1, descending); // axis=-1 is OK now as we swapped + auto topkIdx = std::dynamic_pointer_cast(topkVal)->tupleView(); // get a view on the top-k values + return std::make_tuple(swapAxes(topkVal, axis, -1), swapAxes(topkIdx, axis, -1)); // non-op if axes are the same +} + +Expr2 argmax(Expr a, int axis) { + return topk(a, 1, axis, /*descending=*/true); +} + +Expr2 argmin(Expr a, int axis) { + return topk(a, 1, axis, /*descending=*/false); +} + Expr maximum(Expr a, Expr b) { return Expression(a, b); } +// @TODO: implement version without constant +Expr maximum(float a, Expr b) { + auto aExpr = b->graph()->constant({}, inits::fromValue(a)); + return Expression(aExpr, b); +} + +Expr maximum(Expr a, float b) { + return maximum(b, a); +} + Expr minimum(Expr a, Expr b) { return Expression(a, b); } +// @TODO: implement version without constant +Expr minimum(float a, Expr b) { + auto aExpr = b->graph()->constant({}, inits::fromValue(a)); + return Expression(aExpr, b); +} + +Expr minimum(Expr a, float b) { + return minimum(b, a); +} + +Expr abs(Expr a) { + return Expression(a); +} + Expr lt(Expr a, Expr b) { return Expression(a, b, -1, false); } Expr eq(Expr a, Expr b) { return Expression(a, b, 0, false); } Expr gt(Expr a, Expr b) { return Expression(a, b, 1, false); } diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 7600d2be1..081efb2d5 100755 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -10,6 +10,10 @@ Expr checkpoint(Expr a); typedef Expr(ActivationFunction)(Expr); +typedef std::function&)> LambdaNodeFunctor; +Expr lambda(const std::vector&, Shape, Type, LambdaNodeFunctor); +Expr lambda(const std::vector&, Shape, Type, LambdaNodeFunctor, LambdaNodeFunctor); + Expr plus(const std::vector&); // TODO: should be logistic(), not sigmoid() @@ -40,9 +44,12 @@ Expr prelu(Expr a, float alpha = 0.01); Expr prelu(const std::vector&, float alpha = 0.01); Expr log(Expr a); - Expr exp(Expr a); +Expr sin(Expr a); +Expr cos(Expr a); +Expr tan(Expr a); + Expr clip(Expr a, float c); Expr operator-(Expr a); @@ -65,6 +72,8 @@ Expr operator/(Expr a, Expr b); Expr operator/(float a, Expr b); Expr operator/(Expr a, float b); +Expr abs(Expr a); + // Expr pow(Expr a, Expr b); // Expr pow(float a, Expr b); // Expr pow(Expr a, float b); @@ -73,7 +82,30 @@ Expr logaddexp(Expr a, Expr b); // Note: Following numpy, minimum() is element-wise, while min() is along an axis in both Numpy and PyTorch. Expr maximum(Expr a, Expr b); +Expr maximum(float a, Expr b); +Expr maximum(Expr a, float b); + Expr minimum(Expr a, Expr b); +Expr minimum(float a, Expr b); +Expr minimum(Expr a, float b); + +// Pair of expressions, currently used for topk nodes only +typedef std::tuple Expr2; + +// Marian pseudo-operator to access elements of a tuple, just the same as std::get(tuple) +template +Expr get(Expr2 tuple) { return std::get(tuple); } + +// PyTorch-like topk operator, returns a 2-tuple of nodes, first node is top-k values +// second node is indices of these values according to given axis. Order is descending +// by default, outputs are ordered. +Expr2 topk(Expr a, int k, int axis, bool descending = true); + +// Convenience operator that maps to topk(a, k=1, axis, descending=true) +Expr2 argmax(Expr a, int axis); + +// Convenience operator that maps to topk(a, k=1, axis, descending=false) +Expr2 argmin(Expr a, int axis); // Note: We cannot overload the relational operators, as they also mean something for Expr itself. // Note: These names follow PyTorch convention. @@ -160,6 +192,30 @@ Expr stopGradient(Expr a); Expr gather(Expr a, int axis, Expr indices); +#if 0 + // reverse operation to gather. a is expression into with values from b are inserted and positions indices along axis. + // with broadcasting + + auto knn = get<0>(KNN->apply(query, k)); // [beam, time, batch, k] + + auto W = reshape(gather(Wt_, -2, flatten(knn)), {beam * time * batch, k, dim}); + auto b = reshape(gather(b_, -1, flatten(knn)), {beam * time * batch, 1, k }); + query = reshape(query, {beam * time * batch, 1, dim}); + auto logits = bdot(query, W, false, true); // [beam * time * batch, 1, k] + logits = reshape(logits + b, {beam, time, batch, k}); // @TODO: add baffine node + + auto shape = indices.shape(); + shape.set(-1, 32000); + auto output = grep->constant(shape, inits::lowest(), logits->value_type()); + output = scatter(output, -1, indices, logits); + + // auto a = graph->constant({2,2,5,32000}, inits::fromValue(minimal)) + // scatter(a, -1, indices, values) + // PyTorch does for out-of-place scatter: out = a.scatter(-1, indices, values) +Expr scatter(Expr a, int axis, Expr indices, Expr b); + +#endif + // Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr... Expr index_select(Expr a, int axis, Expr indices); diff --git a/src/graph/node.cpp b/src/graph/node.cpp index c15c4eb6a..256f7623f 100755 --- a/src/graph/node.cpp +++ b/src/graph/node.cpp @@ -5,13 +5,10 @@ namespace marian { -size_t Node::allocate() { - size_t elements = 0; +void Node::allocate() { if(!val_) { graph()->allocateForward(this); - elements = val_->shape().elements(); } - return elements; } void Node::free() { diff --git a/src/graph/node.h b/src/graph/node.h index c017eeb2c..9c5382d40 100644 --- a/src/graph/node.h +++ b/src/graph/node.h @@ -85,7 +85,7 @@ class Node : public Chainable { virtual void setId(size_t id) override { id_ = id; } virtual size_t getId() override { return id_; } - + virtual void increaseEdges(size_t edges = 1) { edges_ += edges; }; virtual void decreaseEdges(size_t edges = 1) { edges_ -= edges; }; virtual size_t edges() { return edges_; }; @@ -100,7 +100,7 @@ class Node : public Chainable { virtual bool marked_for_debug() override { return markedForDebug_; } virtual const std::string& debug_message() override { return debugMessage_; } - virtual size_t allocate() override; + virtual void allocate() override; virtual void free() override; diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp index 27796d337..531cfaad0 100755 --- a/src/graph/node_initializers.cpp +++ b/src/graph/node_initializers.cpp @@ -233,6 +233,44 @@ Ptr sinusoidalPositionEmbeddings(int start) { }, Type::float32); } +// computes the equivalent of Python's range() +template +Ptr range(T begin, T end, T step) { + return fromLambda([begin, end, step](Tensor t) { + auto nElem = t->shape().elements(); + std::vector v; v.reserve(nElem); + for (T i = begin; i < end; i += step) + v.push_back(i); + ABORT_IF(nElem != v.size(), "range does not match constant shape"); + t->set(v); + }, typeId()); +} + +template Ptr range (float16 begin, float16 end, float16 step); +template Ptr range (float begin, float end, float step); +template Ptr range(IndexType begin, IndexType end, IndexType step); + } // namespace inits +} // namespace marian + + +#if BLAS_FOUND +#include "faiss/VectorTransform.h" +namespace marian { +namespace inits { + +Ptr randomRotation(size_t seed) { + auto rot = [=](Tensor t) { + int rows = t->shape()[-2]; + int cols = t->shape()[-1]; + faiss::RandomRotationMatrix rrot(cols, rows); // transposed in faiss + rrot.init((int)seed); + t->set(rrot.A); + }; + return fromLambda(rot, Type::float32); +} + +} // namespace inits } // namespace marian +#endif diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h index 21dcc95d3..9dc01a0d0 100755 --- a/src/graph/node_initializers.h +++ b/src/graph/node_initializers.h @@ -176,6 +176,21 @@ Ptr fromWord2vec(const std::string& file, */ Ptr sinusoidalPositionEmbeddings(int start); +/** + * Computes a random rotation matrix for LSH hashing. This is part + * of a hash function. The values are orthonormal and computed via + * QR decomposition. Same seed results in same random rotation. + */ +Ptr randomRotation(size_t seed = Config::seed); + +/** + * Computes a range from begin to end-1, like Python's range(). + * The constant being initialized must have one dimension that matches + * the number of elements being generated, while any other dimension must be 1. + */ +template +Ptr range(T begin, T end, T step = (T)1); + } // namespace inits } // namespace marian diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp index 932fee888..5d9cc1bdb 100644 --- a/src/graph/node_operators.cpp +++ b/src/graph/node_operators.cpp @@ -16,13 +16,10 @@ ConstantNode::ConstantNode(Ptr graph, setTrainable(false); } -size_t ConstantNode::allocate() { - size_t elements = 0; +void ConstantNode::allocate() { if(!val_) { graph()->allocateForward(this); - elements = val_->shape().elements(); } - return elements; } void ConstantNode::init() { diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h index 7479bd691..c3dde73c9 100644 --- a/src/graph/node_operators.h +++ b/src/graph/node_operators.h @@ -14,7 +14,7 @@ struct ConstantNode : public Node { ~ConstantNode() {} - virtual size_t allocate() override; + virtual void allocate() override; virtual void init() override; const std::string type() override { return "const"; } @@ -50,9 +50,8 @@ struct ParamNode : public Node { ~ParamNode() {} - virtual size_t allocate() override { + virtual void allocate() override { ABORT_IF(!val_, "Parameters should be allocated by their graph. Parameter {} was not", name_); - return 0; } virtual void init() override; diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h old mode 100755 new mode 100644 index 63158ffae..82685d5fc --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -13,8 +13,65 @@ namespace marian { +class LambdaNodeOp : public NaryNodeOp { +private: + typedef const std::vector& Inputs; + typedef std::function LambdaNodeFunctor; + + std::unique_ptr forward_; + std::unique_ptr backward_; + +public: + LambdaNodeOp(Inputs inputs, Shape shape, Type type, + LambdaNodeFunctor forward) + : NaryNodeOp(inputs, shape, type), + forward_(new LambdaNodeFunctor(forward)) { + Node::trainable_ = !!backward_; + } + + LambdaNodeOp(Inputs inputs, Shape shape, Type type, + LambdaNodeFunctor forward, + LambdaNodeFunctor backward) + : NaryNodeOp(inputs, shape, type), + forward_(new LambdaNodeFunctor(forward)), + backward_(new LambdaNodeFunctor(backward)) { + } + + void forward() override { + (*forward_)(this, children_); + } + + void backward() override { + ABORT_IF(!backward_, "No backward lambda given?"); + (*backward_)(this, children_); + } + + const std::string type() override { return "lambda"; } + + virtual size_t hash() override { + size_t seed = NaryNodeOp::hash(); + util::hash_combine(seed, forward_.get()); + util::hash_combine(seed, backward_.get()); + return seed; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(forward_ != cnode->forward_) // pointer compare on purpose + return false; + if(backward_ != cnode->backward_) // pointer compare on purpose + return false; + return true; + } +}; + class DotNodeOp : public NaryNodeOp { private: + friend class SerializationHelpers; bool transA_; bool transB_; float scalar_; @@ -35,14 +92,14 @@ class DotNodeOp : public NaryNodeOp { auto shapeB = b->shape(); if(transB) { - shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]); + shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]); // @TODO: why not use negative indices? shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]); } Shape outShape = shapeA; outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]); ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2], - "Matrix product requires inner dimensions to match"); + "Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB); return outShape; } @@ -157,6 +214,7 @@ class DotNodeOp : public NaryNodeOp { class AffineNodeOp : public NaryNodeOp { private: + friend class SerializationHelpers; bool transA_; bool transB_; float scalar_; @@ -187,7 +245,7 @@ class AffineNodeOp : public NaryNodeOp { Shape outShape = shapeA; outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]); ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2], - "Matrix product requires inner dimensions to match"); + "Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB); return outShape; } @@ -324,6 +382,7 @@ class AffineNodeOp : public NaryNodeOp { class DotBatchedNodeOp : public NaryNodeOp { private: + friend class SerializationHelpers; bool transA_; bool transB_; float scalar_; @@ -351,7 +410,7 @@ class DotBatchedNodeOp : public NaryNodeOp { Shape outShape = shapeA; outShape.set(-1, shapeB[-1]); ABORT_IF(shapeA[-1] != shapeB[-2], - "Batched matrix product requires inner dimensions to match"); + "Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB); return outShape; } @@ -674,12 +733,12 @@ struct GatherNodeOp : public NaryNodeOp { NodeOps forwardOps() override { return {NodeOp( - Select(val_, child(0)->val(), child(1)->val(), axis_))}; + Select(val_, child(0)->val(), child(1)->val(), axis_))}; } NodeOps backwardOps() override { return {NodeOp( - Insert(child(0)->grad(), adj_, child(1)->val(), axis_))}; + Insert(child(0)->grad(), adj_, child(1)->val(), axis_))}; } Shape newShape(Expr a, int axis, Expr indices) { @@ -722,6 +781,8 @@ struct GatherNodeOp : public NaryNodeOp { return true; } +private: + friend class SerializationHelpers; int axis_; }; @@ -817,7 +878,7 @@ struct MultNodeOp : public ElementBinaryNodeOp { NodeOp(Add(_1 * _2, child(1)->grad(), adj_, child(0)->val()))}; } - const std::string type() override { return "×"; } + const std::string type() override { return "*"; } }; struct DivNodeOp : public ElementBinaryNodeOp { @@ -842,7 +903,7 @@ struct DivNodeOp : public ElementBinaryNodeOp { child(1)->val()))}; } - const std::string type() override { return "÷"; } + const std::string type() override { return "/"; } }; // struct PowNodeOp : public ElementBinaryNodeOp { @@ -1047,19 +1108,19 @@ struct ConcatenateNodeOp : public NaryNodeOp { ABORT_IF(nodes.empty(), "No child nodes given"); Shape shape = nodes[0]->shape(); - ax_ = shape.axis(ax); + axis_ = shape.axis(ax); int sum = 0; auto checkShape = shape; for(auto child : nodes) { - checkShape.set(ax_, child->shape()[ax_]); // don't abort on different sizes on axis dim. + checkShape.set(axis_, child->shape()[axis_]); // don't abort on different sizes on axis dim. ABORT_IF(checkShape != child->shape(), "Child shapes {} and {} cannot be concatenated along axis {}", shape, child->shape(), ax); - sum += child->shape()[ax_]; + sum += child->shape()[axis_]; } - shape.set(ax_, sum); + shape.set(axis_, sum); return shape; } @@ -1068,7 +1129,7 @@ struct ConcatenateNodeOp : public NaryNodeOp { std::vector concatenees; for(size_t i = 0; i < children_.size(); ++i) concatenees.push_back(child(i)->val()); - Concatenate(val_, concatenees, ax_); + Concatenate(val_, concatenees, axis_); } void backward() override { @@ -1078,12 +1139,12 @@ struct ConcatenateNodeOp : public NaryNodeOp { childPtr->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly deconcatenees.push_back(childPtr->grad()); } - Deconcatenate(deconcatenees, adj_, ax_); + Deconcatenate(deconcatenees, adj_, axis_); } virtual size_t hash() override { size_t seed = NaryNodeOp::hash(); - util::hash_combine(seed, ax_); + util::hash_combine(seed, axis_); return seed; } @@ -1093,20 +1154,25 @@ struct ConcatenateNodeOp : public NaryNodeOp { auto cnode = std::dynamic_pointer_cast(node); if(!cnode) return false; - if(ax_ != cnode->ax_) + if(axis_ != cnode->axis_) return false; return true; } const std::string type() override { return "concat"; } - int ax_; +private: + friend class SerializationHelpers; + int axis_; }; +// layer norm along last axis struct LayerNormalizationOp : public NaryNodeOp { public: LayerNormalizationOp(const std::vector& nodes, float eps = 1e-9) - : NaryNodeOp(nodes), eps_(eps) {} + : NaryNodeOp(nodes), eps_(eps) { + // @TODO: dimension check + } NodeOps forwardOps() override { return {NodeOp( @@ -1117,6 +1183,7 @@ struct LayerNormalizationOp : public NaryNodeOp { eps_))}; } + // @BUGBUG: backward has not been tested for broadcasting gamma/beta NodeOps backwardOps() override { return {NodeOp( LayerNormalizationGrad( @@ -1152,6 +1219,7 @@ struct LayerNormalizationOp : public NaryNodeOp { } private: + friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp float eps_; }; diff --git a/src/graph/node_operators_tuple.h b/src/graph/node_operators_tuple.h new file mode 100644 index 000000000..c7a9531a1 --- /dev/null +++ b/src/graph/node_operators_tuple.h @@ -0,0 +1,167 @@ +#pragma once + +#include "graph/node_operators_unary.h" + +namespace marian { + +// Base class for a node that has more than one forward value tensor. +// For now we only do one additional value. +class TupleNode { +protected: + Tensor tupleVal_; // the additional forward value tensor + + friend struct TupleViewNodeOp; + + // Force implementation of these functions. The node that inherits from this + // should use them during allocation/deallocation. + virtual void allocateTuple() = 0; + virtual void freeTuple() = 0; + +public: + // The specific node needs to implement what the view is actually looking at + virtual Expr tupleView() = 0; +}; + +// This is a view similar to ReshapeNodeOp or SliceViewNodeOp +// that uses the additional value as its tensor without +// allocating or destroying anthing. This has the purpose to +// create an Expression that can be put on the tape of the graph and +// visited in correct topological order. It should also make sure that +// the node that actually holds the memory persists via the reference +// to tupleNode_. +struct TupleViewNodeOp : public UnaryNodeOp { +private: + Expr tupleNode_; // hold a reference to the actually node with the tuple data + +public: + TupleViewNodeOp(Expr origin, Shape shape, Type type) + : UnaryNodeOp(origin, shape, type), + tupleNode_(origin) { + Node::destroy_ = false; // should not be detroyed or freed, the origin node is handling that + Node::trainable_ = false; // for now this is not trainable + } + + // make sure these functions don't actually do anything, origin node handles all this + ~TupleViewNodeOp() {} + void allocate() override {} + void free() override {} + void forward() override {} + void backward() override {} + void init_dependent() override {} + void set_zero_adjoint() override {} + + // Return the additional tuple tensor as the val() tensor for the view + Tensor& val() override { + // we have to use a raw pointer cast here as we cannot add IntrusivePtr ref-counting to TupleNode + // otherwise we would have ambigous inheritence + auto tptr = dynamic_cast(tupleNode_.get()); + ABORT_IF(!tptr, "Could not convert to tuple?"); + return tptr->tupleVal_; + }; + + // Currently not trainable. We will see if that will be useful at some point + Tensor& grad() override { + ABORT("There should be no gradients for tuple values"); + }; + + const std::string type() override { return "tupleView"; } + const std::string color() override { return "grey"; } +}; + +// This is an implementation of topk, similar to the PyTorch node. +// At the moment we only handle axis=-1 in here, but do transposes +// in the actual operator to handle other axes (inefficiently). +// The normal forward values here are the top-k values per axis, +// the additional value from the TupleNode contains the integer +// indices of the top-k values. +struct TopKNodeOp : public UnaryNodeOp, + public TupleNode { +private: + int k_; // how many top-k results? + int axis_; // on which axis + bool descending_; // sort-order, by default descending. PyTorch has a version without sorting, we always sort. + +public: + TopKNodeOp(Expr a, int k, int axis, bool descending = true) + : UnaryNodeOp(a, newShape(a, k, axis)), + k_{k}, descending_{descending} {} + + Shape newShape(Expr a, int k, int axis) { + Shape shape = a->shape(); + axis_ = shape.axis(axis); + + shape.set(axis_, k); + return shape; + } + + // imlementation of TupleNode-specific pure-virtual functions for allocation + void allocateTuple() override final { + graph()->getTensorAllocator()->allocate(tupleVal_, shape(), Type::uint32); + } + + // we override the normal allocation to include the TupleNode allocation + void allocate() override { + UnaryNodeOp::allocate(); + allocateTuple(); + } + + // imlementation of TupleNode-specific pure-virtual functions for de-allocation + void freeTuple() override final { + if(graph()) { + if(tupleVal_) { + graph()->free(tupleVal_); + tupleVal_ = nullptr; + } + } + } + + // we override the normal allocation to include the TupleNode de-allocation + void free() override { + UnaryNodeOp::free(); + freeTuple(); + } + + // Create and return a TupleView to the additional forward value + virtual Expr tupleView() override final { + return Expression(this, shape(), Type::uint32); + } + + void forward() override { + TopK(/*out*/val_, /*out: topkIndices=*/tupleVal_, + graph()->allocator(), + child(0)->val(), k_, axis_, descending_); + } + + void backward() override { + Insert(/*out*/child(0)->grad(), adj_, val_, axis_); + } + + const std::string type() override { return "topk"; } + + virtual size_t hash() override { + if(!hash_) { + hash_ = NaryNodeOp::hash(); + util::hash_combine(hash_, k_); + util::hash_combine(hash_, axis_); + util::hash_combine(hash_, descending_); + } + return hash_; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(k_ != cnode->k_) + return false; + if(axis_ != cnode->axis_) + return false; + if(descending_ != cnode->descending_) + return false; + return true; + } +}; + +} diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h old mode 100755 new mode 100644 index a52ecf0e1..c565e0357 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -31,6 +31,7 @@ struct UnaryNodeOp : public NaryNodeOp { struct ScalarAddNodeOp : public UnaryNodeOp { private: + friend class SerializationHelpers; float scalar_{0}; public: @@ -88,6 +89,7 @@ struct CastNodeOp : public UnaryNodeOp { struct ScalarMultNodeOp : public UnaryNodeOp { private: + friend class SerializationHelpers; float scalar_{0}; public: @@ -462,6 +464,7 @@ enum class ReduceNodeOpCode { }; struct ReduceNodeOp : public UnaryNodeOp { + friend class SerializationHelpers; int axis_; ReduceNodeOpCode opCode_; int reducedDim_; // dimension of axis being reduced, e.g. used in mean() @@ -470,7 +473,10 @@ struct ReduceNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, newShape(a, axis)), opCode_(opCode) { reducedDim_ = a->shape()[axis]; // e.g. used in mean() - ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(), "bug in determining reducedDim"); + ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(), + "Bug in determining reducedDim {} != {}", + reducedDim_, + a->shape().elements() / shape().elements()); } NodeOps forwardOps() override { @@ -611,6 +617,54 @@ struct ExpNodeOp : public UnaryNodeOp { const std::string type() override { return "exp"; } }; +struct SinNodeOp : public UnaryNodeOp { + SinNodeOp(Expr a) : UnaryNodeOp(a) {} + + NodeOps forwardOps() override { + using namespace functional; + return {NodeOp(Element(_1 = sin(_2), val_, child(0)->val()))}; + } + + NodeOps backwardOps() override { + using namespace functional; + return {NodeOp(Add(_1 * cos(_2), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() override { return "sin"; } +}; + +struct CosNodeOp : public UnaryNodeOp { + CosNodeOp(Expr a) : UnaryNodeOp(a) {} + + NodeOps forwardOps() override { + using namespace functional; + return {NodeOp(Element(_1 = cos(_2), val_, child(0)->val()))}; + } + + NodeOps backwardOps() override { + using namespace functional; + return {NodeOp(Add(_1 * -sin(_2), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() override { return "sin"; } +}; + +struct TanNodeOp : public UnaryNodeOp { + TanNodeOp(Expr a) : UnaryNodeOp(a) {} + + NodeOps forwardOps() override { + using namespace functional; + return {NodeOp(Element(_1 = tan(_2), val_, child(0)->val()))}; + } + + NodeOps backwardOps() override { + using namespace functional; + return {NodeOp(Add(_1 / sqr(cos(_2)), child(0)->grad(), adj_, child(0)->val()))}; + } + + const std::string type() override { return "sin"; } +}; + struct SqrtNodeOp : public UnaryNodeOp { float epsilon_; @@ -679,13 +733,10 @@ struct NegNodeOp : public UnaryNodeOp { return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; } - const std::string type() override { return "-"; } + const std::string type() override { return "negate"; } }; struct TransposeNodeOp : public UnaryNodeOp { - std::vector axes_; - std::vector axesBw_; - TransposeNodeOp(Expr a, const std::vector& axes) : UnaryNodeOp(a, newShape(a, axes)), axes_{axes}, axesBw_(axes.size()) { for(int i = 0; i < axes_.size(); ++i) @@ -736,10 +787,16 @@ struct TransposeNodeOp : public UnaryNodeOp { const std::string type() override { return "transpose"; } const std::string color() override { return "orange"; } + +private: + friend class SerializationHelpers; + std::vector axes_; + std::vector axesBw_; }; class ReshapeNodeOp : public UnaryNodeOp { private: + friend class SerializationHelpers; Expr reshapee_; public: @@ -751,7 +808,7 @@ class ReshapeNodeOp : public UnaryNodeOp { ~ReshapeNodeOp() {} - size_t allocate() override { return 0; } + void allocate() override {} void free() override {} void forward() override {} @@ -817,7 +874,7 @@ class ClipGradientNodeOp : public UnaryNodeOp { ~ClipGradientNodeOp() {} - size_t allocate() override { return 0; } + void allocate() override {} void free() override {} void forward() override {} @@ -874,6 +931,7 @@ class ClipGradientNodeOp : public UnaryNodeOp { // The resulting object must be consecutive in memory. class SliceViewNodeOp : public UnaryNodeOp { private: + friend class SerializationHelpers; Expr viewedNode_; // viewed underlying node Slice slice_; // index range int axis_; // and axis along which it is viewed @@ -903,7 +961,7 @@ class SliceViewNodeOp : public UnaryNodeOp { return outShape; } - size_t allocate() override { return 0; } + void allocate() override {} void free() override {} void forward() override {} @@ -999,10 +1057,28 @@ struct ShiftNodeOp : public UnaryNodeOp { return true; } +private: + friend class SerializationHelpers; Shape shift_; // shift offsets in each dimension float padValue_; // what value to shift in }; +struct AbsNodeOp : public UnaryNodeOp { + AbsNodeOp(Expr a) : UnaryNodeOp(a) {} + + NodeOps forwardOps() override { + using namespace functional; + return {NodeOp(Element(_1 = abs(_2), val_, child(0)->val()))}; + } + + NodeOps backwardOps() override { + using namespace functional; + return {NodeOp(Add(sgn(_1) * _2, child(0)->grad(), child(0)->val(), adj_))}; + } + + const std::string type() override { return "abs"; } +}; + #ifdef CUDNN class PoolingOp : public UnaryNodeOp { public: @@ -1068,6 +1144,7 @@ class PoolingWithMaskingOp : public UnaryNodeOp { const std::string type() override { return "layer_pooling"; } protected: + friend class SerializationHelpers; Expr mask_; int width_; bool isEven_; diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 768c041cc..10b00268b 100755 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -6,8 +6,7 @@ #include "data/factored_vocab.h" #include "rnn/types.h" // for State::select() #include "models/states.h" // for EncoderState - -//using std::size_t; // not sure why this is needed +#include "layers/lsh.h" namespace marian { Logits::Logits(Expr logits) : Logits(New(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) @@ -24,7 +23,7 @@ namespace marian { ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); auto firstLogits = logits_.front()->loss(); - ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), + ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), "Labels not matching logits shape ({} != {}, {})??", labels.size() * firstLogits->shape()[-1], firstLogits->shape().elements(), @@ -218,6 +217,17 @@ namespace marian { if (Wt_) return; + // this option is only set in the decoder + if(!lsh_ && options_->hasAndNotEmpty("output-approx-knn")) { +#if BLAS_FOUND + auto k = opt>("output-approx-knn")[0]; + auto nbits = opt>("output-approx-knn")[1]; + lsh_ = New(k, nbits); +#else + ABORT("Requires BLAS library"); +#endif + } + auto name = options_->get("prefix"); auto numOutputClasses = options_->get("dim"); @@ -257,6 +267,22 @@ namespace marian { Logits Output::applyAsLogits(Expr input) /*override final*/ { lazyConstruct(input->shape()[-1]); +#if BLAS_FOUND + auto affineOrLSH = [this](Expr x, Expr W, Expr b, bool transA, bool transB) { + if(lsh_) { + ABORT_IF( transA, "Transposed query not supported for LSH"); + ABORT_IF(!transB, "Untransposed indexed matrix not supported for LSH"); + return lsh_->apply(x, W, b); + } else { + return affine(x, W, b, transA, transB); + } + }; +#else + auto affineOrLSH = [](Expr x, Expr W, Expr b, bool transA, bool transB) { + return affine(x, W, b, transA, transB); + }; +#endif + if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices()); cachedShortb_ = index_select(b_ , -1, shortlist_->indices()); @@ -333,7 +359,12 @@ namespace marian { input1 = layerNorm(input1, name + "_ffn"); } // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix - auto factorLogits = affine(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true, /*scale=*/1.0f); // [B... x U] factor logits + Expr factorLogits; + if(g == 0) + factorLogits = affineOrLSH(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + else + factorLogits = affine(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits + // optionally add lemma-dependent bias if (Plemma) { // [B... x U0] int lemmaVocabDim = Plemma->shape()[-1]; @@ -396,11 +427,11 @@ namespace marian { } } return Logits(std::move(allLogits), factoredVocab_); + } else if (shortlist_) { + return Logits(affineOrLSH(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); + } else { + return Logits(affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); } - else if (shortlist_) - return Logits(affine(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true)); - else - return Logits(affine(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true)); } } @@ -419,7 +450,7 @@ namespace marian { // Embedding layer initialization should depend only on embedding size, hence fanIn=false auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length - + if (options_->has("embFile")) { std::string file = opt("embFile"); if (!file.empty()) { @@ -493,6 +524,8 @@ namespace marian { auto batchMask = graph->constant({dimWidth, dimBatch, 1}, inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed())); #endif + // give the graph inputs readable names for debugging and ONNX + batchMask->set_name("data_" + std::to_string(/*batchIndex_=*/0) + "_mask"); return std::make_tuple(batchEmbeddings, batchMask); } @@ -510,8 +543,10 @@ namespace marian { Expr Embedding::applyIndices(const std::vector& embIdx, const Shape& shape) const /*override final*/ { ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary"); - auto selectedEmbs = rows(E_, embIdx); // [(B*W) x E] - selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] + auto embIdxExpr = E_->graph()->indices(embIdx); + embIdxExpr->set_name("data_" + std::to_string(/*batchIndex_=*/0)); // @TODO: how to know the batch index? + auto selectedEmbs = rows(E_, embIdxExpr); // [(B*W) x E] + selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E] // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately) selectedEmbs = dropout(selectedEmbs, options_->get("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); return selectedEmbs; diff --git a/src/layers/generic.h b/src/layers/generic.h index a3b9bac45..6233c6d91 100755 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -228,6 +228,12 @@ class Dense : public LayerBase, public IUnaryLayer { Expr apply(Expr input) override { return apply(std::vector({input})); } }; +} // namespace mlp + +class LSH; + +namespace mlp { + class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { private: // parameters held by this layer @@ -239,10 +245,11 @@ class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList { Expr cachedShortb_; // these match the current value of shortlist_ Expr cachedShortLemmaEt_; Ptr factoredVocab_; - + // optional parameters set/updated after construction Expr tiedParam_; Ptr shortlist_; + Ptr lsh_; void lazyConstruct(int inputDim); public: diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp index 408f6fc73..66eabd8a6 100755 --- a/src/layers/loss.cpp +++ b/src/layers/loss.cpp @@ -32,8 +32,6 @@ Ptr newMultiLoss(Ptr options) { return New(); else ABORT("Unknown multi-loss-type {}", multiLossType); - - return nullptr; } } // namespace marian diff --git a/src/layers/lsh.cpp b/src/layers/lsh.cpp new file mode 100644 index 000000000..9195ae6c6 --- /dev/null +++ b/src/layers/lsh.cpp @@ -0,0 +1,118 @@ +#include "layers/lsh.h" +#include "3rd_party/faiss/IndexLSH.h" +#include "graph/expression_operators.h" +#include "tensors/cpu/prod_blas.h" + +namespace marian { + +Expr LSH::apply(Expr input, Expr W, Expr b) { + auto idx = search(input, W); + return affine(idx, input, W, b); +} + +Expr LSH::search(Expr query, Expr values) { + ABORT_IF(query->graph()->getDeviceId().type == DeviceType::gpu, + "LSH index (--output-approx-knn) currently not implemented for GPU"); + + auto kShape = query->shape(); + kShape.set(-1, k_); + + auto forward = [this](Expr out, const std::vector& inputs) { + auto query = inputs[0]; + auto values = inputs[1]; + + int dim = values->shape()[-1]; + + if(!index_ || indexHash_ != values->hash()) { + LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_); + index_.reset(new faiss::IndexLSH(dim, nbits_, + /*rotate=*/dim != nbits_, + /*train_thesholds*/false)); + int vRows = values->shape().elements() / dim; + index_->train(vRows, values->val()->data()); + index_->add( vRows, values->val()->data()); + indexHash_ = values->hash(); + } + + int qRows = query->shape().elements() / dim; + std::vector distances(qRows * k_); + std::vector ids(qRows * k_); + + index_->search(qRows, query->val()->data(), k_, + distances.data(), ids.data()); + + std::vector vOut; + vOut.reserve(ids.size()); + for(auto id : ids) + vOut.push_back((IndexType)id); + + out->val()->set(vOut); + }; + + return lambda({query, values}, kShape, Type::uint32, forward); +} + +Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) { + auto outShape = input->shape(); + int dimVoc = W->shape()[-2]; + outShape.set(-1, dimVoc); + + auto forward = [this](Expr out, const std::vector& inputs) { + auto lowest = NumericLimits(out->value_type()).lowest; + out->val()->set(lowest); + + int dimIn = inputs[1]->shape()[-1]; + int dimOut = out->shape()[-1]; + int dimRows = out->shape().elements() / dimOut; + + auto outPtr = out->val()->data(); + auto idxPtr = inputs[0]->val()->data(); + auto queryPtr = inputs[1]->val()->data(); + auto WPtr = inputs[2]->val()->data(); + auto bPtr = inputs[3]->val()->data(); + + for(int row = 0; row < dimRows; ++row) { + auto currIdxPtr = idxPtr + row * k_; // move to next batch of k entries + auto currQueryPtr = queryPtr + row * dimIn; // move to next input query vector + auto currOutPtr = outPtr + row * dimOut; // move to next output position vector (of vocabulary size) + for(int k = 0; k < k_; k++) { + int relPos = currIdxPtr[k]; // k-th best vocabulay item + auto currWPtr = WPtr + relPos * dimIn; // offset for k-th best embedding + currOutPtr[relPos] = bPtr[relPos]; // write bias value to position + + // proceed one vector product at a time writing to the correct position + sgemm(false, true, 1, 1, dimIn, 1.0f, currQueryPtr, dimIn, currWPtr, dimIn, 1.0f, &currOutPtr[relPos], 1); + } + } + }; + + return lambda({idx, input, W, b}, + outShape, + input->value_type(), + forward); +} + +// @TODO: alternative version which does the same as above with Marian operators, currently missing "scatter". +// this uses more memory and likely to be slower. Would make sense to have a scatter node that actually creates +// the node instead of relying on an existing node, e.g. scatter(shape, defaultValue, axis, indices, values); +#if 0 +Expr LSH::affine(Expr idx, Expr input, Expr W, Expr b) { + int dim = input->shape()[-1]; + int bch = idx->shape().elements() / k; + + auto W = reshape(rows(Wt_, flatten(idx)), {bch, k, dim}); // [rows, k, dim] + auto b = reshape(cols(b_, flatten(idx)), {bch, 1, k}); // [rows, 1, k] + + auto aff = reshape(bdot(reshape(input, {bch, 1, dim}), W, false, true) + b, idx->shape()); // [beam, time, batch, k] + + int dimVoc = Wt_->shape()[-2]; + auto oShape = input->shape(); + oShape.set(-1, dimVoc); + auto lowest = graph_->constant(oShape, + inits::fromValue(NumericLimits(input->value_type()).lowest), + input->value_type()); + return scatter(lowest, -1, idx, aff); +} +#endif + +} // namespace marian \ No newline at end of file diff --git a/src/layers/lsh.h b/src/layers/lsh.h new file mode 100644 index 000000000..d9239fc33 --- /dev/null +++ b/src/layers/lsh.h @@ -0,0 +1,26 @@ +#include "graph/expression_graph.h" +#include + +namespace faiss { + struct IndexLSH; +} + +namespace marian { + +class LSH { +public: + LSH(int k, int nbits) : k_{k}, nbits_{nbits} {} + Expr apply(Expr query, Expr values, Expr bias); + +private: + Ptr index_; + size_t indexHash_{0}; + + int k_{100}; + int nbits_{1024}; + + Expr search(Expr query, Expr values); + Expr affine(Expr idx, Expr query, Expr values, Expr bias); +}; + +} \ No newline at end of file diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp index 2f7687aa4..1401eb2bf 100755 --- a/src/microsoft/quicksand.cpp +++ b/src/microsoft/quicksand.cpp @@ -246,7 +246,7 @@ DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) { // @TODO: clean-up this code and unify with marian-conv. The targetPrec parameter is not clear enought etc. bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec) { - std::cout << "Converting from: " << inputFile << ", to: " << outputFile << std::endl; + std::cerr << "Converting from: " << inputFile << ", to: " << outputFile << ", precision: " << targetPrec << std::endl; YAML::Node config; std::stringstream configStr; @@ -268,7 +268,7 @@ bool convertModel(std::string inputFile, std::string outputFile, int32_t targetP // added a flag if the weights needs to be packed or not graph->packAndSave(outputFile, configStr.str(), saveGemmType); - std::cout << "Conversion Finished." << std::endl; + std::cerr << "Conversion Finished." << std::endl; return true; } diff --git a/src/models/decoder.h b/src/models/decoder.h old mode 100755 new mode 100644 index 725573c19..5ddaa9643 --- a/src/models/decoder.h +++ b/src/models/decoder.h @@ -47,9 +47,9 @@ class DecoderBase : public EncoderDecoderLayerBase { ABORT_IF(shortlist_, "How did a shortlist make it into training?"); - auto yShifted = shift(y, {1, 0, 0}); + auto yDelayed = shift(y, {1, 0, 0}); // insert zero at front; first word gets predicted from a target embedding of 0 - state->setTargetHistoryEmbeddings(yShifted); + state->setTargetHistoryEmbeddings(yDelayed); state->setTargetMask(yMask); const Words& data = subBatch->data(); diff --git a/src/models/encoder.h b/src/models/encoder.h old mode 100755 new mode 100644 diff --git a/src/models/transformer.h b/src/models/transformer.h index 84f1cd651..4fea94d0c 100755 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -11,6 +11,8 @@ #include "models/states.h" #include "models/transformer_factory.h" #include "rnn/constructors.h" +#define _USE_MATH_DEFINES // enables math constants. We need M_PI_2 +#include namespace marian { @@ -26,7 +28,8 @@ class Transformer : public EncoderOrDecoderBase { protected: using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_; - std::unordered_map cache_; // caching transformation of the encoder that should not be created again + std::unordered_map cache_; // caching transformation of the encoder that should not be created again + mutable/*lazy*/ std::vector sinusoidalEmbeddingsFreq_, sinusoidalEmbeddingsOffs_; // cached contributions to sinusoidal embeddings // attention weights produced by step() // If enabled, it is set once per batch during training, and once per step during translation. @@ -84,8 +87,25 @@ class Transformer : public EncoderOrDecoderBase { // according to paper embeddings are scaled up by \sqrt(d_m) embeddings = std::sqrt((float)dimEmb) * embeddings; // embeddings were initialized to unit length; so norms will be in order of sqrt(dimEmb) +#ifdef USE_ONNX // TODO 'Sin' op and constant sine generate different result. So, use constant when 'USE_ONNX' is not defined for now. + // precompute the arguments to sin() (the cos(x) are expressed as sin(x+pi/2)) + if (sinusoidalEmbeddingsFreq_.empty()) { + auto numTimescales = dimEmb / 2; + for (size_t i = 0; i < dimEmb; i++) { + sinusoidalEmbeddingsFreq_.push_back((float)pow(1e-4, ((i % numTimescales) / (numTimescales - 1.0)))); // rotor frequency + sinusoidalEmbeddingsOffs_.push_back((float) ((i / numTimescales) * M_PI_2 )); // 0 (for sin) or pi/2 (for cos) + } + } + auto frequencies = graph_->constant({ dimEmb }, inits::fromVector(sinusoidalEmbeddingsFreq_)); + auto cosOffsets = graph_->constant({ dimEmb }, inits::fromVector(sinusoidalEmbeddingsOffs_)); + auto positionRange = graph_->constant({ dimWords, 1, 1 }, inits::range((float)start, (float)start + (float)dimWords)); + positionRange->set_name("data_" + std::to_string(batchIndex_) + "_posrange"); + auto signal = sin(positionRange * frequencies + cosOffsets); +#else // USE_ONNX auto signal = graph_->constant({dimWords, 1, dimEmb}, inits::sinusoidalPositionEmbeddings(start)); +#endif // USE_ONNX + embeddings = embeddings + signal; } @@ -313,6 +333,7 @@ class Transformer : public EncoderOrDecoderBase { const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim] const Expr& values, // ...? const Expr& mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length] + int dimHeads, bool cache = false, bool saveAttentionWeights = false) { int dimModel = input->shape()[-1]; @@ -321,10 +342,8 @@ class Transformer : public EncoderOrDecoderBase { auto opsPre = opt("transformer-preprocess"); auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb); - auto heads = opt("transformer-heads"); - // multi-head self-attention over previous input - output = MultiHead(prefix, dimModel, heads, output, keys, values, mask, cache, saveAttentionWeights); + output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights); auto opsPost = opt("transformer-postprocess"); output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb); @@ -347,7 +366,7 @@ class Transformer : public EncoderOrDecoderBase { decoderLayerState.output = values; return LayerAttention(prefix, input, values, values, selfMask, - /*cache=*/false); + opt("transformer-heads"), /*cache=*/false); } static inline @@ -532,7 +551,8 @@ class EncoderTransformer : public Transformer { layer, // query layer, // keys layer, // values - layerMask); // [batch size, num heads broadcast=1, max length broadcast=1, max length] + layerMask, // [batch size, num heads broadcast=1, max length broadcast=1, max length] + opt("transformer-heads")); layer = LayerFFN(prefix_ + "_l" + std::to_string(i) + "_ffn", layer); checkpoint(layer); // sets a manually specified checkpoint if gradient checkpointing is enabled, does nothing otherwise. } @@ -600,6 +620,7 @@ class DecoderTransformer : public Transformer { "prefix", prefix_ + "_ff_logit_out", "dim", dimTrgVoc, "vocab", opt>("vocabs")[batchIndex_], // for factored outputs + "output-approx-knn", opt>("output-approx-knn", {}), "lemma-dim-emb", opt("lemma-dim-emb", 0)); // for factored outputs if(opt("tied-embeddings") || opt("tied-embeddings-all")) @@ -621,6 +642,7 @@ class DecoderTransformer : public Transformer { int dim = opt("dim-emb"); auto start = graph->constant({1, 1, dimBatch, dim}, inits::zeros()); + start->set_name("decoder_start_state_" + std::to_string(batchIndex_)); rnn::States startStates(opt("dec-depth"), {start, start}); // don't use TransformerState for RNN layers @@ -675,6 +697,7 @@ class DecoderTransformer : public Transformer { selfMask = selfMask * decoderMask; } + // gather encoder contexts std::vector encoderContexts; std::vector encoderMasks; for(auto encoderState : state->getEncoderStates()) { @@ -741,7 +764,7 @@ class DecoderTransformer : public Transformer { checkpoint(query); - // source-target attention + // cross-attention (source-target) // Iterate over multiple encoders and simply stack the attention blocks if(encoderContexts.size() > 0) { for(size_t j = 0; j < encoderContexts.size(); ++j) { // multiple encoders are applied one after another @@ -772,6 +795,7 @@ class DecoderTransformer : public Transformer { encoderContexts[j], // keys encoderContexts[j], // values encoderMasks[j], + opt("transformer-heads"), /*cache=*/true, saveAttentionWeights); } diff --git a/src/onnx/expression_graph_onnx_exporter.cpp b/src/onnx/expression_graph_onnx_exporter.cpp new file mode 100644 index 000000000..d27f1360c --- /dev/null +++ b/src/onnx/expression_graph_onnx_exporter.cpp @@ -0,0 +1,180 @@ +#ifdef USE_ONNX + +#include "onnx/expression_graph_onnx_exporter.h" + +#include "models/model_factory.h" +#include "models/encoder_decoder.h" +#include "data/corpus_base.h" +#include "tensors/cpu/fbgemm/expression_graph_packable.h" + +#include + +namespace marian { + // The goal is to export three functions: + // - encode_source(): encodes the source + // output: encoder_state + // - decode_first(): resets decoder state and performs the first decoding step + // main output: log prob vector for step 0 + // - decode_next(): performs a subsequent decoding step (called repeatedly) + // main output: log prob vector + // This is done by generating the tape for encoding followed by the first two decoding steps. + // As we do this, we remember the Exprs on the tape that are the inputs and the outputs + // of the three functions. + // Since not all Marian operations have a 1:1 ONNX counterpart, the tape now get rewritten + // such that it only consists of operations that ONNX has. + // Now we cut out three sub-graphs from the tape. Each sub-graph represents one + // of the three functions. The sub-graph is delimited by the inputs and outputs we remembered above. + // Limitations: + // - Inner recurrences, e.g. an RNN encoder, are not supported, since we cannot export control flow to ONNX. + // - Dynamic objects that depend on the input are not supported. + // For example constants whose shape depends on the input length. + // That's why we had to change the sinusoidal embeddings from a constant to a computation. + // - The input length is represented by a "unique" dimension value (97). This brittle. + // That dimension value must not occur naturally in the model. + // That dimension must also not be used in dimension calculations. + // E.g. the exporter does not recognize if a constant is added to it, or if it gets multiplied. + void ExpressionGraphONNXExporter::exportToONNX(const std::string& modelToPrefix, Ptr modelOptions, const std::vector& vocabPaths) + { + auto graph = shared_from_this(); + + // get the model and the vocabularies + auto model = std::dynamic_pointer_cast(models::createModelFromOptions(modelOptions, models::usage::translation)); + std::vector> vocabs; + for (auto vocabPath : vocabPaths) { + Ptr vocab = New(modelOptions, vocabs.size()); + vocab->load(vocabPath, INT_MAX); + vocabs.emplace_back(vocab); + } + setInference(true); // note: must also set "inference" parameter on options + + // if we must suppress , we do that by patching the bias + const auto trgUnkId = vocabs.back()->getUnkId(); + int unkColId = -1; + if (trgUnkId != Word::NONE && !modelOptions->get("allow-unk", false)) { // do we need to suppress unk? + unkColId = trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector? + // find the bias + const std::string outputBiasName = "decoder_ff_logit_out_b"; + auto outputBias = graph->get(outputBiasName); + auto outputBiasVal = outputBias->val(); + std::vector outputBiasVec; + outputBiasVal->get(outputBiasVec); + outputBiasVec[unkColId] = -std::numeric_limits::infinity(); + outputBiasVal->set(outputBiasVec); + } + + // the input length is represented by a value that hopefully is not used elsewhere + const size_t sentinelDim = 97; // who uses prime numbers as dimensions anyways! + size_t numEncoders = vocabs.size() - 1; // @TODO: test this exporter for >1 encoder + + // some helper functions + auto extractInputByName = [&](const std::string& name) { + auto expr = tryFindForwardNodeByName(name); + ABORT_IF(!expr, "Unexpectedly could not find input node named {}", name); + expr->set_name("none"); // and nuke the name, as it will be created again in step() + return std::make_pair(name, expr); + }; + auto extractEmbeddingInputs = [&](bool forEncoder) { + // embedding inputs must be found by name, since Marian does not clearly separate batch and Expr version of the batch + std::vector> embeddingInputs; + for (size_t i = 0; i < numEncoders; i++) { + // inputs must be found by name, since Marian does not clearly separate batch and Expr version of the batch + std::string inputName = "data_" + std::to_string(i); + embeddingInputs.push_back(extractInputByName(inputName)); + if (forEncoder) { + embeddingInputs.push_back(extractInputByName(inputName + "_mask")); + embeddingInputs.push_back(extractInputByName(inputName + "_posrange")); + } + } + return embeddingInputs; + }; + auto extractStates = [&](Ptr decoderState) { + std::vector states; // all decoder-state Exprs in a long list + for (const auto& d : decoderState->getStates()) { + states.push_back(d.output); + states.push_back(d.cell); + } + return states; + }; + + // run a fake batch through the encoder (this goes into encode_source()) and create decoder start states + // This adds the operations to the tape. + std::vector> subBatches; + for (size_t batchIndex = 0; batchIndex < numEncoders; batchIndex++) { + auto sb = New(1, sentinelDim, vocabs[batchIndex]); + // set word indices to random values + std::transform(sb->data().begin(), sb->data().end(), sb->data().begin(), + [&](Word) -> Word { return vocabs[batchIndex]->randWord(); }); + // mask: no items ask being masked out + std::fill(sb->mask().begin(), sb->mask().end(), 1.f); + subBatches.push_back(std::move(sb)); + } + auto batch = New(subBatches); + auto startState = model->startState(graph, batch); + + // fish out the embedding inputs by name and neutralize the names + // These constitute the inputs for the graph we are cutting out for encode_source(). + auto encoderEmbeddingInputs = extractEmbeddingInputs(/*forEncoder=*/true); + std::vector> encoderContexts; + for (const auto& e : startState->getEncoderStates()) + encoderContexts.push_back(std::make_pair("encoder_context_" + std::to_string(encoderContexts.size()), e->getContext())); + + // run it further until the first prediction --> decode_first() + // This adds more operations to the tape. + auto decodeFirstState = model->step(graph, startState, /*hypIndices=*/{}, + /*words=*/{}, /*batchIndices=*/{ 0 }, /*beamSize=*/1); + auto decodeFirstPosRangeInput = extractInputByName("data_" + std::to_string(numEncoders) + "_posrange"); + + // run it further until the next prediction --> decode_next() + // This adds more operations to the tape. + auto decodeNextState = model->step(graph, decodeFirstState, /*hypIndices=*/{}, + /*words=*/{ vocabs.back()->randWord() }, /*batchIndices=*/{ 0 }, /*beamSize=*/1); + auto decodeNextEmbeddingInput = extractEmbeddingInputs(/*forEncoder=*/false); + auto decodeNextPosRangeInput = extractInputByName("data_" + std::to_string(numEncoders) + "_posrange"); + + ABORT_IF(encoderContexts.size() != numEncoders, "Unexpected mismatch in number of encoders??"); + + // create a descriptor for the three functions, which consists of + // - function name + // - list of inputs and outputs, as name-Expr pairs + FunctionDefs functionDefs; + + std::vector> inputs; + std::vector> outputs; + + // descriptor for encode_source(data_0, data_0_mask) -> encoder_context_0 + inputs = encoderEmbeddingInputs; + outputs = encoderContexts; + functionDefs["encode_source"] = std::make_pair(std::move(inputs), std::move(outputs)); + + // descriptor for decode_first(data_1_posrange, encoder_context_0, data_0_mask) -> logits, out_decoder_state_0, out_decoder_state_1, ... + inputs.emplace_back(decodeFirstPosRangeInput); + for (size_t i = 0; i < numEncoders; i++) { + inputs.emplace_back(encoderContexts[i]); + inputs.emplace_back(encoderEmbeddingInputs[1+2*i]); + } + outputs.emplace_back(std::make_pair("first_logits", decodeFirstState->getLogProbs().getLogits())); + for (const auto& dss : extractStates(decodeFirstState)) + outputs.emplace_back(std::make_pair("first_decoder_state_" + std::to_string(outputs.size()-1), dss)); + functionDefs["decode_first"] = std::make_pair(std::move(inputs), std::move(outputs)); + + // descriptor for decode_next(prev_word, data_1_posrange, encoder_context_0, data_0_mask, decoder_state_0, decoder_state_1, ...) -> logits, decoder_state_0, decoder_state_1, ... + inputs.emplace_back(std::make_pair("prev_word", decodeNextEmbeddingInput[0].second)); + inputs.emplace_back(decodeNextPosRangeInput); + for (size_t i = 0; i < numEncoders; i++) { + inputs.emplace_back(encoderContexts[i]); + inputs.emplace_back(encoderEmbeddingInputs[1 + 2 * i]); + } + for (const auto& dss : extractStates(decodeFirstState)) + inputs.emplace_back(std::make_pair("decoder_state_" + std::to_string(inputs.size() - (numEncoders*2 + 2)), dss)); + outputs.emplace_back(std::make_pair("next_logits", decodeNextState->getLogProbs().getLogits())); + for (const auto& dss : extractStates(decodeNextState)) + outputs.emplace_back(std::make_pair("next_decoder_state_" + std::to_string(outputs.size() - 1), dss)); + functionDefs["decode_next"] = std::make_pair(std::move(inputs), std::move(outputs)); + + // now export the sub-graph as given by the function descriptor + serializeToONNX(modelToPrefix, std::move(functionDefs), sentinelDim); + } +} + +#endif + diff --git a/src/onnx/expression_graph_onnx_exporter.h b/src/onnx/expression_graph_onnx_exporter.h new file mode 100644 index 000000000..31c752766 --- /dev/null +++ b/src/onnx/expression_graph_onnx_exporter.h @@ -0,0 +1,29 @@ +#include "graph/expression_graph.h" + +namespace marian { + // export of Marian models to ONNX + class ExpressionGraphONNXExporter : public ExpressionGraph { +#ifdef USE_ONNX + public: + // export a seq2seq model to a set of ONNX files + void exportToONNX(const std::string& modelToPrefix, Ptr modelOptions, const std::vector& vocabPaths); + + private: + // [name] -> (vector(name, Expr), vector(name, Expr)) + typedef std::map>, std::vector> >> FunctionDefs; + + // serialize the current nodesForward_ to an ONNX file. This operation is destructive. + void serializeToONNX(const std::string& filename, FunctionDefs&& functionDefs, size_t sentinelDim); + + // find a node on the current forward tape + Expr tryFindForwardNodeByName(const std::string& nodeName) const; + + // helper to transform nodesForward_ to only use the subset of operations supported by ONNX + void expandMacroOpsForONNX(std::map>, std::vector> >>& functionDefs); + + // helper to build nodesForward_ from root nodes + void rebuildNodesForward(const struct InputsMap& inputsMap, + const std::vector>& outputDefs); +#endif // USE_ONNX + }; +} diff --git a/src/onnx/expression_graph_onnx_serialization.cpp b/src/onnx/expression_graph_onnx_serialization.cpp new file mode 100644 index 000000000..e8a837343 --- /dev/null +++ b/src/onnx/expression_graph_onnx_serialization.cpp @@ -0,0 +1,1093 @@ +#ifdef USE_ONNX + +#include "onnx/expression_graph_onnx_exporter.h" +#include "graph/expression_operators.h" +#include "graph/node_operators_unary.h" +#include "graph/node_operators_binary.h" +#include "common/version.h" +#define AuxillaryParseTableField AuxiliaryParseTableField // in protobuf 3.12, the generated source has a spelling error +#include "3rd_party/onnx/protobuf/onnx-ml.pb-wrapper.h" +#include +#include +#include +#include +#include + +namespace marian { + + // collection of helper functions for accessing and converting Expr properties + // This class is a friend of all node-op classes whose attributes we need to access. + class SerializationHelpers { + public: + // helper for accessing class members in Marian's polymorphic node classes + // If 'e' is of NNaryNodeOp then execute getFn() and return true. + template + static bool tryGetAttributes(Expr e, const F& getFn) { + auto np = std::dynamic_pointer_cast(e); + if (!np) + return false; + getFn(np); + return true; + } + + template + static bool tryGetScalarAttribute(Expr e, float& scalar) { + return tryGetAttributes(e, [&](IPtr np) { scalar = np->scalar_; }); + } + + template + static bool tryGetMatMulAttributes(Expr e, bool& transA, bool& transB, float& scalar) { + return tryGetAttributes(e, [&](IPtr np) { + transA = np->transA_; + transB = np->transB_; + scalar = np->scalar_; + }); + } + + template + static bool tryGetEpsilonAttribute(Expr e, float& eps) { + return tryGetAttributes(e, [&](IPtr np) { eps = np->eps_; }); + } + + template + static bool tryGetAxisAttribute(Expr e, size_t& axis) { + return tryGetAttributes(e, [&](IPtr np) { axis = (size_t)e->shape().axis(np->axis_); }); + } + + template + static bool tryGetAxesAttribute(Expr e, std::vector& axes) { + return tryGetAttributes(e, [&](IPtr np) { + axes.clear(); + for (auto ax : np->axes_) + axes.push_back((size_t)e->shape().axis(ax)); + }); + } + + template + static bool tryGetShiftAttributes(Expr e, std::vector& shift, float& padValue) { + return tryGetAttributes(e, [&](IPtr np) { + shift.assign(np->shift_.begin(), np->shift_.end()); + padValue = np->padValue_; + }); + } + + template + static bool tryGetSliceAttribute(Expr e, Slice& slice) { + return tryGetAttributes(e, [&](IPtr np) { slice = np->slice_; }); + } + + template + static bool tryGetReshapeeAttributePtr(Expr e, Expr*& ep) { + return tryGetAttributes(e, [&](IPtr np) { ep = &np->reshapee_; }); + } + + template + static bool tryGetStepNodeAttributePtr(Expr e, Expr*& ep) { + return tryGetAttributes(e, [&](IPtr np) { ep = &np->stepNode_; }); + } + + template + static bool tryGetMaskAttributePtr(Expr e, Expr*& ep) { + return tryGetAttributes(e, [&](IPtr np) { ep = &np->mask_; }); + } + + // call this for mandatory parameters, e.g. tryGetMaskAttributePtr(...) || tryFailed("message", ...) + template + static bool fail(Args&&... args) { + ABORT(std::forward(args)...); + } + static bool fail() { return fail("an attempt to access a Marian node attribute unexpectedly failed due to a type mismatch"); } + }; + using E = SerializationHelpers; + + struct InputsMap : public std::map { + Expr operator()(Expr e) const { + auto iter = find(e); // redirect input if found + if (iter != end()) + e = iter->second; + return e; + } + }; + + // helper for rebuildNodesForward() + static void addNodeAndChildren(Expr node, std::list& nodesForward, std::set& visited, const InputsMap& inputsMap) + { + // check if this is an input + // In that case, we generate a replacement node instead, which has no children and thus terminates the recursion. + // All nodes that reference this input are, however, unmodified. + // The tape is now inconsistent. The consumer of this tape must perform child mapping. + auto replacementNode = inputsMap(node); + if (replacementNode != node) + node = replacementNode; + // recursion terminates if we already visited a node + // (Input mapping is taken into account already.) + auto res = visited.insert(node); + if (!res.second) // already in visited set: done + return; + for (auto& child : node->children()) // children come before node itself + addNodeAndChildren(child, nodesForward, visited, inputsMap); + nodesForward.push_back(node); + } + + // rebuild nodesForward_ from a graph given by its set of roots + // Also replaces the inputs by constants, but does not redirect references (leaving an invalid tape--must be corrected on the fly by the caller!). + void ExpressionGraphONNXExporter::rebuildNodesForward(const InputsMap& inputsMap, + const std::vector>& outputDefs) { + nodesForward_.clear(); + std::set visited; + for (auto& outputDef : outputDefs) + addNodeAndChildren(outputDef.second, nodesForward_, visited, inputsMap); + } + + class NodeReferenceRedirector { + std::map nodeMap; // [orig node] -> replacement nodes + + public: + void addRedirect(const Expr& whichNode, const Expr& withWhichNode) { + nodeMap[whichNode] = withWhichNode; + } + + // in-place redirect an Expr reference, i.e. look up the redirect and replace the original with it + void redirectReference(Expr& child) const { + auto iter = nodeMap.find(child); + if (iter != nodeMap.end()) { + child = iter->second; // redirect child to the replacement node + ABORT_IF(nodeMap.find(child) != nodeMap.end(), "Nested macro expansion??"); + } + }; + + // redirect all references (=children and more in special cases) + void redirectAllReferencesIn(Expr v) const { + // redirect all children + auto& children = v->children(); // this is a mutable reference + for (auto& child : children) { // child is a mutable reference + redirectReference(child); + } + // redirect additional references tat some nodes hold + Expr* ep{}; + if (E::tryGetReshapeeAttributePtr (v, ep) || + //E::tryGetStepNodeAttributePtr (v, ep) || // @TODO: review all of these and update the names + E::tryGetMaskAttributePtr(v, ep)) { + redirectReference(*ep); + } + } + }; + + static Expr newConstant(Expr v, Shape shape, float val, std::string suffix) { + auto expr = v->graph()->constant(shape, inits::fromVector(std::vector(shape.elements(), val))); + expr->set_name("const_" + v->type() + "_" + std::to_string(v->getId()) + "_" + suffix); + // Note: By convention, all constants should be named const_ something (and all data inputs data_), + // to distinguish them from trainable weight tensors. + return expr; + } + + // unroll higher-level operations for which no ONNX equivalent exists + // This updates the functionDefs' root nodes in-place. + // Note: This appends to nodesForward_ in-place. Some meta-information, like root node, is not updated correctly. + void ExpressionGraphONNXExporter::expandMacroOpsForONNX(std::map>, std::vector> >>& functionDefs) { + LOG(info, "[graph] Expanding macro ops into primitives. Current graph size is {}", nodesForward_.size()); + NodeReferenceRedirector nodeReferenceRedirector; + // clear memoization cache, as it removes some children for ops that have not changed since last inference + tensors_->clearLongtermMemory(); + // Note: expansions will add to the existing tape in-place. But we disallows nested expansions, + // i.e. disallow looping over newly created nodes, because otherwise the nodeReferenceRedirector + // becomes very complicated because those new nodes are no longer topo-sorted. + // The for loop below loops also over newly-created nodes, but those may not + // trigger another expansion, which will be caught in redirectReference() above. + auto beg = nodesForward_.begin(); + auto end = nodesForward_.end(); + for (auto vi = beg; vi != end; ++vi) { + auto& v = *vi; + // redirect all children of this node, in case they got mapped in this process + nodeReferenceRedirector.redirectAllReferencesIn(v); + // expand macro ops + Expr n; +#if 0 // For GC ONNX, some ops are still missing. Map these first. + // @BUGBUG: These operators are not up-to-date + if (v->type() == "highway") { + // Replace Sigmoid by Softmax. The only sigmoid in the system comes from highway. + auto y = v->child(0); // something like [B, H, T, dim] + auto x = v->child(1); + auto t = v->child(2); + auto shape = x->shape(); + ABORT_IF(y->shape() != shape || t->shape() != shape, "unexpected highway shapes??"); + // Softmax([x,0]) = (Sigmoid(x), 1-Sigmoid(x)) + // Softmax([x,y]) = e^x / (e^x + e^y) + // Sigmoid(x) = e^x / (e^x + e^0) + auto shape1 = Shape{shape.elements() / shape.back(), shape.back(), 1}; + t = reshape(t, shape1); + auto tAug = concatenate({t, newConstant(v, t->shape(), 0.0f, "zero_row")}, -1); // [(B*H*T, dim, 2)] + auto s = softmax(tAug, /*axis=*/-1); // = (Sigmoid(t), 1-Sigmoid(t)) : [(B*H*T, dim, 2)] + s = swapAxes(s, 0, -1); // step() only supports axis=0 + auto sy = step(s, 0, /*axis=*/0); + auto sx = step(s, 1, /*axis=*/0); + sy = swapAxes(sy, 0, -1); + sx = swapAxes(sx, 0, -1); + sy = reshape(sy, shape); + sx = reshape(sx, shape); + n = sy * y + sx * x; + //LOG(info, "OVERWRITING highway, {} -> {} -> {} -> back", std::string(shape), std::string(shape1), std::string(tAug->shape())); + } + else if (v->type() == "sum") { + // replace ReduceSum by a matrix product with a vector of ones + auto x = v->child(0); + auto shape = x->shape(); + size_t lastAxis = shape.size() - 1; + size_t axis; + E::tryGetAxisAttribute(v, axis) || E::fail(); + if (axis != lastAxis) // bring axis to be reduced into last dimension so that we can MatMul + x = swapAxes(x, (int)axis, (int)lastAxis); + auto ones = newConstant(v, {x->shape().back(), 1}, 1.0f, "ones"); + n = dot(x, ones); // [..., D] * [D, 1] = [..., 1] + if (axis != lastAxis) // and swap it back + n = swapAxes(n, (int)axis, (int)lastAxis); + //LOG(info, "OVERWRITING sum {}/{}, {} -> {} -> . -> {}", axis, lastAxis, std::string(shape), std::string(x->shape()), std::string(n->shape())); + } + else if (v->type() == "layer_normalization") { + // layerNorm along last axis + auto x = v->child(0); + auto s = v->child(1); + auto b = v->child(2); + auto vecDim = x->shape().back(); + // for summing up elements, we use MatMul + auto onesOverDim = newConstant(v, {vecDim, 1}, 1.0f / vecDim, "ones_over_dim"); + // compute mean and variance + auto mean = dot(x, onesOverDim); + auto x0 = x - mean; + auto var = dot(x0 * x0, onesOverDim); + // variance-normalize + float epsilon; + E::tryGetEpsilonAttribute(v, epsilon) || E::fail(); + auto sigma = sqrt(newConstant(v, {}, epsilon, "epsilon") + var); + auto xnorm = x0 / sigma; + // and final scale/bias + n = xnorm * s + b; + //LOG(info, "OVERWRITING layerNorm {} -> {}", std::string(x->shape()), std::string(mean->shape())); + } + else +#endif + if (v->type() == "scalar_add") { + float scalar{}; + E::tryGetScalarAttribute(v, scalar) || E::fail(); + n = v->child(0) + newConstant(v, {}, scalar, "scalar"); + } + else if (v->type() == "scalar_mult") { + float scalar{}; + E::tryGetScalarAttribute(v, scalar) || E::fail(); + n = v->child(0) * newConstant(v, {}, scalar, "scalar"); + } + else if (v->type() == "square") { + auto x = v->child(0); + n = x * x; + } +#if 0 // @BUGBUG: not supported for now, since we don't aim at training. This requires a function called select() which no longer exists. + else if (v->type() == "x-ent") { + auto x = v->child(0); // logits : some_shape + (num_classes,) + auto y = v->child(1); // indices: some_shape + (1,) + // C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i] + auto xShape = x->shape(); + // note: indices are flattened into a vector + auto yShape = xShape; // true shape of y -> result shape + yShape.back() = 1; + auto nl = logsoftmax(x); + //nl->debug("nl"); +#if 1 // ONNX has no batched select/gather, so we must fake it. + // We first flatten the batch to a vector. + nl = flatten(nl); // now x: (totalWords, vocabSize), while y: (totalWords,) + // Then we create a constant with offsets into this vector + auto vocabSize = xShape.back(); + auto totalWords = xShape.elements() / vocabSize; // total batch size across batch and length dimension + std::vector offs; + for (size_t i = 0; i < totalWords; i++) + offs.push_back((unsigned int)(i * vocabSize)); + auto offsExpr = v->graph()->indices(offs); + offsExpr->set_name("const_" + v->type() + "_offsets_" + std::to_string(v->getId())); + // Now form indices into the flattened vector using the offsets + y = y + offsExpr; // -> [y0, y1 + V, y2 + 2V, ...] + // Now we can select with this. + n = -select(nl, y, /*axis=*/-1); + n = reshape(n, yShape); + //LOG(info, "x-ent: {}, {} -> {}", std::string(x->shape()), std::string(y->shape()), std::string(n->shape())); +#else // better version, but unfortunately neither Marian nor ONNX support batched select/gather + y = reshape(y, yShape); + n = -select(nl, y, /*axis=*/-1); // @TODO: update if we ever add axis_ to x-ent +#endif + } +#endif + else if (v->type() == "highway") { + auto y = v->child(0); + auto x = v->child(1); + auto t = v->child(2); + auto s = sigmoid(t); + auto oneExpr = newConstant(v, {}, 1.0f, "one"); + n = s * y + (oneExpr - s) * x; + } + else if ( v->type() == "bdot" || + (v->type() == "dot" /* && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2)*/) || + (v->type() == "affine" && (v->child(0)->shape().size() != 2 || v->child(1)->shape().size() != 2 || v->child(2)->shape().size() > 2))) { + // ONNX MatMul behaves like Numpy matmul, and therefore implements batched semantics. + // ONNX MatMul has no transA/B/scale parameters, so we must handle those as explicit operations. + // affine() could also be ONNX Gemm, but that does not support outer ranks, so we just expand it into dot(). + // @TODO: ^^ we can just reshape(). Code is already below, but ONNX Gemm always crashes, so this is disabled for now. + auto a = v->child(0); + auto b = v->child(1); + bool transA{}, transB{}; float scalar{}; // (gcc complains without the initializers, which I think is a compiler bug) + E::tryGetMatMulAttributes (v, transA, transB, scalar) || + E::tryGetMatMulAttributes(v, transA, transB, scalar) || + E::tryGetMatMulAttributes (v, transA, transB, scalar) || E::fail(); + //LOG(info, "{} {}={}x{} trans = {}, {} and scalar = {}", + // v->type(), std::string(v->shape()), std::string(a->shape()), std::string(b->shape()), transA, transB, scalar); + if (transA || transB || scalar != 1.0f || + (v->type() == "affine" && (a->shape().size() != 2 || b->shape().size() != 2 || v->child(2)->shape().size() > 2))) { + //LOG(info, "patching {} {}={}x{} due to trans = {}, {} and scalar = {}", + // v->type(), std::string(v->shape()), std::string(a->shape()), std::string(b->shape()), transA, transB, scalar); + if (transA) { // note: we don't optimize for this since it does not happen in present models + a = swapAxes(a, -1, -2); + transA = false; + } + // @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization. + //if (v->type() != "bdot" && b->shape().size() == 2) { // [A,B,C,I,J] x [J,K] --> reshape into regular matrix product + // ABORT_IF(transA, "Transposition not mapped away??"); + // a = reshape(a, Shape({ a->shape().elements() / a->shape()[-1], a->shape()[-1] })); // now it's a regular matrix product, can use Gemm + //} + /*else*/ if (transB) { // not a regular matrix product: cannot use Gemm, so must transpose manually + b = swapAxes(b, -1, -2); + transB = false; + } + float extraScalar = 1.0f; + if (v->type() == "bdot") { // this maps to ONNX MatMul + extraScalar = scalar; // must add extra scale operation at the end + scalar = 1.0f; // we cannot scale in ONNX MatMul + ABORT_IF(transA || transB || scalar != 1.0f, "Transposition and/or scalar not mapped away??"); + n = bdot(a, b, transA, transB, scalar); + } + else { // dot, affine + // @BUGBUG: Gemm always crashes with ONNX runtime. So we can't do this optimization. + //if (a->shape().size() != 2 || b->shape().size() != 2) { // not ONNX MatMul: must use explicit scale operation + extraScalar = scalar; + scalar = 1.0f; + //} + n = dot(a, b, transA, transB, scalar); + //LOG(info, "{} {} x {} -> {}", v->type(), std::string(a->shape()), std::string(b->shape()), std::string(n->shape())); + if (v->type() == "affine") + n = n + v->child(2); + } + //if (v->type() == "affine") + // LOG(info, "{} + {} -> {}", v->type(), std::string(v->child(2)->shape()), std::string(n->shape())); + if (extraScalar != 1.0f) + n = n * newConstant(v, {}, extraScalar, "scalar"); + if (n->shape() != v->shape()) + n = reshape(n, v->shape()); // if we did some shaping to get a regular matrix product, reshape it back + } + } + else if (v->type() == "affine" && v->children().size() > 3) { + // affine() may have a redundant vector of ones, which we strip here + // This then becomes Gemm. + v->children().resize(3); + ABORT("affine() can presently not stripped of its additional ones vector. Need to fix Marian first to run with this."); + // Note: Cannot recreate affine() as a new node, because that will get that fourth axis again. + // @BUGBUG: This will crash. + } +#if 0 // @BUGBUG: select() no longer exists. Likely some other ops are missing now. + else if (v->type() == "select") { + // select maps to Gather, and is limited to non-batched and the last axis + size_t axis; + E::tryGetAxisAttribute(v, axis) || E::fail(); + auto data = v->child(0); + auto indices = v->child(1); + auto dataShape = data->shape(); + auto dataRank = dataShape.size(); + auto indicesShape = indices->shape(); + auto indicesRank = indicesShape.size(); + auto indicesDim = indicesShape[(int)axis - (int)dataShape.size()]; + ABORT_IF(indicesShape.elements() != indicesDim, "ONNX does not support batched select()"); + if (indicesRank != 1 || axis != dataRank - 1) { + if (indicesRank != 1) + indices = flatten(indices); // (batched Gather is not supported) + if (axis != dataRank - 1) + data = swapAxes(data, (int)axis, (int)dataRank - 1); // swap select axis to back + n = select(data, indices, -1); + if (axis != dataRank - 1) + n = swapAxes(n, (int)axis, (int)dataRank - 1); + } + } +#endif + else if (v->type() == "layer_normalization" && + (v->child(0)->shape().size() != 3 || v->child(1)->shape().size() != 1 || (v->children().size() > 2 && v->child(2)->shape().size() != 1))) { + // ONNX InferenceNormalization is layer norm for shapes (N, C, D, ...) where N and C are + // batch dimensions, and D... all share normalization statistics ("mean and variance are + // computed per instance per channel"). + // Marian layer_normalization normalizes along axis -1. + // Hence, if the input rank is != 3, we must temporarily reshape. + // Also, ONNX expects scale and bias to contain C values (one for each c), while Marian + // shares scale and bias along C but uses vectors of dim D. Hence, we must apply them manually. + // This op gets replaced by a sequence that includes the same op, but with + // gamma and beta being scalars, which is invalid for Marian. + // (This will fail if layerNorm is applied to a scalar, which makes no sense.) + auto x = v->child(0); + auto s = v->child(1); + auto b = v->children().size() > 2 ? v->child(2) : nullptr; // beta is optional + auto outShape = x->shape(); + auto vecDim = outShape[-1]; + x = reshape(x, {outShape.elements() / vecDim, 1, vecDim}); // -> (N, C, D) + ABORT_IF((s->shape().size() > 1 && s->shape()[-1] != s->shape().elements()) || + (b && b->shape().size() > 1 && b->shape()[-1] != b->shape().elements()), + "scale and bias must be vectors or single rows"); + s = flatten(s); + if (b) + b = flatten(b); + //LOG(info, "layer_normalization reshaped from {} to {}", std::string(outShape), std::string(x->shape())); + float epsilon; + E::tryGetEpsilonAttribute(v, epsilon) || E::fail(); + //LOG(info, "LNORM {}, {}, {} vs. {}, {}", std::string(x->shape()), std::string(oneExpr->shape()), std::string(zeroExpr->shape()), std::string(s->shape()), std::string(b->shape())); + n = layerNorm(x, newConstant(v, {1}, 1.0f, "one"), newConstant(v, {1}, 0.0f, "zero"), epsilon); + n = n * s; + if (b) + n = n + b; + n = reshape(n, outShape); + } + else if (v->type() == "const" && v->name().find("dropout_mask_") == 0) { + // This is a randomly generated mask. We must replace this by RandomUniform. + // This is done in 3 steps: + // - We expand v as (uniform < keepProb) * scale; but because Marian has no "<", we use "-" instead for now. @HACKHACK 1 + // - The uniform for now is a constant, which later gets converted as ONNX RandomUniform(0,1). @HACKHACK 2 + // - The "-" with left arg of v gets patched to become ONNX Less. @HACKHACK 1 fix-up + auto pString = v->name(); + pString.erase(0, pString.find_last_of('_') + 1); + float dropProb = std::stof(pString); + //LOG(info, "Found dropProb constant {} -> {}", v->name(), dropProb); + float keepProb = 1.f - dropProb; + float scale = 1.f / keepProb; + auto uniformExpr = v->graph()->constant(v->shape(), inits::zeros()); + uniformExpr->set_name("opRandomUniform_" + std::to_string(v->getId())); // not using newConstant because of special node name + // (uniform(0,1) < keepProb) * scale + n = (uniformExpr - newConstant(v, {}, keepProb, "keepProb")) * newConstant(v, {}, scale, "scale"); + // @HACKHACK 1: Marian has no "less than", so we use "-" instead. Must patch that back later. + // @HACKHACK 2: We use a specially-named constant as the placeholder for uniform(0,1). + } + + if (n) { + // copy key properties + if (v->name() != n->name()) // (this tests for the empty name) + n->set_name(v->name() + "_expanded"); // (this branch is actually never taken presently) + n->setTrainable(v->trainable()); + // register mapping + nodeReferenceRedirector.addRedirect(v, n); + LOG(info, "[graph] Macro op {} expanded with new root op {}", v->type(), n->type()); + } + } + for (auto& functionDef : functionDefs) { + for (auto& output : functionDef.second.second) // redirect outputs: a root may also have been a macro op + nodeReferenceRedirector.redirectReference(output.second); + for (auto& output : functionDef.second.first) // redirect inputs: inputs may be the outputs of other functions + nodeReferenceRedirector.redirectReference(output.second); + } + + // Since we added the expanded ops to the end of nodesForward_, we must bring it + // back into topologically sorted order. + LOG(info, "[graph] After creating expanded nodes, we now have {} nodes", nodesForward_.size()); + } + + using namespace onnx; // all -Proto classes come from here + + const std::string LENGTH_AXIS_NAME = "SOURCE_LENGTH"; // the source length is a named (dynamic) axis with this name + + // C++ port of a subset of https://github.com/onnx/onnx/blob/master/onnx/helper.py + static ValueInfoProto makeValueInfoProto(std::string name, TensorProto_DataType dataType, std::vector shape, size_t sentinelDim) { + ValueInfoProto valueInfo; + valueInfo.set_name(name); + auto* valueInfoType = valueInfo.mutable_type(); + auto* valueInfoTensorType = valueInfoType->mutable_tensor_type(); + valueInfoTensorType->set_elem_type(dataType); + auto* valueInfoTensorTypeShape = valueInfoTensorType->mutable_shape(); + for (auto dim : shape) + if (dim == sentinelDim) + valueInfoTensorTypeShape->add_dim()->set_dim_param(LENGTH_AXIS_NAME); + else + valueInfoTensorTypeShape->add_dim()->set_dim_value(dim); + return valueInfo; + } + + template // note: for now, must pass the matching dataType (not checked) + static TensorProto makeTensorProto(std::string name, TensorProto_DataType dataType, std::vector shape, std::vector vals) { + TensorProto tensor; + tensor.set_name(name); + tensor.set_data_type(dataType); + for (auto dim : shape) + tensor.add_dims(dim); +#if 0 // @HACKHACK for debugging: keep files small during debugging, so that we can load and view those files easily + *tensor.mutable_raw_data() = std::string((char*)vals.data(), (char*)(vals.data() + std::min(size_t(10), vals.size()))); +#else + *tensor.mutable_raw_data() = std::string((char*)vals.data(), (char*)(vals.data() + vals.size())); +#endif + return tensor; + } + + static inline void addAttribute(NodeProto& node, std::string name, std::vector val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); + for (auto i : val) + attribute->add_ints(i); + } + static inline void addAttribute(NodeProto& node, std::string name, std::vector val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INTS); + for (auto i : val) + attribute->add_ints(i); + } + static inline void addAttribute(NodeProto& node, std::string name, std::string val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_STRING); + attribute->set_s(val); + } + static inline void addAttribute(NodeProto& node, std::string name, float val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT); + attribute->set_f(val); + } + static inline void addAttribute(NodeProto& node, std::string name, int val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INT); + attribute->set_i(val); + } + static inline void addAttribute(NodeProto& node, std::string name, size_t val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INT); + attribute->set_i(val); + } + static inline void addAttribute(NodeProto& node, std::string name, bool val) { + AttributeProto* attribute = node.add_attribute(); + attribute->set_name(name); + attribute->set_type(AttributeProto_AttributeType::AttributeProto_AttributeType_INT); + attribute->set_i(val ? 1 : 0); // bool is stored as int in ONNX + } + static void addAttributes(NodeProto&) { // end of recursion + } + template + static void addAttributes(NodeProto& node, std::string name, T val, Attributes&&... moreAttributes) { + addAttribute(node, name, val); + addAttributes(node, std::forward(moreAttributes)...); + } + + template + static NodeProto makeNode(std::string opType, std::string nodeName, + std::vector inputs, std::vector outputs, + Attributes&&... attributes) { + NodeProto node; + node.mutable_op_type()->assign(opType); + for (auto input : inputs) + node.add_input(input); + for (auto output : outputs) + node.add_output(output); + if (!nodeName.empty()) + node.set_name(nodeName); + addAttributes(node, std::forward(attributes)...); + return node; + } + + static GraphProto makeGraph(const std::vector& nodes, std::string name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& initializers, + const std::vector& valueInfos) { + GraphProto graph; + for (auto& node : nodes) + *graph.add_node() = node; + graph.set_name(name); + for (auto& input : inputs) + *graph.add_input() = input; + for (auto& output : outputs) + *graph.add_output() = output; + for (auto& initializer: initializers) + *graph.add_initializer() = initializer; + for (auto& valueInfo : valueInfos) +#if 0 // add some as explicit outputs for debugging + if (valueInfo.name() == "opReshape_292" || valueInfo.name() == "opPad_294") + *graph.add_output() = valueInfo; + else +#endif + *graph.add_value_info() = valueInfo; + valueInfos; + return graph; + } + + static ModelProto makeModel(const GraphProto& graph, std::string producerName) { + ModelProto model; + model.set_ir_version(IR_VERSION); + model.set_producer_name(producerName); + model.mutable_graph()->CopyFrom(graph); +#define OPSET_IMPORT_VERSION 9 // 9 is needed for some newer ops + model.add_opset_import()->set_version(OPSET_IMPORT_VERSION); + return model; + } + + static std::string mapExprOp(Expr e) { + const static std::map opMap = { + {"+" , "Add"}, + {"-" , "Sub"}, + {"*" , "Mul"}, + {"/" , "Div"}, + {"negate" , "Neg"}, + {"ReLU" , "Relu"}, + {"reshape" , "Reshape"}, + {"affine" , "Gemm"}, // @TODO: is this just a hack, or meant to be used for this? It is not really standard GEMM semantics. + {"bdot" , "MatMul"}, + {"dot" , "MatMul"}, + {"sigmoid" , "Sigmoid"}, + {"sqrt" , "Sqrt"}, + {"sin" , "Sin"}, + {"cos" , "Cos"}, + {"tan" , "Tan"}, + {"layer_normalization" , "InstanceNormalization"}, + {"softmax" , "Softmax"}, + {"logsoftmax" , "LogSoftmax"}, + {"sum" , "ReduceSum"}, + {"transpose" , "Transpose"}, + {"concat" , "Concat"}, + {"sliceView" , "Slice"}, + {"shift" , "Pad"}, + {"rows" , "Gather"}, + {"select" , "Gather"}, + // The following are never emitted to ONNX. Keep our original type names to avoid special-casing lots of code. + {"const" , "const"}, + {"param" , "param"} + }; + auto iter = opMap.find(e->type()); + ABORT_IF(iter == opMap.end(), "ONNX export of operation {} is presently not supported", e->type()); + return iter->second; + } + + // get a unique name for an Expr. Either an actual name, or OP_ID if not named. + // 'nameOverrides' overrides that name. This is used for inputs and outputs. + static std::string getExprName(Expr e, const std::map& nameOverrides) { + if (nameOverrides.find(e) != nameOverrides.end()) + return nameOverrides.at(e); + std::string name = e->name(); + if (name == "none") // Marian assigns "none" to denote an unassigned name + name = (e->type() == "const" ? "" : "op") + mapExprOp(e) + "_" + std::to_string(e->getId()); + // For 'const', do not prefix "op", so that all internal constants in the system + // (i.e. not input data) have a prefix "const_" to distinguish them from weight tensors. + return name; + } + + // convert Marian shape into vector + static std::vector getExprShape(Expr e) { + const auto& shape = e->shape(); + return std::vector(shape.begin(), shape.end()); + } + + // get TensorProto_DataType for an Expr + // Note: We map Marian uint32_t to ONNX signed integers because those are only used + // for indices for Gather operations, where Marian requires unsigned and ONNX signed. + static TensorProto_DataType getExprDataType(Expr expr) { + switch (expr->value_type()) { + case marian::Type::float32: return TensorProto_DataType::TensorProto_DataType_FLOAT; + //case marian::Type::uint32: //return TensorProto_DataType::TensorProto_DataType_UINT32; + case marian::Type::uint32: // uint32 becomes ONNX INT32 as well (see above) + case marian::Type::int32: return TensorProto_DataType::TensorProto_DataType_INT32; + default: ABORT("Tensor type not supported yet"); + } + } + + // convert a Marian constant to an ONNX TensorProto + static TensorProto makeExprTensorProto(Expr expr, const std::map& nameOverrides) { + auto dataType = getExprDataType(expr); + auto name = getExprName (expr, nameOverrides); + auto shape = getExprShape (expr); + switch(expr->value_type()) { + case marian::Type::float32: { // @TODO: template this? + std::vector valBuf; + expr->val()->get(valBuf); + return makeTensorProto(name, dataType, shape, valBuf); + } + case marian::Type::uint32: { + std::vector valBuf; // note: uint32_t still get passed to ONNX as signed INT32 (cf. getExprDataType()) + expr->val()->get(valBuf); + return makeTensorProto(name, dataType, shape, valBuf); + } + case marian::Type::int32: { + std::vector valBuf; + expr->val()->get(valBuf); + return makeTensorProto(name, dataType, shape, valBuf); + } + default: + ABORT("Tensor type not supported yet"); + } + } + + static void logNode(const NodeProto& node, const std::vector& shape, size_t sentinelDim) { + std::string s = node.name() + " = " + node.op_type() + "("; + auto addComma = [&]() { if (s.back() != '(' && s.back() != '[') s += ", "; }; + for (int i = 0; i < node.input_size(); i++) { + auto inputName = node.input(i); + addComma(); + s += inputName; + } + for (int i = 0; i < node.attribute_size(); i++) { + auto attribute = node.attribute(i); + addComma(); + s += attribute.name() + "=?"; + } + s += (") : ["); + for (auto dim : shape) { + addComma(); + if (dim == sentinelDim) + s += LENGTH_AXIS_NAME; + else + s += std::to_string(dim); + } + s.push_back(']'); + LOG(info, s); + } + + // convert a Marian Expr to an ONNX node + // This function needs inputs and initializers because the special case of Reshape needs + // to create an extra input with initializer. + static void addExprNode(Expr expr, std::vector& nodes, std::vector& inputs, + std::vector& initializers, + const std::map& nameOverrides, const InputsMap& inputsMap, + size_t sentinelDim) { + // get all children + // These may reference inputs, and hence must be mapped right here. + // The original child in this case is not on the tape. + auto children = expr->children(); + for (auto& child : children) + child = inputsMap(child); + + // inputs are referenced by their node names (also when they are leaves) + std::vector inputNames; + for (const auto& child : children) + inputNames.push_back(getExprName(child, nameOverrides)); + + auto name = getExprName(expr, nameOverrides); // node name is used as both output name and node name + auto op = mapExprOp(expr); + + //if (op == "MatMul" && expr->child(0)->shape().size() == 2 && expr->child(1)->shape().size() == 2) { + // op = "Gemm"; + //} + +#if 1 // workaround for onnxruntime which does not handle Pad correctly + if (op == "Pad") { + // Implement Pad as Slice >> Concat + std::vector shifts; + float padValue{}; // (compiler bug: without initialization, I get an uninit warning, yet it is correctly set) + E::tryGetShiftAttributes(expr, shifts, padValue) || E::fail(); + ABORT_IF(shifts[0] != 1, "can only shift by one"); + for (size_t i = 1; i < shifts.size(); i++) + ABORT_IF(shifts[i] != 0, "can only shift along first axis"); + auto shape = getExprShape(children[0]); + // Slice [0:-1,:,:] + auto sliceName = name + "_Slice"; + auto sliceNode = makeNode("Slice", sliceName, inputNames, {sliceName}); + addAttribute(sliceNode, "axes", std::vector{0}); + addAttribute(sliceNode, "starts", std::vector{0}); + addAttribute(sliceNode, "ends", std::vector{shape[0] - 1}); // drop last step + nodes.push_back(sliceNode); + LOG(info, "Pad slice op {}", sliceName); + // create a padding constant + auto paddingName = "const_" + name + "_Padding"; + shape[0] = 1; + size_t n = 1; + for (auto& dim : shape) + n *= dim; + std::vector zeros(n); + inputs. push_back(makeValueInfoProto(paddingName, TensorProto_DataType::TensorProto_DataType_FLOAT, shape, sentinelDim)); + initializers.push_back(makeTensorProto (paddingName, TensorProto_DataType::TensorProto_DataType_FLOAT, shape, zeros)); + LOG(info, "Pad constant {}", paddingName); + // Concat([paddingNode, sliceNode], axis=0) + auto node = makeNode("Concat", name, {paddingName, sliceName}, {name}); + addAttribute(node, "axis", 0); + nodes.push_back(node); + LOG(info, "Pad concat op {}", name); + return; + } +#endif + + auto node = makeNode(op, name, inputNames, {name}); + //LOG(info, "NODE {} {} -> {}", name, expr->type(), E::mapExprOp(expr)); + + // add attributes needed by some operators + + // fix up inputs + if (node.op_type() == "Reshape") { // Reshape requires the shape itself to be a tensor. + auto shapeInputName = "const_" + getExprName(expr, {}) + "_shape_attr"; + *node.add_input() = shapeInputName; + // create a new input and a new initializer + auto shape = getExprShape(expr); + auto shape64 = std::vector(shape.begin(), shape.end()); + for (auto& dim : shape64) + if (dim == (int64_t)sentinelDim) + dim = -1; // means that this one is inferred at runtime + std::vector shapeShape{shape.size()}; // ONNX Reshape requires shape in INT64 + inputs. push_back(makeValueInfoProto(shapeInputName, TensorProto_DataType::TensorProto_DataType_INT64, shapeShape, sentinelDim)); + initializers.push_back(makeTensorProto (shapeInputName, TensorProto_DataType::TensorProto_DataType_INT64, shapeShape, shape64)); + std::string s = shapeInputName; + for (auto& dim : shape64) + s += " " + std::to_string(dim); + LOG(info, s); + } + // axis attribute + size_t axis; + std::vector axes; + if (E::tryGetAxisAttribute(expr, axis)// || + //E::tryGetAxisAttribute(expr, axis) + ) { // axis_ -> 'axis' + addAttribute(node, "axis", axis); + } + else if (E::tryGetAxisAttribute(expr, axis) || + E::tryGetAxisAttribute(expr, axis)) { // {axis_} -> 'axes' + addAttribute(node, "axes", std::vector{axis}); + } + else if (E::tryGetAxesAttribute(expr, axes)) { // here, the axes are called 'perm' + addAttribute(node, "perm", axes); + } + else if (node.op_type() == "Softmax" || node.op_type() == "LogSoftmax") { + // Note: ONNX (Log)Softmax is not along an axis; rather along all axes >= given axis (they get flattened). + addAttribute(node, "axis", expr->shape().size()-1); // Marian softmax defaults to last axis. @TODO: update if we ever add an axis_ parameter. + } + else if (expr->type() == "rows") { // becomes Gather + // Example, adopted from ONNX docs: + // axis = 0 + // data = [ [1.0, 1.2], [2.3, 3.4], [4.5, 5.7], ] + // indices = [ 0, 1, 1, 2, ] + // output = [ [1.0, 1.2], [2.3, 3.4], [2.3, 3.4], [4.5, 5.7], ] + ABORT_IF(expr->shape().size() != 2, "Unexpected input shape for rows()"); + addAttribute(node, "axis", 0); + } + // slice attributes (starts, ends) + Slice slice; + if (E::tryGetSliceAttribute(expr, slice)) { + addAttribute(node, "starts", std::vector{(size_t)slice.begin}); + addAttribute(node, "ends" , std::vector{(size_t)slice.end}); + addAttribute(node, "steps" , std::vector{(size_t)slice.stride}); + } + // shift attributes (shift, padValue) + std::vector shifts; + float padValue{}; // (compiler bug: without initialization, I get an uninit warning, yet it is correctly set) + if (E::tryGetShiftAttributes(expr, shifts, padValue)) { + std::vector pads; + for (auto shift : shifts) + pads.push_back(shift); // shift = #padValues to insert at front (or, for, shift < 0, to remove at front) + for (auto shift : shifts) + pads.push_back(-shift); // and #values to remove at end (or, for, shift < 0, to insert at end) + ABORT_IF(pads.size() != 2 * expr->shape().size(), "Unexpected number of shift dimensions"); + addAttribute(node, "pads", pads); + addAttribute(node, "value", padValue); + addAttribute(node, "mode", std::string("constant")); + } + + // matmul attributes + bool transA, transB; + float scalar; + // @BUGBUG: I cannot get Gemm to work, ONNX runtime always crashes. So we will NEVER get here. + if (node.op_type() == "Gemm") { // we get here for affine() or dot() + // Note: We only get here if Gemm can implement this configuration. + ABORT_IF(children[0]->shape().size() != 2 || children[1]->shape().size() != 2 || + (children.size() > 2 && children[2]->shape().size() > 2), + "Gemm unexpectedly used for non-matrix inputs"); + E::tryGetMatMulAttributes(expr, transA, transB, scalar) || + E::tryGetMatMulAttributes (expr, transA, transB, scalar) || E::fail(); + /*if (transA) */ addAttribute(node, "transA", transA ? 1 : 0); + /*if (transB) */ addAttribute(node, "transB", transB ? 1 : 0); + /*if (scalar != 1.0f)*/ addAttribute(node, "alpha", scalar); + //addAttribute(node, "beta", 0.0f); + } + else if (E::tryGetMatMulAttributes (expr, transA, transB, scalar) || + E::tryGetMatMulAttributes(expr, transA, transB, scalar)) { + // transpose/scalar not supported by ONNX MatMul, must have been expanded before we get here + ABORT_IF(transA || transB || scalar != 1.0f, "Unexpected transpose or scalar attributes for {}", expr->type()); + } + // epsilon attribute + float epsilon; + if (E::tryGetEpsilonAttribute(expr, epsilon)) { + addAttribute(node, "epsilon", epsilon); + } + // dropout patches + if (node.op_type() == "Sub" && children[0]->type() == "const" && children[0]->name().find("opRandomUniform_") == 0) { + // @HACKHACK 1: For dropout, we route a "<" operation through a Marian "-" because it has no "<". + *node.mutable_op_type() = "Less"; + // Note: Since this is a hack, we don't bother to fix up the node name, which is still opSub_ID. + } + else if (expr->type() == "const" && expr->name().find("opRandomUniform_") == 0) { + // @HACKHACK 2: The dropout weight, which is a 'const' in Marian, acts as a placeholder for + // a RandomUniform operation. In place of a 'const', we generate a uniform(0,1) node + // of the same shape. + *node.mutable_op_type() = "RandomUniform"; + addAttribute(node, "shape", getExprShape(expr)); + } + nodes.push_back(node); + } + + // serialize the nodesForward_ of a graph right after build() into an ONNX-formatted file + // We declare this to be ONNX operator set 9. @TODO: Which ONNX version does this correspond to? + // The nodes must only contain operations supported by ONNX, so the caller must first call + // expandMacroOpsForONNX(). + // One batch axis can be variable-length. It is recognized via a hack: by a special + // dimension value that otherwise never naturally occurs, e.g. a larger prime number. + // We will not recognize derivates of this value, such as value+1 or value x another dimension. + // @TODO: This presently does not support variable batch dimensions. How does ONNX handle them? + // @TODO: How to handle guided alignment? That's another input. Name? Shape? + // This is based on the simple example in + // https://github.com/onnx/onnx/blob/master/onnx/examples/make_model.ipynb + void ExpressionGraphONNXExporter::serializeToONNX(const std::string& fileRoot, FunctionDefs&& functionDefs, size_t sentinelDim) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + // @TODO: expansion must deal with multiple sub-tapes (encoder, init) + // expand Marian macro operations such as "highway" or "scalar_add" that ONNX does not have + // After this, nodesForward_ is not topologically sorted. + expandMacroOpsForONNX(functionDefs); + + for (const auto& functionDef : functionDefs) { + const auto& graphName = functionDef.first; + const auto& inputDefs = functionDef.second.first; + const auto& outputDefs = functionDef.second.second; + + // some stats + LOG(info, "[onnx] Exporting graph {}", graphName); + + std::map nameOverrides; // we implant input and output names dynamically (instead of setting the name in Expr) + + // clear memoization caches + tensors_->clearShorttermMemory(); + tensors_->clearLongtermMemory(); + + // create new dummy const nodes for all function arguments + // These nodes will be replaced in rebuildNodesForward() and act as recursion stops. + // The actual child references are NOT replaced. + // Also, we collect the nameOverrides for all input and output nodes. + InputsMap inputsMap; + for (auto& inputDef : inputDefs) { + const auto& input = inputDef.second; + ABORT_IF(inputsMap.find(input) != inputsMap.end(), "Duplicate inputDef expr??"); + auto arg = constant(input->shape(), inits::zeros(), input->value_type()); + inputsMap[input] = arg; + nameOverrides[arg] = inputDef.first; + } + for (const auto& outputDef : outputDefs) + nameOverrides[inputsMap(outputDef.second)] = outputDef.first; + + // regenerate nodesForward_ from the roots, only for the function under consideration + // This redirects all items in inputsMap in the graph and in outputDefs as well. + // I.e. actual inputs are already replaced by Constants on the tape, but other nodes' + // references are not! + // All references from this point on have to be run through inputsMap(). + rebuildNodesForward(inputsMap, outputDefs); + LOG(info, "[graph] Topologically sorted, garbage-collected graph has size {}", nodesForward_.size()); + + // sanity check: is the tape consistent, assuming the inputsMap? + std::set nodesOnTape; + for (const auto& e : nodesForward_) + nodesOnTape.insert(e); + for (const auto& e : nodesForward_) for (const auto& c : e->children()) { + if (nodesOnTape.find(c) == nodesOnTape.end()) + LOG(info, "Redirected child: {}, {}", c->getId(), c->name()); + ABORT_IF(nodesOnTape.find(inputsMap(c)) == nodesOnTape.end(), + "Node {} {} refers to child {} {} that is off tape??", e->getId(), e->name(), c->getId(), c->name()); + } + + // sanity check: did we consume all expected inputs? + std::set mappedInputSet; // set of replacement Exprs (those constants) for inputs + for (auto ee : inputsMap) + mappedInputSet.insert(ee.second); + std::set seenMappedInputs; + for (const auto& expr : nodesForward_) { + ABORT_IF(inputsMap.find(expr) != inputsMap.end(), "An input node (id={}) was not mapped??", expr->getId()); + if (mappedInputSet.find(expr) != mappedInputSet.end()) + seenMappedInputs.insert(expr); + } + for (auto e : mappedInputSet) + if (seenMappedInputs.find(e) == seenMappedInputs.end()) { + LOG(info, "WARNING: Input {} not consumed in input graph", nameOverrides[e]); + nodesForward_.push_back(e); + } + //ABORT_IF(seenMappedInputs.find(e) == seenMappedInputs.end(), "Input node {} not found in input graph??", nameOverrides[e]); + + // output set -- these nodes are exported differently + std::set outputsSet; + for (const auto& outputDef : outputDefs) + outputsSet.insert(inputsMap(outputDef.second)); + + std::vector inputsParamsAndConstants; // parameters and constants all are considered inputs, just with initializers + + // Create a the nodes -> array of NodeProto + std::vector nodes; + std::vector initializers; // constants are inputs with initializers that hold their values. They go here. + std::vector shapeInfos; // expected shapes of operations (for diagnostics only) + std::vector outputs; // outputs' shapes + for(const auto& expr : nodesForward_) { + //LOG(info, "exporting node name {} op {} ({})", getExprName(expr), E::mapExprOp(expr), expr->children().size()); + if (expr->type() == "param" || + (expr->type() == "const" && expr->name().find("opRandomUniform_") != 0)) { // leaves are not nodes in ONNX (except for the uniform placeholder @HACKHACK 2) + //LOG(info, "exporting leaf name {} op {} ({})", getExprName(expr), E::mapExprOp(expr), expr->children().size()); + auto shape = getExprShape(expr); + inputsParamsAndConstants.push_back(makeValueInfoProto(getExprName(expr, nameOverrides), getExprDataType(expr), shape, sentinelDim)); + // don't create an initializers entry for inputs + if (std::any_of(inputsMap.begin(), inputsMap.end(), [&](const std::pair& inputMap) { + return inputMap.second == expr; + })) { // skip designated inputs + ABORT_IF(expr->type() != "const", "Data inputs must be 'const' nodes"); + //LOG(info, "No initializer for data-input node {}", getExprName(expr)); + continue; + } + // run initializers, to realize value of consts (params already got theirs) + expr->allocate(); + expr->init(); + expr->forward(); + ABORT_IF(!expr->val(), "Leaf '{}' of type {} unexpectedly lacks a value despite trying really hard", expr->name(), expr->type()); + initializers.push_back(makeExprTensorProto(expr, nameOverrides)); + continue; // parameters must become initializers, name=input name + } + addExprNode(expr, nodes, inputsParamsAndConstants, initializers, nameOverrides, inputsMap, sentinelDim); + logNode(nodes.back(), getExprShape(expr), sentinelDim); + + auto valueInfo = makeValueInfoProto(nodes.back().name(), getExprDataType(expr), getExprShape(expr), sentinelDim); + if (outputsSet.find(expr) != outputsSet.end()) + outputs.push_back(valueInfo); + //else // we add expected-shape information, to more easily be able to track down where it may fail + // shapeInfos.push_back(valueInfo); + } + + //LOG(info, "total nodes: {}, incl. {} inputs, {} op shapes", nodesForward_.size(), inputs.size(), shapeInfos.size()); + + // @TODO: write a log message with the inputs and output names (the function signature) + + // Create the graph -> GraphProto + auto graphDef = makeGraph(nodes, graphName, inputsParamsAndConstants, outputs, initializers, shapeInfos); + + // Create the model -> ModelProto + auto modelDef = makeModel(graphDef, /*producer_name=*/"Marian " + buildVersion()); + + // save it + auto filename = fileRoot + "." + graphName + ".onnx"; + auto s = modelDef.SerializeAsString(); + ABORT_IF(s.empty(), "Failed to serialize ONNX graph to string buffer", filename); + std::ofstream o(filename, std::ios::binary); + ABORT_IF(o.fail(), "Failed to create ONNX model file {}", filename); + o.write(s.data(), s.size()); + o.close(); + ABORT_IF(o.fail(), "Failed to write ONNX model to {}", filename); + LOG(info, "[onnx] ONNX graph '{}' written to {}", graphName, filename); + } + + // tape has been destroyed many times, so clear it for good + nodesForward_.clear(); + } + + Expr ExpressionGraphONNXExporter::tryFindForwardNodeByName(const std::string& nodeName) const { + auto iter = std::find_if(nodesForward_.begin(), nodesForward_.end(), [&](Expr node) {return node->name() == nodeName; }); + if (iter == nodesForward_.end()) + return nullptr; + else + return *iter; + } + +} // namespace marian + +#endif // USE_ONNX + diff --git a/src/onnx/protobuf.cpp b/src/onnx/protobuf.cpp new file mode 100644 index 000000000..a0a38a104 --- /dev/null +++ b/src/onnx/protobuf.cpp @@ -0,0 +1,123 @@ +// This builds the runtime library for protobuf on Windows (on Linux it is installed in the OS). +// We include all CPP files from this CPP file. This way, we can find them via the include-path mechanism. +// You will need to set an environment variable PROTOBUF_RUNTIME_INC so that %PROTOBUF_RUNTIME_INC%\google\protobuf exists. + +#ifdef _MSC_VER +#ifdef USE_ONNX +// note: some of the below is the result of trial-and-error, not necessarily the minimal set +#include "google/protobuf/stubs/common.cc" +#include "google/protobuf/port_undef.inc" +#undef max +#include "google/protobuf/stubs/bytestream.cc" +#include "google/protobuf/stubs/int128.cc" +#include "google/protobuf/stubs/status.cc" +#include "google/protobuf/port_undef.inc" +#undef max +#include "google/protobuf/stubs/statusor.cc" +#undef min +#include "google/protobuf/stubs/stringpiece.cc" +#include "google/protobuf/stubs/stringprintf.cc" +#include "google/protobuf/stubs/structurally_valid.cc" +namespace google { namespace protobuf { const auto LOGLEVEL_0 = LogLevel::LOGLEVEL_INFO; } } +#include "google/protobuf/stubs/strutil.cc" +#include "google/protobuf/stubs/substitute.cc" +#undef GetCurrentTime +#include "google/protobuf/stubs/time.cc" + +#include "google/protobuf/io/coded_stream.cc" +#include "google/protobuf/io/gzip_stream.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/io/io_win32.cc" +#include "google/protobuf/io/printer.cc" +#include "google/protobuf/io/strtod.cc" +#include "google/protobuf/io/tokenizer.cc" +#include "google/protobuf/io/zero_copy_stream.cc" +#include "google/protobuf/io/zero_copy_stream_impl.cc" +#include "google/protobuf/io/zero_copy_stream_impl_lite.cc" + +#include "google/protobuf/any.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/any.pb.cc" +#include "google/protobuf/any_lite.cc" +#define schemas schemas1 +#define file_default_instances file_default_instances1 +#include "google/protobuf/api.pb.cc" +#include "google/protobuf/arena.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/descriptor.cc" +#include "google/protobuf/port_undef.inc" +#define schemas schemas2 +#define file_default_instances file_default_instances2 +#include "google/protobuf/descriptor.pb.cc" +#include "google/protobuf/descriptor_database.cc" +#define schemas schemas3 +#define file_default_instances file_default_instances3 +#include "google/protobuf/duration.pb.cc" +#include "google/protobuf/dynamic_message.cc" +#define schemas schemas4 +#define file_default_instances file_default_instances4 +#include "google/protobuf/empty.pb.cc" +#include "google/protobuf/extension_set_heavy.cc" +#include "google/protobuf/port_undef.inc" +#define cpp_type cpp_type1 +#define real_type real_type1 +#include "google/protobuf/extension_set.cc" +#undef real_type1 +#undef cpp_type +#include "google/protobuf/port_undef.inc" +#define schemas schemas5 +#define file_default_instances file_default_instances5 +#include "google/protobuf/field_mask.pb.cc" +#include "google/protobuf/generated_enum_util.cc" +#define IsMapFieldInApi IsMapFieldInApi1 +#undef schemas +#include "google/protobuf/generated_message_reflection.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/generated_message_table_driven.cc" +#define MutableUnknownFields MutableUnknownFields1 +#include "google/protobuf/generated_message_table_driven_lite.cc" +#undef MutableUnknownFields +#include "google/protobuf/generated_message_util.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/implicit_weak_message.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/map_field.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/message.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/message_lite.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/parse_context.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/reflection_ops.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/repeated_field.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/service.cc" +#define schemas schemas6 +#define file_default_instances file_default_instances6 +#define schemas schemas7 +#define file_default_instances file_default_instances7 +#include "google/protobuf/source_context.pb.cc" +#define schemas schemas8 +#define file_default_instances file_default_instances8 +#include "google/protobuf/struct.pb.cc" +#include "google/protobuf/text_format.cc" +#include "google/protobuf/port_undef.inc" +#define schemas schemas9 +#define file_default_instances file_default_instances9 +#include "google/protobuf/timestamp.pb.cc" +#define schemas schemasa +#define file_default_instances file_default_instancesa +#include "google/protobuf/type.pb.cc" +#include "google/protobuf/unknown_field_set.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/wire_format.cc" +#include "google/protobuf/port_undef.inc" +#include "google/protobuf/wire_format_lite.cc" +#include "google/protobuf/port_undef.inc" +#define schemas schemasb +#define file_default_instances file_default_instancesb +#include "google/protobuf/wrappers.pb.cc" +#endif +#endif diff --git a/src/rescorer/rescorer.h b/src/rescorer/rescorer.h old mode 100755 new mode 100644 diff --git a/src/rescorer/score_collector.cpp b/src/rescorer/score_collector.cpp index 1577feba7..32903a39e 100644 --- a/src/rescorer/score_collector.cpp +++ b/src/rescorer/score_collector.cpp @@ -74,7 +74,6 @@ std::string ScoreCollector::getAlignment(const data::SoftAlignment& align) { } else { ABORT("Unrecognized word alignment type"); } - return ""; } ScoreCollectorNBest::ScoreCollectorNBest(const Ptr& options) diff --git a/src/tensors/cpu/fbgemm/expanded_gemm.h b/src/tensors/cpu/fbgemm/expanded_gemm.h index 38c543c75..a5c93f6bc 100644 --- a/src/tensors/cpu/fbgemm/expanded_gemm.h +++ b/src/tensors/cpu/fbgemm/expanded_gemm.h @@ -97,7 +97,7 @@ struct FbgemmPacked16PackNodeOp : public UnaryNodeOp { const std::string type() override { return "packMatFp16"; } - Shape newShape(Expr MAYBE_UNUSED a, bool MAYBE_UNUSED transpose) { + Shape newShape(Expr a, bool transpose) { #if USE_FBGEMM auto shapeMat = a->shape(); // Should be 2D - weight matrix @@ -118,6 +118,7 @@ struct FbgemmPacked16PackNodeOp : public UnaryNodeOp { Shape outShape({(int)packsize_}); return outShape; #else + a; transpose; ABORT("Packed GEMM requires a build with USE_FBGEMM enabled"); return Shape(); #endif // USE_FBGEMM diff --git a/src/tensors/cpu/fbgemm/packed_gemm.cpp b/src/tensors/cpu/fbgemm/packed_gemm.cpp index 064c3c2be..65dca1f70 100644 --- a/src/tensors/cpu/fbgemm/packed_gemm.cpp +++ b/src/tensors/cpu/fbgemm/packed_gemm.cpp @@ -251,7 +251,7 @@ void fbgemmPacked8PackInfo(const marian::Shape& shape, // This function computes the offset values for each column which are used for compensating the remainders of quantized values // More detailed math is avilable in the FBGEMM's blog - https://engineering.fb.com/ml-applications/fbgemm/ -inline void col_offsets_with_zero_pt_s8acc32( +inline void colOffsetsWithZeroPtS8acc32( bool transpose, int K, int N, @@ -355,8 +355,8 @@ void fbgemmPacked8Pack(marian::Tensor out, int len = k * n; // 1. collect stats for each column - float* bqScale = new float[n]; - int32_t* bqZeropoint = new int32_t[n]; + float* quantScaleB = new float[n]; + int32_t* quantZeropointB = new int32_t[n]; const float* data = inData; float val = 0; @@ -367,8 +367,8 @@ void fbgemmPacked8Pack(marian::Tensor out, // This routine compute the quantization range for each column - either one of min/max range or quantRangeStdDevs sigma range. for (size_t jj = 0; jj < n; jj++) { // for each column, collect stats (min/max or mean/std.dev.) - float min = std::numeric_limits::max(), max = std::numeric_limits::min(); - double mean = 0, sqrsum = 0; + float min = std::numeric_limits::max(), max = std::numeric_limits::lowest(); + double mean = 0, sqrSum = 0; for (size_t ii = 0; ii < k; ii++) { // in a column, go throuhg all the rows and collect stats val = getVal2dArr(data, ii, jj, k, n, transpose); // If quantRangeStdDevs is 0.f, min/max values of the columns is used as a quantization range @@ -380,22 +380,22 @@ void fbgemmPacked8Pack(marian::Tensor out, } else { // Quantize by std.dev. range mean += val; - sqrsum += val * val; + sqrSum += val * val; } } // If a quantization range (in multiples of std. dev.) is given with a non-zero value, // it calculate the range for this column (different quantization scale/offset are used for each column) if(quantRangeStdDevs != 0.f) { mean /= k; - sqrsum /= k; - sqrsum -= mean * mean; - sqrsum = sqrt(sqrsum); - min = (float)(mean - quantRangeStdDevs * sqrsum); - max = (float)(mean + quantRangeStdDevs * sqrsum); + sqrSum /= k; + sqrSum -= mean * mean; + sqrSum = sqrt(sqrSum); + min = (float)(mean - quantRangeStdDevs * sqrSum); + max = (float)(mean + quantRangeStdDevs * sqrSum); } // based on the quantization range, this computes the scale and offset for the quantization - bqScale[jj] = (max - min) / quantizedRange; - bqZeropoint[jj] = (int32_t)(quantizedMax - max / bqScale[jj]); + quantScaleB[jj] = (max - min) / quantizedRange; + quantZeropointB[jj] = (int32_t)(quantizedMax - max / quantScaleB[jj]); } // 2. quantize @@ -408,8 +408,8 @@ void fbgemmPacked8Pack(marian::Tensor out, #endif for (int jj = 0; jj < n; jj++) { TensorQuantizationParams bQuantParam; - bQuantParam.scale = bqScale[jj]; - bQuantParam.zero_point = bqZeropoint[jj]; + bQuantParam.scale = quantScaleB[jj]; + bQuantParam.zero_point = quantZeropointB[jj]; bQuantParam.precision = 7; // Use half of the quantization range to prevent overflow of VPMADDUBSW if (transpose) @@ -422,13 +422,13 @@ void fbgemmPacked8Pack(marian::Tensor out, } // 3. compute column offsets - int32_t* col_offsets = new int32_t[n]; - col_offsets_with_zero_pt_s8acc32(transpose, k, n, quantized, bqZeropoint, col_offsets, 1); + int32_t* colOffsets = new int32_t[n]; + colOffsetsWithZeroPtS8acc32(transpose, k, n, quantized, quantZeropointB, colOffsets, 1); - int8_t* packedbuf = out->data(); + int8_t* packedBuf = out->data(); for(auto i = 0; i < packsize; i++) { - packedbuf[i] = 0; + packedBuf[i] = 0; } // 4. packing @@ -436,23 +436,23 @@ void fbgemmPacked8Pack(marian::Tensor out, PackBMatrix packedBN( transpose ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, - nrow, ncol, quantized, transpose ? nrow : ncol, packedbuf, 1, params); + nrow, ncol, quantized, transpose ? nrow : ncol, packedBuf, 1, params); // copy quantization scale - memcpy(packedbuf + (packsize - n * (sizeof(float) + sizeof(int32_t) + sizeof(int32_t))), bqScale, n * sizeof(float)); + memcpy(packedBuf + (packsize - n * (sizeof(float) + sizeof(int32_t) + sizeof(int32_t))), quantScaleB, n * sizeof(float)); // copy quantization offset - memcpy(packedbuf + (packsize - n * (sizeof(int32_t) + sizeof(int32_t))), bqZeropoint, n * sizeof(int32_t)); + memcpy(packedBuf + (packsize - n * (sizeof(int32_t) + sizeof(int32_t))), quantZeropointB, n * sizeof(int32_t)); // copy column offsets to the memory - memcpy(packedbuf + (packsize - n * sizeof(int32_t)), col_offsets, n * sizeof(int32_t)); + memcpy(packedBuf + (packsize - n * sizeof(int32_t)), colOffsets, n * sizeof(int32_t)); #ifdef _MSC_VER _aligned_free(quantized); #else free(quantized); #endif - delete[] col_offsets; - delete[] bqScale; - delete[] bqZeropoint; + delete[] colOffsets; + delete[] quantScaleB; + delete[] quantZeropointB; } // GEMM operation on the packed B matrix @@ -549,73 +549,93 @@ void fbgemmPacked8Gemm(marian::Tensor C, const fbgemm::BlockingFactors* params = getBlockingFactors(packType); - if((packType == Type::packed8avx2 && fbgemmHasAvx512Support()) - || (packType == Type::packed8avx512 && !fbgemmHasAvx512Support())) { + // Check if the packed format matches with the available AVX instruction set in the machine + const bool avx512Support = fbgemmHasAvx512Support(); + if((packType == Type::packed8avx2 && avx512Support) + || (packType == Type::packed8avx512 && !avx512Support)) { ABORT("FBGEMM doesn't allow to use {} packing order on {} CPUs", packType == Type::packed8avx2 ? "AVX2" : "AVX512", - fbgemmHasAvx512Support() ? "AVX512" : "AVX2"); + avx512Support ? "AVX512" : "AVX2"); } // compute range to quantize A (activations) - (min/max quantization) - float min_est = std::numeric_limits::max(), max_est = std::numeric_limits::min(); + float minA = std::numeric_limits::max(), maxA = std::numeric_limits::lowest(); - int elem = A->shape().elements(); - float* data = A->data(); + int elemA = A->shape().elements(); + float* dataA = A->data(); // AVX based find min/max - FindMinMax(data, &min_est, &max_est, elem); - - float ascale = (max_est - min_est) / 255; - int32_t azeropoint = (int32_t)(255 - max_est / ascale); - - std::vector row_offset_buf(PackAWithQuantRowOffset::rowOffsetBufferSize()); - PackAWithQuantRowOffset packAN( + FindMinMax(dataA, &minA, &maxA, elemA); + + float quantScaleA = (maxA - minA) / 255; + int32_t quantZeropointA = (int32_t)(255 - maxA / quantScaleA); + + // To avoid any repeated memory allocation and deallocation, make the scratch buffer variables static thread_local + // In a multi-threaded situation, heap access lock for the memory allocation/free could + // makes all the threads are blocked by each other. (heap contention) + const size_t sizeBufA = params->KCB * params->MCB; + static thread_local std::vector packedBufA; + if (packedBufA.size() < sizeBufA) + packedBufA.resize(sizeBufA); + const size_t sizeRowOffsetBufA = PackAWithQuantRowOffset::rowOffsetBufferSize(); + static thread_local std::vector rowOffsetBufA; + if (rowOffsetBufA.size() < sizeRowOffsetBufA) + rowOffsetBufA.resize(sizeRowOffsetBufA); + + PackAWithQuantRowOffset packA( transA ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, (int32_t)(transA ? k : m), (int32_t)(transA ? m : k), A->data(), (int32_t)(transA ? m : k), - nullptr, /*buffer for packed matrix*/ - ascale, - azeropoint, + // buffer for packed matrix, pass a pre-allocated memory to avoid additional allocation/deallocation inside fbgemm + packedBufA.data(), + quantScaleA, + quantZeropointA, 1, /*groups*/ - row_offset_buf.data(), + rowOffsetBufA.data(), params); // packed matrix size of B - int bPackSize = PackMatrix, int8_t>::packedBufferSize((int32_t)k, (int32_t)n); + int packSizeB = PackMatrix, int8_t>::packedBufferSize((int32_t)k, (int32_t)n); // retrieve B matrix - int8_t* bdata = B->data(); - float* bqScale = new float[n]; - memcpy(bqScale, bdata + bPackSize, n * sizeof(float)); - - int32_t* bqZeropoint = new int32_t[n]; - memcpy(bqZeropoint, bdata + bPackSize + n * sizeof(float), n * sizeof(int32_t)); - - int32_t* col_offsets = new int32_t[n]; - memcpy(col_offsets, bdata + bPackSize + n * (sizeof(float) + sizeof(int32_t)), n * sizeof(int32_t)); + int8_t* dataB = B->data(); + + // To avoid any repeated memory allocation and deallocation, make the scratch buffer variables static thread_local + // In a multi-threaded situation, heap access lock for the memory allocation/free could + // makes all the threads are blocked by each other. (heap contention) + static thread_local std::vector quantScaleB; + if (quantScaleB.size() < n) + quantScaleB.resize(n); + memcpy(quantScaleB.data(), dataB + packSizeB, n * sizeof(float)); + + static thread_local std::vector quantZeropointB; + if (quantZeropointB.size() < n) + quantZeropointB.resize(n); + memcpy(quantZeropointB.data(), dataB + packSizeB + n * sizeof(float), n * sizeof(int32_t)); + + static thread_local std::vector colOffsetsB; + if (colOffsetsB.size() < n) + colOffsetsB.resize(n); + memcpy(colOffsetsB.data(), dataB + packSizeB + n * (sizeof(float) + sizeof(int32_t)), n * sizeof(int32_t)); DoNothing doNothingObj{}; ReQuantizeForFloat outputProcObj( doNothingObj, - ascale, - bqScale, - azeropoint, - bqZeropoint, - packAN.getRowOffsetBuffer(), - col_offsets, + quantScaleA, + quantScaleB.data(), + quantZeropointA, + quantZeropointB.data(), + packA.getRowOffsetBuffer(), + colOffsetsB.data(), nullptr, (std::uint32_t) n); - PackBMatrix repackedBN( - transB ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, (int32_t) k, (int32_t) n, bdata, (int32_t) (transB ? k : n), 1, params); + PackBMatrix repackedB( + transB ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, (int32_t) k, (int32_t) n, dataB, (int32_t) (transB ? k : n), 1, params); // gemm computation - fbgemmPacked(packAN, repackedBN, C->data(), (int32_t*)C->data(), (int32_t) n, outputProcObj, 0, 1, params); - - delete[] col_offsets; - delete[] bqZeropoint; - delete[] bqScale; + fbgemmPacked(packA, repackedB, C->data(), (int32_t*)C->data(), (int32_t) n, outputProcObj, 0, 1, params); } #endif // USE_FBGEMM diff --git a/src/tensors/cpu/int16.h b/src/tensors/cpu/int16.h index abb465b6d..f2bdd0a91 100644 --- a/src/tensors/cpu/int16.h +++ b/src/tensors/cpu/int16.h @@ -19,7 +19,6 @@ struct QuantizeNodeOp : public UnaryNodeOp { NodeOps backwardOps() override { ABORT("Only used for inference"); - return {NodeOp(0)}; } const std::string type() override { return "quantizeInt16"; } @@ -54,7 +53,6 @@ class DotNodeOp : public NaryNodeOp { NodeOps backwardOps() override { ABORT("Only used for inference"); - return {NodeOp(0)}; } const std::string type() override { return "dotInt16"; } @@ -92,7 +90,6 @@ class AffineNodeOp : public NaryNodeOp { NodeOps backwardOps() override { ABORT("Only used for inference"); - return {NodeOp(0)}; } const std::string type() override { return "affineInt16"; } diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp index 9bdedd545..9a5d233dc 100755 --- a/src/tensors/cpu/prod.cpp +++ b/src/tensors/cpu/prod.cpp @@ -7,51 +7,13 @@ #include "tensors/tensor.h" #include "tensors/tensor_allocator.h" -#if MKL_FOUND -#include -#else -#if BLAS_FOUND -#include -#endif -#endif - +#include "prod_blas.h" #include "sharp/int_gemm.h" namespace marian { namespace cpu { -#if BLAS_FOUND -inline void sgemm(bool transA, - bool transB, - int rows_a, - int rows_b, - int width, - float alpha, - float* a, - int lda, - float* b, - int ldb, - float beta, - float* c, - int ldc) { - cblas_sgemm(CblasRowMajor, - transA ? CblasTrans : CblasNoTrans, - transB ? CblasTrans : CblasNoTrans, - rows_a, - rows_b, - width, - alpha, - a, - lda, - b, - ldb, - beta, - c, - ldc); -} -#endif - void Prod(marian::Tensor C, const marian::Tensor& A, const marian::Tensor& B, @@ -134,7 +96,7 @@ void ProdBatched(marian::Tensor C, auto strideC = n * m; auto batchC = std::max(batchA, batchB); -#if MKL_FOUND +#if 0 // TODO Accuracy regression. Batched GEMM generate different output. Investigating and disable for now. CBLAS_TRANSPOSE transA_forarr = CblasNoTrans; CBLAS_TRANSPOSE transB_forarr = CblasNoTrans; diff --git a/src/tensors/cpu/prod_blas.h b/src/tensors/cpu/prod_blas.h new file mode 100644 index 000000000..9f3080efa --- /dev/null +++ b/src/tensors/cpu/prod_blas.h @@ -0,0 +1,38 @@ +#if MKL_FOUND +#include +#else +#if BLAS_FOUND +#include +#endif +#endif + +#if BLAS_FOUND +inline void sgemm(bool transA, + bool transB, + int rows_a, + int rows_b, + int width, + float alpha, + float* a, + int lda, + float* b, + int ldb, + float beta, + float* c, + int ldc) { + cblas_sgemm(CblasRowMajor, + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + rows_a, + rows_b, + width, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc); +} +#endif \ No newline at end of file diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 5aa695204..5f286f74c 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -27,11 +27,18 @@ namespace cpu { template void CopyCastTo(To* out, const From* in, int length) { for(int i = 0; i < length; ++i) - out[i] = in[i]; +#ifdef _MSC_VER +#pragma warning (push) +#pragma warning (disable: 4244) // 'argument': conversion from 'const From' to 'float', possible loss of data +#endif + out[i] = (To)in[i]; +#ifdef _MSC_VER +#pragma warning (pop) +#endif } // Casting has been factored into two functions "CopyCastFrom" and -// "CopyCastTo". This only serves the purpuse to autmatically create +// "CopyCastTo". This only serves the purpuse to automatically create // the full Carthesian product of possible type cast via template magic. // Extending CopyCast and CopyCastFrom with a new branch in the "if" clause // adds all possible variants. @@ -52,6 +59,8 @@ void CopyCast(Tensor out, const Tensor in) { CopyCastFrom(out, in->data(), (int)in->size()); } else if(in->type() == Type::float16) { CopyCastFrom(out, in->data(), (int)in->size()); + } else if(in->type() == Type::uint32) { + CopyCastFrom(out, in->data(), (int)in->size()); } else { ABORT("CopyCastFrom from type {} not implemented", in->type()); } @@ -645,12 +654,15 @@ void PasteCols(Tensor out_, } } +#if 0 // this version seems to actually be buggy, but also not used in decoding? // Optimized version of Select for axis=2 // @TODO: make this generally fast without this special version void SelectAxis2(Tensor out, const Tensor in, const Tensor indices) { + std::cerr << indices->debug() << std::endl; + matchOrAbort(indices->type()); functional::Shape outShape = out->shape(); @@ -675,6 +687,7 @@ void SelectAxis2(Tensor out, } } } +#endif void Select(Tensor out, const Tensor in, @@ -692,8 +705,10 @@ void Select(Tensor out, functional::Array dims; int axisCPU = (int)(axis + functional::Shape::size() - out->shape().size()); +#if 0 // buggy but not really used? if(axisCPU == 2 && outShape == idxShape) // specialization for axis==2 when there is no broadcasting, @TODO to be removed once we have a faster implementation below return SelectAxis2(out, in, indices); +#endif for(int index = 0; index < length; ++index) { outShape.dims(index, dims); // compute dimension-based indices from global index; @@ -1017,19 +1032,15 @@ void AttBack(Tensor gVa_, } } -void LayerNormalization(Tensor out_, - Tensor in_, - Tensor gamma_, - Tensor beta_, - float eps) { - float* out = out_->data(); - const float* in = in_->data(); - const float* alpha = gamma_->data(); - const float* beta = beta_ ? beta_->data() : nullptr; - - int rows = in_->shape().elements() / in_->shape().back(); - int cols = in_->shape().back(); - +MARIAN_FFAST_MATH_BEGIN +template +void LayerNormalizationImpl(float* out, + const float* in, + const float* alpha, + const float* beta, + float eps, + int rows, + int cols) { #pragma omp parallel for for(int j = 0; j < rows; ++j) { float* so = out + j * cols; @@ -1053,16 +1064,55 @@ void LayerNormalization(Tensor out_, #pragma omp simd for(int i = 0; i < cols; ++i) { - float t = alpha[i] * ((sp[i] - mean) / sigma); - if(beta != nullptr) { - t += beta[i]; - } + float t = alpha[alphaStride * i] * ((sp[i] - mean) / sigma); + if(hasBeta) + t += beta[betaStride * i]; so[i] = t; } } } +MARIAN_FFAST_MATH_END + +template +inline void LayerNormalizationDispatchBeta(float* out, + const float* in, + const float* alpha, + Tensor beta, + float eps, + int rows, + int cols) { + if (beta) { + if (beta->shape().back() > 1) { + LayerNormalizationImpl(out, in, alpha, beta->data(), eps, rows, cols); + } else { + LayerNormalizationImpl(out, in, alpha, beta->data(), eps, rows, cols); + } + } else { + LayerNormalizationImpl(out, in, alpha, nullptr, eps, rows, cols); + } +} + +void LayerNormalization(Tensor out_, + Tensor in_, + Tensor gamma_, + Tensor beta, + float eps) { + float* out = out_->data(); + const float* in = in_->data(); + const float* alpha = gamma_->data(); + const int alphaStride = gamma_->shape().back() > 1; // broadcasting for alpha and beta + int rows = in_->shape().elements() / in_->shape().back(); + int cols = in_->shape().back(); + if (alphaStride == 0) { + LayerNormalizationDispatchBeta<0>(out, in, alpha, beta, eps, rows, cols); + } else { + LayerNormalizationDispatchBeta<1>(out, in, alpha, beta, eps, rows, cols); + } +} + +MARIAN_FFAST_MATH_BEGIN void LayerNormalizationGrad(Tensor gradX_, Tensor gradGamma_, Tensor gradBeta_, @@ -1080,6 +1130,10 @@ void LayerNormalizationGrad(Tensor gradX_, float* x = x_->data(); float* gamma = gamma_->data(); float* beta = beta_ ? beta_->data() : nullptr; + // @TODO: The CPU implementation supports scalar gamma and beta. This is a left-over, + // we should enable that in the GPU version as well. + const int gammaStride = gamma_->shape().back() > 1; // broadcasting for alpha and beta. 0 means it's a scalar + const int betaStride = beta_ && beta_->shape().back() > 1; size_t rows = y_->shape().elements() / y_->shape()[-1]; size_t cols = y_->shape()[-1]; @@ -1100,7 +1154,7 @@ void LayerNormalizationGrad(Tensor gradX_, #pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj) for(size_t i = 0; i < cols; ++i) { sum_x += xRow[i]; - sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[i] : 0.f)) / gamma[i]; + sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i]; sum_adj += adjRow[i]; } @@ -1115,15 +1169,15 @@ void LayerNormalizationGrad(Tensor gradX_, #pragma omp simd for(size_t i = 0; i < cols; ++i) { float grad_x = 0.f; - float x_hat = (yRow[i] - beta[i]) / gamma[i]; + float x_hat = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i]; grad_x += cols * adjRow[i]; grad_x -= sum_adj; grad_x -= sum_adj_x * x_hat; grad_x /= cols * sigma; - gradXRow[i] += gamma[i] * grad_x; - gradGamma[i] += adjRow[i] * x_hat; - gradBeta[i] += adjRow[i]; + gradXRow[i] += gamma[gammaStride * i] * grad_x; + gradGamma[gammaStride * i] += adjRow[i] * x_hat; + gradBeta[betaStride * i] += adjRow[i]; } } } else { @@ -1142,7 +1196,8 @@ void LayerNormalizationGrad(Tensor gradX_, #pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj) for(size_t i = 0; i < cols; ++i) { sum_x += xRow[i]; - sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[i] : 0.f)) / gamma[i]; + sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i]; + // @TODO: beta is NULL here ^^ sum_adj += adjRow[i]; } @@ -1157,25 +1212,26 @@ void LayerNormalizationGrad(Tensor gradX_, #pragma omp simd for(size_t i = 0; i < cols; ++i) { float grad_x = 0.f; - float x_hat = yRow[i] / gamma[i]; + float x_hat = yRow[i] / gamma[gammaStride * i]; grad_x += cols * adjRow[i]; grad_x -= sum_adj; grad_x -= sum_adj_x * x_hat; grad_x /= cols * sigma; - gradXRow[i] += gamma[i] * grad_x; - gradGamma[i] += adjRow[i] * x_hat; + gradXRow[i] += gamma[gammaStride * i] * grad_x; + gradGamma[gammaStride * i] += adjRow[i] * x_hat; } } } } +MARIAN_FFAST_MATH_END void Shift(Tensor out_, Tensor in_, marian::Shape shift, float padValue, bool invert) { - int offset = 0; + int offset = 0; // out[i + offset] = in[i]; shift>0 inserts values at front, shifts back, pushes out for(int i = 0; i < shift.size(); ++i) offset += in_->shape().stride(i) * shift[i]; diff --git a/src/tensors/cpu/topk.cpp b/src/tensors/cpu/topk.cpp new file mode 100644 index 000000000..92dcba591 --- /dev/null +++ b/src/tensors/cpu/topk.cpp @@ -0,0 +1,54 @@ +#include "tensors/tensor_operators.h" +#include "tensors/allocator.h" +#include + +// CPU implementation of proper Marian top-k operator for TopkNodeOp +// This file contains a lot of code-duplicaton with src/translator/nth_element.cpp +// the goal is to replace the beam-search specific topk search with this code. +// Currently this is only used in the unit tests, but we will move forward and +// make the beam-search more graph and operator-based. + +namespace marian { +namespace cpu { + +void TopK(Tensor outVal, Tensor outInd, Ptr /*allocator*/, const Tensor in, int k, int axis, bool descending) { + + ABORT_IF(axis != in->shape().size() - 1, "Currently only works for last axis"); + ABORT_IF(in->type() != Type::float32, "Input should have type {}", Type::float32); + ABORT_IF(outInd->type() != Type::uint32, "Output should be have type {}", Type::uint32); + + int cols = in->shape()[axis]; + int rows = in->shape().elements() / cols; + + ABORT_IF(k > cols, "Cannot select more than {} elements for axis {}", cols, axis); + + std::vector idxs(cols); + std::iota(idxs.begin(), idxs.end(), 0); + + const float* inDataPtr = in->data(); + IndexType* outIndPtr = outInd->data(); + float* outValPtr = outVal->data(); + for(int i = 0; i < rows; ++i) { + std::partial_sort( + // sorts the top N (beam size) idxs by score to the front + idxs.begin(), + idxs.begin() + k, + idxs.end(), + [&](int a, int b) { + return descending ? inDataPtr[a] > inDataPtr[b] : inDataPtr[a] < inDataPtr[b]; + } + ); + + for(int j = 0; j < k; j++) { + outIndPtr[j] = idxs[j]; + outValPtr[j] = inDataPtr[idxs[j]]; + } + + outIndPtr += k; + outValPtr += k; + inDataPtr += cols; + } +} + +} +} diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc index 98723b9d0..64093253a 100755 --- a/src/tensors/gpu/add.inc +++ b/src/tensors/gpu/add.inc @@ -32,4 +32,7 @@ template void Add, template void Add, UnaryFunctor, Assignee<3>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor, UnaryFunctor, Assignee<3>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); template void Add, Assignee<2>>, Assignee<3>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor, Assignee<2>>, Assignee<3>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); template void Add, Capture>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor, Capture>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor); -template void Add, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); \ No newline at end of file +template void Add, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); +template void marian::gpu::Add >, marian::functional::Assignee<2> >, IntrusivePtr, IntrusivePtr >(marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); +template void marian::gpu::Add, marian::functional::Assignee<2> > >, IntrusivePtr, IntrusivePtr >(marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); +template void marian::gpu::Add,marian::functional::UnaryFunctor > >,class IntrusivePtr,class IntrusivePtr >(marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,class IntrusivePtr,class IntrusivePtr,class IntrusivePtr); diff --git a/src/tensors/gpu/add_all.cu b/src/tensors/gpu/add_all.cu index ad3ac2526..bc78709a3 100644 --- a/src/tensors/gpu/add_all.cu +++ b/src/tensors/gpu/add_all.cu @@ -18,7 +18,7 @@ void AggregateAllVar(Ptr allocator, AccType aggInit, AggFunctor aggFunctor, AccType scale, - Tensor out, + marian::Tensor out, const Tensors... tensors) { cudaSetDevice(out->getDeviceId().no); @@ -34,7 +34,7 @@ void AggregateAllVar(Ptr allocator, // The all_reduce kernel by nivida needs to perform multiple passes if the number of blocks needed to perform the reduction is larger than 1. // Here we allocate the memory for the intermediate reductions for each block. - Tensor blockMem; + marian::Tensor blockMem; if(blocks > 1 || out->type() != typeId()) { // if the out tensor does not have elementType AccType we need to allocate and convert later MemoryPiece::PtrType temporaryMemory; if(allocator) { @@ -45,7 +45,7 @@ void AggregateAllVar(Ptr allocator, temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType) * blocks); // @TODO: consider implementing MemoryPiece::cudaMalloc(size) for managed memory } blockMem = TensorBase::New(temporaryMemory, - Shape({blocks}), + marian::Shape({blocks}), typeId(), out->getBackend()); blockMem->set(aggInit); // set temporary memory to aggInit @@ -81,8 +81,8 @@ void AggregateAll(Ptr allocator, AccType aggInit, AggFunctor aggFunctor, AccType scale, - Tensor out, - const Tensor in1) { + marian::Tensor out, + const marian::Tensor in1) { AggregateAllVar(allocator, functor, aggInit, aggFunctor, scale, out, in1); } @@ -92,9 +92,9 @@ void AggregateAll(Ptr allocator, AccType aggInit, AggFunctor aggFunctor, AccType scale, - Tensor out, - const Tensor in1, - const Tensor in2) { + marian::Tensor out, + const marian::Tensor in1, + const marian::Tensor in2) { AggregateAllVar(allocator, functor, aggInit, aggFunctor, scale, out, in1, in2); } @@ -104,10 +104,10 @@ void AggregateAll(Ptr allocator, AccType aggInit, AggFunctor aggFunctor, AccType scale, - Tensor out, - const Tensor in1, - const Tensor in2, - const Tensor in3) { + marian::Tensor out, + const marian::Tensor in1, + const marian::Tensor in2, + const marian::Tensor in3) { AggregateAllVar(allocator, functor, aggInit, aggFunctor, scale, out, in1, in2, in3); } diff --git a/src/tensors/gpu/add_all.inc b/src/tensors/gpu/add_all.inc index 73b0bda97..2147f2607 100644 --- a/src/tensors/gpu/add_all.inc +++ b/src/tensors/gpu/add_all.inc @@ -33,6 +33,9 @@ template void AggregateAll, Capture>, Assignee<2>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, Capture>, Assignee<2>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor); template void AggregateAll, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); template void AggregateAll, Assignee<1>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, Assignee<1>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor); +template void marian::AggregateAll >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); +template void marian::AggregateAll, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); +template void marian::AggregateAll,marian::functional::UnaryFunctor > >,marian::functional::BinaryFunctor,marian::functional::Assignee<2> > >(std::shared_ptr,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,marian::functional::BinaryFunctor,marian::functional::Assignee<2> >,float,IntrusivePtr,IntrusivePtr,IntrusivePtr); #if COMPILE_FP16 template void AggregateAll<__half, float, BinaryFunctor>, Assignee<2>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor>, Assignee<2>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor); @@ -68,4 +71,7 @@ template void AggregateAll<__half, float, BinaryFunctor, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, BinaryFunctor>, BinaryFunctor>>, BinaryFunctor>>>>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, Assignee<1>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor); template void AggregateAll<__half, float, BinaryFunctor, Assignee<1>>, BinaryFunctor, Assignee<2>>>(std::shared_ptr, BinaryFunctor, Assignee<1>>, float, BinaryFunctor, Assignee<2>>, float, marian::Tensor, marian::Tensor); -#endif \ No newline at end of file +template void marian::AggregateAll<__half, float, marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::BinaryFunctor >, marian::functional::Assignee<2> >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); +template void marian::AggregateAll<__half, float, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, marian::functional::BinaryFunctor, marian::functional::Assignee<2> > >(std::shared_ptr, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > >, float, marian::functional::BinaryFunctor, marian::functional::Assignee<2> >, float, IntrusivePtr, IntrusivePtr, IntrusivePtr); +template void marian::AggregateAll<__half,float,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,marian::functional::BinaryFunctor,marian::functional::Assignee<2> > >(std::shared_ptr,marian::functional::BinaryFunctor,marian::functional::UnaryFunctor > >,float,marian::functional::BinaryFunctor,marian::functional::Assignee<2> >,float,IntrusivePtr,IntrusivePtr,IntrusivePtr); +#endif diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc index 00aff3d91..e2e74200d 100755 --- a/src/tensors/gpu/element.inc +++ b/src/tensors/gpu/element.inc @@ -59,6 +59,8 @@ template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > > >, marian::Tensor >(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > > > >, marian::Tensor, marian::Tensor); template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >>(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >, marian::Tensor); template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::BinaryFunctor, marian::functional::Capture>, marian::functional::BinaryFunctor > > > > >, IntrusivePtr >(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::BinaryFunctor, marian::functional::Capture>, marian::functional::BinaryFunctor > > > > >, marian::Tensor, marian::Tensor); +template void marian::gpu::Element, marian::functional::UnaryFunctor > >, IntrusivePtr >(marian::functional::Assign, marian::functional::UnaryFunctor > >, IntrusivePtr, IntrusivePtr); +template void marian::gpu::Element, marian::functional::UnaryFunctor > >, IntrusivePtr >(marian::functional::Assign, marian::functional::UnaryFunctor > >, IntrusivePtr, IntrusivePtr); // How to add new specializations: // When you use a new specialization, it will cause a link error of this form (example): @@ -67,4 +69,3 @@ template void marian::gpu::Element' with 'marian::Tensor' - diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu old mode 100755 new mode 100644 index eefcc405f..90a6bc667 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -139,6 +139,8 @@ void CopyCast(Tensor out, const Tensor in) { #endif } else if(in->type() == Type::float64) { CopyCastFrom(out, in->data(), (int)in->size()); + } else if(in->type() == Type::uint32) { + CopyCastFrom(out, in->data(), (int)in->size()); } else { ABORT("CopyCastFrom from type {} not implemented", in->type()); } @@ -476,6 +478,8 @@ void TransposeND(Tensor out, Tensor in, const std::vector& vAxis) { if(in->type() == Type::float32) { gTranspose0213<<>>(out->data(), in->data(), rows, cols, stride1, stride2); + } else if(in->type() == Type::uint32) { + gTranspose0213<<>>(out->data(), in->data(), rows, cols, stride1, stride2); #if COMPILE_FP16 } else if(in->type() == Type::float16) { gTranspose0213<<>>(out->data(), in->data(), rows, cols, stride1, stride2); @@ -499,6 +503,8 @@ void TransposeND(Tensor out, Tensor in, const std::vector& vAxis) { if(in->type() == Type::float32) { gTransposeND<<>>(out, in, axes); + } else if(in->type() == Type::uint32) { + gTransposeND<<>>(out, in, axes); #if COMPILE_FP16 } else if(in->type() == Type::float16) { gTransposeND<<>>(out, in, axes); @@ -1217,6 +1223,14 @@ void Select(Tensor out, indices->data(), indices->shape()); #endif + } else if(out->type() == Type::uint32) { + gSelect<<>>(out->data(), + out->shape(), + in->data(), + in->shape(), + axisGPU, + indices->data(), + indices->shape()); } else { ABORT("Select not implemented for type {}", out->type()); } diff --git a/src/tensors/gpu/topk.cu b/src/tensors/gpu/topk.cu new file mode 100644 index 000000000..94256fb7a --- /dev/null +++ b/src/tensors/gpu/topk.cu @@ -0,0 +1,382 @@ +#include "tensors/tensor_operators.h" +#include "tensors/gpu/cuda_helpers.h" +#include "tensors/allocator.h" + +#include + +// GPU implementation of proper Marian top-k operator for TopkNodeOp +// This file contains a lot of code-duplicaton with src/translator/nth_element.cu +// the goal is to replace the beam-search specific topk search with this code. +// Currently this is only used in the unit tests, but we will move forward and +// make the beam-search more graph and operator-based. + +namespace marian { +namespace gpu { + +const int MAX_BINS = 500; +const int BLOCK_SIZE = 512; + +#define UNROLL_MAXARG_LOOP(n, max) \ + if(tid < (n) && tid + (n) < (max)) { \ + if(sharedValues[tid + (n)] > sharedValues[tid]) { \ + sharedIndices[tid] = sharedIndices[tid + (n)]; \ + sharedValues[tid] = sharedValues[tid + (n)]; \ + } \ + } + +// finds maximum element (first step) +template +__global__ void gMaxElement(IndexType* binIndices, // out: top-k positions + T* binValues, // out: top-k scores + const T* inValues, // this is the probs array, only one with type float or half + int rows, // we iterate over this many rows, row-major layout + int cols, // a row has that many columns, row-major layout + float minimal, // minimal is the smallest possible value. For simplicity we assume we look for the maxmimum. + bool descending) // This will be the largest possible value if the order is reversed (i.e. we look for the minimum). +{ + extern __shared__ float sharedValues[]; + __shared__ IndexType sharedIndices[BLOCK_SIZE]; + + // id of current thread within block + int tid = threadIdx.x; + + float flip = descending ? 1.f : -1.f; + + // Roll over every row in row-major 2D representation of the data + for(int rowIdx = 0; rowIdx < rows; ++rowIdx) { + int begin = rowIdx * cols; // start index of a row + int end = rowIdx * cols + cols; // end index of a row + + // We look at at most blockDim.x * 2 = 1024 values within a block, i.e. each thread reduces two values. + // Here we set the position to begin + blockId * 1024 + threadId. If a row has more values we + // partition the row according to blocks of 1024 values. + int i = begin + blockIdx.x * (blockDim.x * 2) + tid; + + // Initialize shared values to minimal value. + sharedValues[tid] = minimal; + + // Do first set of comparisons outside loop, saves one iteration. + if(i + blockDim.x < end) { // Are we in a position for which we can access and compare two values in a row partition (shifted by block size)? + // yes, hence compare: + float a = flip * (float)inValues[i]; // value from first half of row parition for this block + float b = flip * (float)inValues[i + blockDim.x]; // value from second half of row partition for this block + if(a > b) { // just a max + sharedIndices[tid] = i; + sharedValues[tid] = a; + } else { + sharedIndices[tid] = i + blockDim.x; + sharedValues[tid] = b; + } + } else if(i < end) { // Are we instead in a position that has access to one value in the row partition (shifting by block size would be out of bounds)? + // Yes, hence save the current value and index as new max, no need to compare. + sharedIndices[tid] = i; + sharedValues[tid] = flip * (float)inValues[i]; + } // nothing else to do here + + // We move to the next set of 1024 values shifted by block size times number of blocks + // and look at two of them according to thread id. + while(i + 2 * gridDim.x * blockDim.x < end) { + i += 2 * gridDim.x * blockDim.x; + + // Check if first value is larger than what we have seen so far + float a = flip * (float)inValues[i]; + if(a > sharedValues[tid]) { + // Yes, hence save index and value + sharedIndices[tid] = i; + sharedValues[tid] = a; + } + + // Check if second value is larger than what we have seen so far + if(i + blockDim.x < end) { + float b = flip * (float)inValues[i + blockDim.x]; + if(b > sharedValues[tid]) { + // Yes, hence save index and value + sharedIndices[tid] = i + blockDim.x; + sharedValues[tid] = b; + } + } + } + + // We are done with the first sweep and have populated shared memory, time to wait for the other threads and reduce it all + __syncthreads(); + + // Reduce over shared memory, here per loop until we hit the last 32 unreduced elements + for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { + if(tid < s && tid + s < end) { + if(sharedValues[tid + s] > sharedValues[tid]) { + // keep the max + sharedIndices[tid] = sharedIndices[tid + s]; + sharedValues[tid] = sharedValues[tid + s]; + } + } + __syncthreads(); + } + + // Reduce over shared memory, here per unrolled code for powers of 2 lower equal 32. + // Because we are at 32 (warp size) the threads run in lock-step and we can abandon syncing. + UNROLL_MAXARG_LOOP(32, end); + UNROLL_MAXARG_LOOP(16, end); + UNROLL_MAXARG_LOOP(8, end); + UNROLL_MAXARG_LOOP(4, end); + UNROLL_MAXARG_LOOP(2, end); + UNROLL_MAXARG_LOOP(1, end); + + // OK, we are done with the reduction and in the first thread + if(tid == 0) { + // assign the final maximal value to the bin, one bin per row and block + binIndices[rowIdx * gridDim.x + blockIdx.x] = sharedIndices[0]; // [rows, num_blocks] + binValues[rowIdx * gridDim.x + blockIdx.x] = sharedValues[0]; // [rows, num_blocks] + } + __syncthreads(); + } +} + +// This runs after the function above, we now have the maximum value per row and block and can look further +// for the top-k results. As above we pretend this does only maximum search. +// This runs restricted to one row (one row per block) +template +__global__ void gMaxElementUpdate(IndexType* binIndices, // memory for bin indices + T* binValues, // memory for bin costs + IndexType* outIndices, // result indices + T* outValues, // result costs + T* inValues, // should work well enough with half, uses float everywhere else + const int cols, // size of continous memory we search over + const int K, // how many top-K elements? + int numBlocks, // number of blocks/bins used in above function (per row) + float minimal, // value for minimal element + bool descending) +{ + extern __shared__ float sharedValues[]; + __shared__ int sharedIndices[BLOCK_SIZE]; + __shared__ float bestBinCost; + __shared__ int bestBinCostIdx; + + const int tid = threadIdx.x; + + float flip = descending ? 1.f : -1.f; + + // we only look at one row in this kernel + const int rowIdx = blockIdx.x; // index of the row corresponds to block index + const int begin = rowIdx * cols; // start offset for this row relative to inValues tensor start + const int end = rowIdx * cols + cols; // end offset for this row relative to inValues tensor start + + int num_bins = numBlocks; // why not just use numBlocks? + + // iterate over top-k results + for(int k = 0; k < K; ++k) { + + int kthOutIdx = rowIdx * K + k; // offset into output tensor relative to outIndices/outValues tensor start + int i = tid; + + sharedValues[tid] = minimal; // initialize to smallest value, everything else will be larger + + // as in the function above, the code here does a tree reduction over shared memory to find the single maximum element + if(i + blockDim.x < num_bins) { + float a = binValues[rowIdx * numBlocks + i]; + float b = binValues[rowIdx * numBlocks + i + blockDim.x]; + if(a > b) { + sharedValues[tid] = a; + sharedIndices[tid] = i; + } else { + sharedValues[tid] = b; + sharedIndices[tid] = i + blockDim.x; + } + } else if(i < num_bins) { + sharedValues[tid] = binValues[rowIdx * numBlocks + i]; + sharedIndices[tid] = i; + } + + while(i + 2 * blockDim.x < num_bins) { + i += 2 * blockDim.x; + + float a = binValues[rowIdx * numBlocks + i]; + if(a > sharedValues[tid]) { + sharedValues[tid] = a; + sharedIndices[tid] = i; + } + + if(i + blockDim.x < num_bins) { + float b = binValues[rowIdx * numBlocks + i + blockDim.x]; + if(b > sharedValues[tid]) { + sharedValues[tid] = b; + sharedIndices[tid] = i + blockDim.x; + } + } + } + + __syncthreads(); + + for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { + if(tid < s && tid + s < num_bins) { + if(sharedValues[tid + s] > sharedValues[tid]) { + sharedValues[tid] = sharedValues[tid + s]; + sharedIndices[tid] = sharedIndices[tid + s]; + } + } + __syncthreads(); + } + + UNROLL_MAXARG_LOOP(32, num_bins); + UNROLL_MAXARG_LOOP(16, num_bins); + UNROLL_MAXARG_LOOP(8, num_bins); + UNROLL_MAXARG_LOOP(4, num_bins); + UNROLL_MAXARG_LOOP(2, num_bins); + UNROLL_MAXARG_LOOP(1, num_bins); + + if(tid == 0) { + bestBinCost = sharedValues[0]; + bestBinCostIdx = rowIdx * numBlocks + sharedIndices[0]; + + inValues[binIndices[bestBinCostIdx]] = flip * minimal; // this is restored in the last lines of this function + + outIndices[kthOutIdx] = binIndices[bestBinCostIdx] - begin; // relative to beginning of row hence substract `begin` + outValues[kthOutIdx] = flip * bestBinCost; // undo flip by flipping again + } + + __syncthreads(); + + // Second part of the algorithm, why it that not replacing the first function call?? + // Also shouldn't we skip here if k == K - 1? + + // After marking the previously largest element with "flip * minimal" we populate again + // shared memory with the largest element as in gMaxElement(...) + + if(k < K - 1) { + i = begin + (bestBinCostIdx - rowIdx * numBlocks) * (blockDim.x * 2) + tid; + const int dist = num_bins * 2 * blockDim.x; + + sharedValues[tid] = minimal; + + if(i + blockDim.x < end) { + float a = flip * (float)inValues[i]; + float b = flip * (float)inValues[i + blockDim.x]; + if(a > b) { + sharedIndices[tid] = i; + sharedValues[tid] = a; + } else { + sharedIndices[tid] = i + blockDim.x; + sharedValues[tid] = b; + } + } else if(i < end) { + sharedIndices[tid] = i; + sharedValues[tid] = flip * (float)inValues[i]; + } + + while(i + dist < end) { + i += dist; + + float a = flip * (float)inValues[i]; + if(a > sharedValues[tid]) { + sharedIndices[tid] = i; + sharedValues[tid] = a; + } + + if(i + blockDim.x < end) { + float b = flip * (float)inValues[i + blockDim.x]; + if(b > sharedValues[tid]) { + sharedIndices[tid] = i + blockDim.x; + sharedValues[tid] = b; + } + } + } + + __syncthreads(); + + for(int s = (blockDim.x >> 1); s > 32; s >>= 1) { + if(tid < s && tid + s < end) { + if(sharedValues[tid + s] > sharedValues[tid]) { + sharedIndices[tid] = sharedIndices[tid + s]; + sharedValues[tid] = sharedValues[tid + s]; + } + } + __syncthreads(); + } + + UNROLL_MAXARG_LOOP(32, end); + UNROLL_MAXARG_LOOP(16, end); + UNROLL_MAXARG_LOOP(8, end); + UNROLL_MAXARG_LOOP(4, end); + UNROLL_MAXARG_LOOP(2, end); + UNROLL_MAXARG_LOOP(1, end); + + if(tid == 0) { + binIndices[bestBinCostIdx] = sharedIndices[0]; + binValues[bestBinCostIdx] = sharedValues[0]; + } + __syncthreads(); + } + } + + // final operation to restore blanked-out input values. They were blanked out for marking + // already found values. Since we want input values to be invariant we restore here. + // @TODO: The lack of constness here might be a problem for concurrent processing (which we currently don't have) + for(int k = tid; k < K; k += blockDim.x) { + int kthOutIdx = rowIdx * K + k; + inValues[begin + outIndices[kthOutIdx]] = outValues[kthOutIdx]; + } +} + +void TopK(Tensor outVal, Tensor outInd, Ptr allocator, const Tensor in, int k, int axis, bool descending) { + + ABORT_IF(axis != in->shape().size() - 1, "Currently only works for last axis"); + ABORT_IF(!isFloat(in->type()), "Input should be float type and not {}", in->type()); + ABORT_IF(outInd->type() != Type::uint32, "Output should be have type {}", Type::uint32); + ABORT_IF(outVal->type() != in->type(), "Output should be have type {}", in->type()); + + cudaSetDevice(outInd->getDeviceId().no); + + int cols = in->shape()[-1]; // e.g. in beam search that would be [beam * dimVoc] + int rows = in->shape().elements() / cols; // e.g. in beam search that would be [time * batch] + + ABORT_IF(k > cols, "Cannot select more than {} elements for axis {}", cols, axis); + + float minimal = NumericLimits(in->type()).lowest; // lowest if looking for max + + const int numBlocks = std::min(MAX_BINS, int(cols / (2 * BLOCK_SIZE)) + int(cols % (2 * BLOCK_SIZE) != 0)); + auto tempMemInd = allocator->alloc(rows * numBlocks); + + MemoryPiece::PtrType tempMemVal; + if(in->type() == Type::float32) { + tempMemVal = allocator->alloc(rows * numBlocks); + // first find the maximum value per row and block and save indices and values to temporary memory + gMaxElement<<>>( + tempMemInd->data(), tempMemVal->data(), + in->data(), rows, cols, minimal, descending); + gMaxElementUpdate<<>>( + tempMemInd->data(), tempMemVal->data(), + outInd->data(), outVal->data(), + in->data(), cols, k, numBlocks, minimal, descending); +#if COMPILE_FP16 + } else if(in->type() == Type::float16) { + tempMemVal = allocator->alloc<__half>(rows * numBlocks); + // first find the maximum value per row and block and save indices and values to temporary memory + gMaxElement<<>>( + tempMemInd->data(), tempMemVal->data<__half>(), + in->data<__half>(), rows, cols, minimal, descending); + gMaxElementUpdate<<>>( + tempMemInd->data(), tempMemVal->data<__half>(), + outInd->data(), outVal->data<__half>(), + in->data<__half>(), cols, k, numBlocks, minimal, descending); +#endif + } else { + ABORT("Topk not implemented for type {}", in->type()); + } + + allocator->free(tempMemInd); + allocator->free(tempMemVal); +} + +} +} diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h index 6f4c202ab..4c214b20a 100755 --- a/src/tensors/tensor.h +++ b/src/tensors/tensor.h @@ -36,12 +36,24 @@ class TensorBase { Ptr backend) : memory_(memory), shape_(shape), type_(type), backend_(backend) {} - TensorBase(MemoryPiece::PtrType memory, Shape shape, Ptr backend) + TensorBase(MemoryPiece::PtrType memory, + Shape shape, + Ptr backend) : memory_(memory), shape_(shape), type_(Type::float32), backend_(backend) {} + // Wraps existing memory + template + TensorBase(T* rawMemory, + size_t rawMemoryNum, + Shape shape, + Type type, + Ptr backend) + : memory_(MemoryPiece::New((uint8_t*)rawMemory, rawMemoryNum * sizeof(T))), + shape_(shape), type_(type), backend_(backend) {} + public: // Use this whenever pointing to MemoryPiece typedef IPtr PtrType; diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index 227bc9533..77a86c717 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -25,13 +25,16 @@ namespace marian { template -void copy(Ptr& MAYBE_UNUSED backend, const InIt beg, const InIt end, OutIt it) { +void copy(Ptr& backend, const InIt beg, const InIt end, OutIt it) { #ifdef CUDA_FOUND if(backend->getDeviceId().type == DeviceType::gpu) gpu::copy(backend, beg, end, it); else -#endif std::copy(beg, end, it); +#else + backend; + std::copy(beg, end, it); +#endif } DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor); @@ -190,7 +193,7 @@ void LayerNormalizationGrad(Tensor gradX, } static inline void LayerNormalizationGrad( - Ptr MAYBE_UNUSED allocator, + Ptr allocator, Tensor gradX, Tensor gradGamma, Tensor gradBeta, @@ -220,6 +223,8 @@ DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const marian::Tensor) DISPATCH4(Select, marian::Tensor, const marian::Tensor, const marian::Tensor, int) DISPATCH4(Insert, marian::Tensor, const marian::Tensor, const marian::Tensor, int) +DISPATCH7(TopK, marian::Tensor, marian::Tensor, Ptr, const marian::Tensor, int, int, bool); + DISPATCH2(LSTMCellForward, marian::Tensor, std::vector) DISPATCH2(LSTMOutputForward, marian::Tensor, std::vector); // clang-format on diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index e31ed1d4e..ccf8cc72d 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -1,25 +1,26 @@ # Unit tests add_subdirectory(units) +if(NOT MSVC) + # Testing apps + set(APP_TESTS + logger + dropout + sqlite + prod + cli + pooling + ) -# Testing apps -set(APP_TESTS - logger - dropout - sqlite - prod - cli - pooling -) + foreach(test ${APP_TESTS}) + add_executable("test_${test}" "${test}.cpp") -foreach(test ${APP_TESTS}) - add_executable("test_${test}" "${test}.cpp") + if(CUDA_FOUND) + target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS}) + else(CUDA_FOUND) + target_link_libraries("test_${test}" marian ${EXT_LIBS}) + endif(CUDA_FOUND) - if(CUDA_FOUND) - target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS}) - else(CUDA_FOUND) - target_link_libraries("test_${test}" marian ${EXT_LIBS}) - endif(CUDA_FOUND) - - set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") -endforeach(test) + set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") + endforeach(test) +endif(NOT MSVC) \ No newline at end of file diff --git a/src/tests/units/CMakeLists.txt b/src/tests/units/CMakeLists.txt index 654355b7f..afc7349be 100644 --- a/src/tests/units/CMakeLists.txt +++ b/src/tests/units/CMakeLists.txt @@ -17,5 +17,10 @@ foreach(test ${UNIT_TESTS}) target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch) endif(CUDA_FOUND) + if(MSVC) + # Disable C4305: truncation from 'double' to '_Ty' + target_compile_options("run_${test}" PUBLIC /wd4305) + endif(MSVC) + add_test(NAME ${test} COMMAND "run_${test}") endforeach(test) diff --git a/src/tests/units/fastopt_tests.cpp b/src/tests/units/fastopt_tests.cpp index cec6ab902..5bac08a41 100644 --- a/src/tests/units/fastopt_tests.cpp +++ b/src/tests/units/fastopt_tests.cpp @@ -5,8 +5,6 @@ using namespace marian; TEST_CASE("FastOpt can be constructed from a YAML node", "[fastopt]") { - YAML::Node node; - SECTION("from a simple node") { YAML::Node node = YAML::Load("{foo: bar}"); const FastOpt o(node); diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 581cd05c7..9f3a3a12e 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -33,20 +33,37 @@ void tests(DeviceType device, Type floatType = Type::float32) { std::vector values, values2; - SECTION("scalar multiplication") { + SECTION("elementwise unary and binary operators with scalars") { graph->clear(); values.clear(); - std::vector vB({1, 2, 3, 4, 5, 6}); - auto B = graph->param("B", {3, 2}, inits::fromVector(vB)); - auto B2 = B * 2.0f; - graph->forward(); + std::vector vA({1, -2, 3, -4}); + auto a = graph->constant({2, 2, 1}, inits::fromVector(vA)); - CHECK(B2->shape() == Shape({3, 2})); - B2->val()->get(values); + auto compare = [&](Expr res, std::function f) -> bool { + if (res->shape() != Shape({ 2, 2, 1 })) + return false; + res->val()->get(values); + std::vector ref{f(vA[0]), f(vA[1]), f(vA[2]), f(vA[3])}; + return std::equal(values.begin(), values.end(), ref.begin(), floatEqual); + }; + + // @TODO: add all operators and scalar variants here for completeness + auto rsmult = 2.f * a; + auto rabs = abs(a); + auto rmax1 = maximum(a, 1); + auto rmax2 = maximum(1, a); + auto rmin1 = minimum(a, 1); + auto rmin2 = minimum(1, a); - std::vector vB2({2, 4, 6, 8, 10, 12}); - CHECK(values == vB2); + graph->forward(); + + CHECK(compare(rsmult, [](float a) {return 2.f * a;})); + CHECK(compare(rabs, [](float a) {return std::abs(a);})); + CHECK(compare(rmax1, [](float a) {return std::max(a, 1.f);})); + CHECK(compare(rmax2, [](float a) {return std::max(1.f, a);})); + CHECK(compare(rmin1, [](float a) {return std::min(a, 1.f);})); + CHECK(compare(rmin2, [](float a) {return std::min(1.f, a);})); } SECTION("elementwise binary operators with broadcasting") { @@ -59,12 +76,23 @@ void tests(DeviceType device, Type floatType = Type::float32) { auto a = graph->constant({2, 2, 1}, inits::fromVector(vA)); auto b = graph->constant({2, 1}, inits::fromVector(vB)); - auto compare = [&](Expr res, std::function f, bool exactMatch) -> bool { + // Two lambdas below differ in the use of floatEqual or floatApprox and + // are not merged because MSVC compiler returns C2446: no conversion from + // lambda_x to lambda_y + auto compare = [&](Expr res, std::function f) -> bool { if (res->shape() != Shape({ 2, 2, 1 })) return false; res->val()->get(values); std::vector ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])}; - return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox); + return std::equal(values.begin(), values.end(), ref.begin(), floatEqual); + }; + + auto compareApprox = [&](Expr res, std::function f) -> bool { + if(res->shape() != Shape({2, 2, 1})) + return false; + res->val()->get(values); + std::vector ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])}; + return std::equal(values.begin(), values.end(), ref.begin(), floatApprox); }; auto rplus = a + b; @@ -83,19 +111,19 @@ void tests(DeviceType device, Type floatType = Type::float32) { graph->forward(); - CHECK(compare(rplus, [](float a, float b) {return a + b;}, true)); - CHECK(compare(rminus, [](float a, float b) {return a - b;}, true)); - CHECK(compare(rmult, [](float a, float b) {return a * b;}, true)); - CHECK(compare(rdiv, [](float a, float b) {return a / b;}, false)); - CHECK(compare(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}, false)); - CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}, true)); - CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}, true)); - CHECK(compare(rlt, [](float a, float b) {return a < b;}, true)); - CHECK(compare(req, [](float a, float b) {return a == b;}, true)); - CHECK(compare(rgt, [](float a, float b) {return a > b;}, true)); - CHECK(compare(rge, [](float a, float b) {return a >= b;}, true)); - CHECK(compare(rne, [](float a, float b) {return a != b;}, true)); - CHECK(compare(rle, [](float a, float b) {return a <= b;}, true)); + CHECK(compare(rplus, [](float a, float b) {return a + b;})); + CHECK(compare(rminus, [](float a, float b) {return a - b;})); + CHECK(compare(rmult, [](float a, float b) {return a * b;})); + CHECK(compareApprox(rdiv, [](float a, float b) {return a / b;})); + CHECK(compareApprox(rlae, [](float a, float b) {return logf(expf(a) + expf(b));})); + CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);})); + CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);})); + CHECK(compare(rlt, [](float a, float b) {return a < b;})); + CHECK(compare(req, [](float a, float b) {return a == b;})); + CHECK(compare(rgt, [](float a, float b) {return a > b;})); + CHECK(compare(rge, [](float a, float b) {return a >= b;})); + CHECK(compare(rne, [](float a, float b) {return a != b;})); + CHECK(compare(rle, [](float a, float b) {return a <= b;})); } SECTION("transposing and reshaping") { @@ -382,8 +410,8 @@ void tests(DeviceType device, Type floatType = Type::float32) { std::vector SV; // create CSR version of S std::vector SI, SO; SO.push_back((IndexType)SI.size()); - for (IndexType i = 0; i < S->shape()[0]; i++) { - for (IndexType j = 0; j < S->shape()[1]; j++) { + for (IndexType i = 0; i < (IndexType)S->shape()[0]; i++) { + for (IndexType j = 0; j < (IndexType)S->shape()[1]; j++) { auto k = 4 * i + j; if (vS[k] != 0) { SV.push_back(vS[k]); @@ -460,7 +488,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { aff1->val()->get(values); CHECK(values == vAff); - std::vector values2; + values2.clear(); CHECK(aff2->shape() == aff1->shape()); aff2->val()->get(values2); CHECK(values2 == values); @@ -636,7 +664,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { SECTION("relation of rows and columns selection using transpose") { graph->clear(); values.clear(); - std::vector values2; + values2.clear(); std::vector vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5}); std::vector idx({0, 1}); @@ -746,7 +774,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { SECTION("rows/cols as gather operations") { graph->clear(); values.clear(); - std::vector values2; + values2.clear(); std::vector vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5}); @@ -770,6 +798,80 @@ void tests(DeviceType device, Type floatType = Type::float32) { C2->val()->get(values2); CHECK( values == values2 ); } + + SECTION("topk operations") { + graph->clear(); + values.clear(); + + std::vector vA({ 0, .3333, -.2, + -.3, 0, 4.5, + 5.2, -10, 101.45, + -100.05, 0, 1.05e-5}); + + auto a = graph->constant({2, 2, 3}, inits::fromVector(vA)); + + // get top-k indices and values as a tuple + auto rtopk1 = topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/true); + auto rval1 = get<0>(rtopk1); // values from top-k + auto ridx1 = get<1>(rtopk1); // indices from top-k + auto gval1 = gather(a, -1, ridx1); // get the same values via gather and indices + + auto ridx2 = get<1>(topk(a, /*k=*/2, /*axis=*/-1, /*descending=*/false)); + auto gval2 = gather(a, -1, ridx2); // get the same values via gather and indices + + auto ridx3 = get<1>(argmin(a, -1)); + auto ridx3_ = slice(ridx2, -1, 0); // slice and cast now support uint32_t/IndexType + + // @TODO: add integer types to more operators + auto eq3 = eq(cast(ridx3, floatType), cast(ridx3_, floatType)); + + auto rtopk4 = argmax(a, /*axis=*/-2); // axes other than -1 are currently implemented via inefficient transpose + auto rval4 = get<0>(rtopk4); + auto ridx4 = get<1>(rtopk4); + auto gval4 = gather(a, -2, ridx4); + + graph->forward(); + + CHECK(rval1 != gval1); + CHECK(rval1->shape() == gval1->shape()); + CHECK(ridx1->shape() == gval1->shape()); + + std::vector vval1 = { 0.3333, 0, + 4.5, 0, + 101.45, 5.2, + 1.05e-5, 0 }; + + std::vector rvalues; + std::vector gvalues; + rval1->val()->get(rvalues); + gval1->val()->get(gvalues); + CHECK( rvalues == gvalues ); + CHECK( rvalues == vval1 ); + + std::vector vval2 = { -0.2, 0, + -0.3, 0, + -10.0, 5.2, + -100.05, 0 }; + gval2->val()->get(values); + CHECK( values == vval2 ); + + eq3->val()->get(values); + CHECK( values == std::vector({1, 1, 1, 1}) ); + + std::vector vidx4; + ridx4->val()->get(vidx4); + CHECK( ridx4->shape() == Shape({2, 1, 3}) ); + CHECK( vidx4 == std::vector({0, 0, 1, + 0, 1, 0}) ); + + std::vector vval4 = { 0, 0.3333, 4.5, + 5.2, 0, 101.45 }; + rval4->val()->get(values); + CHECK( values == vval4 ); + + gval4->val()->get(values); + CHECK( values == vval4 ); + } } #ifdef CUDA_FOUND @@ -795,7 +897,7 @@ TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") TEST_CASE("Compare aggregate operator", "[graph]") { auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).margin(0.001f); }; - + Config::seed = 1234; std::vector initc; @@ -817,7 +919,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") { SECTION("initializing with zero (cpu)") { std::vector values1; std::vector values2; - + auto graph1 = New(); graph1->setDevice({0, DeviceType::cpu}); graph1->reserveWorkspaceMB(40); @@ -825,7 +927,7 @@ TEST_CASE("Compare aggregate operator", "[graph]") { auto graph2 = New(); graph2->setDevice({0, DeviceType::gpu}); graph2->reserveWorkspaceMB(40); - + auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc)); auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita)); auto prod1 = scalar_product(chl1, adj1, -1); @@ -844,4 +946,4 @@ TEST_CASE("Compare aggregate operator", "[graph]") { } #endif - #endif \ No newline at end of file + #endif diff --git a/src/tests/units/rnn_tests.cpp b/src/tests/units/rnn_tests.cpp index 6405ef7a6..5d67729b3 100644 --- a/src/tests/units/rnn_tests.cpp +++ b/src/tests/units/rnn_tests.cpp @@ -181,7 +181,7 @@ void tests(DeviceType type, Type floatType = Type::float32) { auto context = concatenate({rnnFw.construct(graph)->transduce(input, mask), rnnBw.construct(graph)->transduce(input, mask)}, - /*axis =*/ input->shape().size() - 1); + /*axis =*/ (int)input->shape().size() - 1); if(second > 0) { // add more layers (unidirectional) by transducing the output of the diff --git a/src/training/communicator.h b/src/training/communicator.h index 472744910..9c2c8ebad 100644 --- a/src/training/communicator.h +++ b/src/training/communicator.h @@ -10,6 +10,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wsuggest-override" #endif +#undef HOST #include "mpi.h" #ifdef __GNUC__ #pragma GCC diagnostic pop diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h index 4ea17d8b1..049a8ef5f 100755 --- a/src/translator/beam_search.h +++ b/src/translator/beam_search.h @@ -5,6 +5,7 @@ #include "translator/history.h" #include "translator/scorers.h" #include "data/factored_vocab.h" +#include "data/shortlist.h" #include "translator/helpers.h" #include "translator/nth_element.h" @@ -68,21 +69,21 @@ class BeamSearch { // They can be between 0 and (vocabSize * nBestBeamSize * batchSize)-1. // (beamHypIdx refers to the GPU tensors, *not* the beams[] array; they are not the same in case of purging) const auto key = nBestKeys[i]; - + // decompose key into individual indices (batchIdx, beamHypIdx, wordIdx) const auto beamHypIdx = (key / vocabSize) % nBestBeamSize; const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize; const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0; - + WordIndex wordIdx; if(dropHyp) { // if we force=drop the hypothesis, assign EOS, otherwise the expected word id. if(factoredVocab) { // when using factoredVocab, extract the EOS lemma index from the word id, we predicting factors one by one here, hence lemma only std::vector eosFactors; factoredVocab->word2factors(factoredVocab->getEosId(), eosFactors); wordIdx = (WordIndex)eosFactors[0]; - } else { // without factoredVocab lemma index and word index are the same. Safe cruising. + } else { // without factoredVocab lemma index and word index are the same. Safe cruising. wordIdx = trgVocab_->getEosId().toWordIndex(); } } else { // we are not dropping anything, just assign the normal index @@ -90,9 +91,9 @@ class BeamSearch { } // @TODO: We currently assign a log probability of 0 to all beam entries of the dropped batch entry, instead it might be a good idea to use - // the per Hyp pathScore without the current expansion (a bit hard to obtain). - // For the case where we drop empty inputs, 0 is fine. For other use cases like a forced stop, the penultimate pathScore might be better. - // For the empty hyp this would naturally result in 0, too. + // the per Hyp pathScore without the current expansion (a bit hard to obtain). + // For the case where we drop empty inputs, 0 is fine. For other use cases like a forced stop, the penultimate pathScore might be better. + // For the empty hyp this would naturally result in 0, too. const float pathScore = dropHyp ? 0.f : nBestPathScores[i]; // 0 (Prob = 1, maximum score) if dropped or expanded path score for (batchIdx, beamHypIdx, word) const auto& beam = beams[origBatchIdx]; @@ -102,7 +103,7 @@ class BeamSearch { continue; if(pathScore == INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor) continue; - + ABORT_IF(pathScore < INVALID_PATH_SCORE, "Actual pathScore ({}) is lower than INVALID_PATH_SCORE ({})??", pathScore, INVALID_PATH_SCORE); // This should not happen in valid situations. Currently the only smaller value would be -inf (effect of overflow in summation?) ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??"); // effectively this is equivalent to ABORT_IF(beams[origBatchIdx].empty(), ...) diff --git a/src/translator/history.h b/src/translator/history.h index 463c75a1c..17b642c65 100644 --- a/src/translator/history.h +++ b/src/translator/history.h @@ -37,7 +37,15 @@ class History { size_t size() const { return history_.size(); } // number of time steps - NBestList nBest(size_t n) const { + /* return n best hypotheses + * @param n size of n-best list + * @param skipEmpty skip empty hypotheses (see also: https://arxiv.org/abs/1908.10090) + * @return at most max(n, beamSize) translation hypotheses + * Note: if n is equal to the beam size, skipEmpty is true, and the empty hypothesis is in + * the top-n translations, the function will return less than n candidates. It is up to + * the caller to check the number of returned hypotheses. + */ + NBestList nBest(size_t n, bool skipEmpty = false) const { NBestList nbest; for (auto topHypsCopy = topHyps_; nbest.size() < n && !topHypsCopy.empty(); topHypsCopy.pop()) { auto bestHypCoord = topHypsCopy.top(); @@ -48,17 +56,18 @@ class History { // trace back best path Words targetWords = bestHyp->tracebackWords(); - + if (skipEmpty && targetWords.size() == 0) + continue; // skip empty translation // note: bestHyp->getPathScore() is not normalized, while bestHypCoord.normalizedPathScore is nbest.emplace_back(targetWords, bestHyp, bestHypCoord.normalizedPathScore); } return nbest; } - Result top() const { + Result top() const { const NBestList& nbest = nBest(1); ABORT_IF(nbest.empty(), "No hypotheses in n-best list??"); - return nbest[0]; + return nbest[0]; } size_t getLineNum() const { return lineNo_; } diff --git a/src/translator/output_collector.cpp b/src/translator/output_collector.cpp index 76bc4cbbc..078be232b 100644 --- a/src/translator/output_collector.cpp +++ b/src/translator/output_collector.cpp @@ -79,13 +79,14 @@ void OutputCollector::Write(long sourceId, } } -StringCollector::StringCollector() : maxId_(-1) {} +StringCollector::StringCollector(bool quiet /*=false*/) : maxId_(-1), quiet_(quiet) {} void StringCollector::add(long sourceId, const std::string& best1, const std::string& bestn) { std::lock_guard lock(mutex_); - LOG(info, "Best translation {} : {}", sourceId, best1); + if(!quiet_) + LOG(info, "Best translation {} : {}", sourceId, best1); outputs_[sourceId] = std::make_pair(best1, bestn); if(maxId_ <= sourceId) maxId_ = sourceId; diff --git a/src/translator/output_collector.h b/src/translator/output_collector.h index ffcbd2d50..0e6bfc9f8 100644 --- a/src/translator/output_collector.h +++ b/src/translator/output_collector.h @@ -74,14 +74,15 @@ class OutputCollector { class StringCollector { public: - StringCollector(); + StringCollector(bool quiet = false); StringCollector(const StringCollector&) = delete; void add(long sourceId, const std::string& best1, const std::string& bestn); std::vector collect(bool nbest); protected: - long maxId_; + long maxId_; // the largest index of the translated source sentences + bool quiet_; // if true do not log best translations std::mutex mutex_; typedef std::map> Outputs; diff --git a/src/translator/translator.h b/src/translator/translator.h index cc68a4f01..15eb98702 100755 --- a/src/translator/translator.h +++ b/src/translator/translator.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "data/batch_generator.h" #include "data/corpus.h" #include "data/shortlist.h" @@ -245,10 +247,14 @@ class TranslateService : public ModelServiceTask { } std::string run(const std::string& input) override { - auto corpus_ = New(std::vector({input}), srcVocabs_, options_); + // split tab-separated input into fields if necessary + auto inputs = options_->get("tsv", false) + ? convertTsvToLists(input, options_->get("tsv-fields", 1)) + : std::vector({input}); + auto corpus_ = New(inputs, srcVocabs_, options_); data::BatchGenerator batchGenerator(corpus_, options_); - auto collector = New(); + auto collector = New(options_->get("quiet-translation", false)); auto printer = New(options_, trgVocab_); size_t batchId = 0; @@ -258,7 +264,6 @@ class TranslateService : public ModelServiceTask { ThreadPool threadPool_(numDevices_, numDevices_); for(auto batch : batchGenerator) { - auto task = [=](size_t id) { thread_local Ptr graph; thread_local std::vector> scorers; @@ -287,5 +292,30 @@ class TranslateService : public ModelServiceTask { auto translations = collector->collect(options_->get("n-best")); return utils::join(translations, "\n"); } + +private: + // Converts a multi-line input with tab-separated source(s) and target sentences into separate lists + // of sentences from source(s) and target sides, e.g. + // "src1 \t trg1 \n src2 \t trg2" -> ["src1 \n src2", "trg1 \n trg2"] + std::vector convertTsvToLists(const std::string& inputText, size_t numFields) { + std::vector outputFields(numFields); + + std::string line; + std::vector lineFields(numFields); + std::istringstream inputStream(inputText); + bool first = true; + while(std::getline(inputStream, line)) { + utils::splitTsv(line, lineFields, numFields); + for(size_t i = 0; i < numFields; ++i) { + if(!first) + outputFields[i] += "\n"; // join sentences with a new line sign + outputFields[i] += lineFields[i]; + } + if(first) + first = false; + } + + return outputFields; + } }; } // namespace marian diff --git a/vs/Marian.vcxproj b/vs/Marian.vcxproj old mode 100755 new mode 100644 index 0cb4a5de0..bde4404e1 --- a/vs/Marian.vcxproj +++ b/vs/Marian.vcxproj @@ -55,7 +55,7 @@ - $(MSMPI_INC); $(MSMPI_INC)\x64 + $(MSMPI_INC); $(MSMPI_INC)\x64; $(PROTOBUF_RUNTIME_INC) $(OutDir);$(SolutionDir)$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration);$(MSMPI_LIB64) @@ -70,7 +70,7 @@ Level4 Disabled - USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) + USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) false true /bigobj /arch:AVX %(AdditionalOptions) @@ -107,7 +107,7 @@ MaxSpeed true true - USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) + USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; FINTEGER=uint64_t; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions) false Speed /d2Zi+ /bigobj /arch:AVX %(AdditionalOptions) @@ -141,6 +141,34 @@ + + TurnOffAllWarnings + TurnOffAllWarnings + + + TurnOffAllWarnings + TurnOffAllWarnings + + + TurnOffAllWarnings + TurnOffAllWarnings + + + TurnOffAllWarnings + TurnOffAllWarnings + + + TurnOffAllWarnings + TurnOffAllWarnings + + + TurnOffAllWarnings + TurnOffAllWarnings + + + TurnOffAllWarnings + TurnOffAllWarnings + true true @@ -805,6 +833,16 @@ TurnOffAllWarnings TurnOffAllWarnings + + false + false + $(MSMPI_INC); $(MSMPI_INC)\x64; $(PROTOBUF_RUNTIME_INC) + $(MSMPI_INC); $(MSMPI_INC)\x64; $(PROTOBUF_RUNTIME_INC) + + + true + true + true true @@ -1010,6 +1048,14 @@ + + + + + + + + @@ -1206,6 +1252,14 @@ + + false + false + + + false + false + true true @@ -1341,6 +1395,7 @@ + @@ -1373,11 +1428,18 @@ + false + + + + TurnOffAllWarnings + TurnOffAllWarnings + @@ -1398,49 +1460,58 @@ + - + true true - + true true - + true true - + true true - + true true - + true true - + true true - + + true + true + + + true + true + + true true - + true true - + true true - + true true @@ -1485,6 +1556,9 @@ + + + @@ -1493,6 +1567,7 @@ + @@ -1733,6 +1808,7 @@ + @@ -1771,6 +1847,8 @@ + + @@ -1860,6 +1938,10 @@ true + + true + true + @@ -1868,6 +1950,7 @@ true Document + false @@ -1897,10 +1980,19 @@ false false + + + true + true + + + true + true + @@ -1924,6 +2016,7 @@ + true @@ -1931,6 +2024,7 @@ + diff --git a/vs/Marian.vcxproj.filters b/vs/Marian.vcxproj.filters old mode 100755 new mode 100644 index bb6080ae8..ef6fd9acd --- a/vs/Marian.vcxproj.filters +++ b/vs/Marian.vcxproj.filters @@ -250,6 +250,15 @@ common + + 3rd_party\onnx\protobuf + + + 3rd_party\onnx\protobuf + + + rescorer + tensors\cpu\sharp @@ -427,39 +436,6 @@ 3rd_party\pathie-cpp\src - - tests - - - tests - - - tests - - - tests - - - tests - - - tests - - - tests - - - tests - - - tests - - - tests - - - tests - examples\mnist @@ -877,6 +853,84 @@ training + + common + + + tensors\cpu + + + onnx + + + onnx + + + onnx + + + tests + + + tests + + + tests + + + tests + + + tests + + + tests + + + tests\units + + + tests\units + + + tests\units + + + tests\units + + + tests\units + + + tests\units + + + tests\units + + + 3rd_party\faiss + + + 3rd_party\faiss + + + 3rd_party\faiss + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + layers + @@ -1687,6 +1741,12 @@ tensors\gpu + + 3rd_party\onnx\protobuf + + + 3rd_party\onnx\protobuf + 3rd_party\sentencepiece\src @@ -2251,6 +2311,43 @@ tensors\cpu\fbgemm + + common + + + tensors\gpu + + + + onnx + + + 3rd_party\faiss + + + 3rd_party\faiss + + + 3rd_party\faiss + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + 3rd_party\faiss\utils + + + layers + @@ -2499,6 +2596,24 @@ {bf361868-f451-45b8-9695-570d67924972} + + {ba00e638-d55d-4722-9caa-c9e6e133a072} + + + {153199ee-2f29-4bd6-8187-4235ac020ccd} + + + {f3849101-b84e-48c1-af5d-27c79ce89e70} + + + {d145421b-1723-47f1-858d-2e49adb24b03} + + + {e479c8b0-9309-4df4-a28c-9af145c69b51} + + + {b100324b-a506-45fa-948e-40be75b239fc} + @@ -2627,6 +2742,18 @@ 3rd_party\half_float + + tensors\gpu + + + 3rd_party\onnx\protobuf + + + 3rd_party\faiss + + + 3rd_party\faiss + @@ -2641,6 +2768,12 @@ examples + + tests\units + + + 3rd_party\faiss + @@ -2670,5 +2803,14 @@ tensors\gpu + + tensors\gpu + + + tensors\gpu + + + tests + \ No newline at end of file