Skip to content

Commit

Permalink
add guard for optional dropout_rate
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk committed Jan 11, 2024
1 parent 9de2496 commit 3b9ff9b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions xla/stream_executor/cuda/cuda_dnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6689,7 +6689,8 @@ GetCudnnFlashAttentionOperationGraph(
intermediate_ops, intermediate_bmm2_lhs_dims,
intermediate_bmm2_lhs_strides,
intermediate_bmm2_lhs_descriptor.type(),
/*input_tensor*/ softmax_fwd_out, *dropout_rate));
/*input_tensor*/ softmax_fwd_out,
use_dropout ? *dropout_rate : 0));
bmm2_input_tensor = std::move(dropout_out);

std::vector<int64_t> bmm2_rhs_dims =
Expand Down Expand Up @@ -7194,7 +7195,7 @@ GetCudnnFlashAttentionBackwardOperationGraph(
auto tensor_p_after_scale_dropout,
CreateCudnnFlashAttentionDropoutBwdTensor(
intermediate_ops, p_dims, p_strides, dtype, tensor_p_after_softmax,
tensor_dropout_mask, *dropout_rate));
tensor_dropout_mask, use_dropout ? *dropout_rate : 0));

// after_scale_dropout -> s_transpose
auto p_transpose_dims = p_dims;
Expand Down Expand Up @@ -9427,7 +9428,7 @@ CudnnSupport::FusedMHABackwardRunnerFromDesc(
use_dropout ? (1.0f / (1.0f - *dropout_rate)) : 1.0f;
ScalingParam dropout_scale(dropout_scale_value, dnn::DataType::kFloat);
// scale prob
double scale_prob_value = 1.0 - *dropout_rate;
double scale_prob_value = use_dropout ? 1.0 - *dropout_rate : 1.0f;
ScalingParam scale_prob(scale_prob_value, dnn::DataType::kFloat);
scalar_values = {alpha_scale, dropout_scale, scale_prob};
// push dropout seed and offset here
Expand Down

0 comments on commit 3b9ff9b

Please sign in to comment.