-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Arith] add SizeVar representing non-neg valued variable in a tensor shape #4684
Conversation
@tqchen do you have any idea how come shape_var fails to be in func arg after I rebase from upstream today? update, |
@tqchen @icemelon9 CI's green. it's ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some initial reviews, what will happen if we pass var into the compute, is it still supported?
include/tvm/expr.h
Outdated
*/ | ||
class ShapeVarNode : public VarNode { | ||
public: | ||
static ShapeVar make(DataType dtype, std::string name_hint); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use constructor in ShapeVar, as per #4648
python/tvm/build_module.py
Outdated
@@ -309,7 +313,7 @@ def get_binds(args, compact=False, binds=None): | |||
arg_list.append(binds[x]) | |||
elif isinstance(x, schedule.Buffer): | |||
arg_list.append(x) | |||
elif isinstance(x, expr.Var): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep the original one, as we can make ShapeVar as subclass expr.Var
cc @ZihengJiang @Hzfengsy would be great if you can help to take a look |
passing var into the compute still works |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some additional comments wrt to integer analysis, would be great if we can comment the additional test that covers these cases
src/arithmetic/bound_deducer.cc
Outdated
class BoundRemover : public ExprMutator { | ||
public: | ||
PrimExpr Remove(const PrimExpr& e) { | ||
remove_bounded_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you document its behavior? Is it still necessary, given bound info is no longer wrapped in the assert_
src/arithmetic/bound_deducer.cc
Outdated
@@ -297,6 +331,18 @@ void BoundDeducer::Transform() { | |||
void BoundDeducer::Deduce() { | |||
Init(); | |||
if (!success_) return; | |||
|
|||
// Any variable appears in both expr and result, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not understand the following comment. Why would n/4 be simplified to 0? if n>=0 the simplification rule does not stand, if n< 4, then it is a valid simplification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for >=
, current logic is to first get the lower bound of the left side then deduce for i
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, perhaps we can add more explanation then.
Also, would it results in an incorrect bound? or just a more strict bound(which still makes the condition holds)
For a better approach. My feeling is that we should first have a rewriting that tries to move vars into one side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it is just a more strict bound.
I agree, let me have a try to see if it is feasible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as it is also using IntervalSetEvaluator
, this is also not required after having shape_var return single_pt
src/arithmetic/int_set.cc
Outdated
// in case the domain contains variables to be relaxed. | ||
return Eval(res); | ||
} else { | ||
return IntervalSet(0, GetRef<ShapeVar>(op)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be IntervalSet::SinglePoint(var) as in Var? This is a bit overly relaxed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair. I'll change and verify.
src/arithmetic/int_set.cc
Outdated
@@ -519,6 +536,14 @@ class IntervalSetEvaluator : | |||
return set->min_value.same_as(value) && set->max_value.same_as(value); | |||
} | |||
|
|||
bool SelfBoundedVar(const IntervalSet& set, | |||
const PrimExpr& value) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment about the function name.
src/arithmetic/int_set.cc
Outdated
IntervalSet b = this->Eval(op->b); | ||
if ((MatchPoint(a, op->a) && (MatchPoint(b, op->b) || SelfBoundedVar(b, op->b))) | ||
|| (SelfBoundedVar(a, op->a) && SelfBoundedVar(b, op->b))) { | ||
// e.g., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is a bit confusing. If the divisor is already a set(which means it is relaxed), we should not return the original var.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but in Combine<Div>
such case evaluates to IntervalSet::Everything()
.
e.g., 4 / tvm.var() => 4 / tvm.var()
, while 4 / tvm.shape_var() => (-inf, +inf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this can be fixed by having shape_var return single_pt(as in the other comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah... you're right.
@tqchen updated per comments. do you think there could be a better name than |
I don't have a better idea for name. Indeed shape could indicate a tuple rather than an integer. We could potentially rename the relay's shape template variable for clarity as it is really type var. But we could think about a better name. |
include/tvm/expr.h
Outdated
class ShapeVar : public Var { | ||
public: | ||
explicit ShapeVar(ObjectPtr<Object> n) : Var(n) {} | ||
TVM_DLL explicit ShapeVar(std::string name_hint = "s", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the constructor
How about |
We can ask other's thoughts about it(perhaps list a few candidates and send an rfc?). TIndex seems to be a reasonable name, although it is a bit ambiguous because people need to guess what does T mean). |
@tqchen I have address the comments and changed the name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a minor nits @icemelon9 please also take another look. Nice set of improvements!
python/tvm/api.py
Outdated
@@ -192,6 +192,25 @@ def var(name="tindex", dtype=int32): | |||
return _api_internal._Var(name, dtype) | |||
|
|||
|
|||
def size_var(name="tindex", dtype=int32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is tindex a good default? size?
cc @yzhliu feel free to merge after fixing the nits. BTW, it also reminds me if we would like to bring IterVar as a sub-class of Var, right now it takes Var as a member. |
Thanks @tqchen I'll take a look into the IndexVar and send another pr |
…shape (apache#4684) * [arith] add ShapeVar representing non-neg valued variable in a tensor shape * bounder remover; deal with div in int_set differently * fix bounder_remover * migrate unittest to use shape_var * use tvm.shape_var in integration & relay tests * add test case; fix Var register * fix lint * fix lint again * add default ShapeVar visitor in Relay * fix override * fix ShapeVar visit bug * revert IntervalSet for shape_var * remove bound_remover * remove is_var; use constructor for shapevar/var instead * ShapeVar -> SizeVar; add constructor comments * shape_var -> size_var in doc * tindex -> size
…shape (apache#4684) * [arith] add ShapeVar representing non-neg valued variable in a tensor shape * bounder remover; deal with div in int_set differently * fix bounder_remover * migrate unittest to use shape_var * use tvm.shape_var in integration & relay tests * add test case; fix Var register * fix lint * fix lint again * add default ShapeVar visitor in Relay * fix override * fix ShapeVar visit bug * revert IntervalSet for shape_var * remove bound_remover * remove is_var; use constructor for shapevar/var instead * ShapeVar -> SizeVar; add constructor comments * shape_var -> size_var in doc * tindex -> size
…shape (apache#4684) * [arith] add ShapeVar representing non-neg valued variable in a tensor shape * bounder remover; deal with div in int_set differently * fix bounder_remover * migrate unittest to use shape_var * use tvm.shape_var in integration & relay tests * add test case; fix Var register * fix lint * fix lint again * add default ShapeVar visitor in Relay * fix override * fix ShapeVar visit bug * revert IntervalSet for shape_var * remove bound_remover * remove is_var; use constructor for shapevar/var instead * ShapeVar -> SizeVar; add constructor comments * shape_var -> size_var in doc * tindex -> size
To provide extra information for arith simplification.
This is an alternative approach for #4486
More background,
https://discuss.tvm.ai/t/discuss-embed-more-bound-information-into-var-or-expr
https://discuss.tvm.ai/t/significant-increase-in-the-amount-of-cuda-code-gen-after-migrating-indexdiv-mod-to-floordiv-mod
I haven't change
tvm.var
totvm.shape_var
in topi because @icemelon9 is in the middle of refactoring #4644 . We can do it later.@tqchen @icemelon9 Please review