Skip to content

Commit

Permalink
Add unroll=True to all the calls of fori_loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589842015
  • Loading branch information
jax authors committed Dec 11, 2023
1 parent 352e10e commit 9d0a991
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/experimental/pallas/ops/tpu/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@ def k_body(i, _):
pl.store(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)),
pl.load(dk_scratch_ref, (pl.ds(start_k, block_k), slice(None)))
+ dk.astype(dk_scratch_ref.dtype))
lax.fori_loop(0, block_k_major // block_k, k_body, None)
lax.fori_loop(0, block_k_major // block_k, k_body, None, unroll=True)

if causal:
should_run = below_or_on_diag(
Expand All @@ -887,7 +887,7 @@ def k_body(i, _):

@pl.when(should_run)
def run():
lax.fori_loop(0, block_q_major // block_q, q_body, None)
lax.fori_loop(0, block_q_major // block_q, q_body, None, unroll=True)

@pl.when(q_seq_index == q_seq_len // block_q_major - 1)
def end_of_q_sequence():
Expand Down Expand Up @@ -1234,7 +1234,7 @@ def body(i, _):

@pl.when(should_run)
def run():
lax.fori_loop(0, block_k_major // block_k, body, None)
lax.fori_loop(0, block_k_major // block_k, body, None, unroll=True)

@pl.when(should_not_run)
def zero_out_ds():
Expand Down

0 comments on commit 9d0a991

Please sign in to comment.