diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index da67e6b41fbf..211deea2d6d0 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -961,6 +961,7 @@ def __init__( pp_style: str = "1f1b", num_model_chunks: int = 1, enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 128, ) -> None: super().__init__() assert ( @@ -1033,6 +1034,7 @@ def __init__( enable_jit_fused=self.enable_jit_fused, enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index b5c9e66e0b87..114082bf9c9d 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -35,6 +35,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False extra_kwargs: Dict[str, Any] = field(default_factory=dict) + make_vocab_size_divisible_by: int = 128 # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']