From b0b020e4e3735493bfb56133d368ba534d552267 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 9 Aug 2023 17:06:04 -0700 Subject: [PATCH] Generalize make_model_config_json function (#200) (#212) * Improve make_model_config Signed-off-by: Thanawan Atchariyachanvanit * Update sentencetransformermodel.py Signed-off-by: Thanawan Atchariyachanvanit * Update CHANGELOG.md Signed-off-by: Thanawan Atchariyachanvanit * Update CHANGELOG.md Signed-off-by: Thanawan Atchariyachanvanit * Update CHANGELOG.md Signed-off-by: Thanawan Atchariyachanvanit * Update CHANGELOG.md Signed-off-by: Thanawan Atchariyachanvanit --------- Signed-off-by: Thanawan Atchariyachanvanit (cherry picked from commit 513ac111c3962710588e21842793712acd014804) Co-authored-by: Thanawan Atchariyachanvanit --- CHANGELOG.md | 3 +- .../ml_models/sentencetransformermodel.py | 30 +++++++------------ 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7250985c..70b016b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix ModelUploader bug & Update model tracing demo notebook by @thanawan-atc in ([#185](https://github.com/opensearch-project/opensearch-py-ml/pull/185)) - Fix make_model_config_json function by @thanawan-atc in ([#188](https://github.com/opensearch-project/opensearch-py-ml/pull/188)) - Make make_model_config_json function more concise by @thanawan-atc in ([#191](https://github.com/opensearch-project/opensearch-py-ml/pull/191)) -- Enabled auto-truncation for any pretrained models ([#192]https://github.com/opensearch-project/opensearch-py-ml/pull/192) +- Enabled auto-truncation for any pretrained models by @Yerzhaisang in ([#192](https://github.com/opensearch-project/opensearch-py-ml/pull/192)) +- Generalize make_model_config_json function by @thanawan-atc in ([#200](https://github.com/opensearch-project/opensearch-py-ml/pull/200)) ## [1.0.0] diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index fb166f28..8fbcb1a0 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -1068,28 +1068,20 @@ def make_model_config_json( or normalize_result is None ): try: - if ( - model_type is None - and len(model._modules) >= 1 - and isinstance(model._modules["0"], Transformer) - ): - model_type = model._modules["0"].auto_model.__class__.__name__ - model_type = model_type.lower().rstrip("model") if embedding_dimension is None: embedding_dimension = model.get_sentence_embedding_dimension() - if ( - pooling_mode is None - and len(model._modules) >= 2 - and isinstance(model._modules["1"], Pooling) - ): - pooling_mode = model._modules["1"].get_pooling_mode_str().upper() - if normalize_result is None: - if len(model._modules) >= 3 and isinstance( - model._modules["2"], Normalize - ): + + for str_idx, module in model._modules.items(): + if model_type is None and isinstance(module, Transformer): + model_type = module.auto_model.__class__.__name__ + model_type = model_type.lower().rstrip("model") + elif pooling_mode is None and isinstance(module, Pooling): + pooling_mode = module.get_pooling_mode_str().upper() + elif normalize_result is None and isinstance(module, Normalize): normalize_result = True - else: - normalize_result = False + # TODO: Support 'Dense' module + if normalize_result is None: + normalize_result = False except Exception as e: raise Exception( f"Raised exception while getting model data from pre-trained hugging-face model object: {e}"