diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index cb66ece750a66..f75e61c8859ec 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -97,6 +97,41 @@ def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): param_map : Mapping[Var, Union[PrimExpr, Buffer]] The mapping from function params to the instance + Examples + -------- + We can define a Meta TIR function with symbolic shape: + + .. code-block:: python + + @tvm.script.tir + def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: + A = tir.match_buffer(a, (m, n), "float32") + B = tir.match_buffer(b, (m, n), "float32") + + with tir.block([m, n], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + + Then we can make it specialized with given shapes or buffers. + + .. code-block:: python + + a, _, m, n = mem_copy.params + func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + # or + func = mem_copy.specialize({n: 16, m: 16}) + + The specialized function: + + .. code-block:: python + + @tvm.script.tir + def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + with tir.block([16, 16], "") as [vi, vj]: + B[vi, vj] = A[vi, vj] + Returns ------- func : PrimFunc