From 39699759e6cb0feaa903848098da2cf94b6625aa Mon Sep 17 00:00:00 2001 From: cjkkkk Date: Thu, 7 Dec 2023 13:32:11 -0800 Subject: [PATCH] fix case with no bias but also no causal_mask --- xla/stream_executor/cuda/cuda_dnn.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 9ab38132c3989..1234540fbdd9d 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -9305,7 +9305,7 @@ CudnnSupport::FusedMHARunnerFromDesc( scalar_input_values.push_back(dropout_scale); dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - if (bias_descriptor == std::nullopt) { + if (is_causal_mask) { // push negative infinity here scalar_input_uids.push_back(CudnnfMHAUid::NEG_INFINITY_ID); double negative_infinity_value = -std::numeric_limits::infinity();