Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdpa_fp8 having different seqlen_q and seqlen_k #112

Open
MustafaFayez opened this issue Sep 21, 2024 · 0 comments
Open

sdpa_fp8 having different seqlen_q and seqlen_k #112

MustafaFayez opened this issue Sep 21, 2024 · 0 comments

Comments

@MustafaFayez
Copy link

Hi I tried running a sdpa_fp8 graph where seqlen_q and seqlen_k are different, however it seems that it only uses the seqlen_q as in performance is the same when I only sweep seqlen_k, here is the func I wrote:

def cudnn_spda_setup(q, k, v, seqlen_q, seqlen_k, causal=False):
    b, _, nheads, headdim = q.shape
    assert cudnn is not None, 'CUDNN is not available'
    device = q.device
    o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=q.dtype, device=device)
    o_gpu_transposed = torch.as_strided(
        o_gpu,
        [b, nheads, seqlen_q, headdim],
        [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1],
    )
    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=device)
    amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=device)
    amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=device)
    graph = cudnn.pygraph(
        io_data_type=convert_to_cudnn_type(q.dtype),
        intermediate_data_type=cudnn.data_type.FLOAT,
        compute_data_type=cudnn.data_type.FLOAT,
    )
    shape_q = (b, nheads, seqlen_q, headdim)
    shape_k = (b, nheads, seqlen_k, headdim)
    shape_v = (b, nheads, seqlen_k, headdim)
    shape_o = (b, nheads, seqlen_q, headdim)
    qkv_num_elems = math.prod(shape_q) + math.prod(shape_k) + math.prod(shape_v)
    (stride_q, stride_k, stride_v, stride_o, offset_q, offset_k, offset_v) = generate_layout(
        shape_q,
        shape_k)
    qkv = torch.randn(qkv_num_elems, dtype=torch.float16, device="cuda")
    qkv_gpu = qkv.to(q.dtype)
    q_gpu = torch.as_strided(qkv_gpu, shape_q, stride_q, storage_offset=offset_q)
    k_gpu = torch.as_strided(qkv_gpu, shape_k, stride_k, storage_offset=offset_k)
    v_gpu = torch.as_strided(qkv_gpu, shape_v, stride_v, storage_offset=offset_v)
    q = graph.tensor(name = "Q",
        dim = list(q_gpu.shape),
        stride = list(q_gpu.stride()),
        data_type=convert_to_cudnn_type(qkv_gpu.dtype)
    )
    k = graph.tensor(name = "K",
        dim = list(k_gpu.shape),
        stride = list(k_gpu.stride()),
        data_type=convert_to_cudnn_type(qkv_gpu.dtype)
    )
    v = graph.tensor(name = "V",
        dim = list(v_gpu.shape),
        stride = list(v_gpu.stride()),
        data_type=convert_to_cudnn_type(qkv_gpu.dtype)
    )
    def get_default_scale_tensor():
        return graph.tensor(
            dim = [1, 1, 1, 1],
            stride = [1, 1, 1, 1],
            data_type=cudnn.data_type.FLOAT
        )

    default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda")
    descale_q = get_default_scale_tensor()
    descale_k = get_default_scale_tensor()
    descale_v = get_default_scale_tensor()
    descale_s = get_default_scale_tensor()
    scale_s = get_default_scale_tensor()
    scale_o = get_default_scale_tensor()

    o, _, amax_s, amax_o = graph.sdpa_fp8(
        q=q,
        k=k,
        v=v,
        descale_q=descale_q,
        descale_k=descale_k,
        descale_v=descale_v,
        descale_s=descale_s,
        scale_s=scale_s,
        scale_o=scale_o,
        is_inference=True,
        attn_scale=1.0 / math.sqrt(headdim),
        use_causal_mask=causal,
        name="sdpa",
    )

    o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride())

    amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride())
    amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride())
    # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)

    graph.validate()
    graph.build_operation_graph()
    graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
    graph.check_support()
    graph.build_plans()

    variant_pack = {
        q: q_gpu,
        k: k_gpu,
        v: v_gpu,
        descale_q: default_scale_gpu,
        descale_k: default_scale_gpu,
        descale_v: default_scale_gpu,
        descale_s: default_scale_gpu,
        scale_s: default_scale_gpu,
        scale_o: default_scale_gpu,
        o: o_gpu_transposed,
        amax_s: amax_s_gpu,
        amax_o: amax_o_gpu,
    }

    workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)

    def run(*args, **kwargs):
        graph.execute(variant_pack, workspace)
        return o_gpu, amax_o_gpu

    return run

what am I doing wrong?
Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant