Skip to content

Commit

Permalink
[pallas:triton] Do not DCE the jaxpr in the lowering pass
Browse files Browse the repository at this point in the history
There isn't an obvious reason for doing DCE there, and the Mosaic TPU backend
in fact doesn't DCE.

PiperOrigin-RevId: 680530571
  • Loading branch information
superbobry authored and Google-ML-Automation committed Sep 30, 2024
1 parent cdc7278 commit a6311ef
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,6 @@ def lower_jaxpr_to_triton_module(
raise NotImplementedError(
"scratch memory not implemented in the Triton backend"
)
with grid_mapping.trace_env():
jaxpr, _ = pe.dce_jaxpr(
jaxpr, [True] * len(jaxpr.outvars), instantiate=True
)
with _new_ir_context(), ir.Location.unknown():
module = ir.Module.create()
attrs = module.operation.attributes
Expand Down

0 comments on commit a6311ef

Please sign in to comment.