diff --git a/src/runtime/hexagon/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon/hexagon_user_dma.cc index f943dfd5abb1..6e286ae8b3f4 100644 --- a/src/runtime/hexagon/hexagon/hexagon_user_dma.cc +++ b/src/runtime/hexagon/hexagon/hexagon_user_dma.cc @@ -39,7 +39,7 @@ int init_hexagon_user_dma() { return DMA_SUCCESS; } -int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { +int hexagon_user_dma_1d_sync_helper(void* dst, void* src, uint32_t length) { #if defined(__hexagon__) && __HEXAGON_ARCH__ >= 68 static int config_dma = init_hexagon_user_dma(); if (config_dma != DMA_SUCCESS) { @@ -114,6 +114,35 @@ int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { #endif } +int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { + // One DMA transfer can copy atmost DESC_LENGTH_MASK bytes. + // Make the common case quick. + if (length <= DESC_LENGTH_MASK) return hexagon_user_dma_1d_sync_helper(dst, src, length); + + // Split big transfers into smaller transfers. + char* cast_src = static_cast(src); + char* cast_dst = static_cast(dst); + for (uint32_t i = 0; i < length;) { + // Ensure there is no overflow while updating i + uint32_t cur_len = std::min(length - i, DESC_LENGTH_MASK); + int ret_val = hexagon_user_dma_1d_sync_helper(&cast_dst[i], &cast_src[i], cur_len); + if (ret_val != DMA_SUCCESS) return ret_val; + // 2 cases for new val for i: + // 1. length - i <= DESC_LENGTH_MASK (<= MAX_UINT) + // new_i = i + (length - i) = length, no more iter + // and no overflow (since (length - i) <= (MAX_UINT - i)) + // 2. length - i > DESC_LENGTH_MASK + // length > (i + DESC_LENGTH_MASK) + // new_i = (i + DESC_LENGTH_MASK) + // length > new_i for next iter, we're done + // length - i > DESC_LENGTH_MASK + // and length <= MAX_UINT, + // so MAX_UINT >= length > DESC_LEN_MASK + i + // MAX_UINT > (DESC_LEN_MASK + i), so no overflow + i += cur_len; + } + return DMA_SUCCESS; +} } // namespace hexagon } // namespace runtime } // namespace tvm