From 8b9d720ff4cf3219f419d8516eaaaf35206d02b9 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 24 May 2022 18:16:40 -0700 Subject: [PATCH] Add BlockFrame (#34) * add BlockFrame * upd * add T::axis::Spatial/Reduce * include dom in for-frame * finish T.axis.remap --- src/script/builder/builder.h | 20 +++- src/script/builder/tir/block_frame.cc | 135 ++++++++++++++++++++++ src/script/builder/tir/block_frame.h | 76 ++++++++++++ src/script/builder/tir/for_frame.cc | 69 ++++++----- src/script/builder/tir/for_frame.h | 27 ++--- src/script/builder/tir/prim_func_frame.cc | 4 +- 6 files changed, 283 insertions(+), 48 deletions(-) create mode 100644 src/script/builder/tir/block_frame.cc create mode 100644 src/script/builder/tir/block_frame.h diff --git a/src/script/builder/builder.h b/src/script/builder/builder.h index 7357223dd64e..53700ba8c64d 100644 --- a/src/script/builder/builder.h +++ b/src/script/builder/builder.h @@ -39,6 +39,10 @@ class FrameNode : public runtime::Object { TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, runtime::Object); public: + virtual void EnterWithScope() {} + + virtual void ExitWithScope() {} + virtual ~FrameNode() { for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { (*it)(); @@ -48,6 +52,17 @@ class FrameNode : public runtime::Object { class Frame : public runtime::ObjectRef { public: + void EnterWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->EnterWithScope(); + } + + void ExitWithScope() { + ICHECK(data_ != nullptr); + static_cast(data_.get())->ExitWithScope(); + data_.reset(); + } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode); protected: @@ -67,15 +82,14 @@ class BuilderNode : public runtime::Object { public: template - TFrame FindFrame() const { + Optional FindFrame() const { using TFrameNode = typename TFrame::ContainerType; for (auto it = frames.rbegin(); it != frames.rend(); ++it) { if (const TFrameNode* p = (*it).template as()) { return GetRef(p); } } - LOG(FATAL) << "IndexError: Cannot find frame: " << TFrameNode::_type_key; - throw; + return NullOpt; } }; diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc new file mode 100644 index 000000000000..14ae64fdf533 --- /dev/null +++ b/src/script/builder/tir/block_frame.cc @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./block_frame.h" + +#include "./for_frame.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +BlockFrame::BlockFrame(String name) { + ObjectPtr n = make_object(); + n->name = name; + n->iter_vars.clear(); + n->reads = NullOpt; + n->writes = NullOpt; + n->init = NullOpt; + n->alloc_buffers.clear(); + n->match_buffers.clear(); + n->annotations.clear(); + n->iter_values.clear(); + n->predicate = NullOpt; + data_ = n; +} + +namespace axis { + +// TODO(@junrushao1994): figure out the Block syntax without BlockRealize + +tvm::tir::IterVar PushBlockVar(tvm::tir::IterVar iter_var, PrimExpr binding) { + if (const BlockFrameNode* opt_frame = Builder::Current()->frames.back().as()) { + BlockFrame frame = GetRef(opt_frame); + frame->iter_vars.push_back(iter_var); + frame->iter_values.push_back(binding); + } else { + LOG(FATAL) << "TypeError: The last frame is not BlockFrame"; + } + return iter_var; +} + +tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype) { + using namespace tvm::tir; + ICHECK(dom.defined()) << "Spatial axis must have a domain"; + int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); + return PushBlockVar(IterVar(/*dom=*/dom, // + /*var=*/Var("_", dtype.with_bits(bits)), // + /*iter_type=*/IterVarType::kDataPar, // + /*thread_tag=*/""), + binding); +} + +tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) { + using namespace tvm::tir; + ICHECK(dom.defined()) << "Spatial axis must have a domain"; + int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); + return PushBlockVar(IterVar(/*dom=*/dom, // + /*var=*/Var("_", dtype.with_bits(bits)), // + /*iter_type=*/IterVarType::kCommReduce, // + /*thread_tag=*/""), + binding); +} + +Array Remap(String kinds, Array bindings, DataType dtype) { + using namespace tvm::tir; + Array results; + ICHECK_EQ(kinds.size(), bindings.size()); + int n = bindings.size(); + results.reserve(n); + for (int i = 0; i < n; ++i) { + char c = kinds.c_str()[i]; + PrimExpr e = bindings[i]; + const VarNode* v = e.as(); + ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap"; + Range dom{nullptr}; + for (const auto& frame : Builder::Current()->frames) { + if (const auto* for_frame = frame.as()) { + ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size()); + int n = for_frame->doms.size(); + for (int i = 0; i < n; ++i) { + if (for_frame->vars[i].get() == v) { + dom = for_frame->doms[i]; + break; + } + } + if (dom.defined()) { + break; + } + } + } + ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef(v); + DataType dtype = v->dtype; + if (c == 'S') { + results.push_back(PushBlockVar(IterVar(/*dom=*/dom, + /*var=*/Var("_", dtype), + /*iter_type=*/IterVarType::kDataPar, + /*thread_tag=*/""), + e)); + } else if (c == 'R') { + results.push_back(PushBlockVar(IterVar(/*dom=*/dom, + /*var=*/Var("_", dtype), + /*iter_type=*/IterVarType::kCommReduce, + /*thread_tag=*/""), + e)); + } else { + LOG(FATAL) << "Unknown axis kind: " << c; + } + } + return results; +} + +} // namespace axis + +TVM_REGISTER_NODE_TYPE(BlockFrameNode); + +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm diff --git a/src/script/builder/tir/block_frame.h b/src/script/builder/tir/block_frame.h new file mode 100644 index 000000000000..bec8db18b7ef --- /dev/null +++ b/src/script/builder/tir/block_frame.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_ +#define TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_ + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace builder { +namespace tir { + +class BlockFrameNode : public TIRFrameNode { + public: + String name; + Array iter_vars; + Optional> reads; + Optional> writes; + Optional init; + Array alloc_buffers; + Array match_buffers; + Map annotations; + + Array iter_values; + Optional predicate; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("iter_vars", &iter_vars); + v->Visit("reads", &reads); + v->Visit("writes", &writes); + v->Visit("init", &init); + v->Visit("alloc_buffers", &alloc_buffers); + v->Visit("match_buffers", &match_buffers); + v->Visit("annotations", &annotations); + v->Visit("iter_values", &iter_values); + v->Visit("predicate", &predicate); + } + + static constexpr const char* _type_key = "script.builder.tir.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode); +}; + +class BlockFrame : public TIRFrame { + public: + explicit BlockFrame(String name); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode); +}; + +namespace axis { +tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype); +tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype); +Array Remap(String kinds, Array bindings, DataType dtype); +} // namespace axis +} // namespace tir +} // namespace builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_BUILDER_TIR_BLOCK_FRAME_H_ diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc index e17191036120..3aa02bf48997 100644 --- a/src/script/builder/tir/for_frame.cc +++ b/src/script/builder/tir/for_frame.cc @@ -23,24 +23,28 @@ namespace script { namespace builder { namespace tir { -ForFrame::ForFrame(Array loop_vars, ForFrame::FMakeForLoop f_make_for_loop) { +ForFrame::ForFrame(Array vars, Array doms, + ForFrameNode::FMakeForLoop f_make_for_loop) { ObjectPtr n = make_object(); - n->loop_vars = std::move(loop_vars); + n->vars = std::move(vars); + n->doms = std::move(doms); n->f_make_for_loop = std::move(f_make_for_loop); data_ = std::move(n); } -#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ - With Method(PrimExpr min, PrimExpr extent, Map attrs) { \ - ObjectPtr n = make_object(); \ - int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->loop_vars = {tvm::tir::Var("v", DataType::Int(bits))}; \ - n->f_make_for_loop = [=](Array vars, tvm::tir::Stmt body) -> tvm::tir::For { \ - ICHECK_EQ(vars.size(), 1); \ - return tvm::tir::For(/*loop_var=*/vars[0], min, extent, Kind, body, \ - /*thread_binding=*/NullOpt, attrs); \ - }; \ - return With(n); \ +#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ + ForFrame Method(PrimExpr min, PrimExpr extent, Map attrs) { \ + using namespace tvm::tir; \ + ObjectPtr n = make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType::Int(bits))}; \ + n->doms = {Range(min, extent)}; \ + n->f_make_for_loop = [attrs](Array vars, Array doms, Stmt body) { \ + ICHECK_EQ(vars.size(), 1); \ + ICHECK_EQ(doms.size(), 1); \ + return For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, NullOpt, attrs); \ + }; \ + return ForFrame(n); \ } TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Serial, tvm::tir::ForKind::kSerial); @@ -50,39 +54,44 @@ TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Unroll, tvm::tir::ForKind::kUnrolled); #undef TVM_SCRIPT_BUILDER_TIR_FOR_CREATE -With ThreadBinding(PrimExpr min, PrimExpr extent, String thread, - Map attrs) { +ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread, Map attrs) { using namespace tvm::tir; ObjectPtr n = make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); - n->loop_vars = {Var("v", DataType::Int(bits))}; - n->f_make_for_loop = [=](Array vars, Stmt body) -> For { + n->vars = {Var("v", DataType::Int(bits))}; + n->doms = {Range(min, extent)}; + n->f_make_for_loop = [attrs, thread](Array vars, Array doms, Stmt body) -> For { ICHECK_EQ(vars.size(), 1); - IterVar iter_var(Range(nullptr), Var(ObjectPtr(nullptr)), IterVarType::kThreadIndex, - thread); - return For(vars[0], min, extent, tvm::tir::ForKind::kThreadBinding, body, iter_var, attrs); + ICHECK_EQ(doms.size(), 1); + IterVar iter_var(Range(nullptr), NullValue(), IterVarType::kThreadIndex, thread); + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, + attrs); }; - return With(n); + return ForFrame(n); } -With Grid(Array extents) { +ForFrame Grid(Array extents) { using namespace tvm::tir; ObjectPtr n = make_object(); - n->loop_vars.reserve(extents.size()); + n->vars.reserve(extents.size()); + n->doms.reserve(extents.size()); for (const auto& extent : extents) { - n->loop_vars.push_back(Var("v", extent.dtype())); + DataType dtype = extent.dtype(); + n->vars.push_back(Var("v", extent.dtype())); + n->doms.push_back(Range(make_const(dtype, 0), extent)); } - n->f_make_for_loop = [=](Array vars, Stmt body) -> Stmt { - ICHECK_EQ(extents.size(), vars.size()); - int n = extents.size(); + n->f_make_for_loop = [](Array vars, Array doms, Stmt body) -> Stmt { + ICHECK_EQ(vars.size(), doms.size()); + int n = vars.size(); for (int i = n - 1; i >= 0; --i) { + Range dom = doms[i]; Var var = vars[i]; - PrimExpr extent = extents[i]; - body = For(var, Integer(0), extent, ForKind::kSerial, body, /*thread_binding=*/NullOpt, {}); + body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body), + /*thread_binding=*/NullOpt, /*annotations=*/{}); } return body; }; - return With(n); + return ForFrame(n); } TVM_REGISTER_NODE_TYPE(ForFrameNode); diff --git a/src/script/builder/tir/for_frame.h b/src/script/builder/tir/for_frame.h index b4e634905eed..7b94645b46e8 100644 --- a/src/script/builder/tir/for_frame.h +++ b/src/script/builder/tir/for_frame.h @@ -34,13 +34,15 @@ namespace tir { class ForFrameNode : public TIRFrameNode { public: using FMakeForLoop = - runtime::TypedPackedFunc, tvm::tir::Stmt)>; + runtime::TypedPackedFunc, Array, tvm::tir::Stmt)>; - Array loop_vars; + Array vars; + Array doms; FMakeForLoop f_make_for_loop; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("loop_vars", &loop_vars); + v->Visit("vars", &vars); + v->Visit("doms", &doms); // `f_make_for_loop` is not visited. } @@ -50,9 +52,8 @@ class ForFrameNode : public TIRFrameNode { class ForFrame : public TIRFrame { public: - using FMakeForLoop = ForFrameNode::FMakeForLoop; - - explicit ForFrame(Array loop_vars, FMakeForLoop f_make_for_loop); + explicit ForFrame(Array vars, Array doms, + ForFrameNode::FMakeForLoop f_make_for_loop); void EnterWithScope() { ICHECK(data_ != nullptr); } @@ -64,13 +65,13 @@ class ForFrame : public TIRFrame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode); }; -With Serial(PrimExpr min, PrimExpr extent, Map annotations); -With Parallel(PrimExpr min, PrimExpr extent, Map annotations); -With Vectorized(PrimExpr min, PrimExpr extent, Map annotations); -With Unroll(PrimExpr min, PrimExpr extent, Map annotations); -With ThreadBinding(PrimExpr min, PrimExpr extent, String thread, - Map annotations); -With Grid(Array extents); +ForFrame Serial(PrimExpr min, PrimExpr extent, Map annotations); +ForFrame Parallel(PrimExpr min, PrimExpr extent, Map annotations); +ForFrame Vectorized(PrimExpr min, PrimExpr extent, Map annotations); +ForFrame Unroll(PrimExpr min, PrimExpr extent, Map annotations); +ForFrame ThreadBinding(PrimExpr min, PrimExpr extent, String thread, + Map annotations); +ForFrame Grid(Array extents); } // namespace tir } // namespace builder diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 3736f692de53..70fb93e0e53a 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -25,13 +25,13 @@ namespace builder { namespace tir { void Arg(tvm::tir::Var var) { - PrimFuncFrame frame = Builder::Current()->FindFrame(); + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); frame->args.push_back(var); } void Arg(tvm::tir::Buffer buffer) { using namespace tvm::tir; - PrimFuncFrame frame = Builder::Current()->FindFrame(); + PrimFuncFrame frame = Builder::Current()->FindFrame().value(); Var handle(buffer->name + "_handle", DataType::Handle()); frame->args.push_back(handle); frame->buffer_map.Set(handle, buffer);