Skip to content

Commit

Permalink
[PYTHON][FFI] Cythonize NDArray.copyto (#4549)
Browse files Browse the repository at this point in the history
* [PYTHON][FFI] Cythonize NDArray.copyto

* Cythonize the shape property
  • Loading branch information
tqchen authored and ZihengJiang committed Dec 20, 2019
1 parent ce0b6d5 commit bc5367a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
10 changes: 10 additions & 0 deletions python/tvm/_ffi/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def __del__(self):
def _tvm_handle(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value

def _copyto(self, target_nd):
"""Internal function that implements copy to target ndarray."""
check_call(_LIB.TVMArrayCopyFromTo(self.handle, target_nd.handle, None))
return target_nd

@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))

def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/_ffi/_cython/ndarray.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ cdef class NDArrayBase:
def __set__(self, value):
self._set_handle(value)

@property
def shape(self):
"""Shape of this array"""
return tuple(self.chandle.shape[i] for i in range(self.chandle.ndim))

def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view
Expand All @@ -76,6 +81,11 @@ cdef class NDArrayBase:
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))

def _copyto(self, target_nd):
"""Internal function that implements copy to target ndarray."""
CALL(TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL))
return target_nd

def to_dlpack(self):
"""Produce an array from a DLPack Tensor without copying memory
Expand Down
17 changes: 6 additions & 11 deletions python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ def from_dlpack(dltensor):

class NDArrayBase(_NDArrayBase):
"""A simple Device/CPU Array object in runtime."""
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))

@property
def dtype(self):
Expand Down Expand Up @@ -240,6 +236,7 @@ def copyfrom(self, source_array):
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))

t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
Expand Down Expand Up @@ -294,14 +291,12 @@ def copyto(self, target):
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, target.handle, None))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
return self._copyto(target)
elif isinstance(target, TVMContext):
res = empty(self.shape, self.dtype, target)
return self._copyto(res)
raise ValueError("Unsupported target type %s" % str(type(target)))


def free_extension_handle(handle, type_code):
Expand Down

0 comments on commit bc5367a

Please sign in to comment.