From cced72d037f2a6bc0bb0475c8bc4c1762603aa94 Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:49:58 +0800 Subject: [PATCH] fix tensor data update for gemini loss caluculation (#5442) --- applications/Colossal-LLaMA-2/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index 2e4bab75a085..d97da61e4dc8 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -56,6 +56,7 @@ def format_numel_str(numel: int) -> str: def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor