-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] DeduceBound #40
Changes from 25 commits
2b9cb33
461d778
bccdfce
31478b2
2b4d091
2539265
0ea07ab
cf9f3ba
8f72e50
9106944
1f1ff8f
b409040
e4bee27
e3a5f9e
b1617a8
7abe378
f829694
96ded33
5a8fa91
71349f6
f3e3fa9
d5aedde
35683a8
d9794bb
696976a
434835a
2527b2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# pylint: disable=protected-access, no-member | ||
"""Arithmetic data structure and utility""" | ||
from __future__ import absolute_import as _abs | ||
from ._ctypes._node import NodeBase, register_node | ||
from . import _api_internal | ||
|
||
@register_node | ||
class IntSet(NodeBase): | ||
"""Represent a set of integer in one dimension.""" | ||
def is_nothing(self): | ||
"""Whether the set represent nothing""" | ||
return _api_internal._IntSetIsNothing(self) | ||
|
||
def is_everything(self): | ||
"""Whether the set represent everything""" | ||
return _api_internal._IntSetIsEverything(self) | ||
|
||
@register_node | ||
class IntervalSet(IntSet): | ||
"""Represent set of continuous interval""" | ||
def min(self): | ||
"""get the minimum value""" | ||
return _api_internal._IntervalSetGetMin(self) | ||
|
||
def max(self): | ||
"""get the maximum value""" | ||
return _api_internal._IntervalSetGetMax(self) | ||
|
||
@register_node | ||
class StrideSet(IntSet): | ||
"""Represent set of strided integers""" | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* Implementation of API functions related to arith | ||
* \file api_arith.cc | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir.h> | ||
#include <tvm/api_registry.h> | ||
#include "../arithmetic/int_set.h" | ||
#include "../arithmetic/int_set_internal.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
TVM_REGISTER_API(_arith_intset_single_point) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntSet::single_point(args[0]); | ||
}); | ||
|
||
TVM_REGISTER_API(_arith_intset_interval) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = IntSet::interval(args[0], args[1]); | ||
}); | ||
|
||
TVM_REGISTER_API(_arith_DeduceBound) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = DeduceBound(args[0], args[1], args[2]); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntervalSetGetMin) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().min(); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntervalSetGetMax) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().max(); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntSetIsNothing) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().is_nothing(); | ||
}); | ||
|
||
TVM_REGISTER_API(_IntSetIsEverything) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = args[0].operator IntSet().is_everything(); | ||
}); | ||
|
||
} // namespace arith | ||
} // namespace tvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file bound_deducer.cc | ||
* \brief Utility to deduce bound of expression | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir_pass.h> | ||
#include <tvm/ir_visitor.h> | ||
#include <tvm/api_registry.h> | ||
#include <unordered_set> | ||
#include <unordered_map> | ||
#include "./int_set.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
using namespace ir; | ||
using Halide::Internal::Interval; | ||
|
||
// a visitor to find the path to the target variable | ||
// from a expression. | ||
class VariablePathFinder: public IRVisitor { | ||
public: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VariablePathFinder There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to look out for errors when a variable appears in multiple locations in the expression |
||
explicit VariablePathFinder(Var target) : target_(target) {} | ||
|
||
void Visit(const NodeRef& node) final { | ||
if (visited_.count(node.get()) != 0) return; | ||
visited_.insert(node.get()); | ||
|
||
if (!found_) path_.push_back(node.get()); | ||
if (node.same_as(target_)) found_ = true; | ||
IRVisitor::Visit(node); | ||
if (!found_) path_.pop_back(); | ||
} | ||
|
||
std::vector<const Node*> path_; | ||
|
||
private: | ||
bool found_{false}; | ||
Var target_; | ||
std::unordered_set<const Node*> visited_; | ||
}; | ||
|
||
// get the path to the variable, | ||
// return empty vector to represent failure | ||
std::vector<const Node*> GetPath(Var target, Expr expr) { | ||
VariablePathFinder v(target); | ||
v.Visit(expr); | ||
return v.path_; | ||
} | ||
|
||
class BoundDeduceIntputChecker; | ||
class Converter; | ||
|
||
// a visitor to deduce the bound of a variable from a expression | ||
class BoundDeducer: public IRVisitor { | ||
public: | ||
friend class BoundDeduceInputChecker; | ||
friend class Converter; | ||
BoundDeducer(Var target, Expr expr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing everything in constructor have a problem of not being able to throw exception out, consider do it in another function, say Deduce |
||
const std::unordered_map<const Variable*, IntSet>& dom_map) | ||
: target_(target), expr_(expr), dom_map_(dom_map) {} | ||
|
||
bool Init(); | ||
void Deduce(); | ||
|
||
void Visit(const NodeRef& e) final { | ||
if (!success) return; | ||
if (e.get() == path_[iter_++]) { | ||
IRVisitor::Visit(e); | ||
} else { | ||
success = false; | ||
return; | ||
} | ||
} | ||
|
||
void Visit_(const LT* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const LE* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const GT* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const GE* op) final { | ||
LOG(FATAL) << "unable to deduce due to multiple comparison operator"; | ||
} | ||
|
||
void Visit_(const Add* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
result -= left ? op->b : op->a; | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Sub* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
if (left) { | ||
result += op->b; | ||
} else { | ||
result -= op->a; | ||
result = - result; | ||
is_greater = !is_greater; | ||
} | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Mul* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
Expr operand = left ? op->b : op->a; | ||
|
||
SignType sign; | ||
if (operand.type().is_uint()) { | ||
sign = kPositive; | ||
} else { | ||
sign = expr_map_[operand].sign_type(); | ||
} | ||
|
||
if (sign == SignType::kNegative) { | ||
is_greater = !is_greater; | ||
} else if (sign == SignType::kUnknown) { | ||
// unable to get the sign of operand | ||
success = false; | ||
return; | ||
} | ||
|
||
// always use relax bound | ||
result = result / operand + (is_greater ? 1 : -1); | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
Expr result; | ||
bool is_greater{true}; | ||
bool is_equal{true}; | ||
bool success{true}; | ||
|
||
private: | ||
Var target_; | ||
Expr expr_; | ||
const std::unordered_map<const Variable*, IntSet>& dom_map_; | ||
ExprIntSetMap expr_map_; | ||
std::vector<const Node*> path_; | ||
size_t iter_{0}; | ||
}; | ||
|
||
class BoundDeduceInputChecker: public IRVisitor { | ||
public: | ||
bool Check(BoundDeducer* deducer) { | ||
deducer_ = deducer; | ||
Visit(deducer_->expr_); | ||
return target_count == 1; | ||
} | ||
|
||
void Visit(const NodeRef& e) final { | ||
if (e.same_as(deducer_->target_)) ++target_count; | ||
IRVisitor::Visit(e); | ||
} | ||
|
||
private: | ||
BoundDeducer* deducer_; | ||
size_t target_count{0}; | ||
}; | ||
|
||
bool BoundDeducer::Init() { | ||
BoundDeduceInputChecker checker; | ||
if (!checker.Check(this)) success = false; | ||
|
||
if (const LT* op = expr_.as<LT>()) { | ||
is_greater = false; | ||
is_equal = false; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else if (const LE* op = expr_.as<LE>()) { | ||
is_greater = false; | ||
is_equal = true; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else if (const GT* op = expr_.as<GT>()) { | ||
is_greater = true; | ||
is_equal = false; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else if (const GE* op = expr_.as<GE>()) { | ||
is_greater = true; | ||
is_equal = true; | ||
expr_ = op->a; | ||
result = op->b; | ||
} else { | ||
success = false; | ||
} | ||
return success; | ||
} | ||
|
||
void BoundDeducer::Deduce() { | ||
Init(); | ||
if (!success) return; | ||
|
||
// get the path | ||
path_ = GetPath(target_, expr_); | ||
// get the sign of every subexpr | ||
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); | ||
|
||
Visit(expr_); | ||
} | ||
|
||
// assuming e >= 0, deduce the bound of variable from it. | ||
// return empty set to represent deduce failure. | ||
IntSet DeduceBound(Var v, Expr e, | ||
const Map<Var, IntSet>& dom_map) { | ||
std::unordered_map<const Variable*, IntSet> dmap; | ||
for (auto kv : dom_map) { | ||
dmap[kv.first.get()] = kv.second; | ||
} | ||
BoundDeducer d(v, e, dmap); | ||
d.Deduce(); | ||
if (!d.success) return IntSet::nothing(); | ||
Expr min = Interval::neg_inf, max = Interval::pos_inf; | ||
if (d.is_greater) { | ||
min = d.is_equal ? d.result : d.result + 1; | ||
} else { | ||
max = d.is_equal ? d.result : d.result - 1; | ||
} | ||
return IntSet::interval(min, max); | ||
} | ||
|
||
} // namespace arith | ||
} // namespace tvm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this include