Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda 11 support? #790

Closed
surak opened this issue Feb 24, 2021 · 16 comments
Closed

Cuda 11 support? #790

surak opened this issue Feb 24, 2021 · 16 comments

Comments

@surak
Copy link

surak commented Feb 24, 2021

We have 4,000 NVIDIA A100 GPUS and would like to use deepSpeed on them. Thing is, during setup.py:

 [WARNING]  sparse_attn requires CUDA version 10.1+, does not currently support >=11 or <10.1
 [WARNING]  sparse_attn requires CUDA version 10.1+, does not currently support >=11 or <10.1

By the way, the llvm line is wrong, too.

@arashashari
Copy link
Contributor

Sparse Attention currently only work on V100. However we will be updating soon to be compatible on A100 as well.
In case you don't need Sparse Attention in your tests, feel free to opt it out and you should be fine. Otherwise we will notify you when the update is in.

@avacaondata
Copy link

Excuse me, I'm having the same problem, as I want to use deepspeed on A-100 GPUs. How can I opt out sparse attention?? Will this have a dramatic impact on efficiency loss?
@arashashari

@surak
Copy link
Author

surak commented Feb 25, 2021

@arashashari Hello Arash! Thanks for the super quick reply.

Turns out we have another supercomputer with about 160 of those V100 gpus. But also there we only provide CUDA 11.0 for the users. So, it's really about CUDA 11, and not about the gpu type, for us.

@jeffra
Copy link
Collaborator

jeffra commented Feb 25, 2021

@alexvaca0 you by default will be opted out of installing/using sparse attention in your deepspeed install. You should be able to use all the other features of deepspeed just fine though. These are just warnings saying it won't be able to compile the sparse attention ops on your system.

@jeffra
Copy link
Collaborator

jeffra commented Feb 25, 2021

@surak: we're actively working on adding support for a100 + cuda 11 for sparse attention. Will hopefully update soon on this thread. Regarding v100 + cuda 11 we suspect this will work as is but have not had a chance or access to a machine with this config to test it out fully. Would you like to give it a try? if so here's a branch that allows this config:
https://github.com/microsoft/DeepSpeed/tree/sparse-attn-cuda11

@surak
Copy link
Author

surak commented Feb 26, 2021

@surak: we're actively working on adding support for a100 + cuda 11 for sparse attention. Will hopefully update soon on this thread. Regarding v100 + cuda 11 we suspect this will work as is but have not had a chance or access to a machine with this config to test it out fully. Would you like to give it a try? if so here's a branch that allows this config:
https://github.com/microsoft/DeepSpeed/tree/sparse-attn-cuda11

I had done EXACTLY the same patch myself, in order to at least let the thing install (hadn't have time to test it before, was just making sure everything installed). Thanks!

@surak
Copy link
Author

surak commented Feb 26, 2021

Other patches I have are:

Pip tries to install a newer triton, and the version does not really matter:

--- DeepSpeed-0.3.11/requirements/requirements-sparse_attn.txt.orig	2021-02-24 23:11:00.212886868 +0100
+++ DeepSpeed-0.3.11/requirements/requirements-sparse_attn.txt	2021-02-24 23:11:08.221726647 +0100
@@ -1 +1 @@
-triton==0.2.3
+triton>=0.2.3

The "or" kinda fails when llvm 10 is present.

--- DeepSpeed-0.3.11/op_builder/sparse_attn.py.orig	2021-02-24 23:01:30.222302088 +0100
+++ DeepSpeed-0.3.11/op_builder/sparse_attn.py	2021-02-24 23:03:24.696006596 +0100
@@ -21,7 +21,7 @@
 
     def is_compatible(self):
         # Check to see if llvm and cmake are installed since they are dependencies
-        required_commands = ['llvm-config|llvm-config-9', 'cmake']
+        required_commands = ['llvm-config', 'cmake']
         command_status = list(map(self.command_exists, required_commands))
         deps_compatible = all(command_status)

Tensorboard already changed

--- DeepSpeed-0.3.11/requirements/requirements.txt.orig	2021-02-24 20:30:14.442256660 +0100
+++ DeepSpeed-0.3.11/requirements/requirements.txt	2021-02-24 20:30:51.209512682 +0100
@@ -1,6 +1,6 @@
 torch>=1.2
 torchvision>=0.4.0
 tqdm
-tensorboardX==1.8
+tensorboardX>=1.8
 ninja

@caffeinetoomuch
Copy link

caffeinetoomuch commented Mar 3, 2021

actively working on adding support for a100 + cuda 11 for sparse attention. Will hopefully update soon on this thread. Regarding v100 + cuda 11 we suspect this will work as is but have not had a chance or access to

@jeffra Hi, hugely excited with this upcoming support! Would this update be compatible with A6000 + cuda11 to utilize the sparse attention? Also, would it have any issue with cuda 11.2?

@awilson9
Copy link

@surak: we're actively working on adding support for a100 + cuda 11 for sparse attention. Will hopefully update soon on this thread. Regarding v100 + cuda 11 we suspect this will work as is but have not had a chance or access to a machine with this config to test it out fully. Would you like to give it a try? if so here's a branch that allows this config:
https://github.com/microsoft/DeepSpeed/tree/sparse-attn-cuda11

What's the recommended way to install this? I tried

pip3 install https://github.com/microsoft/DeepSpeed/archive/sparse-attn-cuda11.zip

but that seems to have issues resolving internal dependencies

@aced125
Copy link

aced125 commented Apr 14, 2021

Hey @jeffra would be keen indeed if the SparseAttention kernel is compatible with A100!

@denti
Copy link

denti commented May 13, 2021

@jeffra I tried to run DeepSpeed https://github.com/microsoft/DeepSpeed/tree/sparse-attn/support-latest-triton on A100 and V100. And both pipelines failed with an error.
Here are steps that I did:

  1. docker run -it --gpus=all --rm -v /home/dtimonin:/home/dtimonin deepspeed/deepspeed:latest-torch170-cuda110 /bin/bash
  2. pip uninstall deepspeed
  3. git clone https://github.com/microsoft/DeepSpeed
  4. git fetch origin pull/902/head:pr902 && git checkout pr902 # Switching to new triton and A100 support
  5. Change requirements/requirements-sparse_attn.txt to triton==0.4 # because triton==1.0xxx was removed from repo and merged into 0.4
  6. DS_BUILD_OPS=1 python setup.py install # into DeepSpeed repo
  7. Trying to run code:
import torch
from deepspeed.ops.sparse_attention import SparseSelfAttention, BigBirdSparsityConfig

sparsity_config = BigBirdSparsityConfig(num_heads=1)
ssa = SparseSelfAttention(sparsity_config=sparsity_config)
a = torch.rand((1, 1, 16, 256))
attention_mask = torch.ones((16,16))
ssa(a,a,a, key_padding_mask=attention_mask)
  1. The error is:
/opt/conda/lib/python3.8/site-packages/deepspeed-0.3.13+2ea0fee-py3.8-linux-x86_64.egg/deepspeed/ops/sparse_attention/matmul.py:271: UserWarning: This overload of nonzero is deprecated:
        nonzero()
Consider using one of the following signatures instead:
        nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729096996/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  nnz = layout.nonzero()
Traceback (most recent call last):
  File "test_sparse_attention.py", line 8, in <module>
    ssa(a,a,a, key_padding_mask=attention_mask)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed-0.3.13+2ea0fee-py3.8-linux-x86_64.egg/deepspeed/ops/sparse_attention/sparse_self_attention.py", line 152, in forward
    attn_output_weights = sparse_dot_sdd_nt(query, key)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed-0.3.13+2ea0fee-py3.8-linux-x86_64.egg/deepspeed/ops/sparse_attention/matmul.py", line 720, in __call__
    c = _sparse_matmul.apply(a,
  File "/opt/conda/lib/python3.8/site-packages/deepspeed-0.3.13+2ea0fee-py3.8-linux-x86_64.egg/deepspeed/ops/sparse_attention/matmul.py", line 537, in forward
    c = _sparse_matmul.fn[mode](a,
  File "/opt/conda/lib/python3.8/site-packages/deepspeed-0.3.13+2ea0fee-py3.8-linux-x86_64.egg/deepspeed/ops/sparse_attention/matmul.py", line 204, in _sdd_matmul
    current = kernel(
  File "/opt/conda/lib/python3.8/site-packages/triton/kernel.py", line 116, in __call__
    kernel = self.fn.autotune(params, grid, self.stream)
RuntimeError: CUDA: Error- context is destroyed

@jeffra I have A100 and V100 servers and I'm ready to help you to test different updates operatively.

@ddomingof
Copy link

Any clues on the exact configuration to run DeepSpeed on CUDA11 and A100s? @surak

@surak
Copy link
Author

surak commented May 26, 2021

Any clues on the exact configuration to run DeepSpeed on CUDA11 and A100s? @surak

As you saw, I have some patches (mentioned above), and we run it directly on the system, no container involved.

@jeffra
Copy link
Collaborator

jeffra commented Jun 14, 2021

Hi @surak and @ddomingof, @RezaYazdaniAminabadi and @arashashari merged this PR #902 recently that should help along this line. Have you tried this again with our latest v0.4 release?

@hyunwoongko
Copy link
Contributor

@denti move layer and tensor to cuda. they located in cpu.

@loadams
Copy link
Contributor

loadams commented Aug 18, 2023

Closing this issue as it is stale - if anyone is hitting this issue, please re-open or link a new issue. Thanks!

@loadams loadams closed this as completed Aug 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests