Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] DeduceBound #40

Merged
merged 27 commits into from
Feb 17, 2017
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 642ae5 to e68ae6
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._ctypes._node import register_node

from . import tensor
from . import arith
from . import expr
from . import stmt
from . import make
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ctypes/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def _init_api_functions(root_namespace):
module_internal = sys.modules["%s._api_internal" % root_namespace]
namespace_match = {
"_make_": sys.modules["%s.make" % root_namespace],
"_arith_": sys.modules["%s.arith" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# pylint: disable=protected-access, no-member
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal

@register_node
class IntSet(NodeBase):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _api_internal._IntSetIsNothing(self)

def is_everything(self):
"""Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self)

@register_node
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)

def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)

@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass

51 changes: 51 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to arith
* \file api_arith.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include "../arithmetic/int_set.h"
#include "../arithmetic/int_set_internal.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this include


namespace tvm {
namespace arith {

TVM_REGISTER_API(_arith_intset_single_point)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});

TVM_REGISTER_API(_arith_intset_interval)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});

TVM_REGISTER_API(_arith_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1], args[2]);
});

TVM_REGISTER_API(_IntervalSetGetMin)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});

TVM_REGISTER_API(_IntervalSetGetMax)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});

TVM_REGISTER_API(_IntSetIsNothing)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});

TVM_REGISTER_API(_IntSetIsEverything)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});

} // namespace arith
} // namespace tvm
230 changes: 230 additions & 0 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include <unordered_map>
#include "./int_set.h"

namespace tvm {
namespace arith {

using namespace ir;
using Halide::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
class VariablePathFinder: public IRVisitor {
public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VariablePathFinder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to look out for errors when a variable appears in multiple locations in the expression

explicit VariablePathFinder(Var target) : target_(target) {}

void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());

if (!found_) path_.push_back(node.get());
if (node.same_as(target_)) found_ = true;
IRVisitor::Visit(node);
if (!found_) path_.pop_back();
}

std::vector<const Node*> path_;

private:
bool found_{false};
Var target_;
std::unordered_set<const Node*> visited_;
};

// get the path to the variable,
// return empty vector to represent failure
std::vector<const Node*> GetPath(Var target, Expr expr) {
VariablePathFinder v(target);
v.Visit(expr);
return v.path_;
}

class BoundDeduceIntputChecker;
class Converter;

// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
friend class BoundDeduceInputChecker;
friend class Converter;
BoundDeducer(Var target, Expr expr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing everything in constructor have a problem of not being able to throw exception out, consider do it in another function, say Deduce

const std::unordered_map<const Variable*, IntSet>& dom_map)
: target_(target), expr_(expr), dom_map_(dom_map) {}

bool Init();
void Deduce();

void Visit(const NodeRef& e) final {
if (!success) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
return;
}
}

void Visit_(const LT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const LE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const GT* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const GE* op) final {
LOG(FATAL) << "unable to deduce due to multiple comparison operator";
}

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
} else {
result -= op->a;
result = - result;
is_greater = !is_greater;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;

SignType sign;
if (operand.type().is_uint()) {
sign = kPositive;
} else {
sign = expr_map_[operand].sign_type();
}

if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
return;
}

// always use relax bound
result = result / operand + (is_greater ? 1 : -1);
Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};
bool is_equal{true};
bool success{true};

private:
Var target_;
Expr expr_;
const std::unordered_map<const Variable*, IntSet>& dom_map_;
ExprIntSetMap expr_map_;
std::vector<const Node*> path_;
size_t iter_{0};
};

class BoundDeduceInputChecker: public IRVisitor {
public:
bool Check(BoundDeducer* deducer) {
deducer_ = deducer;
Visit(deducer_->expr_);
return target_count == 1;
}

void Visit(const NodeRef& e) final {
if (e.same_as(deducer_->target_)) ++target_count;
IRVisitor::Visit(e);
}

private:
BoundDeducer* deducer_;
size_t target_count{0};
};

bool BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;

if (const LT* op = expr_.as<LT>()) {
is_greater = false;
is_equal = false;
expr_ = op->a;
result = op->b;
} else if (const LE* op = expr_.as<LE>()) {
is_greater = false;
is_equal = true;
expr_ = op->a;
result = op->b;
} else if (const GT* op = expr_.as<GT>()) {
is_greater = true;
is_equal = false;
expr_ = op->a;
result = op->b;
} else if (const GE* op = expr_.as<GE>()) {
is_greater = true;
is_equal = true;
expr_ = op->a;
result = op->b;
} else {
success = false;
}
return success;
}

void BoundDeducer::Deduce() {
Init();
if (!success) return;

// get the path
path_ = GetPath(target_, expr_);
// get the sign of every subexpr
expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_);

Visit(expr_);
}

// assuming e >= 0, deduce the bound of variable from it.
// return empty set to represent deduce failure.
IntSet DeduceBound(Var v, Expr e,
const Map<Var, IntSet>& dom_map) {
std::unordered_map<const Variable*, IntSet> dmap;
for (auto kv : dom_map) {
dmap[kv.first.get()] = kv.second;
}
BoundDeducer d(v, e, dmap);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
if (d.is_greater) {
min = d.is_equal ? d.result : d.result + 1;
} else {
max = d.is_equal ? d.result : d.result - 1;
}
return IntSet::interval(min, max);
}

} // namespace arith
} // namespace tvm
Loading