Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Remove legacy const pattern functions #5387

Merged
merged 1 commit into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions src/arith/compute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,6 @@ template<typename Op>
inline PrimExpr ComputeReduce(
const Array<PrimExpr>& values, PrimExpr empty_value);

inline bool GetConst(PrimExpr e, int64_t* out) {
if (e.dtype().is_vector()) return false;
const int64_t* v = tir::as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}

// get a small constant int
inline bool GetConstInt(PrimExpr e, int* out) {
int64_t v1 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
return false;
}

template<>
inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b;
Expand Down
11 changes: 11 additions & 0 deletions src/arith/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
base.derived(), stride.derived(), lanes.derived());
}

template<typename TBase>
inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>
ramp(const Pattern<TBase>& base,
int stride,
int lanes) {
return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
base.derived(),
PConstWithTypeLike<TBase>(base.derived(), stride),
PConst<int>(lanes));
}

/*!
* \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value.
Expand Down
39 changes: 20 additions & 19 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "codegen_llvm.h"
#include "codegen_cpu.h"
#include "../../arith/pattern_match.h"
#include "../build_common.h"
namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -363,27 +364,27 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
int base = 0, width = 0;

int64_t base = 0, width = 0;
arith::PVar<IntImm> pbase, pstride;
arith::PVar<int> planes;
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
if (index.defined()) {
const RampNode* ramp = index.as<RampNode>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
if (arith::ramp(pbase, pstride, planes).Match(index)) {
base = pbase.Eval()->value;
int64_t xwith = planes.Eval() * pstride.Eval()->value;
width = 1;
while (width < xwith) {
width *= 2;
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
while (base % width) {
base -= base % width;
width *= 2;
}
} else if (auto* ptr = index.as<tir::IntImmNode>()) {
width = 1;
base = ptr->value;
}
}
llvm::MDNode* meta = md_tbaa_root_;
Expand All @@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
for (int64_t w = 1024; w >= width; w /= 2) {
int64_t b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
Expand Down
19 changes: 10 additions & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
#include <iomanip>
#include <cctype>
#include "codegen_c.h"
#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t) && !is_vol) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
if (auto* ptr = index.as<tir::IntImmNode>()) {
int64_t offset = ptr->value;
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
Expand Down Expand Up @@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
} else {
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
PrimExpr base;
if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base);

arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
std::ostringstream svalue_expr;
Expand Down Expand Up @@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
PrimExpr base;
if (GetRamp1Base(op->index, t.lanes(), &base)) {
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
Expand Down
6 changes: 3 additions & 3 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv::Value v;
if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index);
int size = 0;
CHECK(arith::GetConstInt(extent, &size))
auto* sizeptr = extent.as<tir::IntImmNode>();
CHECK(sizeptr)
<< "SPIRV only allows constant thread group size " << " get " << extent;
CHECK_LT(ts.dim_index, 3);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else {
v = builder_->GetWorkgroupID(ts.dim_index);
}
Expand Down
6 changes: 3 additions & 3 deletions src/tir/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) {
Bind_(make_const(DataType::UInt(64), const_offset * data_bytes),

if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
Expand Down
16 changes: 0 additions & 16 deletions src/tir/pass/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
return align;
}

/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
const RampNode* r = index.as<RampNode>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_PASS_IR_UTIL_H_
8 changes: 4 additions & 4 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
}
void VisitExpr_(const CallNode *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
int rw_mask = 0;
CHECK(arith::GetConstInt(op->args[4], &rw_mask));
const auto* rw_mask = op->args[4].as<IntImmNode>();
const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var);
CHECK(rw_mask);
// read
if (rw_mask & 1) {
if (rw_mask->value & 1) {
HandleUseVar(buffer_var);
}
if (rw_mask & 2) {
if (rw_mask->value & 2) {
HandleWriteVar(buffer_var);
}
this->VisitExpr(op->args[2]);
Expand Down
4 changes: 3 additions & 1 deletion src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
CHECK(arith::GetConstInt(attr->value, &(e.extent)))
const auto* ptr = attr->value.as<IntImmNode>();
CHECK(ptr)
<< "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
Expand Down
6 changes: 2 additions & 4 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <unordered_set>

#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->dtype);
if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) {
if (dev_type == kDLCPU) {
if (const auto* dev_type = device_type_.as<IntImmNode>()) {
if (dev_type->value == kDLCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt;
Expand Down
38 changes: 20 additions & 18 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

#include <unordered_set>

#include "../pass/ir_util.h"
#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"

Expand Down Expand Up @@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index);
} else {
PrimExpr base;
CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base))
arith::PVar<PrimExpr> base;
CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index))
<< "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store";
UpdatePattern(base);
UpdatePattern(base.Eval());
}
} else {
StmtVisitor::VisitStmt_(op);
Expand All @@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
int coeff = 0;
PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);

CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0)
const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
<< "LowerWarpMemory failed due to store index=" << index
<< ", require positive constant coefficient on warp index " << warp_index_
<< " but get " << mcoeff;

if (warp_coeff_ != 0) {
CHECK_EQ(warp_coeff_, coeff)
CHECK_EQ(warp_coeff_, mcoeff_as_int->value)
<< "LowerWarpMemory failed due to two different store coefficient to warp index";
} else {
warp_coeff_ = coeff;
warp_coeff_ = mcoeff_as_int->value;
}
}

Expand All @@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
// the warp index
Var warp_index_;
// the coefficient
int warp_coeff_{0};
int64_t warp_coeff_{0};
// analyzer.
arith::Analyzer* analyzer_;
};
Expand All @@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value) &&
value <= warp_size_ &&
warp_size_ % value == 0)
auto* value_as_int = op->value.as<IntImmNode>();
CHECK(value_as_int &&
value_as_int->value <= warp_size_ &&
warp_size_ % value_as_int->value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead";
Expand All @@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel";
} else {
width_ = value;
width_ = value_as_int->value;
warp_index_ = iv;
}
}
Expand Down Expand Up @@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
// in this access pattern.
std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
if (index.dtype().lanes() != 1) {
PrimExpr base, local_index, group;
CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base);
PrimExpr local_index, group;

arith::PVar<PrimExpr> base;
CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index));

std::tie(local_index, group) = SplitIndexByGroup(base.Eval());
local_index =
RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
return std::make_pair(local_index, group);
Expand Down
11 changes: 6 additions & 5 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
<< "Prefetch dim should be the same as buffer dim";

int block_size = 1,
elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(),
shape = 0;
elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();

int starts = op->bounds.size() - 1;
while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
&& elem_cnt >= block_size * shape) {
block_size *= shape;

while (starts > 0) {
auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>();
if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break;
block_size *= static_cast<int>(shape_as_int->value);
starts--;
}
PrimExpr stride(elem_cnt / block_size);
Expand Down
7 changes: 2 additions & 5 deletions src/tir/transforms/unroll_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,13 @@ class LoopUnroller : public StmtExprMutator {

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
int value = static_cast<int>(Downcast<Integer>(op->value)->value);
std::swap(value, auto_max_step_);
Stmt ret = this->VisitStmt(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
int value = 0;
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
bool explicit_unroll = Downcast<Integer>(op->value)->value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
Expand Down
Loading