Skip to content

Commit

Permalink
Add revision option for models on Hugging Face Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
marella committed Aug 20, 2023
1 parent bb98113 commit d28c5c2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ from_pretrained(
config: Optional[ctransformers.hub.AutoConfig] = None,
lib: Optional[str] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
**kwargs
) → LLM
```
Expand All @@ -207,6 +208,7 @@ Loads the language model from a local file or remote repo.
- <b>`config`</b>: `AutoConfig` object.
- <b>`lib`</b>: The path to a shared library or one of `avx2`, `avx`, `basic`.
- <b>`local_files_only`</b>: Whether or not to only look at local files (i.e., do not try to download the model).
- <b>`revision`</b>: The specific model version to use. It can be a branch name, a tag name, or a commit id.

**Returns:**
`LLM` object.
Expand Down
6 changes: 5 additions & 1 deletion ctransformers/gptq/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def from_pretrained(
model_path_or_repo_id: str,
*,
local_files_only: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> LLM:
"""Loads the language model from a local file or remote repo.
Expand All @@ -37,6 +38,8 @@ def from_pretrained(
name of a Hugging Face Hub model repo.
local_files_only: Whether or not to only look at local files
(i.e., do not try to download the model).
revision: The specific model version to use. It can be a branch
name, a tag name, or a commit id.
Returns:
`LLM` object.
Expand All @@ -56,12 +59,13 @@ def from_pretrained(
model_path = None
if path_type == "file":
model_path = Path(model_path_or_repo_id).parent
elif path_type == "dir" :
elif path_type == "dir":
model_path = Path(model_path_or_repo_id)
elif path_type == "repo":
model_path = snapshot_download(
repo_id=model_path_or_repo_id,
local_files_only=local_files_only,
revision=revision,
)

return LLM(model_path=model_path, config=config)
29 changes: 26 additions & 3 deletions ctransformers/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def from_pretrained(
cls,
model_path_or_repo_id: str,
local_files_only: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> "AutoConfig":
path_type = get_path_type(model_path_or_repo_id)
Expand All @@ -48,6 +49,7 @@ def from_pretrained(
model_path_or_repo_id,
auto_config,
local_files_only=local_files_only,
revision=revision,
)

for k, v in kwargs.items():
Expand All @@ -65,11 +67,13 @@ def _update_from_repo(
repo_id: str,
auto_config: "AutoConfig",
local_files_only: bool,
revision: Optional[str] = None,
) -> None:
path = snapshot_download(
repo_id=repo_id,
allow_patterns="config.json",
local_files_only=local_files_only,
revision=revision,
)
cls._update_from_dir(path, auto_config)

Expand Down Expand Up @@ -110,6 +114,7 @@ def from_pretrained(
config: Optional[AutoConfig] = None,
lib: Optional[str] = None,
local_files_only: bool = False,
revision: Optional[str] = None,
**kwargs,
) -> LLM:
"""Loads the language model from a local file or remote repo.
Expand All @@ -123,6 +128,8 @@ def from_pretrained(
lib: The path to a shared library or one of `avx2`, `avx`, `basic`.
local_files_only: Whether or not to only look at local files
(i.e., do not try to download the model).
revision: The specific model version to use. It can be a branch
name, a tag name, or a commit id.
Returns:
`LLM` object.
Expand All @@ -135,12 +142,14 @@ def from_pretrained(
return gptq.AutoModelForCausalLM.from_pretrained(
model_path_or_repo_id,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)

config = config or AutoConfig.from_pretrained(
model_path_or_repo_id,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_type = model_type or config.model_type
Expand All @@ -163,6 +172,7 @@ def from_pretrained(
model_path_or_repo_id,
model_file,
local_files_only=local_files_only,
revision=revision,
)

return LLM(
Expand All @@ -178,21 +188,34 @@ def _find_model_path_from_repo(
repo_id: str,
filename: Optional[str],
local_files_only: bool,
revision: Optional[str] = None,
) -> str:
if not filename and not local_files_only:
filename = cls._find_model_file_from_repo(repo_id=repo_id)
filename = cls._find_model_file_from_repo(
repo_id=repo_id,
revision=revision,
)
allow_patterns = filename or "*.bin"
path = snapshot_download(
repo_id=repo_id,
allow_patterns=allow_patterns,
local_files_only=local_files_only,
revision=revision,
)
return cls._find_model_path_from_dir(path, filename=filename)

@classmethod
def _find_model_file_from_repo(cls, repo_id: str) -> Optional[str]:
def _find_model_file_from_repo(
cls,
repo_id: str,
revision: Optional[str] = None,
) -> Optional[str]:
api = HfApi()
repo_info = api.repo_info(repo_id=repo_id, files_metadata=True)
repo_info = api.repo_info(
repo_id=repo_id,
files_metadata=True,
revision=revision,
)
files = [
(f.size, f.rfilename)
for f in repo_info.siblings
Expand Down

0 comments on commit d28c5c2

Please sign in to comment.