Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal slowdown on m3 max when external monitor is plugged in #24163

Open
bjeurissen opened this issue Oct 7, 2024 · 1 comment
Open

jax-metal slowdown on m3 max when external monitor is plugged in #24163

bjeurissen opened this issue Oct 7, 2024 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@bjeurissen
Copy link

bjeurissen commented Oct 7, 2024

Description

I am using macOS Sequoia 15.0.1 (24A348) on an Apple M3 Max.

python: 3.12.7
jax: 0.4.34
jaxlib: 0.4.34
jax-metal: 0.1.0

I noticed that I am consistently getting a reduction in speed by more than a factor of 2 on a simple toy problem when I attach an external monitor.

E.g. without an external monitor I get:

10000/10000 [00:09<00:00, 1021.68it/s]

With an external monitor, I get:

10000/10000 [00:24<00:00, 407.24it/s]

My suspicion was that refresh rate of the monitor could have something to do with it, so I decided to unplug the monitor and test with the internal display, but this time reducing the refresh rate from 120 Hz (default) to 60 Hz in the Display settings of macOS and again I got a reduction in speed (although slightly less than with an external monitor plugged in):

10000/10000 [00:18<00:00, 549.69it/s]

Strange thing is that when I fiddle a lot with the display controls (switching between mirroring and extended screen), I can sometimes have it run at 1000it/s even with the external display plugged in. This to me suggests that this is probably a macOS video/mps driver issue and not specific to jax-metal. I could not reproduce this problem with any other libraries for numerical calculations that support MPS though, although none of those managed to get more than 400it/s to begin with, even without an external monitor attached.

System info (python version, jaxlib version, accelerator, etc.)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1728313982.685485  115532 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Max

systemMemory: 128.00 GB
maxCacheSize: 48.00 GB

I0000 00:00:1728313982.692505  115532 service.cc:145] XLA service 0x600003911200 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1728313982.692514  115532 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1728313982.693547  115532 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1728313982.693555  115532 mps_client.cc:384] XLA backend will use up to 103078739968 bytes on device 0 for SimpleAllocator.
jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 15:57:01) [Clang 17.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='<REMOVED>', release='24.0.0', version='Darwin Kernel Version 24.0.0: Tue Sep 24 23:35:10 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6031', machine='arm64')
@bjeurissen bjeurissen added the bug Something isn't working label Oct 7, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 7, 2024

I also suspect this is a Mac OS issue and nothing we can address from JAX, but @shuhand0 can determine that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants