Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, Relay] improve bfloat16 support #2

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ __pycache__/
.Python
env/
build/
build_debug/

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't change this. You can change it locally, but don't upsteam.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll fix this.

build_release/
develop-eggs/
dev_tvm/
dist/
downloads/
eggs/
Expand Down
16 changes: 12 additions & 4 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,10 +835,18 @@ TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s
Span span = Span());

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
return tir::Call(x.dtype(), op, {x}, span); \
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType srcType = x.dtype(); \
DataType dstType(kDLFloat, 32, srcType.lanes()); \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make those \ in a row.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

PrimExpr castX = tir::Cast(dstType, {x}, span); \
PrimExpr result = tir::Call(dstType, op, {castX}, span); \
return tir::Cast(srcType, {result}, span); \
} else { \
return tir::Call(x.dtype(), op, {x}, span); \
} \
}

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/transform/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
"nn.bias_add",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we can change this default list. Better to have another CPU list, otherwise you need to evaluate the impact to NV hardware.

"nn.batch_norm",
]
DEFAULT_FOLLOW_LIST = [
# These ops add new data or change shape
Expand Down Expand Up @@ -80,8 +82,6 @@
"subtract",
"multiply",
"divide",
"nn.bias_add",
"nn.batch_norm",
"sqrt",
"shape_of",
# Simple activations
Expand Down
2 changes: 1 addition & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {

// x / 2.0 = x * 0.5
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
ICHECK(op->dtype.is_float() ||
ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
Expand Down
20 changes: 10 additions & 10 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,14 @@ int64_t GetLoopExtent(const ForNode* node) {
// Count math ops in an expr
class MathOpCounter : public StmtExprVisitor {
public:
#define VisitBinary(Type, float_ct, int_ct) \
void VisitExpr_(const Type* op) final { \
if (op->a.dtype().is_float()) { \
float_ct++; \
} else { \
int_ct++; \
} \
StmtExprVisitor::VisitExpr_(op); \
#define VisitBinary(Type, float_ct, int_ct) \
void VisitExpr_(const Type* op) final { \
if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \
float_ct++; \
} else { \
int_ct++; \
} \
StmtExprVisitor::VisitExpr_(op); \
}

VisitBinary(AddNode, float_addsub, int_addsub);
Expand Down Expand Up @@ -299,13 +299,13 @@ class MathOpCounter : public StmtExprVisitor {
effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;

if (is_pure) {
if (op->dtype.is_float()) {
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
float_math_func++;
} else {
int_math_func++;
}
} else {
if (op->dtype.is_float()) {
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
float_other_func++;
} else {
int_other_func++;
Expand Down
20 changes: 15 additions & 5 deletions src/autotvm/touch_extractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,37 @@ class TouchExtractor : public FeatureVisitor {

// arithmetic stats
void VisitExpr_(const AddNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].add_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const SubNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].add_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const MulNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].mul_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const DivNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].div_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

void VisitExpr_(const ModNode* op) final {
if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++;
if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
itervar_map[itervar_stack_.back()].div_ct++;
}
FeatureVisitor::VisitExpr_(op);
}

Expand Down
3 changes: 3 additions & 0 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ void CodeGenHybrid::PrintType(DataType t, std::ostream& os) {
} else if (t.is_int()) {
os << "int";
ICHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64);
} else if (t.is_bfloat16()) {
os << "bfloat";
ICHECK(t.bits() == 16);
} else {
ICHECK(t.is_uint()) << "Unsupported type " << t;
os << "uint";
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@ class CodegenCBase {
dtype = "float";
} else if (runtime::TypeMatch(ttype->dtype, kDLFloat, 16)) {
dtype = "half";
} else if (runtime::TypeMatch(ttype->dtype, kDLBfloat, 16)) {
dtype = "bfloat";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 32)) {
dtype = "int";
} else if (runtime::TypeMatch(ttype->dtype, kDLInt, 64)) {
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ inline std::string DType2String(const tvm::DataType dtype) {
os << "int";
} else if (dtype.is_uint()) {
os << "uint";
} else if (dtype.is_bfloat16()) {
os << "bfloat";
} else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
os << "custom["
<< (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string()
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,8 @@ bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< ", weights shape = " << weights->shape);
return false;
}
if (!(predictions->dtype == weights->dtype && predictions->dtype.is_float())) {
if (!(predictions->dtype == weights->dtype &&
(predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "NLLLossRel: predictions and weights should"
<< " be of the same floating type.");
Expand Down
26 changes: 26 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ namespace relay {
} else if (type == DataType::Float(16)) { \
typedef uint16_t DType; \
{ __VA_ARGS__ } \
} else if (type == DataType::BFloat(16)) { \
typedef uint16_t DType; \
{ __VA_ARGS__ } \
} else if (type == DataType::Int(64)) { \
typedef int64_t DType; \
{ __VA_ARGS__ } \
Expand Down Expand Up @@ -259,6 +262,11 @@ inline Constant MakeConstantScalar(DataType dtype, T value) {
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
} else if (dtype == DataType::BFloat(16)) {
// convert to bfloat16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(value));
} else {
*static_cast<DType*>(arr->data) = value;
}
Expand Down Expand Up @@ -286,6 +294,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else if (dtype == DataType::BFloat(16)) {
// convert to bfloat16
// storage is uint16_t
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
Expand Down Expand Up @@ -314,6 +328,12 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(value[i]));
} else if (dtype == DataType::BFloat(16)) {
// convert to bfloat16
// storage is uint16_t
*(static_cast<DType*>(arr->data) + i) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
static_cast<float>(value[i]));
} else {
*(static_cast<DType*>(arr->data) + i) = value[i];
}
Expand Down Expand Up @@ -417,6 +437,12 @@ static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& ar
} else if (array->dtype.bits == 64) {
return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
}
} else if (array->dtype.code == kDLBfloat) {
if (array->dtype.bits == 16) {
return dmlc::optional<long double>(
__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
}
return dmlc::optional<long double>();
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/crt/common/packed_func.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ DLDataType String2DLDataType(const char* s) {
} else if (!strncmp(s, "float", 5)) {
t.code = kDLFloat;
scan = s + 5;
} else if (!strncmp(s, "bfloat", 6)) {
t.code = kDLBfloat;
scan = s + 6;
} else if (!strncmp(s, "handle", 6)) {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/vm/bytecode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
case kDLFloat:
os << "float";
break;
case kDLBfloat:
os << "bfloat";
break;
}

os << int(dtype.bits);
Expand Down
12 changes: 12 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
!rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(ltype, rhs);
} else if (!ltype.is_bfloat16() &&
(rtype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->float when the other operand is a float
lhs = cast(rtype, lhs);
} else if ((ltype.is_bfloat16() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_bfloat16()) {
// Cast int->float when the other operand is a float
rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (ltype.bits() < rtype.bits()) {
Expand Down Expand Up @@ -186,6 +194,8 @@ PrimExpr max_value(const DataType& dtype, Span span) {
} else if (dtype.bits() == 16) {
return FloatImm(dtype, 65504.0, span);
}
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
return PrimExpr();
Expand Down Expand Up @@ -219,6 +229,8 @@ PrimExpr min_value(const DataType& dtype, Span span) {
} else if (dtype.bits() == 16) {
return FloatImm(dtype, -65504.0, span);
}
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
return PrimExpr();
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type,
IntImm(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), dtype.lanes()));
if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) {
if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) ||
dtype == DataType::Int(1) || dtype == DataType::UInt(16))) {
auto type_msg = tvm::tir::StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
Expand Down
Loading