Skip to content

Commit

Permalink
Fix 1d-softmax schedule. (#11719)
Browse files Browse the repository at this point in the history
  • Loading branch information
lazycal committed Jun 15, 2022
1 parent f667342 commit d2e2f71
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def sched_warp_softmax():
return False
return True

if len(outs[0].shape) > 2:
if len(outs[0].shape) != 2:
ops = [max_elem.op, expsum.op, softmax_op]
if delta is not None:
ops.append(delta.op)
Expand Down
14 changes: 10 additions & 4 deletions tests/python/topi/python/test_topi_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@
"softmax": {
"topi": topi.nn.softmax,
"ref": tvm.topi.testing.softmax_python,
"dimensions": [2, 4],
"dimensions": [1, 2, 4],
},
"log_softmax": {
"topi": topi.nn.log_softmax,
"ref": tvm.topi.testing.log_softmax_python,
"dimensions": [2],
},
}
shapes = [(32, 10), (3, 4), (1, 16, 256, 256)]
shapes = [(32, 10), (3, 4), (1, 16, 256, 256), (32,)]
softmax_operation, shape = tvm.testing.parameters(
*[
(name, shape)
Expand All @@ -69,13 +69,19 @@ def ref_data(shape, dtype, softmax_operation):

a_np = np.random.uniform(size=shape).astype(dtype)

if len(shape) == 2:
if len(shape) == 1:
a_np_2d = a_np[None, :]
b_np_2d = tvm.topi.testing.softmax_python(a_np_2d)
b_np = b_np_2d[0]
elif len(shape) == 2:
b_np = ref_func(a_np)
elif len(shape) == 4:
_, c, h, w = a_np.shape
a_np_2d = a_np.transpose(0, 2, 3, 1).reshape(h * w, c)
b_np_2d = tvm.topi.testing.softmax_python(a_np_2d)
b_np = b_np_2d.reshape(1, h, w, c).transpose(0, 3, 1, 2)
else:
raise NotImplementedError(f"{len(shape)}-D shape not supported")

return a_np, b_np

Expand All @@ -89,7 +95,7 @@ def test_softmax(target, dev, shape, dtype, ref_data, softmax_operation):
A = te.placeholder(shape, dtype=dtype, name="A")

topi_op = configs[softmax_operation]["topi"]
B = topi_op(A, axis=1)
B = topi_op(A, axis=min(len(shape) - 1, 1))

with tvm.target.Target(target):
fschedule = tvm.topi.testing.dispatch(target, _softmax_schedule)
Expand Down

0 comments on commit d2e2f71

Please sign in to comment.