Skip to content

Commit

Permalink
Fix error when tp > 1 (vllm-project#2644)
Browse files Browse the repository at this point in the history
Co-authored-by: zhaoyang-star <zhao.yang16@zte.com.cn>
  • Loading branch information
2 people authored and NikolaBorisov committed Jan 31, 2024
1 parent deed1ff commit 5c1a1f8
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
cache_config = copy.deepcopy(self.cache_config)

for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers,
Expand All @@ -256,7 +255,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
rank,
distributed_init_method,
lora_config=self.lora_config,
cache_config=cache_config,
kv_cache_dtype=self.cache_config.cache_dtype,
))

driver_rank = 0
Expand All @@ -269,7 +268,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
cache_config=cache_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)

Expand Down

0 comments on commit 5c1a1f8

Please sign in to comment.