Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, Relay] improve bfloat16 support #2

Closed
wants to merge 16 commits into from
Closed

Conversation

yangulei
Copy link
Owner

@yangulei yangulei commented Jan 20, 2022

Motivation:

We are enabling bfloat16 in BYOC-oneDNN following the path: [float32 graph] --> <AMP> --> [bfloat16 graph] --> <BYOC> --> [TVM + oneDNN module]. While some of the Passes like FoldConstant can not work for bfloat16 before the improvements below.

Changes:

  • Add runtime datatype dispatch and skip asserts for uint16 for bfloat16 compatibility.
  • Add bfloat16 casting for unary intrinsic operators to enable the graph optimization.
  • Improve the bf16_legalize module to enable bfloat16 lowering.

With those improvements, a float32 graph could be converted to bfloat16 through AMP, and then be lowered to inference in bfloat16 mode now.

Tested Models (gluoncv):

  • ResNet<18/34/50/101/152>_v1b
  • VGG<11/13/16/19>
  • VGG<11/13/16/19>_bn
  • DenseNet121
  • InceptionV3

By tested I mean I confirm it did some transformation on the graph and a forward pass could be run on CPU and matches the fp32 output somewhat. I have nothing on performance metrics or other devices yet.

As @AndrewZhaoLuo said at apache#8069

Pending:

The support for bfloat16 in BYOC-oneDNN is based on multi-blocking layout transform and the extensions on BYOC-oneDNN and pending.

@yangulei yangulei changed the title improve bfloat16 support [TIR, Relay] improve bfloat16 support Jan 20, 2022
.gitignore Outdated
@@ -11,7 +11,10 @@ __pycache__/
.Python
env/
build/
build_debug/

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't change this. You can change it locally, but don't upsteam.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll fix this.

static const Op& op = Op::Get("tir." #OpName); \
if (x.dtype().is_bfloat16()) { \
DataType srcType = x.dtype(); \
DataType dstType(kDLFloat, 32, srcType.lanes()); \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make those \ in a row.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -40,6 +40,8 @@
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
"nn.bias_add",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we can change this default list. Better to have another CPU list, otherwise you need to evaluate the impact to NV hardware.

@@ -126,3 +155,4 @@ def test_fp16_conversion(target, dev):
test_basic_build()
test_fp16_build()
test_fp16_conversion()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add test_bf16_conversion as fp16?

Copy link
Owner Author

@yangulei yangulei Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhennanQin
Copy link

LGTM!

@yangulei yangulei closed this Apr 1, 2022
@yangulei
Copy link
Owner Author

yangulei commented Apr 1, 2022

Thanks for reviewing, this PR has been merged to the official repo.

yangulei pushed a commit that referenced this pull request Sep 8, 2022
* Revert "[skip ci] Revert "[ci] Default to n=2 for test parallelism (apache#12376)" (apache#12413)"

This reverts commit 478b672.

* [ci] Default to n=2 for test parallelism

This is attempt #2 of apache#12376 which was reverted in apache#12413. The changes
in `plugin.py` should keep all the tests on the same node so sporadic
failures don't happen due to scheduling.

Co-authored-by: driazati <driazati@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants