Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Jul 1, 2021
1 parent b629c0b commit 22a61ff
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 25 deletions.
38 changes: 38 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
* \param param_map The mapping from function params to the instance.
* \return The new function with parameter specialized.
* \note We can define a Meta TIR function with symbolic shape:
*
* \code
* @tvm.script.tir
* def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
* A = tir.match_buffer(a, (m, n), "float32")
* B = tir.match_buffer(b, (m, n), "float32")
*
* with tir.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*
* Then we can make it specialized with given shapes or buffers.
*
* \code
* a, _, m, n = mem_copy.params
* func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
* # or
* func = mem_copy.specialize({n: 16, m: 16})
* \endcode
*
* \code {.language-id}
* @tvm.script.tir
* def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
* A = tir.match_buffer(a, (16, 16), "float32")
* B = tir.match_buffer(b, (16, 16), "float32")
*
* with tir.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*/
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);

/*!
* \brief PrimFunc specific attribute names.
*
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def with_body(self, new_body, span=None):
return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span)

def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
"""Metaprogramming usage: specialize parameters of PrimFunc
"""Specialize parameters of PrimFunc
Parameters
----------
Expand Down
62 changes: 44 additions & 18 deletions src/tir/ir/specialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,23 @@ class PrimFuncSpecializer : public StmtExprMutator {
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
};

/*!
* \brief Update Specialize var map with buffer matching.
* \param func The function to be specialized.
* \param param The given function parameter
* \param specific_buf The matching buffer.
* \param var_map The var mapping to be updated.
* \note This function will match target buffer's shape, strides and element_offset
* For example, we define a buffer in PrimFunc:
* A = tir.match_buffer(a, [m, n])
*
* Then we match it with a buffer B = tir.decl_buffer((8, 16))
*
* It means we have two var mappings here: m = 8 and n = 16
*
* If the buffer signature is not a Var, the mapping will fail.
* e.g. A = tir.match_buffer(a, [m * 2, n + 1])
*/
void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf,
VarMap* var_map) {
// preliminaries
Expand Down Expand Up @@ -275,6 +292,13 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer
<< " vs. " << specific_buf->offset_factor << ".";
}

/*!
* \brief Update Specialize var map with parameter value.
* \param func The function to be specialized.
* \param param The given function parameter
* \param specific_expr The parameter value.
* \param var_map The var mapping to be updated.
*/
void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr,
VarMap* var_map) {
// check param is in PrimFunc's parameters
Expand All @@ -286,26 +310,28 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx
(*var_map)[param] = specific_expr;
}

/**************** Implementation ****************/

PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
VarMap var_map;
for (const auto& kv : param_map) {
const Var& param = kv.first;
const ObjectRef& instance = kv.second;
if (instance->IsInstance<BufferNode>()) {
UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance), &var_map);
} else if (instance->IsInstance<PrimExprNode>()) {
UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
} else {
LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
<< instance->GetTypeKey();
}
}
return PrimFuncSpecializer::Specialize(func, std::move(var_map));
}

/**************** FFI ****************/

TVM_REGISTER_GLOBAL("tir.Specialize")
.set_body_typed<PrimFunc(PrimFunc, Map<Var, ObjectRef>)>([](PrimFunc func,
Map<Var, ObjectRef> param_map) {
VarMap var_map;
for (const auto& kv : param_map) {
const Var& param = kv.first;
const ObjectRef& instance = kv.second;
if (instance->IsInstance<BufferNode>()) {
UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance), &var_map);
} else if (instance->IsInstance<PrimExprNode>()) {
UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
} else {
LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
<< instance->GetTypeKey();
}
}
return PrimFuncSpecializer::Specialize(std::move(func), std::move(var_map));
});
TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize);

} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_specialize_nothing():
assert func.same_as(matmul) # Pointer the same


def test_tensor_dimension_invariant_code_matmul():
def test_specialize_matmul():
a, _, _, n = matmul.params
# fully specialized
func = matmul.specialize({a: tir.decl_buffer((128, 128))})
Expand All @@ -163,7 +163,7 @@ def test_tensor_dimension_invariant_code_matmul():
tvm.ir.assert_structural_equal(func, matmul_m_8x)


def test_tensor_dimension_invariant_code_elemwise():
def test_specialize_elemwise():
a, c = element_wise.params
C = element_wise.buffer_map[c]
# fully specialized
Expand All @@ -174,7 +174,7 @@ def test_tensor_dimension_invariant_code_elemwise():
tvm.ir.assert_structural_equal(func, element_wise_128_n)


def test_tensor_dimension_invariant_code_mem_copy():
def test_specialize_mem_copy():
a, _, m, n, p, q = mem_copy.params
# fully specialized
func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)})
Expand All @@ -186,8 +186,14 @@ def test_tensor_dimension_invariant_code_mem_copy():
tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n)


def test_specialize_recursive_load():
# TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]]
pass


if __name__ == "__main__":
test_specialize_nothing()
test_tensor_dimension_invariant_code_matmul()
test_tensor_dimension_invariant_code_elemwise()
test_tensor_dimension_invariant_code_mem_copy()
test_specialize_matmul()
test_specialize_elemwise()
test_specialize_mem_copy()
test_specialize_recursive_load()

0 comments on commit 22a61ff

Please sign in to comment.