diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h new file mode 100644 index 000000000000..9528be2a85ad --- /dev/null +++ b/include/tvm/meta_schedule/space_generator.h @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SPACE_GENERATOR_H_ +#define TVM_META_SCHEDULE_SPACE_GENERATOR_H_ + +#include +#include + +namespace tvm { +namespace meta_schedule { + +// Forward declaration +class TuneContext; + +/*! \brief The abstract class for design space generation. */ +class SpaceGeneratorNode : public Object { + public: + /*! \brief Default destructor */ + virtual ~SpaceGeneratorNode() = default; + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& tune_context) = 0; + + /*! + * \brief Generate design spaces given a module. + * \param mod The module used for design space generation. + * \return The generated design spaces, i.e., schedules. + */ + virtual Array GenerateDesignSpace(const IRModule& mod) = 0; + + static constexpr const char* _type_key = "meta_schedule.SpaceGenerator"; + TVM_DECLARE_BASE_OBJECT_INFO(SpaceGeneratorNode, Object); +}; + +/*! \brief The design space generator with customized methods on the python-side. */ +class PySpaceGeneratorNode : public SpaceGeneratorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `GenerateDesignSpace` method. + * \param mod The module used for design space generation. + * \return The generated design spaces, i.e., schedules. + */ + using FGenerateDesignSpace = runtime::TypedPackedFunc(const IRModule&)>; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `GenerateDesignSpace` function. */ + FGenerateDesignSpace f_generate_design_space; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_generate_design_space` is not visited + } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + f_initialize_with_tune_context(tune_context); + } + + Array GenerateDesignSpace(const IRModule& mod) final { + return f_generate_design_space(mod); + } + + static constexpr const char* _type_key = "meta_schedule.PySpaceGenerator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PySpaceGeneratorNode, SpaceGeneratorNode); +}; + +/*! + * \brief Managed reference to SpaceGeneratorNode. + * \sa SpaceGeneratorNode + */ +class SpaceGenerator : public ObjectRef { + protected: + SpaceGenerator() = default; + + public: + /*! + * \brief Create a design space generator with customized methods on the python-side. + * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. + * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator PySpaceGenerator( + PySpaceGeneratorNode::FInitializeWithTuneContext initialize_with_tune_context_func, + PySpaceGeneratorNode::FGenerateDesignSpace generate_design_space_func); + + /*! + * \brief Create a design space generator that is union of multiple design space generators. + * \param space_generators An array of design space generators to be unioned. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SPACE_GENERATOR_H_ diff --git a/src/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h similarity index 89% rename from src/meta_schedule/tune_context.h rename to include/tvm/meta_schedule/tune_context.h index 454b8095aabc..87a3a491c8f3 100644 --- a/src/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,6 +20,7 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include #include #include @@ -33,6 +34,8 @@ class TuneContextNode : public runtime::Object { Optional mod; /*! \brief The target to be tuned for. */ Optional target; + /*! \brief The design space generator. */ + Optional space_generator; /*! \brief The name of the tuning task. */ Optional task_name; /*! \brief The random state. */ @@ -43,6 +46,7 @@ class TuneContextNode : public runtime::Object { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); v->Visit("target", &target); + v->Visit("space_generator", &space_generator); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); @@ -62,12 +66,14 @@ class TuneContext : public runtime::ObjectRef { * \brief Constructor. * \param mod The workload to be tuned. * \param target The target to be tuned for. + * \param space_generator The design space generator. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. */ TVM_DLL explicit TuneContext(Optional mod, // Optional target, // + Optional space_generator, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index f0e8af223511..c07b28b4fc9f 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -17,4 +17,5 @@ """Package `tvm.meta_schedule`. The meta schedule infrastructure.""" from . import builder from . import arg_info +from . import space_generator from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py new file mode 100644 index 000000000000..af759d43b34a --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.space_generator package. +Meta Schedule design space generators that generates design +space for generation of measure candidates. +""" + +from .space_generator import SpaceGenerator, PySpaceGenerator +from .space_generator_union import SpaceGeneratorUnion +from .schedule_fn import ScheduleFn diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py new file mode 100644 index 000000000000..64edd9e0bf8c --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta schedule design space generators that generates design +space via a schedule function. +""" +from typing import TYPE_CHECKING, Callable, List, Union + +from tvm.ir import IRModule +from tvm.ir.container import Array +from tvm.tir.schedule import Schedule + +from .space_generator import PySpaceGenerator + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class ScheduleFn(PySpaceGenerator): + """A design space generator with design spaces specified by a schedule function.""" + + # Multiple cases of schedule functions supported + SCH_FN_TYPE = Union[ + Callable[[IRModule], None], # No output + Callable[[IRModule], Schedule], # Single output + Callable[[IRModule], List[Schedule]], # Multiple outputs + ] + + def __init__(self, sch_fn: SCH_FN_TYPE): + """Constructor. + + Parameters + ---------- + sch_fn : SCH_FN_TYPE + The schedule function. + """ + super().__init__() + self.sch_fn = sch_fn + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the design space generator with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the design space generator. + """ + + def generate_design_space(self, mod: IRModule) -> List[Schedule]: + """Generate design spaces given a module. + + Parameters + ---------- + mod : IRModule + The module used for design space generation. + + Returns + ------- + design_spaces : List[Schedule] + The generated design spaces, i.e., schedules. + """ + sch = Schedule(mod) # Make sure the schedule is traced + result = self.sch_fn(sch) # Call the schedule function + if result is None: # Case 1. No output + return [sch] + if isinstance(result, Schedule): # Case 2. Single output + return [result] + if isinstance(result, (list, tuple, Array)): # Case 3. Multiple outputs + for ret in result: # enumerate the outputs + if not isinstance(ret, Schedule): + raise TypeError( + "Wrong type of element in the list, expected Schedule got " + + f"'{type(ret)}': {ret}" + ) + return result + raise TypeError(f"Unexpected return type {type(result)}: {result}") diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py new file mode 100644 index 000000000000..798753d91345 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta Schedule design space generators that generates design +space for generation of measure candidates. +""" + +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.SpaceGenerator") +class SpaceGenerator(Object): + """The abstract design space generator interface.""" + + def initialize_with_tune_context( + self, + tune_context: "TuneContext", + ) -> None: + """Initialize the design space generator with tuning context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the design space generator. + """ + _ffi_api.SpaceGeneratorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def generate_design_space(self, mod: IRModule) -> List[Schedule]: + """Generate design spaces given a module. + + Parameters + ---------- + mod : IRModule + The module used for design space generation. + + Returns + ------- + design_spaces : List[Schedule] + The generated design spaces, i.e., schedules. + """ + return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PySpaceGenerator") +class PySpaceGenerator(SpaceGenerator): + """An abstract design space generator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + def f_generate_design_space(mod: IRModule) -> List[Schedule]: + return self.generate_design_space(mod) + + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorPySpaceGenerator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_generate_design_space, + ) + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + raise NotImplementedError + + def generate_design_space(self, mod: IRModule) -> List[Schedule]: + raise NotImplementedError diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py new file mode 100644 index 000000000000..5541ab0b5026 --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Union of meta Schedule design space generators.""" +from typing import List + +from tvm._ffi import register_object + +from .. import _ffi_api +from .space_generator import SpaceGenerator + + +@register_object("meta_schedule.SpaceGeneratorUnion") +class SpaceGeneratorUnion(SpaceGenerator): + """Union of design space generators.""" + + def __init__(self, space_generators: List[SpaceGenerator]): + """Constructor. + + Parameters + ---------- + space_generators : List[SpaceGenerator] + The list of design space generators to be unioned. + """ + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorSpaceGeneratorUnion, # type: ignore # pylint: disable=no-member + space_generators, + ) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index b2fee178ebd6..4c83b9afa289 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,7 +16,7 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional +from typing import Optional, TYPE_CHECKING from tvm import IRModule from tvm.runtime import Object @@ -26,6 +26,9 @@ from . import _ffi_api +if TYPE_CHECKING: + from .space_generator import SpaceGenerator + @register_object("meta_schedule.TuneContext") class TuneContext(Object): @@ -68,6 +71,7 @@ def __init__( self, mod: Optional[IRModule] = None, target: Optional[Target] = None, + space_generator: Optional["SpaceGenerator"] = None, task_name: Optional[str] = None, rand_state: int = -1, num_threads: Optional[int] = None, @@ -80,6 +84,8 @@ def __init__( The workload to be optimized. target : Optional[Target] = None The target to be optimized for. + space_generator : Optional[SpaceGenerator] = None + The design space generator. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -95,6 +101,7 @@ def __init__( _ffi_api.TuneContext, # type: ignore # pylint: disable=no-member mod, target, + space_generator, task_name, rand_state, num_threads, diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc new file mode 100644 index 000000000000..6df8da2f7aa1 --- /dev/null +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +SpaceGenerator SpaceGenerator::PySpaceGenerator( + PySpaceGeneratorNode::FInitializeWithTuneContext f_initialize_with_tune_context, + PySpaceGeneratorNode::FGenerateDesignSpace f_generate_design_space) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_generate_design_space = std::move(f_generate_design_space); + return SpaceGenerator(n); +} + +TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); +TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorInitializeWithTuneContext") + .set_body_method(&SpaceGeneratorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorGenerateDesignSpace") + .set_body_method(&SpaceGeneratorNode::GenerateDesignSpace); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPySpaceGenerator") + .set_body_typed(SpaceGenerator::PySpaceGenerator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc new file mode 100644 index 000000000000..9c2e3eeabe09 --- /dev/null +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The union of design space generators. */ +class SpaceGeneratorUnionNode : public SpaceGeneratorNode { + public: + /*! \brief The array of design space generators unioned, could be recursive. */ + Array space_generators; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("space_generators", &space_generators); } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + // Initialize each space generator. + for (const SpaceGenerator& space_generator : space_generators) { + space_generator->InitializeWithTuneContext(tune_context); + } + } + + Array GenerateDesignSpace(const IRModule& mod) final { + Array design_spaces; + for (const SpaceGenerator& space_generator : space_generators) { + // Generate partial design spaces from each design space generator. + Array partial = space_generator->GenerateDesignSpace(mod); + // Merge the partial design spaces. + design_spaces.insert(design_spaces.end(), partial.begin(), partial.end()); + } + return design_spaces; + } + + static constexpr const char* _type_key = "meta_schedule.SpaceGeneratorUnion"; + TVM_DECLARE_FINAL_OBJECT_INFO(SpaceGeneratorUnionNode, SpaceGeneratorNode); +}; + +/*! + * \brief Create a design space generator as union of given design space generators. + * \param space_generators Array of the design space generators to be unioned. + * \return The design space generator created. + */ +SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_generators) { + ObjectPtr n = make_object(); + n->space_generators = std::move(space_generators); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") + .set_body_typed(SpaceGenerator::SpaceGeneratorUnion); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 6e80081c1ec2..ad82b6f514a2 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -16,11 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -#include "./tune_context.h" - #include #include +#include "./utils.h" + namespace tvm { namespace meta_schedule { @@ -28,6 +28,7 @@ namespace meta_schedule { * \brief Constructor function of TuneContext class. * \param mod The mod to be optimized. * \param target The target to be optimized for. + * \param space_generator The design space generator. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -35,12 +36,14 @@ namespace meta_schedule { */ TuneContext::TuneContext(Optional mod, // Optional target, // + Optional space_generator, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { ObjectPtr n = make_object(); n->mod = mod; n->target = target; + n->space_generator = space_generator; n->task_name = task_name; if (rand_state == -1) { rand_state = std::random_device()(); @@ -55,10 +58,11 @@ TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") .set_body_typed([](Optional mod, // Optional target, // + Optional space_generator, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, task_name, rand_state, num_threads); + return TuneContext(mod, target, space_generator, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index e6eae4d0d915..a2b5ac4d3184 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -21,8 +21,10 @@ #include #include +#include +#include -#include "../src/support/array.h" +#include "../support/array.h" namespace tvm { namespace meta_schedule {} // namespace meta_schedule diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py new file mode 100644 index 000000000000..3ab60aced197 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule SpaceGenerator """ +# pylint: disable=missing-function-docstring + +import sys +import math + +import pytest + +import tvm +from tvm import tir +from tvm.script import ty + +from tvm.tir.schedule import Schedule, Trace +from tvm.meta_schedule.space_generator import ScheduleFn, SpaceGeneratorUnion + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.tir +class Matmul: + def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + tir.func_attr({"global_symbol": "main"}) + A = tir.match_buffer(a, (1024, 1024), "float32") + B = tir.match_buffer(b, (1024, 1024), "float32") + C = tir.match_buffer(c, (1024, 1024), "float32") + with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_space_generator_schedule_fn(): + mod = Matmul() + space_generator = ScheduleFn(sch_fn=schedule_matmul) + design_spaces = space_generator.generate_design_space(mod) + assert len(design_spaces) == 1 + (schedule,) = design_spaces + _check_correct(schedule) + + +def test_meta_schedule_design_space_generator_union(): + mod = Matmul() + space_generator = ScheduleFn(sch_fn=schedule_matmul) + space_generator_union = SpaceGeneratorUnion([space_generator, space_generator]) + design_spaces = space_generator_union.generate_design_space(mod) + assert len(design_spaces) == 2 + for design_space in design_spaces: + _check_correct(design_space) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_tune_context.py b/tests/python/unittest/test_meta_schedule_tune_context.py index a6c2101928d7..2da4c85ab421 100644 --- a/tests/python/unittest/test_meta_schedule_tune_context.py +++ b/tests/python/unittest/test_meta_schedule_tune_context.py @@ -46,7 +46,7 @@ def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=n def test_tune_context_create(): mod = Matmul() - context = TuneContext(mod, Target("llvm"), "Test Task") + context = TuneContext(mod=mod, target=Target("llvm"), task_name="Test Task") assert context.num_threads > 0 assert context.rand_state != -1 assert context.task_name == "Test Task"