From 10e4f7da724d3e45135d4544678793ecd3e74029 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 16 Sep 2024 13:45:04 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 2f8e4d677c54..5f0e9261c0de 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -121,7 +121,8 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size ) if size >= MEMORY_BOUND: - ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() + if isinstance(ColoAttention._flash_kernel_dispatch, KernelLoader): + ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load() # lazy load if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader): ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][