From b8ad146dfd00710376e9477dd2367cc94399d9bb Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 28 Jan 2021 15:06:12 -0700 Subject: [PATCH] [Relay] Type Relation Fixes (#7362) * fix an error in the dynamic Full Type Relation * Add Diagnostic Errors to Broadcast Type Relations --- src/relay/op/dyn/tensor/transform.cc | 3 +++ src/relay/op/type_relations.cc | 12 ++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index e4e81e3612fb..8bad3943f5ce 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -400,6 +400,9 @@ bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, if (fill_value == nullptr) { return false; } + if (fill_shape == nullptr) { + return false; + } DataType out_dtype = param->dtype; if (out_dtype.bits() == 0) { diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 7a3bfcb21ce6..7b30aea2eb57 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -104,7 +104,11 @@ bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, // << ",Out:" << types[2] << std::endl; if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { - ICHECK_EQ(t0->dtype, t1->dtype); + if (t0->dtype != t1->dtype) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span) + << "data types " << t0->dtype << " and " << t1->dtype + << "do not match in BroadcastRel"); + } reporter->Assign( types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); return true; @@ -120,7 +124,11 @@ bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& att // << ",Out:" << types[2] << std::endl; if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { - ICHECK_EQ(t0->dtype, t1->dtype); + if (t0->dtype != t1->dtype) { + reporter->GetDiagCtx().Emit(Diagnostic::Error(t0->span) + << "data types " << t0->dtype << " and " << t1->dtype + << "do not match in BroadcastCompRel"); + } reporter->Assign(types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), DataType::Bool())); return true;