From d28c5c2a140d56085d8e4ce4a52a85eb948b51df Mon Sep 17 00:00:00 2001 From: Ravindra Marella Date: Mon, 21 Aug 2023 00:40:18 +0530 Subject: [PATCH] Add `revision` option for models on Hugging Face Hub --- README.md | 2 ++ ctransformers/gptq/hub.py | 6 +++++- ctransformers/hub.py | 29 ++++++++++++++++++++++++++--- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a2aaca7..1a0295f 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ from_pretrained( config: Optional[ctransformers.hub.AutoConfig] = None, lib: Optional[str] = None, local_files_only: bool = False, + revision: Optional[str] = None, **kwargs ) → LLM ``` @@ -207,6 +208,7 @@ Loads the language model from a local file or remote repo. - `config`: `AutoConfig` object. - `lib`: The path to a shared library or one of `avx2`, `avx`, `basic`. - `local_files_only`: Whether or not to only look at local files (i.e., do not try to download the model). +- `revision`: The specific model version to use. It can be a branch name, a tag name, or a commit id. **Returns:** `LLM` object. diff --git a/ctransformers/gptq/hub.py b/ctransformers/gptq/hub.py index 4324e07..225d58d 100644 --- a/ctransformers/gptq/hub.py +++ b/ctransformers/gptq/hub.py @@ -28,6 +28,7 @@ def from_pretrained( model_path_or_repo_id: str, *, local_files_only: bool = False, + revision: Optional[str] = None, **kwargs, ) -> LLM: """Loads the language model from a local file or remote repo. @@ -37,6 +38,8 @@ def from_pretrained( name of a Hugging Face Hub model repo. local_files_only: Whether or not to only look at local files (i.e., do not try to download the model). + revision: The specific model version to use. It can be a branch + name, a tag name, or a commit id. Returns: `LLM` object. @@ -56,12 +59,13 @@ def from_pretrained( model_path = None if path_type == "file": model_path = Path(model_path_or_repo_id).parent - elif path_type == "dir" : + elif path_type == "dir": model_path = Path(model_path_or_repo_id) elif path_type == "repo": model_path = snapshot_download( repo_id=model_path_or_repo_id, local_files_only=local_files_only, + revision=revision, ) return LLM(model_path=model_path, config=config) diff --git a/ctransformers/hub.py b/ctransformers/hub.py index 14c1258..02e55d6 100644 --- a/ctransformers/hub.py +++ b/ctransformers/hub.py @@ -32,6 +32,7 @@ def from_pretrained( cls, model_path_or_repo_id: str, local_files_only: bool = False, + revision: Optional[str] = None, **kwargs, ) -> "AutoConfig": path_type = get_path_type(model_path_or_repo_id) @@ -48,6 +49,7 @@ def from_pretrained( model_path_or_repo_id, auto_config, local_files_only=local_files_only, + revision=revision, ) for k, v in kwargs.items(): @@ -65,11 +67,13 @@ def _update_from_repo( repo_id: str, auto_config: "AutoConfig", local_files_only: bool, + revision: Optional[str] = None, ) -> None: path = snapshot_download( repo_id=repo_id, allow_patterns="config.json", local_files_only=local_files_only, + revision=revision, ) cls._update_from_dir(path, auto_config) @@ -110,6 +114,7 @@ def from_pretrained( config: Optional[AutoConfig] = None, lib: Optional[str] = None, local_files_only: bool = False, + revision: Optional[str] = None, **kwargs, ) -> LLM: """Loads the language model from a local file or remote repo. @@ -123,6 +128,8 @@ def from_pretrained( lib: The path to a shared library or one of `avx2`, `avx`, `basic`. local_files_only: Whether or not to only look at local files (i.e., do not try to download the model). + revision: The specific model version to use. It can be a branch + name, a tag name, or a commit id. Returns: `LLM` object. @@ -135,12 +142,14 @@ def from_pretrained( return gptq.AutoModelForCausalLM.from_pretrained( model_path_or_repo_id, local_files_only=local_files_only, + revision=revision, **kwargs, ) config = config or AutoConfig.from_pretrained( model_path_or_repo_id, local_files_only=local_files_only, + revision=revision, **kwargs, ) model_type = model_type or config.model_type @@ -163,6 +172,7 @@ def from_pretrained( model_path_or_repo_id, model_file, local_files_only=local_files_only, + revision=revision, ) return LLM( @@ -178,21 +188,34 @@ def _find_model_path_from_repo( repo_id: str, filename: Optional[str], local_files_only: bool, + revision: Optional[str] = None, ) -> str: if not filename and not local_files_only: - filename = cls._find_model_file_from_repo(repo_id=repo_id) + filename = cls._find_model_file_from_repo( + repo_id=repo_id, + revision=revision, + ) allow_patterns = filename or "*.bin" path = snapshot_download( repo_id=repo_id, allow_patterns=allow_patterns, local_files_only=local_files_only, + revision=revision, ) return cls._find_model_path_from_dir(path, filename=filename) @classmethod - def _find_model_file_from_repo(cls, repo_id: str) -> Optional[str]: + def _find_model_file_from_repo( + cls, + repo_id: str, + revision: Optional[str] = None, + ) -> Optional[str]: api = HfApi() - repo_info = api.repo_info(repo_id=repo_id, files_metadata=True) + repo_info = api.repo_info( + repo_id=repo_id, + files_metadata=True, + revision=revision, + ) files = [ (f.size, f.rfilename) for f in repo_info.siblings