Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepNVMe ZeRO-inf Tutorial #921

Merged
merged 9 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions deepnvme/zero_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Using DeepNVMe for ZeRO-Inference
ZeRO-inference is an ideal use case for the DeepNVMe technology. When you have a model that exceeds the size of availabe GPU memory the [DeepNVMe](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md) library along with ZeRO-inference can be leveraged for high-throughput offline inference.

Maximizing inference throughput (measured in tokens/sec) in this scenario has two parts. First offloading the model parameters to fast Non-Volatile Memory, either a single device or several devices RAIDed together to further increase the effective bandiwidth of the system. These parameters are then swapped into the GPU memory layer by layer to compute the forward pass for inference. This allows for the second part of the process, maximizing the batch size. By swapping in parameters layer by layer the remaining GPU memory can be used by the computational batch which leads to a maximizing of total inference throughput.

## Testing Environment
The environment for these tests was a VM with NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS) installed along with a single NVIDIA H100 GPU containing 96 GB of memory. The VM also had two NVMes each with a read bandwidth of ~6.5 GB/sec. The two NVMes were put into a RAID0 configuration, bringing the effective read bandwidth up to ~13 GB/sec.
<div align="center">
<img src="./media/nvme_config.png" style="width:6.5in;height:3.42153in" />
</div>

## Initial Results
The following models were run from the folder DeepSpeedExamples/inference/huggingface/zero_inference using disk-offload of parameters via the following command:

```bash
deepspeed --num_gpus 1 run_model.py --model $model_name --batch_size $bsz --prompt-len 512 --gen-len 32 --disk-offload $path_to_foler --use_gds
```

Where `--use_gds` is set to enable NVIDIA GDS and move parameters directly between the NVMe and GPU, otherwise an intermediate CPU bounce buffer will be used to move the parameters between the NVMe and GPU.

All models tested were chosen so they could not fit into 96 GB of GPU memory.

GDS | Mixtral-8x22B | Llama3-70B | Bloom-176B
|---|---|---|---|
False | 9.152(bsz=200) | 8.606(bsz=96) | 0.291(bsz=8) |
True | 9.233(bsz=200) | 8.876(bsz=96) | 0.293(bsz=8) |

Throughput measured in tokens/sec.
Binary file added deepnvme/zero_inference/media/nvme_config.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 25 additions & 6 deletions inference/huggingface/zero_inference/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_ds_model(
},
"zero_optimization": {
"stage": 3,
"stage3_prefetch_bucket_size": 2 * hidden_size * hidden_size, # 0,
"stage3_prefetch_bucket_size": 2 * hidden_size * hidden_size,
"stage3_param_persistence_threshold": hidden_size,
"stage3_max_live_parameters": 2 * hidden_size * hidden_size,
},
Expand All @@ -105,17 +105,29 @@ def get_ds_model(
)

if disk_offload:
if config.model_type == 'bloom':
buffer_count = 3 if args.use_gds else 5
buffer_size = 8*GB if args.use_gds else 9*GB

elif config.model_type == 'mixtral':
buffer_count = 10
buffer_size = 1*GB
else:
buffer_count = 5
buffer_size = 2*GB

ds_config["zero_optimization"]["offload_param"] = dict(
device="nvme",
pin_memory=pin_memory,
nvme_path=offload_dir,
buffer_count=5,
buffer_size=9 * GB if config.model_type == 'bloom' else 2 * GB,
buffer_count=buffer_count,
buffer_size=buffer_size,
)
ds_config["aio"] = {
"block_size": 1048576,
"queue_depth": 8,
"thread_count": 1,
"block_size": 1048576*16,
"queue_depth": 64,
"thread_count": 8,
"use_gds": args.use_gds,
"single_submit": False,
"overlap_events": True,
}
Expand All @@ -140,6 +152,10 @@ def get_ds_model(
model = LlamaForCausalLM.from_pretrained(
dummy_weights or model_name, torch_dtype=dtype,
)
elif config.model_type == "mixtral":
model = AutoModelForCausalLM.from_pretrained(
dummy_weights or model_name, torch_dtype=dtype,
)
else:
raise ValueError(f"Unexpected model type: {config.model_type}")

Expand Down Expand Up @@ -192,6 +208,8 @@ def run_generation(
model = BloomForCausalLM(config)
elif config.model_type == "llama":
model = LlamaForCausalLM(config)
elif config.model_type == "mixtral":
model = AutoModelForCausalLM(config)
else:
raise ValueError(f"Unexpected model type: {config.model_type}")
model.save_pretrained(
Expand Down Expand Up @@ -354,6 +372,7 @@ def remove_model_hooks(module):
parser.add_argument("--quant_group_size", type=int, default=64, help="model weight quantization group size")
parser.add_argument("--pin_kv_cache", action="store_true", help="Allocate kv cache in pinned memory for offloading.")
parser.add_argument("--async_kv_offload", action="store_true", help="Using non_blocking copy for kv cache offloading.")
parser.add_argument("--use_gds", action="store_true", help="Use NVIDIA GPU DirectStorage to transfer between NVMe and GPU.")
args = parser.parse_args()

deepspeed.init_distributed()
Expand Down
Loading