Skip to content

Commit

Permalink
[Relay] Type Relation Fixes (#7362)
Browse files Browse the repository at this point in the history
* fix an error in the dynamic Full Type Relation

* Add Diagnostic Errors to Broadcast Type Relations
  • Loading branch information
Matthew Brookhart committed Jan 28, 2021
1 parent f17cba7 commit b8ad146
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (fill_value == nullptr) {
return false;
}
if (fill_shape == nullptr) {
return false;
}

DataType out_dtype = param->dtype;
if (out_dtype.bits() == 0) {
Expand Down
12 changes: 10 additions & 2 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// << ",Out:" << types[2] << std::endl;
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
ICHECK_EQ(t0->dtype, t1->dtype);
if (t0->dtype != t1->dtype) {
reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
<< "data types " << t0->dtype << " and " << t1->dtype
<< "do not match in BroadcastRel");
}
reporter->Assign(
types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1), t0->dtype));
return true;
Expand All @@ -120,7 +124,11 @@ bool BroadcastCompRel(const Array<Type>& types, int num_inputs, const Attrs& att
// << ",Out:" << types[2] << std::endl;
if (auto* t0 = types[0].as<TensorTypeNode>()) {
if (auto* t1 = types[1].as<TensorTypeNode>()) {
ICHECK_EQ(t0->dtype, t1->dtype);
if (t0->dtype != t1->dtype) {
reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span)
<< "data types " << t0->dtype << " and " << t1->dtype
<< "do not match in BroadcastCompRel");
}
reporter->Assign(types[2], ConcreteBroadcast(GetRef<TensorType>(t0), GetRef<TensorType>(t1),
DataType::Bool()));
return true;
Expand Down

0 comments on commit b8ad146

Please sign in to comment.