Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlaxLlamaForCausalLMModule hanging on jax-metal #24221

Open
alexlatif opened this issue Oct 9, 2024 · 2 comments
Open

FlaxLlamaForCausalLMModule hanging on jax-metal #24221

alexlatif opened this issue Oct 9, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@alexlatif
Copy link

Description

To reproduce the working state uncomment the device update to cpu

from transformers import AutoTokenizer
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state

# from llama import FlaxLLaMAForCausalLM  # From the ayaka14732/llama-2-jax repo
from transformers.models.llama.modeling_flax_llama import (
    FlaxLlamaForCausalLMModule,
    LlamaConfig,
)
from transformers.models.llama.tokenization_llama import LlamaTokenizer

# Download tokenizer and model
# jax.config.update("jax_platform_name", "cpu")
print(jax.devices())

tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2")
conf = LlamaConfig.from_pretrained("openlm-research/open_llama_3b_v2")
print(type(conf))
model = FlaxLlamaForCausalLMModule(conf)
print(type(model))

input_prompt = "The future of AI is"
input_ids = tokenizer(input_prompt, return_tensors="jax").input_ids

rng = jax.random.PRNGKey(0)
position_ids = jnp.broadcast_to(jnp.arange(input_ids.shape[-1]), input_ids.shape)

print(position_ids.device)

params = model.init(
    rng, input_ids, attention_mask=jnp.ones_like(input_ids), position_ids=position_ids
)["params"]

model_output = model.apply(
    {"params": params},
    input_ids,
    attention_mask=jnp.ones_like(input_ids),
    position_ids=position_ids,
)

print("Model output logits:", model_output.logits)

predicted_token_ids = jnp.argmax(model_output.logits, axis=-1)

predicted_text = tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)

print("Predicted text:", predicted_text)

System info (python version, jaxlib version, accelerator, etc.)

jaxlib: 0.4.34
numpy:  1.26.4
python: 3.11.0 (v3.11.0:deaf509e8f, Oct 24 2022, 14:43:23) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Alessandros-Air', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:19:22 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8112', machine='arm64')```
@alexlatif alexlatif added the bug Something isn't working label Oct 9, 2024
@rajasekharporeddy
Copy link
Contributor

Hi @alexlatif

I tested the provided code with JAX-metal on a Macbook Pro M1 Pro. While there were no hanging issues, model.init and model.apply took longer than the CPU version. Please find the attached screenshots below:

image
image
image

Thank you.

@alexlatif
Copy link
Author

You're correct in that eventually it does run. However on Macbook Air M2 Sonoma 14.4.1 this took ~5 mins. Any insight on why it's so much slower on metal?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants