diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 97ee7f7211d4..25ed2f9ae8d1 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! + * \brief Specialize parameters of PrimFunc. + * \param func The PrimFunc to be specialized. + * \param param_map The mapping from function params to the instance. + * \return The new function with parameter specialized. + * \note We can define a Meta TIR function with symbolic shape: + * + * \code + * @tvm.script.tir + * def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + * A = tir.match_buffer(a, (m, n), "float32") + * B = tir.match_buffer(b, (m, n), "float32") + * + * with tir.block([m, n], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + * + * Then we can make it specialized with given shapes or buffers. + * + * \code + * a, _, m, n = mem_copy.params + * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + * # or + * func = mem_copy.specialize({n: 16, m: 16}) + * \endcode + * + * \code {.language-id} + * @tvm.script.tir + * def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + * A = tir.match_buffer(a, (16, 16), "float32") + * B = tir.match_buffer(b, (16, 16), "float32") + * + * with tir.block([16, 16], "") as [vi, vj]: + * B[vi, vj] = A[vi, vj] + * \endcode + */ +PrimFunc Specialize(PrimFunc func, const Map& param_map); + /*! * \brief PrimFunc specific attribute names. * diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index f75e61c8859e..b1081d436150 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -89,7 +89,7 @@ def with_body(self, new_body, span=None): return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span) def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): - """Metaprogramming usage: specialize parameters of PrimFunc + """Specialize parameters of PrimFunc Parameters ---------- diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index c6d21ca70760..aa5f271c20c2 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -218,6 +218,23 @@ class PrimFuncSpecializer : public StmtExprMutator { std::unordered_map buffer_map_; }; +/*! + * \brief Update Specialize var map with buffer matching. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_buf The matching buffer. + * \param var_map The var mapping to be updated. + * \note This function will match target buffer's shape, strides and element_offset + * For example, we define a buffer in PrimFunc: + * A = tir.match_buffer(a, [m, n]) + * + * Then we match it with a buffer B = tir.decl_buffer((8, 16)) + * + * It means we have two var mappings here: m = 8 and n = 16 + * + * If the buffer signature is not a Var, the mapping will fail. + * e.g. A = tir.match_buffer(a, [m * 2, n + 1]) + */ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, VarMap* var_map) { // preliminaries @@ -275,6 +292,13 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer << " vs. " << specific_buf->offset_factor << "."; } +/*! + * \brief Update Specialize var map with parameter value. + * \param func The function to be specialized. + * \param param The given function parameter + * \param specific_expr The parameter value. + * \param var_map The var mapping to be updated. + */ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, VarMap* var_map) { // check param is in PrimFunc's parameters @@ -286,26 +310,28 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx (*var_map)[param] = specific_expr; } +/**************** Implementation ****************/ + +PrimFunc Specialize(PrimFunc func, const Map& param_map) { + VarMap var_map; + for (const auto& kv : param_map) { + const Var& param = kv.first; + const ObjectRef& instance = kv.second; + if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else if (instance->IsInstance()) { + UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); + } else { + LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " + << instance->GetTypeKey(); + } + } + return PrimFuncSpecializer::Specialize(func, std::move(var_map)); +} + /**************** FFI ****************/ -TVM_REGISTER_GLOBAL("tir.Specialize") - .set_body_typed)>([](PrimFunc func, - Map param_map) { - VarMap var_map; - for (const auto& kv : param_map) { - const Var& param = kv.first; - const ObjectRef& instance = kv.second; - if (instance->IsInstance()) { - UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); - } else if (instance->IsInstance()) { - UpdateSpecializeVarMap(func, param, Downcast(instance), &var_map); - } else { - LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " - << instance->GetTypeKey(); - } - } - return PrimFuncSpecializer::Specialize(std::move(func), std::move(var_map)); - }); +TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tir_specialize.py similarity index 94% rename from tests/python/unittest/test_tvmscript_meta_programming.py rename to tests/python/unittest/test_tir_specialize.py index 5b304d42d785..2e9f1110732a 100644 --- a/tests/python/unittest/test_tvmscript_meta_programming.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -150,7 +150,7 @@ def test_specialize_nothing(): assert func.same_as(matmul) # Pointer the same -def test_tensor_dimension_invariant_code_matmul(): +def test_specialize_matmul(): a, _, _, n = matmul.params # fully specialized func = matmul.specialize({a: tir.decl_buffer((128, 128))}) @@ -163,7 +163,7 @@ def test_tensor_dimension_invariant_code_matmul(): tvm.ir.assert_structural_equal(func, matmul_m_8x) -def test_tensor_dimension_invariant_code_elemwise(): +def test_specialize_elemwise(): a, c = element_wise.params C = element_wise.buffer_map[c] # fully specialized @@ -174,7 +174,7 @@ def test_tensor_dimension_invariant_code_elemwise(): tvm.ir.assert_structural_equal(func, element_wise_128_n) -def test_tensor_dimension_invariant_code_mem_copy(): +def test_specialize_mem_copy(): a, _, m, n, p, q = mem_copy.params # fully specialized func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) @@ -186,8 +186,14 @@ def test_tensor_dimension_invariant_code_mem_copy(): tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n) +def test_specialize_recursive_load(): + # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]] + pass + + if __name__ == "__main__": test_specialize_nothing() - test_tensor_dimension_invariant_code_matmul() - test_tensor_dimension_invariant_code_elemwise() - test_tensor_dimension_invariant_code_mem_copy() + test_specialize_matmul() + test_specialize_elemwise() + test_specialize_mem_copy() + test_specialize_recursive_load()