Skip to content

Commit

Permalink
precision tests passed
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 18, 2024
1 parent 77f4eaf commit ca16753
Show file tree
Hide file tree
Showing 11 changed files with 585 additions and 316 deletions.
4 changes: 0 additions & 4 deletions colossalai/lazy/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def new_from_pretrained(
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
Expand Down Expand Up @@ -116,7 +115,6 @@ def new_from_pretrained(
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
Expand Down Expand Up @@ -195,7 +193,6 @@ def new_from_pretrained(
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"user_agent": user_agent,
Expand Down Expand Up @@ -312,7 +309,6 @@ def new_from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
Expand Down
4 changes: 0 additions & 4 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,11 +812,7 @@ def backward(ctx, *grad_output):
process_group = ctx.process_group
scatter_dim = ctx.gather_dim
gather_dim = ctx.scatter_dim
if torch.distributed.get_rank() == 0:
print(f"shape before A2A: {grad_output[0].shape}")
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
if torch.distributed.get_rank() == 0:
print(f"shape after A2A: {return_grad.shape}")
return (return_grad, None, None, None)


Expand Down
Loading

0 comments on commit ca16753

Please sign in to comment.