Skip to content

Commit

Permalink
[Meta Schedule][M3c] Argument Info (#9059)
Browse files Browse the repository at this point in the history
This PR is part of the meta schedule project (#8473) that adds metadata of each PrimFunc's argument.
This feature is necessary for dynamic shape auto-tuning.

Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>

Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
  • Loading branch information
7 people committed Sep 22, 2021
1 parent 8f39da1 commit 4c8531d
Show file tree
Hide file tree
Showing 10 changed files with 538 additions and 4 deletions.
111 changes: 111 additions & 0 deletions include/tvm/meta_schedule/arg_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_ARG_INFO_H_
#define TVM_META_SCHEDULE_ARG_INFO_H_

#include <tvm/node/node.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace meta_schedule {

/*! \brief The argument information. */
class ArgInfoNode : public runtime::Object {
public:
static constexpr const char* _type_key = "meta_schedule.ArgInfo";
TVM_DECLARE_BASE_OBJECT_INFO(ArgInfoNode, runtime::Object);

public:
/*! \brief Default destructor. */
virtual ~ArgInfoNode() = default;
/*! \brief Converts the ArgInfo to its corresponding JSON representation. */
virtual ObjectRef AsJSON() const = 0;
};

/*!
* \brief Managed reference to ArgInfoNode
* \sa ArgInfoNode
*/
class ArgInfo : public runtime::ObjectRef {
public:
/*!
* \brief Parse the argument information from a JSON object.
* \param json_obj The json object to parse.
* \return The argument information parsed.
*/
TVM_DLL static ArgInfo FromJSON(const ObjectRef& json_obj);
/*!
* \brief Extract a list of the argument information from PrimFunc.
* \param func The PrimFunc to get argument information from.
* \return An array of the argument information derived.
*/
TVM_DLL static Array<ArgInfo, void> FromPrimFunc(const tir::PrimFunc& func);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode);

protected:
ArgInfo() = default;
};

/*! \brief The tensor argument information. */
class TensorInfoNode : public ArgInfoNode {
public:
/*! \brief The data type of the tensor. */
runtime::DataType dtype;
/*! \brief The shape of the tensor. */
runtime::ShapeTuple shape;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
}

static constexpr const char* _type_key = "meta_schedule.TensorInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode);

public:
ObjectRef AsJSON() const;
};

/*!
* \brief Managed reference to TensorInfoNode
* \sa TensorInfoNode
*/
class TensorInfo : public ArgInfo {
public:
/*!
* \brief Constructor of TensorInfo.
* \param dtype The data type of the tensor argument.
* \param shape The shape tuple of the tensor argument.
*/
TVM_DLL explicit TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape);
/*!
* \brief Parse the argument information from a JSON object.
* \param json_obj The json object to parse.
* \return The argument information parsed.
*/
TVM_DLL static TensorInfo FromJSON(const ObjectRef& json_obj);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorInfo, ArgInfo, TensorInfoNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_ARG_INFO_H_
10 changes: 9 additions & 1 deletion include/tvm/runtime/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <utility>

#include "./base.h"
#include "./optional.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -1344,7 +1345,14 @@ class Map : public ObjectRef {
iterator end() const { return iterator(GetMapNode()->end()); }
/*! \return find the key and returns the associated iterator */
iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }

/*! \return The value associated with the key, NullOpt if not found */
Optional<V> Get(const K& key) const {
MapNode::iterator iter = GetMapNode()->find(key);
if (iter == GetMapNode()->end()) {
return NullOptType{};
}
return DowncastNoCheck<V>(iter->second);
}
void erase(const K& key) { CopyOnWrite()->erase(key); }

/*!
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
# under the License.
"""Package `tvm.meta_schedule`. The meta schedule infrastructure."""
from . import builder
from . import arg_info
from .tune_context import TuneContext
106 changes: 106 additions & 0 deletions python/tvm/meta_schedule/arg_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The argument information"""
from typing import Any, List, Union

from tvm._ffi import register_object
from tvm.runtime import DataType, Object, ShapeTuple
from tvm.tir import PrimFunc

from . import _ffi_api
from .utils import _json_de_tvm


@register_object("meta_schedule.ArgInfo")
class ArgInfo(Object):
"""Argument information"""

def as_json(self) -> Any:
"""Converts the ArgInfo to its corresponding JSON representation."""
return _json_de_tvm(_ffi_api.ArgInfoAsJSON(self)) # type: ignore # pylint: disable=no-member

@staticmethod
def from_json(json_obj: Any) -> "ArgInfo":
"""Parse the argument information from a JSON object.
Parameters
----------
json_obj : Any
The json object to parse.
Returns
-------
parsed : ArgInfo
The argument information parsed.
"""
return _ffi_api.ArgInfoFromJSON(json_obj) # type: ignore # pylint: disable=no-member

@staticmethod
def from_prim_func(func: PrimFunc) -> List["ArgInfo"]:
"""Extract a list of the argument information from PrimFunc.
Parameters
----------
func : PrimFunc
The PrimFunc to get argument information from.
Returns
-------
extracted : List[ArgInfo]
An array of the argument information derived.
"""
return _ffi_api.ArgInfoFromPrimFunc(func) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.TensorInfo")
class TensorInfo(ArgInfo):
"""Tensor argument information
Parameters
----------
dtype : DataType
The data type of the tensor.
shape : ShapeTuple
The shape of the tensor.
"""

dtype: DataType
shape: ShapeTuple

def __init__(
self,
dtype: DataType,
shape: Union[ShapeTuple, List[int]],
) -> None:
"""Constructor
Parameters
----------
dtype : DataType
The data type of the tensor.
shape : ShapeTuple
The shape of the tensor.
"""
if isinstance(shape, ShapeTuple):
shape_tuple = shape
else:
shape_tuple = ShapeTuple(shape)
self.__init_handle_by_constructor__(
_ffi_api.TensorInfo, # type: ignore # pylint: disable=no-member
dtype,
shape_tuple,
)
33 changes: 32 additions & 1 deletion python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
"""Utilities for meta schedule"""
import os
import shutil
from typing import Callable, Union
from typing import Any, Callable, Union

import psutil

from tvm._ffi import get_global_func, register_func
from tvm.error import TVMError
from tvm.ir import Array, Map
from tvm.runtime import String
from tvm.tir import FloatImm, IntImm


@register_func("meta_schedule.cpu_count")
Expand Down Expand Up @@ -95,3 +98,31 @@ def get_global_func_with_default_on_worker(
def remove_build_dir(artifact_path: str) -> None:
"""Clean up the build directory"""
shutil.rmtree(os.path.dirname(artifact_path))


def _json_de_tvm(obj: Any) -> Any:
"""Unpack a TVM nested container to a JSON object in python.
Parameters
----------
obj : Any
The TVM nested container to be unpacked.
Returns
-------
result : Any
The unpacked json object.
"""
if obj is None:
return None
if isinstance(obj, (int, float)):
return obj
if isinstance(obj, (IntImm, FloatImm)):
return obj.value
if isinstance(obj, (str, String)):
return str(obj)
if isinstance(obj, Array):
return [_json_de_tvm(i) for i in obj]
if isinstance(obj, Map):
return {_json_de_tvm(k): _json_de_tvm(v) for k, v in obj.items()}
raise TypeError("Not supported type: " + str(type(obj)))
2 changes: 1 addition & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
from .module import load_module, enabled, system_lib
from .container import String
from .container import String, ShapeTuple
from .params import save_param_dict, load_param_dict
Loading

0 comments on commit 4c8531d

Please sign in to comment.