Skip to content

Commit

Permalink
Fix error when tp > 1 (#2644)
Browse files Browse the repository at this point in the history
Co-authored-by: zhaoyang-star <zhao.yang16@zte.com.cn>
  • Loading branch information
zhaoyang-star and zhaoyang-star committed Jan 29, 2024
1 parent 9090bf0 commit b72af8f
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
cache_config = copy.deepcopy(self.cache_config)

for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers,
Expand All @@ -252,7 +251,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
rank,
distributed_init_method,
lora_config=self.lora_config,
cache_config=cache_config,
kv_cache_dtype=self.cache_config.cache_dtype,
))

driver_rank = 0
Expand All @@ -265,7 +264,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
cache_config=cache_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)

Expand Down

0 comments on commit b72af8f

Please sign in to comment.