Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

fix hf gpt2 example #115

Merged
merged 1 commit into from
Aug 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/hf_gpt2/hf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from energonai.nn import VocabParallelEmbedding1D
from torch.nn import Embedding
from energonai.utils import get_current_device, is_using_pp
from energonai.utils.checkpointing_hf_gpt2 import load_checkpoint
from energonai.utils.checkpointing import load_checkpoint
from energonai.utils.checkpointing_hf_gpt2 import processing_HF_GPT


__all__ = [
Expand Down Expand Up @@ -479,7 +480,7 @@ def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **
assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading"
print(model_kwargs["checkpoint_path"])
assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found"
load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs)
load_checkpoint(model_kwargs["checkpoint_path"], model, preprocess_fn=processing_HF_GPT, **model_kwargs)
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
return model

Expand Down