Skip to content

Commit

Permalink
Merge branch 'feature/lora' of github.com:hpcaitech/ColossalAI into r…
Browse files Browse the repository at this point in the history
…ebase/lora
  • Loading branch information
linsj20 committed Apr 22, 2024
2 parents 4be48f4 + 52a2dde commit 4a322e4
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 47 deletions.
1 change: 1 addition & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from colossalai.quantization import BnbQuantizationConfig
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
Expand Down
8 changes: 8 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,9 @@ def support_no_sync(self) -> bool:
def support_lora(self) -> bool:
return False

def support_lora(self) -> bool:
return False

def control_checkpoint_io(self) -> bool:
return True

Expand Down Expand Up @@ -1354,3 +1357,8 @@ def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> Module:
raise NotImplementedError

def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> Module:
raise NotImplementedError
75 changes: 75 additions & 0 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
sharded_optimizer_loading_epilogue,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -408,6 +410,79 @@ def add_lora_params_to_optimizer(self, model, optimizer):
):
optimizer.param_groups[group_id]["params"].append(param)

def support_lora(self) -> bool:
return True

def enable_lora(
self,
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: Optional[Dict] = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
) -> nn.Module:
from peft import PeftModel, get_peft_model

assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)

if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
else:
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
return peft_model

def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
origin_param_id = id(origin_param)
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group["params"]:
if id(p) == origin_param_id:
return group_id
return -1

def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
origin_param_id = id(origin_param)
lora_param_id = id(lora_param)
target_group_id = None
for group_id, param_group in enumerate(optimizer.param_groups):
for p in param_group["params"]:
if id(p) == lora_param_id:
# check if the lora parameter exists.
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
if id(p) == origin_param_id:
target_group_id = group_id
if target_group_id is not None:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
else:
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND

def add_lora_params_to_optimizer(self, model, optimizer):
"""add lora parameters to optimizer"""
name2param = {}
for name, param in model.named_parameters():
name2param[name] = param

for name, param in name2param.items():
if "lora_A" in name or "lora_B" in name:
origin_key = name.replace("lora_A.", "")
origin_key = origin_key.replace("lora_B.", "")
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
and group_id is not None
and group_id >= 0
):
optimizer.param_groups[group_id]["params"].append(param)

def configure(
self,
model: nn.Module,
Expand Down
1 change: 1 addition & 0 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig, quantize_model

from .dp_plugin_base import DPPluginBase

Expand Down
1 change: 0 additions & 1 deletion colossalai/quantization/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def quantize_model(
if bnb_quantization_config.skip_modules is None:
bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)

# add cpu modules to skip modules only for 4-bit modules
modules_to_not_convert = bnb_quantization_config.skip_modules

# We add the modules we want to keep in full precision
Expand Down
1 change: 0 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@ datasets
pydantic
ray
peft>=0.7.1
bitsandbytes>=0.39.0
#auto-gptq now not support torch1.12
68 changes: 23 additions & 45 deletions tests/test_lora/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,34 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
},
]
for plugin, test_config in product(test_plugins, test_configs):
test_model = copy.deepcopy(model)
# checkpoint loaded model
model_save = model_fn()
model_load = copy.deepcopy(model_save)

optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn

booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, **test_config)
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)

test_model = booster.enable_lora(test_model, **test_config)
model_copy = copy.deepcopy(test_model)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()

optimizer = AdamW(test_model.parameters(), lr=0.001)
criterion = loss_fn
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1

test_model, optimizer, criterion, _, _ = booster.boost(test_model, optimizer, criterion)
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
model_load, _, _, _, _ = booster.boost(model_load)

check_state_dict_equal(model_save.state_dict(), model_load.state_dict())

# test fwd bwd correctness
test_model = model_load
model_copy = copy.deepcopy(model_load)

data = data_gen_fn()
data = {
Expand All @@ -67,44 +84,6 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)


@clear_cache_before_run()
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)

test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_configs = [
{
"lora_config": lora_config,
"quantize": False,
},
{
"lora_config": lora_config,
"quantize": True,
},
]
for plugin, test_config in product(test_plugins, test_configs):
model_save = model_fn()
model_load = copy.deepcopy(model_save)

booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, **test_config)
model_save, _, _, _, _ = booster.boost(model_save)

with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()

# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1

model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
model_load, _, _, _, _ = booster.boost(model_load)

check_state_dict_equal(model_save.state_dict(), model_load.state_dict())


def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
Expand All @@ -114,7 +93,6 @@ def run_lora_test():
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)


def run_dist(rank, world_size, port):
Expand Down

0 comments on commit 4a322e4

Please sign in to comment.