diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 9db5e4081239..33ebe4e34cf0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -286,10 +286,6 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scratch memory not implemented in the Triton backend" ) - with grid_mapping.trace_env(): - jaxpr, _ = pe.dce_jaxpr( - jaxpr, [True] * len(jaxpr.outvars), instantiate=True - ) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() attrs = module.operation.attributes