Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : improve BPE pre-processing + LLaMA 3 and Deepseek support #6920

Merged
merged 61 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
6fbab2d
merged the changes from deepseeker models to main branch
jaggzh Feb 12, 2024
d2cfc22
Moved regex patterns to unicode.cpp and updated unicode.h
dragnil1 Mar 22, 2024
54f93eb
Moved header files
dragnil1 Mar 22, 2024
1c924e4
Resolved issues
dragnil1 Mar 23, 2024
4056dc5
added and refactored unicode_regex_split and related functions
dragnil1 Mar 31, 2024
c8e7d95
Updated/merged the deepseek coder pr
jaggzh Feb 12, 2024
4c3e882
Refactored code
dragnil1 Apr 13, 2024
a5710a4
Adding unicode regex mappings
dragnil1 Apr 15, 2024
7e308ed
Adding unicode regex function
dragnil1 Apr 15, 2024
feeaf4f
Added needed functionality, testing remains
dragnil1 Apr 15, 2024
7535803
Fixed issues
dragnil1 Apr 15, 2024
36d9832
Fixed issue with gpt2 regex custom preprocessor
dragnil1 Apr 17, 2024
06d3e69
unicode : fix? unicode_wstring_to_utf8
ggerganov Apr 26, 2024
c56e19d
lint : fix whitespaces
ggerganov Apr 26, 2024
7a44e44
tests : add tokenizer tests for numbers
ggerganov Apr 26, 2024
d999cf6
unicode : remove redundant headers
ggerganov Apr 26, 2024
aeafb43
tests : remove and rename tokenizer test scripts
ggerganov Apr 26, 2024
e1b2bf7
tests : add sample usage
ggerganov Apr 26, 2024
ed42711
gguf-py : reader prints warnings on duplicate keys
ggerganov Apr 26, 2024
4907e41
llama : towards llama3 tokenization support (wip)
ggerganov Apr 26, 2024
e8c206b
unicode : shot in the dark to fix tests on Windows
ggerganov Apr 26, 2024
e989176
unicode : first try custom implementations
ggerganov Apr 26, 2024
e3f6dc7
Merge branch 'master' into gg/bpe-preprocess
ggerganov Apr 26, 2024
9b4d63a
convert : add "tokenizer.ggml.pre" GGUF KV (wip)
ggerganov Apr 26, 2024
43e12ce
llama : use new pre-tokenizer type
ggerganov Apr 26, 2024
1b9b79d
convert : fix pre-tokenizer type writing
ggerganov Apr 26, 2024
8791e94
lint : fix
ggerganov Apr 26, 2024
a774d70
make : add test-tokenizer-0-llama-v3
ggerganov Apr 26, 2024
c160818
wip
ggerganov Apr 26, 2024
96965f6
models : add llama v3 vocab file
ggerganov Apr 27, 2024
ad92983
llama : adapt punctuation regex + add llama 3 regex
ggerganov Apr 27, 2024
4434c9d
minor
ggerganov Apr 27, 2024
a22645c
unicode : set bomb
ggerganov Apr 27, 2024
2affd0b
unicode : set bomb
ggerganov Apr 27, 2024
ce5485a
unicode : always use std::wregex
ggerganov Apr 27, 2024
91eaa41
unicode : support \p{N}, \p{L} and \p{P} natively
ggerganov Apr 27, 2024
581c4a0
unicode : try fix windows
ggerganov Apr 27, 2024
b97add5
unicode : category support via std::regex
ggerganov Apr 28, 2024
d63cc90
Merge branch 'master' into gg/bpe-preprocess
ggerganov Apr 28, 2024
e972e6c
unicode : clean-up
ggerganov Apr 28, 2024
ee6d1b3
unicode : simplify
ggerganov Apr 28, 2024
7642973
convert : add convert-hf-to-gguf-update.py
ggerganov Apr 28, 2024
4e3e6d8
lint : update
ggerganov Apr 28, 2024
1c888eb
convert : add falcon
ggerganov Apr 28, 2024
1545550
unicode : normalize signatures
ggerganov Apr 28, 2024
491f233
lint : fix
ggerganov Apr 28, 2024
e8dd4a1
lint : fix
ggerganov Apr 28, 2024
02fd977
convert : remove unused functions
ggerganov Apr 28, 2024
0f9058c
convert : add comments
ggerganov Apr 28, 2024
7808150
convert : exercise contractions
ggerganov Apr 28, 2024
7b1210f
lint : fix
ggerganov Apr 28, 2024
ef4cca9
cmake : refactor test targets
ggerganov Apr 29, 2024
43708d2
tests : refactor vocab tests
ggerganov Apr 29, 2024
c68d259
tests : add more vocabs and tests
ggerganov Apr 29, 2024
af05268
unicode : cleanup
ggerganov Apr 29, 2024
c21ab18
scripts : ignore new update script in check-requirements.sh
ggerganov Apr 29, 2024
120cf37
models : add phi-3, mpt, gpt-2, starcoder
ggerganov Apr 29, 2024
9a7d430
tests : disable obsolete
ggerganov Apr 29, 2024
6d6ce93
tests : use faster bpe test
ggerganov Apr 29, 2024
3202676
llama : more prominent warning for old BPE models
ggerganov Apr 29, 2024
80cb312
tests : disable test-tokenizer-1-bpe due to slowness
ggerganov Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,20 @@ examples/server/*.mjs.hpp
poetry.lock
poetry.toml
nppBackup

# Test binaries
/tests/test-grammar-parser
/tests/test-llama-grammar
/tests/test-double-float
/tests/test-grad0
/tests/test-opt
/tests/test-quantize-fns
/tests/test-quantize-perf
/tests/test-sampling
/tests/test-tokenizer-0-llama
/tests/test-tokenizer-0-falcon
/tests/test-tokenizer-0-deepseek-coder
/tests/test-tokenizer-1-llama
/tests/test-tokenizer-1-bpe
/tests/test-rope
/tests/test-backend-ops
13 changes: 12 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ BUILD_TARGETS = \

# Binaries only useful for tests
TEST_TARGETS = \
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
tests/test-llama-grammar tests/test-tokenizer-0-deepseek-coder tests/test-tokenizer-0-deepseek-llm \
tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
Expand Down Expand Up @@ -53,6 +54,10 @@ test: $(TEST_TARGETS)
./$$test_target $(CURDIR)/models/ggml-vocab-llama.gguf; \
elif [ "$$test_target" = "tests/test-tokenizer-0-falcon" ]; then \
./$$test_target $(CURDIR)/models/ggml-vocab-falcon.gguf; \
elif [ "$$test_target" = "tests/test-tokenizer-0-deepseek-coder" ]; then \
./$$test_target $(CURDIR)/models/ggml-vocab-deepseek-coder.gguf; \
elif [ "$$test_target" = "tests/test-tokenizer-0-deepseek-llm" ]; then \
./$$test_target $(CURDIR)/models/ggml-vocab-deepseek-llm.gguf; \
elif [ "$$test_target" = "tests/test-tokenizer-1-llama" ]; then \
continue; \
elif [ "$$test_target" = "tests/test-tokenizer-1-bpe" ]; then \
Expand Down Expand Up @@ -979,6 +984,12 @@ tests/test-tokenizer-0-llama: tests/test-tokenizer-0-llama.cpp ggml.o llama.o $(
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

tests/test-tokenizer-0-deepseek-coder: tests/test-tokenizer-0-deepseek-coder.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

tests/test-tokenizer-0-deepseek-llm: tests/test-tokenizer-0-deepseek-llm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

tests/test-tokenizer-1-bpe: tests/test-tokenizer-1-bpe.cpp ggml.o llama.o $(COMMON_DEPS) console.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
Expand Down
149 changes: 147 additions & 2 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,78 @@ def from_model_architecture(cls, arch):
except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None

@staticmethod
def from_model_architecture(model_architecture):
if model_architecture == "GPTNeoXForCausalLM":
return GPTNeoXModel
if model_architecture == "BloomForCausalLM":
return BloomModel
if model_architecture == "MPTForCausalLM":
return MPTModel
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return BaichuanModel
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
return FalconModel
if model_architecture == "GPTBigCodeForCausalLM":
return StarCoderModel
if model_architecture == "GPTRefactForCausalLM":
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture == "LlamaForCausalLM":
return DeepseekCoderModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
if model_architecture == "Qwen2ForCausalLM":
return Model
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "GPT2LMHeadModel":
return GPT2Model
if model_architecture == "PhiForCausalLM":
return Phi2Model
if model_architecture == "PlamoForCausalLM":
return PlamoModel
if model_architecture == "CodeShellForCausalLM":
return CodeShellModel
if model_architecture == "OrionForCausalLM":
return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
if model_architecture == "BertModel":
return BertModel

@staticmethod
def from_model_name(model_name: str):
model_name_lower = model_name.lower()
if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
return StableLMModel
if model_name_lower == "gptneox":
return GPTNeoXModel
if model_name_lower == "bloom":
return BloomModel
if model_name_lower == "mpt":
return MPTModel
if model_name_lower in ("baichuan"):
return BaichuanModel
if model_name_lower in ("falcon", "rw"):
return FalconModel
if model_name_lower == "gptbigcode":
return StarCoderModel
if model_name_lower == "gptrefact":
return RefactModel
if model_name_lower == "persimmon":
return PersimmonModel
if model_name_lower == "deepseekcoder":
return DeepseekCoderModel
if model_name_lower == "deepseekllm":
return DeepseekLLMModel
return Model

def _is_model_safetensors(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0

Expand All @@ -228,6 +300,53 @@ def _get_part_names(self):
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))

def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0]
if arch == "GPTNeoXForCausalLM":
return gguf.MODEL_ARCH.GPTNEOX
if arch == "BloomForCausalLM":
return gguf.MODEL_ARCH.BLOOM
if arch == "MPTForCausalLM":
return gguf.MODEL_ARCH.MPT
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return gguf.MODEL_ARCH.BAICHUAN
if arch in ("FalconForCausalLM", "RWForCausalLM"):
return gguf.MODEL_ARCH.FALCON
if arch == "GPTBigCodeForCausalLM":
return gguf.MODEL_ARCH.STARCODER
if arch == "GPTRefactForCausalLM":
return gguf.MODEL_ARCH.REFACT
if arch == "PersimmonForCausalLM":
return gguf.MODEL_ARCH.PERSIMMON
if arch == "LlamaForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM
if arch == "QWenLMHeadModel":
return gguf.MODEL_ARCH.QWEN
if arch == "Qwen2ForCausalLM":
return gguf.MODEL_ARCH.QWEN2
if arch == "MixtralForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch == "GPT2LMHeadModel":
return gguf.MODEL_ARCH.GPT2
if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2
if arch == "PlamoForCausalLM":
return gguf.MODEL_ARCH.PLAMO
if arch == "CodeShellForCausalLM":
return gguf.MODEL_ARCH.CODESHELL
if arch == "OrionForCausalLM":
return gguf.MODEL_ARCH.ORION
if arch == "InternLM2ForCausalLM":
return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM
if arch == "BertModel":
return gguf.MODEL_ARCH.BERT

raise NotImplementedError(f'Architecture "{arch}" not supported!')

compilade marked this conversation as resolved.
Show resolved Hide resolved
compilade marked this conversation as resolved.
Show resolved Hide resolved
# used for GPT-2 BPE and WordPiece vocabs
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
tokens: list[str] = []
Expand Down Expand Up @@ -257,9 +376,10 @@ def get_basic_vocab(self) -> tuple[list[str], list[int]]:

return tokens, toktypes

def _set_vocab_gpt2(self) -> None:

def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
tokens, toktypes = self.get_basic_vocab()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_model(tokenizer_model)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

Expand Down Expand Up @@ -1192,7 +1312,31 @@ def write_tensors(self):
n_dims = len(data.shape)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
@Model.register("LlamaForCausalLM")
class DeepseekCoderModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

def set_gguf_parameters(self):
super().set_gguf_parameters()
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

def set_vocab(self):
self._set_vocab_gpt2("deepseek_coder")

class DeepseekLLMModel(DeepseekCoderModel):
def set_vocab(self):
self._set_vocab_gpt2("deepseek_llm")

@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
class StableLMModel(Model):
Expand Down Expand Up @@ -2843,6 +2987,7 @@ def parse_args() -> argparse.Namespace:
help="directory containing model file",
)
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
parser.add_argument("--model-name", type=str, default=None, help="name of the model")

return parser.parse_args()

Expand Down
Loading
Loading