diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 1234540fbdd9d..0c66995f8f8e4 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -9427,7 +9427,8 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( CudnnfMHAUid::d_Q_accum_ID, CudnnfMHAUid::O_ID}; if (bias_descriptor != std::nullopt) { uids.push_back(CudnnfMHAUid::BIAS_ID); - } else { + } + if (is_causal_mask) { // is causal mask // negative infinity double negative_infinity_value = -std::numeric_limits::infinity();