diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 225c2edad38f..a47022426083 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -138,6 +138,7 @@ def __init__(self, dY, X): "aten::_local_scalar_dense": MakeTorchFallback(), "aten::gt.Scalar_out": MakeTorchFallback(), "aten::equal": MakeTorchFallback(), + "aten::_softmax": Softmax("self", axis="dim"), } # Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439 diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 62c12f01f849..515967d4a5f8 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -144,6 +144,14 @@ def test_zero_stride(self): cpu_tensor_copied = ort_tensor.cpu() assert cpu_tensor_copied.stride() == (0, 0, 0) + def test_softmax(self): + device = self.get_device() + cpu_tensor = torch.rand(3, 5) + ort_tensor = cpu_tensor.to(device) + cpu_result = torch.softmax(cpu_tensor, dim=1) + ort_result = torch.softmax(ort_tensor, dim=1) + assert torch.allclose(cpu_result, ort_result.cpu()) + if __name__ == "__main__": unittest.main()