You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tested the provided code with JAX-metal on a Macbook Pro M1 Pro. While there were no hanging issues, model.init and model.apply took longer than the CPU version. Please find the attached screenshots below:
You're correct in that eventually it does run. However on Macbook Air M2 Sonoma 14.4.1 this took ~5 mins. Any insight on why it's so much slower on metal?
Description
To reproduce the working state uncomment the device update to cpu
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: