Skip to content

Commit

Permalink
LLM finetune sparsify masking (#278)
Browse files Browse the repository at this point in the history
* add functions to mask weights during finetuneing

* update logic for loading weights

* update yaml

* update mask name

* add logic to update batchsize based on gpu count

* make sparsify requirements less broad; move sparseml[transformers] to nm deps

* remove flash-attn

* quality
  • Loading branch information
dsikka committed Aug 23, 2023
1 parent 26d6f5f commit 2409ab8
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 23 deletions.
19 changes: 11 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
# load and overwrite version and release info from sparseml package
exec(open(os.path.join("src", "sparsify", "version.py")).read())
print(f"loaded version {version} from src/sparsify/version.py")
version_nm_deps = f"{version_major_minor}.0"
version_nm_deps = f"{version_major_minor}.0.202308"

_PACKAGE_NAME = "sparsify" if is_release else "sparsify-nightly"


_deps = [
"pydantic>=1.8.2,<2.0.0",
"pyyaml>=5.0.0",
Expand All @@ -39,13 +38,14 @@
"setuptools>=56.0.0",
"optuna>=3.0.2",
"onnxruntime-gpu",
]
_nm_deps = [
f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_nm_deps}",
f"{'sparseml' if is_release else 'sparseml-nightly'}[torchvision,transformers,yolov5]~={version_nm_deps}", # noqa E501
f"{'deepsparse' if is_release else 'deepsparse-nightly'}~={version_nm_deps}",
f"{'sparseml' if is_release else 'sparseml-nightly'}[torchvision,yolov5]~={version_nm_deps}", # noqa E501
]

_nm_deps = [
f"{'sparseml' if is_release else 'sparseml-nightly'}[transformers]~={version_nm_deps}", # noqa E501
]

_dev_deps = [
"black>=20.8b1",
Expand All @@ -56,7 +56,10 @@
"fastai>=2.7.7",
]

_llm_deps = ["llm-foundry==0.2.0"]
_llm_deps = [
"llm-foundry==0.2.0",
f"{'nm-transformers' if is_release else 'nm-transformers-nightly'}",
]


def _setup_packages() -> List:
Expand All @@ -70,11 +73,11 @@ def _setup_package_dir() -> Dict:


def _setup_install_requires() -> List:
return _nm_deps + _deps
return _deps


def _setup_extras() -> Dict:
return {"dev": _dev_deps, "llm": _llm_deps}
return {"dev": _dev_deps, "_nm_deps": _nm_deps, "llm": _llm_deps}


def _setup_entry_points() -> Dict:
Expand Down
52 changes: 39 additions & 13 deletions src/sparsify/auto/samples/finetune_llmfoundry_sample.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
max_seq_len: 512
max_seq_len: 2048
global_seed: 17
model_name_or_path: t5-small
model_name: ${model_name_or_path}
load_path: # set via bash script to be absolute path to your sparse checkpoint
model_name_or_path: mosaicml/mpt-7b-instruct
load_path: /storage/dsikka/mpt_7b_instruct_oneshot_sp70.pt
precision: amp_bf16

max_duration: 1ep # run for 2 epochs
max_duration: 1ep
eval_interval: 1ep
# eval_subset_num_batches: 3 # use this for quick testing
eval_first: true
seed: ${global_seed}

device_train_microbatch_size: 1 # set to catch potential OOM cuda errors
device_train_batch_size: 4
device_eval_batch_size: 4
global_train_batch_size: 1
# for mpt-7b dense:
# 4 x A100_80GB = "device_train_microbatch_size: 12"
# 8 x A6000_48GB = "device_train_microbatch_size: 6"

# for mpt-7b sparse (with masks):
# 8 x A6000_48GB = "device_train_microbatch_size: 4"
device_train_batch_size: 1
device_train_microbatch_size: 1
device_eval_batch_size: 1

# Run Name
run_name: testing_run
run_name: test_run

model:
name: hf_t5
name: hf_causal_lm
pretrained: true
pretrained_model_name_or_path: t5-small
pretrained_model_name_or_path: mosaicml/mpt-7b-instruct
max_seq_len: ${max_seq_len}
config_overrides:
attn_config:
attn_impl: torch
# Set this to `true` if using `train_loader.dataset.packing_ratio` below
attn_uses_sequence_id: true

# Tokenizer
tokenizer:
name: ${model_name}
name: EleutherAI/gpt-neox-20b
kwargs:
model_max_length: ${max_seq_len}

Expand Down Expand Up @@ -83,6 +94,14 @@ optimizer:
eps: 1.0e-8
weight_decay: 0.0

# we can't use gradient clipping for sparse training runs because we don't have
# a way to mask gradients of pruned weights, and thus the global gradient norm
# will be incorrect
# algorithms:
# gradient_clipping:
# clipping_type: norm
# clipping_threshold: 1.0

# FSDP
fsdp_config:
sharding_strategy: FULL_SHARD
Expand All @@ -98,8 +117,15 @@ progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}

loggers:
tensorboard: {"log_dir": "my_logs"}
tensorboard: {}

# Checkpoint to local filesystem or remote object store
save_interval: 1ep
Expand Down
56 changes: 54 additions & 2 deletions src/sparsify/auto/tasks/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Union
from typing import Dict, Tuple, Union

import torch
from torch.utils.data import DataLoader

import click
Expand All @@ -36,8 +38,10 @@
build_scheduler,
build_tokenizer,
)
from llmfoundry.utils.config_utils import update_batch_size_info
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from sparsify.auto.tasks.finetune.helpers import MaskPrunedWeights, attach_masks
from transformers import PreTrainedTokenizerBase


Expand All @@ -46,6 +50,9 @@
TEXT_DENOISING_MODELS = ["hf_prefix_lm", "hf_t5"]
TEXT_MODELS = ["hf_causal_lm"]

_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)


class LLMDataTypes(Enum):
TEXT = "text"
Expand Down Expand Up @@ -154,6 +161,36 @@ def _build_model(self, tokenizer: PreTrainedTokenizerBase) -> HuggingFaceModel:
self._train_config.model, tokenizer
)

def _load_weights_and_attach_masks(
self, tokenizer: PreTrainedTokenizerBase
) -> Tuple[torch.nn.Module, Union[None, "MaskPrunedWeights"]]:
"""
If a load_path is provided, attempt to load in weights from the specified
location. Because the mask may be sparse, attach masks, masking where the
weights have already been pruned.
:return: tuple including the model with weights loaded from the `load_path`
and with buffers attached for pruning masks. Also returns the MaskPrunedWeights
algorithm.
"""
model = self._build_model(tokenizer)
try:
model.load_state_dict(
torch.load(self._train_config.get("load_path"), map_location="cpu")[
"state"
]["model"],
strict=True,
)
except Exception as e:
_LOGGER.error(f" Failed to load weights. Returning pretrained model {e}")
if self._train_config.model.pretrained is False:
self._train_config.model.pretrained = True
model = self._build_model(tokenizer)
return model, None

attach_masks(model)
return model, MaskPrunedWeights()

def _build_dataloaders(
self,
dataloader_config: DictConfig,
Expand Down Expand Up @@ -218,8 +255,21 @@ def _build_trainer(self) -> Trainer:
if dist.get_world_size() > 1:
dist.initialize_dist(get_device(None))

self._train_config = update_batch_size_info(self._train_config)

tokenizer = build_tokenizer(self._train_config.tokenizer)
model = self._build_model(tokenizer)

algorithms = []
# If a load_path is provided, try loading weights from the provided path
if self._train_config.get("load_path"):
self._train_config.model.pretrained = False
else:
self._train_config.model.pretrained = True

model, algorithm = self._load_weights_and_attach_masks(tokenizer)
if algorithm:
algorithms.append(algorithm)

optimizer = build_optimizer(self._train_config.optimizer, model)
scheduler = build_scheduler(self._train_config.scheduler)

Expand Down Expand Up @@ -251,6 +301,7 @@ def _build_trainer(self) -> Trainer:
optimizers=optimizer,
schedulers=scheduler,
loggers=loggers,
algorithms=algorithms,
max_duration=self._train_config.max_duration,
eval_interval=self._train_config.eval_interval,
precision=self._train_config.precision,
Expand All @@ -260,6 +311,7 @@ def _build_trainer(self) -> Trainer:
"eval_subset_num_batches", -1
),
log_to_console=self._train_config.get("log_to_console", False),
progress_bar=self._train_config.get("progress_bar", True),
console_log_interval=self._train_config.get("console_log_interval", "1ba"),
device_train_microbatch_size=self._train_config.get(
"device_train_microbatch_size", "auto"
Expand Down
62 changes: 62 additions & 0 deletions src/sparsify/auto/tasks/finetune/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from composer.core import Algorithm, Event


all = ["attach_masks", "MaskPrunedWeights"]


class MaskPrunedWeights(Algorithm):
"""
Composer specific hook which allows us to mask weights after a specific event,
in this case at the end of the batch. Provided as input to the Trainer while
finetuning. Note: can also mask weights before the forward pass by adding
`or event == Event.BATCH_START`
"""

def match(self, event, state):
return event == Event.BATCH_END

@torch.no_grad()
def apply(self, event, state, logger):
def mask_weights(module):
if hasattr(module, "constant_pruning_mask"):
module.weight *= module.constant_pruning_mask

state.model.apply(mask_weights)


def attach_masks(model: torch.nn.Module):
"""
Recursively attach masks to weights which have already been pruned to avoid
finetuning them further.
:param model: torch.nnn.Module to recursively attach masks to if the weights are
already pruned
"""
for _, module in model.named_children():
if isinstance(module, torch.nn.Linear):
constant_pruning_mask = torch.where(
module.weight == 0,
torch.tensor(0, dtype=torch.uint8),
torch.tensor(1, dtype=torch.uint8),
)
module.register_buffer(
"constant_pruning_mask", constant_pruning_mask, persistent=False
)
else:
attach_masks(module)
12 changes: 12 additions & 0 deletions src/sparsify/auto/tasks/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,17 @@
# flake8: noqa
# isort: skip_file


def _check_nm_install():
try:
from .runner import *
except ImportError as exception:
raise ImportError(
"Please install sparsify[nm] to use this pathway."
) from exception


_check_nm_install()

from .args import *
from .runner import *

0 comments on commit 2409ab8

Please sign in to comment.