From a6311ef172650a8793f3cb216d22974355abdf18 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 30 Sep 2024 05:19:13 -0700 Subject: [PATCH] [pallas:triton] Do not DCE the jaxpr in the lowering pass There isn't an obvious reason for doing DCE there, and the Mosaic TPU backend in fact doesn't DCE. PiperOrigin-RevId: 680530571 --- jax/_src/pallas/triton/lowering.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 9db5e4081239..33ebe4e34cf0 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -286,10 +286,6 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scratch memory not implemented in the Triton backend" ) - with grid_mapping.trace_env(): - jaxpr, _ = pe.dce_jaxpr( - jaxpr, [True] * len(jaxpr.outvars), instantiate=True - ) with _new_ir_context(), ir.Location.unknown(): module = ir.Module.create() attrs = module.operation.attributes