Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mixed precision support to deepxde #1650

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Conversation

g-w1
Copy link
Contributor

@g-w1 g-w1 commented Feb 11, 2024

No description provided.

deepxde/config.py Outdated Show resolved Hide resolved
@g-w1 g-w1 force-pushed the mixed branch 4 times, most recently from d3590e4 to 1352213 Compare February 11, 2024 23:19
deepxde/config.py Outdated Show resolved Hide resolved
deepxde/config.py Outdated Show resolved Hide resolved
deepxde/config.py Outdated Show resolved Hide resolved
deepxde/config.py Outdated Show resolved Hide resolved
deepxde/model.py Outdated Show resolved Hide resolved
deepxde/real.py Outdated Show resolved Hide resolved
deepxde/model.py Outdated Show resolved Hide resolved
@lululxvi
Copy link
Owner

If I use the API here to set mixed precision, then all the demo code can run in mixed precision?

@g-w1
Copy link
Contributor Author

g-w1 commented Jun 26, 2024

Not all. The L-BFGS optimizer doesn't work in mixed precision. But if you add the line dde.config.set_default_float("mixed") to the top of your file, it should work unless it uses a feature (like L-BFGS) that is not supported with mixed precision. I have tested this on the Burgers.py example and it works.

@@ -74,7 +74,7 @@ def set_default_float(value):
The default floating point type is 'float32'.

Args:
value (String): 'float16', 'float32', or 'float64'.
value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision in https://arxiv.org/abs/2401.16645).
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value (String): 'float16', 'float32', 'float64', or 'mixed' (mixed precision).

@@ -74,7 +74,7 @@ def set_default_float(value):
The default floating point type is 'float32'.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default floating point type is 'float32'. Mixed precision uses the method in the paper: `J. Hayford, J. Goldman-Wetzler, E. Wang, & L. Lu. Speeding up and reducing memory usage for scientific machine learning via mixed precision. Computer Methods in Applied Mechanics and Engineering, 428, 117093, 2024 <https://doi.org/10.1016/j.cma.2024.117093>`_.


self.opt.step(closure)
def closure_mixed():
Copy link
Owner

@lululxvi lululxvi Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why delete line 360

@@ -85,6 +85,20 @@ def set_default_float(value):
elif value == "float64":
print("Set the default float type to float64")
real.set_float64()
elif value == "mixed":
print("Set the float type to mixed precision of float16 and float32")
real.set_mixed()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is confusing. Here you do real.set_mixed(), but later you do either real.set_float16() or real.set_float32(). It seems you only need a flag mixed. You can do this flag after line 42.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants