diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index e5c92794b9d126..36ed1846c551df 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -6689,7 +6689,8 @@ GetCudnnFlashAttentionOperationGraph( intermediate_ops, intermediate_bmm2_lhs_dims, intermediate_bmm2_lhs_strides, intermediate_bmm2_lhs_descriptor.type(), - /*input_tensor*/ softmax_fwd_out, *dropout_rate)); + /*input_tensor*/ softmax_fwd_out, + use_dropout ? *dropout_rate : 0)); bmm2_input_tensor = std::move(dropout_out); std::vector bmm2_rhs_dims = @@ -7194,7 +7195,7 @@ GetCudnnFlashAttentionBackwardOperationGraph( auto tensor_p_after_scale_dropout, CreateCudnnFlashAttentionDropoutBwdTensor( intermediate_ops, p_dims, p_strides, dtype, tensor_p_after_softmax, - tensor_dropout_mask, *dropout_rate)); + tensor_dropout_mask, use_dropout ? *dropout_rate : 0)); // after_scale_dropout -> s_transpose auto p_transpose_dims = p_dims; @@ -9427,7 +9428,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc( use_dropout ? (1.0f / (1.0f - *dropout_rate)) : 1.0f; ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat); // scale prob - double scale_prob_value = 1.0 - *dropout_rate; + double scale_prob_value = use_dropout ? 1.0 - *dropout_rate : 1.0f; ScalingParam scale_prob(scale_prob_value, dnn::DataType::kFloat); scalar_values = {alpha_scale, dropout_scale, scale_prob}; // push dropout seed and offset here