Skip to content

Commit

Permalink
Add aten::_softmax to eager ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
WilBrady committed Jun 12, 2022
1 parent 5562b47 commit c31e01e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions orttraining/orttraining/eager/opgen/opgen/atenops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, dY, X):
"aten::_local_scalar_dense": MakeTorchFallback(),
"aten::gt.Scalar_out": MakeTorchFallback(),
"aten::equal": MakeTorchFallback(),
"aten::_softmax": Softmax("self", axis="dim"),
}

# Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439
Expand Down
7 changes: 7 additions & 0 deletions orttraining/orttraining/eager/test/ort_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def test_zero_stride(self):
cpu_tensor_copied = ort_tensor.cpu()
assert cpu_tensor_copied.stride() == (0, 0, 0)

def test_softmax(self):
device = self.get_device()
cpu_tensor = torch.rand(3, 5)
ort_tensor = cpu_tensor.to(device)
cpu_result = torch.softmax(cpu_tensor, dim=1)
ort_result = torch.softmax(ort_tensor, dim=1)
assert torch.allclose(cpu_result, ort_result.cpu())

if __name__ == "__main__":
unittest.main()

0 comments on commit c31e01e

Please sign in to comment.