Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for using Python APIs from Rust. #7085

Merged
merged 6 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
from tvm.ir import IRModule
from tvm.relay import transform, build_module
from tvm.runtime.ndarray import cpu
from ...ir import IRModule
from ...relay import transform, build_module
from ...runtime.ndarray import cpu

from . import _ffi_api
from .feature import Feature
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/analysis/annotated_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Regions used in Relay."""

from tvm.runtime import Object
from ...runtime import Object
from . import _ffi_api


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/analysis/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""

from tvm.ir import IRModule
from tvm.runtime import Object
from ...ir import IRModule
from ...runtime import Object
from ..expr import GlobalVar
from . import _ffi_api

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.runtime import ndarray


class GraphRuntimeFactoryModule(object):
class GraphRuntimeFactoryModule:
"""Graph runtime factory module.
This is a module of graph runtime factory

Expand Down
16 changes: 15 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from tvm.ir.transform import PassContext
from tvm.tir import expr as tvm_expr
from .. import nd as _nd, autotvm
from .. import nd as _nd, autotvm, register_func
from ..target import Target
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
Expand Down Expand Up @@ -194,6 +194,20 @@ def get_params(self):
return ret


@register_func("tvm.relay.module_export_library")
def _module_export(module, file_name): # fcompile, addons, kwargs?
return module.export_library(file_name)


@register_func("tvm.relay.build")
def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"):
"""A wrapper around build which discards the Python GraphFactoryRuntime.
This wrapper is suitable to be used from other programming languages as
the runtime::Module can be freely passed between language boundaries.
"""
return build(mod, target, target_host, params, mod_name).module


def build(mod, target=None, target_host=None, params=None, mod_name="default"):
# fmt: off
# pylint: disable=line-too-long
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
Contains the model importers currently defined
for Relay.
"""

from __future__ import absolute_import

from .mxnet import from_mxnet
from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var
from .keras import from_keras
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,10 +914,11 @@ def _impl(inputs, attr, params, mod):


def _sparse_tensor_dense_matmul():
# Sparse utility from scipy
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
# Loading this by default causes TVM to not be loadable from other languages.
# Sparse utility from scipy
from scipy.sparse import csr_matrix

assert len(inputs) == 4, "There should be 4 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

# pylint: disable=redefined-builtin, wildcard-import
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs

from .conv1d import *
from .conv1d_transpose_ncw import *
from .conv2d import *
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

# TVM Runtime Support

This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime.
Currently this is tested on `1.42.0` and above.
This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime,
see [here](https://github.com/apache/tvm/blob/main/rust/tvm/README.md) for more details.

## What Does This Crate Offer?

Expand Down
28 changes: 27 additions & 1 deletion rust/tvm-rt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ pub mod map;
pub mod module;
pub mod ndarray;
mod to_function;
pub mod value;

/// Outputs the current TVM version.
pub fn version() -> &'static str {
Expand All @@ -112,6 +111,8 @@ pub fn version() -> &'static str {
#[cfg(test)]
mod tests {
use super::*;
use crate::{ByteArray, Context, DataType};
use std::{convert::TryInto, str::FromStr};

#[test]
fn print_version() {
Expand All @@ -127,4 +128,29 @@ mod tests {
errors::NDArrayError::EmptyArray.to_string()
);
}

#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = ByteArray::from(w.as_slice());
let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
assert_eq!(
tvm.data(),
w.iter().copied().collect::<Vec<u8>>().as_slice()
);
}

#[test]
fn ty() {
let t = DataType::from_str("int32").unwrap();
let tvm: DataType = RetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t);
}

#[test]
fn ctx() {
let c = Context::from_str("gpu").unwrap();
let tvm: Context = RetValue::from(c).try_into().unwrap();
assert_eq!(tvm, c);
}
}
12 changes: 12 additions & 0 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ where
let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?;
oref.downcast()
}

pub fn empty() -> Self {
Self::from_iter(vec![].into_iter())
}

//(@jroesch): I don't think this is a correct implementation.
pub fn null() -> Self {
Map {
object: ObjectRef::null(),
_data: PhantomData,
}
}
}

pub struct IntoIter<K, V> {
Expand Down
58 changes: 30 additions & 28 deletions rust/tvm-rt/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,24 @@ use std::{
ptr,
};

use crate::object::Object;
use tvm_macros::Object;
use tvm_sys::ffi;

use crate::errors::Error;
use crate::String as TString;
use crate::{errors, function::Function};

const ENTRY_FUNC: &str = "__tvm_main__";

/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
///
/// [`entry_func`]:struct.Module.html#method.entry_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment still true?

#[derive(Debug, Clone)]
pub struct Module {
pub(crate) handle: ffi::TVMModuleHandle,
entry_func: Option<Function>,
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "Module"]
#[type_key = "runtime.Module"]
pub struct ModuleNode {
base: Object,
}

crate::external! {
Expand All @@ -49,21 +52,18 @@ crate::external! {

#[name("runtime.ModuleLoadFromFile")]
fn load_from_file(file_name: CString, format: CString) -> Module;

#[name("runtime.ModuleSaveToFile")]
fn save_to_file(module: Module, name: TString, fmt: TString);

// TODO(@jroesch): we need to refactor this
#[name("tvm.relay.module_export_library")]
fn export_library(module: Module, file_name: TString);
}

impl Module {
pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
Self {
handle,
entry_func: None,
}
}

pub fn entry(&mut self) -> Option<Function> {
if self.entry_func.is_none() {
self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
}
self.entry_func.clone()
pub fn default_fn(&mut self) -> Result<Function, Error> {
self.get_function("default", true)
}

/// Gets a function by name from a registered module.
Expand All @@ -72,7 +72,7 @@ impl Module {
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;

check_call!(ffi::TVMModGetFunction(
self.handle,
self.handle(),
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
Expand All @@ -87,7 +87,7 @@ impl Module {

/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
check_call!(ffi::TVMModImport(self.handle(), dependent_module.handle()))
}

/// Loads a module shared library from path.
Expand All @@ -110,6 +110,14 @@ impl Module {
Ok(module)
}

pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> {
save_to_file(self.clone(), name.into(), fmt.into())
}

pub fn export_library(&self, name: String) -> Result<(), Error> {
export_library(self.clone(), name.into())
}

/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let target = CString::new(target).unwrap();
Expand All @@ -118,13 +126,7 @@ impl Module {
}

/// Returns the underlying module handle.
pub fn handle(&self) -> ffi::TVMModuleHandle {
self.handle
}
}

impl Drop for Module {
fn drop(&mut self) {
check_call!(ffi::TVMModFree(self.handle));
pub unsafe fn handle(&self) -> ffi::TVMModuleHandle {
self.0.clone().unwrap().into_raw() as *mut _
}
}
13 changes: 11 additions & 2 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ impl<T: IsObject> ObjectPtr<T> {
Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
}
}

pub unsafe fn into_raw(self) -> *mut T {
self.ptr.as_ptr()
}
}

impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
Expand Down Expand Up @@ -300,7 +304,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
use crate::ndarray::NDArrayContainer;

match ret_value {
RetValue::ObjectHandle(handle) => {
RetValue::ObjectHandle(handle) | RetValue::ModuleHandle(handle) => {
let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
optr.downcast()
Expand Down Expand Up @@ -329,6 +333,11 @@ impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
assert!(!raw_ptr.is_null());
ArgValue::NDArrayHandle(raw_ptr)
}
"runtime.Module" => {
let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::ModuleHandle(raw_ptr)
}
_ => {
let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
Expand All @@ -346,7 +355,7 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
use crate::ndarray::NDArrayContainer;

match arg_value {
ArgValue::ObjectHandle(handle) => {
ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => {
let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
optr.downcast()
Expand Down
1 change: 1 addition & 0 deletions rust/tvm-rt/src/to_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ impl_typed_and_to_function!(2; A, B);
impl_typed_and_to_function!(3; A, B, C);
impl_typed_and_to_function!(4; A, B, C, D);
impl_typed_and_to_function!(5; A, B, C, D, E);
impl_typed_and_to_function!(6; A, B, C, D, E, G);

#[cfg(test)]
mod tests {
Expand Down
Loading