Skip to content

Commit

Permalink
Backport #203 to 1.x (#216)
Browse files Browse the repository at this point in the history
* updating notebook + bumping version (#199)

* updating notebook + bumping version

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addressing comments

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
(cherry picked from commit 1237aa6)

* Add description field with make_model_config_json function (#203)

* Add description field

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Restore notebook

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Debug test

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Resolve linting issues

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Update CHANGELOG.md

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Debug test_sentencetransformermodel_pytest.py

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Improve test coverage

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Edit test name

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Change CHANGELOG.md & Add comment to sentencetransformermodel.py

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Correct linting

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Improve add description

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Correct linting

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Loosen restriction

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Update sentencetransformermodel.py

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Change function name + Add comment + Add default description

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

* Debug

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>

---------

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>
(cherry picked from commit 20435b1)

---------

Signed-off-by: Thanawan Atchariyachanvanit <latchari@amazon.com>
Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
  • Loading branch information
thanawan-atc and dhrubo-os committed Aug 10, 2023
1 parent 34834dd commit 93dee14
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 10 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# CHANGELOG
Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [1.2.0]

### Added

### Changed

### Fixed
- Enable make_model_config_json to add model description to model config file by @thanawan-atc in ([#203](https://github.com/opensearch-project/opensearch-py-ml/pull/203))

## [1.1.0]

### Added
Expand All @@ -18,8 +27,6 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- adding jupyter notebook based documentation for metrics correlation algorithm by @AlibiZhenis ([#186](https://github.com/opensearch-project/opensearch-py-ml/pull/186))

### Changed
- adding documentation for model group id @dhrubo-os ([#176](https://github.com/opensearch-project/opensearch-py-ml/pull/176))
- adding jupyter notebook based documentation for metrics correlation algorithm by @AlibiZhenis ([#186](https://github.com/opensearch-project/opensearch-py-ml/pull/186))
- Update jenkins file to use updated docker image ([#189](https://github.com/opensearch-project/opensearch-py-ml/pull/189))
- Updated documentation @dhrubo-os ([#98](https://github.com/opensearch-project/opensearch-py-ml/pull/98))
- Updating ML Commons API documentation @AlibiZhenis ([#156](https://github.com/opensearch-project/opensearch-py-ml/pull/156))
Expand Down
2 changes: 2 additions & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ sphinx_rtd_theme
nbsphinx
pandoc
deprecated

# using in SentenceTransformerModel
torch
pyyaml
accelerate
sentence_transformers
transformers
tqdm
mdutils

# traitlets has been having all sorts of release problems lately.
traitlets<5.1
95 changes: 95 additions & 0 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pickle
import platform
import random
import re
import shutil
import subprocess
import time
Expand All @@ -23,6 +24,7 @@
import torch
import yaml
from accelerate import Accelerator, notebook_launcher
from mdutils.fileutils import MarkDownFile
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Normalize, Pooling, Transformer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -1006,6 +1008,74 @@ def set_up_accelerate_config(
"Failed to open config file for ml common upload: " + file_path + "\n"
)

def _get_model_description_from_readme_file(self, readme_file_path) -> str:
"""
Get description of the model from README.md file in the model folder
after the model is saved in local directory
See example here:
https://huggingface.co/sentence-transformers/msmarco-distilbert-base-tas-b/blob/main/README.md)
This function assumes that the README.md has the following format:
# sentence-transformers/msmarco-distilbert-base-tas-b
This is [ ... further description ... ]
# [ ... Next section ...]
...
:param readme_file_path: Path to README.md file
:type readme_file_path: string
:return: Description of the model
:rtype: string
"""
readme_data = MarkDownFile.read_file(readme_file_path)

# Find the description section
start_str = f"# {self.model_id}"
start = readme_data.find(start_str)
if start == -1:
model_name = self.model_id.split("/")[1]
start_str = f"# {model_name}"
start = readme_data.find(start_str)
end = readme_data.find("\n#", start + len(start_str))

# If we cannot find the scope of description section, raise error.
if start == -1 or end == -1:
assert False, "Cannot find description in README.md file"

# Parse out the description section
description = readme_data[start + len(start_str) + 1 : end].strip()
description = description.split("\n")[0]

# Remove hyperlink and reformat text
description = re.sub(r"\(.*?\)", "", description)
description = re.sub(r"[\[\]]", "", description)
description = re.sub(r"\*", "", description)

# Remove unnecessary part if exists (i.e. " For an introduction to ...")
# (Found in https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-dot-v1/blob/main/README.md)
unnecessary_part = description.find(" For an introduction to")
if unnecessary_part != -1:
description = description[:unnecessary_part]

return description

def _generate_default_model_description(self, embedding_dimension) -> str:
"""
Generate default model description of the model based on embedding_dimension
::param embedding_dimension: Embedding dimension of the model.
:type embedding_dimension: int
:return: Description of the model
:rtype: string
"""
print(
"Using default description from embedding_dimension instead (You can overwrite this by specifying description parameter in make_model_config_json function"
)
description = f"This is a sentence-transformers model: It maps sentences & paragraphs to a {embedding_dimension} dimensional dense vector space."
return description

def make_model_config_json(
self,
model_name: str = None,
Expand All @@ -1014,6 +1084,7 @@ def make_model_config_json(
embedding_dimension: int = None,
pooling_mode: str = None,
normalize_result: bool = None,
description: str = None,
all_config: str = None,
model_type: str = None,
verbose: bool = False,
Expand All @@ -1040,6 +1111,9 @@ def make_model_config_json(
:param normalize_result: Optional, whether to normalize the result of the model. If None, check from the pre-trained
hugging-face model object.
:type normalize_result: bool
:param description: Optional, the description of the model. If None, get description from the README.md
file in the model folder.
:type description: str
:param all_config:
Optional, the all_config of the model. If None, parse all contents from the config file of pre-trained
hugging-face model
Expand Down Expand Up @@ -1087,6 +1161,26 @@ def make_model_config_json(
f"Raised exception while getting model data from pre-trained hugging-face model object: {e}"
)

if description is None:
readme_file_path = os.path.join(self.folder_path, "README.md")
if os.path.exists(readme_file_path):
try:
if verbose:
print("reading README.md file")
description = self._get_model_description_from_readme_file(
readme_file_path
)
except Exception as e:
print(f"Cannot scrape model description from README.md file: {e}")
description = self._generate_default_model_description(
embedding_dimension
)
else:
print("Cannot find README.md file to scrape model description")
description = self._generate_default_model_description(
embedding_dimension
)

if all_config is None:
if not os.path.exists(config_json_file_path):
raise Exception(
Expand Down Expand Up @@ -1114,6 +1208,7 @@ def make_model_config_json(
model_config_content = {
"name": model_name,
"version": version_number,
"description": description,
"model_format": model_format,
"model_task_type": "TEXT_EMBEDDING",
"model_config": {
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ sentence_transformers
tqdm
transformers
deprecated
mdutils

#
# Testing
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def check_values(self, oml_obj, pd_obj):

def check_exception(self, ed_exc, pd_exc):
"""Checks that either an exception was raised or not from both opensearch_py_ml and pandas"""
assert (ed_exc is None) == (pd_exc is None) and type(ed_exc) == type(pd_exc)
assert (ed_exc is None) == (pd_exc is None) and isinstance(ed_exc, type(pd_exc))
if pd_exc is not None:
raise pd_exc

Expand Down
4 changes: 2 additions & 2 deletions tests/ml_commons/test_ml_commons_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def clean_test_folder(TEST_FOLDER):


def test_init():
assert type(ml_client._client) == OpenSearch
assert type(ml_client._model_uploader) == ModelUploader
assert isinstance(ml_client._client, OpenSearch)
assert isinstance(ml_client._model_uploader, ModelUploader)


def test_execute():
Expand Down
Loading

0 comments on commit 93dee14

Please sign in to comment.