Skip to content

Commit

Permalink
Complete FusedMM body (#39)
Browse files Browse the repository at this point in the history
* add softmax

* fix styling

Co-authored-by: sidlak <sidlak@cs.washington.edu>
  • Loading branch information
sidlak-c137 and sidlak-c137 committed Sep 26, 2022
1 parent 957d029 commit da966fe
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tests/python/sparsetir/fusedmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,33 @@ def fusedmm(
O = T.match_sparse_buffer(o, [I, F], "float32")

score = T.alloc_sparse_buffer([I, J], "float32")
temp = T.alloc_sparse_buffer([I,], "float32")
temp1 = T.alloc_sparse_buffer([I,], "float32")
softmax = T.alloc_sparse_buffer([I, J], "float32")
# Q^T * K
with T.iter([I, J, F], "SSR", "sddmm") as [i, j, f]:
with T.init():
score[i, j] = T.float32(0)
score[i, j] += Q[i, f] * K[j, f]

# softmax
with T.iter([I], "S", "softmax") as [i]:
with T.iter([J], "R", "computer_max") as [j]:
with T.init():
temp[i] = score[i, j]
temp[i] = T.max(temp[i], score[i, j])
with T.iter([J], "R", "sum_of_exp") as [j]:
with T.init():
temp1[i] = T.float32(0)
temp1[i] += T.exp(score[i, j] - temp[i], dtype="float32")
with T.iter([J], "S", "normalize") as [j]:
softmax[i, j] = T.exp(score[i, j], dtype="float32") / temp1[i]

# softmax * V
with T.iter([I, J, F], "SRS", "spmm") as [i, j, f]:
with T.init():
O[i, f] = T.float32(0)
O[i, f] = O[i, f] + score[i, j] * V[j, f]
O[i, f] = O[i, f] + softmax[i, j] * V[j, f]


def bench_fusedmm():
Expand Down

0 comments on commit da966fe

Please sign in to comment.