Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support fused silu mul #427

Merged
merged 2 commits into from
Aug 9, 2024
Merged

feat: support fused silu mul #427

merged 2 commits into from
Aug 9, 2024

Conversation

zhyncs
Copy link
Member

@zhyncs zhyncs commented Aug 7, 2024

Motivation

as titled

I implemented a simplified version based on FasterTransformers, and I am considering whether to use optimizations like half2, and whether to consider using CUTLASS's LeftSiLUAndMul. Do you have any suggestions? Thanks. @yzh119

Modification

  • fused silu mul

@zhyncs zhyncs requested a review from yzh119 August 7, 2024 19:56
@zhyncs zhyncs self-assigned this Aug 7, 2024
@zhyncs zhyncs added the wip work in progress label Aug 7, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Aug 7, 2024

Using cutlass would be great if they already incorporate half2 operations.

@zhyncs
Copy link
Member Author

zhyncs commented Aug 7, 2024

Using cutlass would be great if they already incorporate half2 operations.

make sense

@zhyncs zhyncs marked this pull request as draft August 7, 2024 20:05
@yzh119
Copy link
Collaborator

yzh119 commented Aug 7, 2024

IMO torch.compile can eventually fuse all these element-wise operations without external custom operators.

I'm okay with introducing these new operators as a workaround solution, and it's preferrable to use existing building blocks to minimize the maintainance overhead. Regarding this operator, can we try using triton directly? I think triton should already supported opitimizations such as half2.

@zhyncs
Copy link
Member Author

zhyncs commented Aug 7, 2024

Ok. I’ll take a look. Thanks!

@zhyncs
Copy link
Member Author

zhyncs commented Aug 8, 2024

import torch
from torch.utils.benchmark import Timer
from itertools import product

from vllm import _custom_ops as ops
from flashinfer.activation import silu_and_mul as flashinfer_silu_and_mul
from flag_gems import silu_and_mul as flag_gems_silu_and_mul

def forward_vllm(x: torch.Tensor) -> torch.Tensor:
    d = x.shape[-1] // 2
    output_shape = x.shape[:-1] + (d,)
    out = torch.empty(output_shape, dtype=torch.float16, device=x.device)
    ops.silu_and_mul(out, x)
    return out

def forward_flashinfer(x: torch.Tensor) -> torch.Tensor:
    d = x.shape[-1] // 2
    out = torch.empty((*x.shape[:-1], d), dtype=torch.float16, device=x.device)
    flashinfer_silu_and_mul(x, out)
    return out

def forward_flag_gems(x: torch.Tensor) -> torch.Tensor:
    d = x.shape[-1] // 2
    return flag_gems_silu_and_mul(x[..., :d], x[..., d:])

def test_consistency():
    x = torch.randn(2, 4, 2*d, dtype=torch.float16, device=device)
    out_vllm = forward_vllm(x)
    out_flashinfer = forward_flashinfer(x)
    out_flag_gems = forward_flag_gems(x)
    assert torch.allclose(out_vllm, out_flashinfer, atol=1e-3, rtol=1e-3)
    assert torch.allclose(out_vllm, out_flag_gems, atol=1e-3, rtol=1e-3)
    assert torch.allclose(out_flashinfer, out_flag_gems, atol=1e-3, rtol=1e-3)
    print("Consistency test passed!")

device = torch.device("cuda")
d = 4096

test_consistency()

results = []
sizes = [2, 8, 32, 128, 512]

for batch_size, seq_length in product(sizes, sizes):
    label = "SiLU and Mul"
    sub_label = f"[{batch_size}, {seq_length}]"

    input_tensor = torch.randn(batch_size, seq_length, 2*d, dtype=torch.float16, device=device)
    
    min_run_time = max(0.1, min(1, batch_size * seq_length / 1e6))

    for num_threads in [1, 4, 16, 32]:
        results.append(
            Timer(
                stmt="forward_vllm(input_tensor)",
                setup="from __main__ import forward_vllm",
                globals={"input_tensor": input_tensor},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="vLLM",
            ).blocked_autorange(min_run_time=min_run_time)
        )

        results.append(
            Timer(
                stmt="forward_flashinfer(input_tensor)",
                setup="from __main__ import forward_flashinfer",
                globals={"input_tensor": input_tensor},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="FlashInfer",
            ).blocked_autorange(min_run_time=min_run_time)
        )

        results.append(
            Timer(
                stmt="forward_flag_gems(input_tensor)",
                setup="from __main__ import forward_flag_gems",
                globals={"input_tensor": input_tensor},
                num_threads=num_threads,
                label=label,
                sub_label=sub_label,
                description="Flag_gems",
            ).blocked_autorange(min_run_time=min_run_time)
        )

compare = torch.utils.benchmark.Compare(results)
compare.print()
Consistency test passed!
[-------------------- SiLU and Mul --------------------]
                  |   vLLM   |  FlashInfer  |  Flag_gems
1 threads: ---------------------------------------------
      [2, 2]      |    11.8  |       7.8    |     123.5
      [2, 8]      |    11.7  |       7.9    |     121.3
      [2, 32]     |    12.4  |       7.9    |     123.7
      [2, 128]    |    12.1  |       7.8    |     123.5
      [2, 512]    |    17.6  |      15.6    |     122.1
      [8, 2]      |    12.0  |       7.9    |     121.1
      [8, 8]      |    12.0  |       8.0    |     109.4
      [8, 32]     |    12.3  |       8.0    |     124.1
      [8, 128]    |    17.6  |      15.5    |     117.6
      [8, 512]    |    90.2  |      85.7    |     122.7
      [32, 2]     |    11.9  |       7.9    |     123.6
      [32, 8]     |    12.8  |       8.3    |     121.3
      [32, 32]    |    17.6  |      15.5    |     123.2
      [32, 128]   |    90.3  |      85.7    |     120.7
      [32, 512]   |   347.5  |     329.6    |     264.4
      [128, 2]    |    12.2  |       8.0    |     121.9
      [128, 8]    |    17.7  |      15.5    |     125.1
      [128, 32]   |    90.3  |      85.7    |     121.2
      [128, 128]  |   347.3  |     330.0    |     264.4
      [128, 512]  |  1375.4  |    1305.8    |     967.7
      [512, 2]    |    17.8  |      15.8    |     123.4
      [512, 8]    |    90.2  |      85.7    |     122.3
      [512, 32]   |   347.6  |     329.3    |     264.6
      [512, 128]  |  1375.7  |    1305.9    |     967.4
      [512, 512]  |  5491.3  |    5204.4    |    3855.6
4 threads: ---------------------------------------------
      [2, 2]      |    12.4  |       8.1    |     122.2
      [2, 8]      |    11.9  |       7.9    |     125.1
      [2, 32]     |    11.8  |       7.9    |     115.1
      [2, 128]    |    11.1  |       8.0    |     123.4
      [2, 512]    |    17.6  |      15.6    |     124.8
      [8, 2]      |    11.8  |       7.8    |     122.2
      [8, 8]      |    12.2  |       8.0    |     122.5
      [8, 32]     |    12.9  |       7.9    |     122.0
      [8, 128]    |    17.6  |      15.6    |     120.8
      [8, 512]    |    90.3  |      85.7    |     122.6
      [32, 2]     |    11.9  |       7.9    |     122.7
      [32, 8]     |    12.9  |       8.1    |     123.2
      [32, 32]    |    17.6  |      15.5    |     124.2
      [32, 128]   |    90.3  |      85.8    |     123.5
      [32, 512]   |   347.7  |     330.0    |     264.3
      [128, 2]    |    12.1  |       7.9    |     123.0
      [128, 8]    |    17.6  |      15.6    |     125.4
      [128, 32]   |    90.2  |      85.7    |     122.8
      [128, 128]  |   347.5  |     329.3    |     264.3
      [128, 512]  |  1376.8  |    1309.1    |     968.9
      [512, 2]    |    17.8  |      15.8    |     123.9
      [512, 8]    |    90.3  |      85.7    |     124.0
      [512, 32]   |   347.4  |     329.6    |     264.6
      [512, 128]  |  1376.7  |    1304.5    |     967.6
      [512, 512]  |  5490.0  |    5211.0    |    3853.5
16 threads: --------------------------------------------
      [2, 2]      |    11.9  |       7.8    |     122.6
      [2, 8]      |    12.2  |       8.0    |     123.2
      [2, 32]     |    12.1  |       8.1    |     121.7
      [2, 128]    |    12.0  |       8.0    |     122.0
      [2, 512]    |    17.6  |      15.5    |     123.9
      [8, 2]      |    11.9  |       7.9    |     121.9
      [8, 8]      |    12.0  |       8.2    |     122.2
      [8, 32]     |    12.5  |       8.0    |     120.7
      [8, 128]    |    17.6  |      15.5    |     122.8
      [8, 512]    |    90.3  |      85.8    |     121.8
      [32, 2]     |    12.3  |       8.0    |     121.8
      [32, 8]     |    12.4  |       8.1    |     122.9
      [32, 32]    |    17.6  |      15.6    |     124.9
      [32, 128]   |    90.2  |      85.8    |     121.7
      [32, 512]   |   347.5  |     329.7    |     264.1
      [128, 2]    |    12.0  |       8.0    |     124.7
      [128, 8]    |    17.6  |      15.5    |     123.3
      [128, 32]   |    90.2  |      85.7    |     122.5
      [128, 128]  |   347.6  |     329.4    |     264.9
      [128, 512]  |  1375.7  |    1306.5    |     967.8
      [512, 2]    |    17.8  |      15.8    |     122.6
      [512, 8]    |    90.3  |      85.7    |     122.1
      [512, 32]   |   347.5  |     329.7    |     264.8
      [512, 128]  |  1376.2  |    1303.8    |     968.1
      [512, 512]  |  5491.3  |    5205.8    |    3860.3
32 threads: --------------------------------------------
      [2, 2]      |    11.6  |       7.9    |     123.9
      [2, 8]      |    12.1  |       7.9    |     124.2
      [2, 32]     |    12.1  |       8.0    |     122.2
      [2, 128]    |    12.2  |       8.0    |     123.5
      [2, 512]    |    17.6  |      15.5    |     125.5
      [8, 2]      |    12.0  |       8.1    |     120.9
      [8, 8]      |    11.9  |       8.0    |     122.2
      [8, 32]     |    12.5  |       8.0    |     123.0
      [8, 128]    |    17.6  |      15.5    |     124.5
      [8, 512]    |    90.2  |      85.7    |     122.9
      [32, 2]     |    11.9  |       8.2    |     122.5
      [32, 8]     |    12.2  |       8.1    |     124.3
      [32, 32]    |    17.6  |      15.5    |     124.4
      [32, 128]   |    90.2  |      85.7    |     122.8
      [32, 512]   |   347.5  |     329.5    |     264.3
      [128, 2]    |    12.2  |       7.9    |     122.8
      [128, 8]    |    17.6  |      15.5    |     124.3
      [128, 32]   |    90.3  |      85.6    |     123.9
      [128, 128]  |   347.4  |     329.4    |     264.1
      [128, 512]  |  1378.5  |    1304.8    |     967.9
      [512, 2]    |    17.9  |      15.8    |     123.8
      [512, 8]    |    90.2  |      85.8    |     122.6
      [512, 32]   |   347.5  |     329.4    |     264.6
      [512, 128]  |  1376.7  |    1304.8    |     968.6
      [512, 512]  |  5492.6  |    5208.3    |    3854.1

Times are in microseconds (us).

@yzh119 yzh119 marked this pull request as ready for review August 9, 2024 08:00
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs I did some simple change to the code (use vectorized read/write), and here is the results I got (by using triton's do_bench function) on H100:

Consistency test passed!
batch_size: 2 seq_length: 2 vllm_time: 0.007171261124312878
batch_size: 2 seq_length: 2 flashinfer_time: 0.005875087808817625
batch_size: 2 seq_length: 2 flaggems_time: 0.02994345873594284
batch_size: 2 seq_length: 8 vllm_time: 0.007260866463184357
batch_size: 2 seq_length: 8 flashinfer_time: 0.005772186443209648
batch_size: 2 seq_length: 8 flaggems_time: 0.0059105088002979755
batch_size: 2 seq_length: 32 vllm_time: 0.0077180881053209305
batch_size: 2 seq_length: 32 flashinfer_time: 0.006187621038407087
batch_size: 2 seq_length: 32 flaggems_time: 0.006364865694195032
batch_size: 2 seq_length: 128 vllm_time: 0.009424506686627865
batch_size: 2 seq_length: 128 flashinfer_time: 0.00816467683762312
batch_size: 2 seq_length: 128 flaggems_time: 0.008360029198229313
batch_size: 2 seq_length: 512 vllm_time: 0.02061079442501068
batch_size: 2 seq_length: 512 flashinfer_time: 0.014950418844819069
batch_size: 2 seq_length: 512 flaggems_time: 0.014861035160720348
batch_size: 8 seq_length: 2 vllm_time: 0.007269856985658407
batch_size: 8 seq_length: 2 flashinfer_time: 0.005773282144218683
batch_size: 8 seq_length: 2 flaggems_time: 0.005844910629093647
batch_size: 8 seq_length: 8 vllm_time: 0.00772811146453023
batch_size: 8 seq_length: 8 flashinfer_time: 0.006187872029840946
batch_size: 8 seq_length: 8 flaggems_time: 0.006329760421067476
batch_size: 8 seq_length: 32 vllm_time: 0.009468046016991138
batch_size: 8 seq_length: 32 flashinfer_time: 0.00817921757698059
batch_size: 8 seq_length: 32 flaggems_time: 0.008257889188826084
batch_size: 8 seq_length: 128 vllm_time: 0.020637067034840584
batch_size: 8 seq_length: 128 flashinfer_time: 0.015106520615518093
batch_size: 8 seq_length: 128 flaggems_time: 0.015257231891155243
batch_size: 8 seq_length: 512 vllm_time: 0.06076494976878166
batch_size: 8 seq_length: 512 flashinfer_time: 0.04020121321082115
batch_size: 8 seq_length: 512 flaggems_time: 0.04041324928402901
batch_size: 32 seq_length: 2 vllm_time: 0.007802661973983049
batch_size: 32 seq_length: 2 flashinfer_time: 0.006300441455096006
batch_size: 32 seq_length: 2 flaggems_time: 0.00637076934799552
batch_size: 32 seq_length: 8 vllm_time: 0.009482021443545818
batch_size: 32 seq_length: 8 flashinfer_time: 0.008183696307241917
batch_size: 32 seq_length: 8 flaggems_time: 0.008226810954511166
batch_size: 32 seq_length: 32 vllm_time: 0.020641470327973366
batch_size: 32 seq_length: 32 flashinfer_time: 0.015115585178136826
batch_size: 32 seq_length: 32 flaggems_time: 0.015271436423063278
batch_size: 32 seq_length: 128 vllm_time: 0.0607980377972126
batch_size: 32 seq_length: 128 flashinfer_time: 0.040251944214105606
batch_size: 32 seq_length: 128 flaggems_time: 0.04044438898563385
batch_size: 32 seq_length: 512 vllm_time: 0.21253922581672668
batch_size: 32 seq_length: 512 flashinfer_time: 0.1371561884880066
batch_size: 32 seq_length: 512 flaggems_time: 0.153084397315979
batch_size: 128 seq_length: 2 vllm_time: 0.00945486780256033
batch_size: 128 seq_length: 2 flashinfer_time: 0.008165393956005573
batch_size: 128 seq_length: 2 flaggems_time: 0.008223879151046276
batch_size: 128 seq_length: 8 vllm_time: 0.020657455548644066
batch_size: 128 seq_length: 8 flashinfer_time: 0.015147659927606583
batch_size: 128 seq_length: 8 flaggems_time: 0.015288702212274075
batch_size: 128 seq_length: 32 vllm_time: 0.06075974926352501
batch_size: 128 seq_length: 32 flashinfer_time: 0.04024820774793625
batch_size: 128 seq_length: 32 flaggems_time: 0.04044437035918236
batch_size: 128 seq_length: 128 vllm_time: 0.2123134285211563
batch_size: 128 seq_length: 128 flashinfer_time: 0.13708913326263428
batch_size: 128 seq_length: 128 flaggems_time: 0.15339134633541107
batch_size: 128 seq_length: 512 vllm_time: 0.8181041479110718
batch_size: 128 seq_length: 512 flashinfer_time: 0.5250738263130188
batch_size: 128 seq_length: 512 flaggems_time: 0.5300045013427734
batch_size: 512 seq_length: 2 vllm_time: 0.020511353388428688
batch_size: 512 seq_length: 2 flashinfer_time: 0.01491069421172142
batch_size: 512 seq_length: 2 flaggems_time: 0.015027211979031563
batch_size: 512 seq_length: 8 vllm_time: 0.060630060732364655
batch_size: 512 seq_length: 8 flashinfer_time: 0.040194932371377945
batch_size: 512 seq_length: 8 flaggems_time: 0.04028919339179993
batch_size: 512 seq_length: 32 vllm_time: 0.2125125527381897
batch_size: 512 seq_length: 32 flashinfer_time: 0.13712455332279205
batch_size: 512 seq_length: 32 flaggems_time: 0.15308579802513123
batch_size: 512 seq_length: 128 vllm_time: 0.818162202835083
batch_size: 512 seq_length: 128 flashinfer_time: 0.5249825119972229
batch_size: 512 seq_length: 128 flaggems_time: 0.529996395111084
batch_size: 512 seq_length: 512 vllm_time: 3.2437238693237305
batch_size: 512 seq_length: 512 flashinfer_time: 2.0770304203033447
batch_size: 512 seq_length: 512 flaggems_time: 2.1354780197143555

I think we achieve the best performance among the three in most cases. Let's merge this first and I don't want to spend too much time on optimizing elementwise kernels :)

@yzh119 yzh119 merged commit ea0ba9a into flashinfer-ai:main Aug 9, 2024
@zhyncs zhyncs deleted the silu branch August 9, 2024 08:09
@zhyncs zhyncs removed the wip work in progress label Aug 9, 2024
yzh119 added a commit that referenced this pull request Aug 9, 2024
🤖 I have created a release *beep* *boop*
---
##
[0.1.4](v0.1.3...v0.1.4)
(2024-08-09)


### Features

* append attention kernels for fp8 kv-cache
([#420](#420))
([906c2f5](906c2f5))
* support min_p sampling
([#422](#422))
([d52f2da](d52f2da))
* deterministic sampling
([#417](#417))
([0dd801d](0dd801d))
* more sampling operator options
([#431](#431))
([68df9c4](68df9c4))
* support fused add rmsnorm
([#419](#419))
([b781513](b781513))
* support fused silu mul
([#427](#427))
([ea0ba9a](ea0ba9a))

### Bug Fixes

* fix dispatch fp16 type when enable fp8
([#430](#430))
([daa5566](daa5566))
* improve numerical stability of sampling kernels
([#429](#429))
([898d8ea](898d8ea))

### Other improvements
* break up `_kernels` into multiple modules
([#428](#428))
([8e482d9](8e482d9))

### Acknowledgement

We thank contributions and feedbacks from the community:
[@comaniac](https://github.com/comaniac),
[@esmeetu](https://github.com/esmeetu),
[@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU),
[@peng1999](https://github.com/peng1999),
[@xslingcn](https://github.com/xslingcn),
[@Yard1](https://github.com/Yard1),
[@zhyncs](https://github.com/zhyncs).

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
@zhyncs zhyncs added the enhancement New feature or request label Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants