From 0e25b61e977f93f8e21296e676fc27d306b9d45e Mon Sep 17 00:00:00 2001 From: Akmal Soliev Date: Mon, 23 Oct 2023 21:20:20 +0200 Subject: [PATCH] [FEAT] Ability to save and load StatsForecast (#667) --- nbs/src/core/core.ipynb | 168 +++++++++++++++++++++++++++++++++++++++ statsforecast/_modidx.py | 6 ++ statsforecast/core.py | 112 +++++++++++++++++++++++++- 3 files changed, 285 insertions(+), 1 deletion(-) diff --git a/nbs/src/core/core.ipynb b/nbs/src/core/core.ipynb index 5c8308d38..c6fae6c90 100644 --- a/nbs/src/core/core.ipynb +++ b/nbs/src/core/core.ipynb @@ -75,9 +75,14 @@ "import logging\n", "import reprlib\n", "import warnings\n", + "import errno\n", + "from pathlib import Path\n", "from os import cpu_count\n", "from typing import Any, List, Optional, Union, Dict\n", "import pkg_resources\n", + "import pickle\n", + "import datetime as dt\n", + "import re\n", "\n", "from fugue.execution.factory import make_execution_engine\n", "import numpy as np\n", @@ -1965,6 +1970,106 @@ " palette='tab20b',\n", " )\n", " \n", + " def save(\n", + " self, \n", + " path: Optional[Union[Path, str]] = None,\n", + " max_size: Optional[str] = None,\n", + " trim: bool = False,\n", + " ):\n", + " \"\"\"Function that will save StatsForecast class with certain settings to make it \n", + " reproducible.\n", + " \n", + " Parameters\n", + " ----------\n", + " path : str or pathlib.Path, optional (default=None)\n", + " Path of the file to be saved. If `None` will create one in the current \n", + " directory using the current UTC timestamp.\n", + " max_size : str, optional (default = None)\n", + " StatsForecast object should not exceed this size.\n", + " Available byte naming: ['B', 'KB', 'MB', 'GB']\n", + " trim : bool (default = False)\n", + " Delete any attributes not needed for inference.\n", + " \"\"\"\n", + " # Will be used to find the size of the fitted models\n", + " # Never expecting anything higher than GB (even that's a lot')\n", + " bytes_hmap = {\n", + " \"B\": 1,\n", + " \"KB\": 2**10,\n", + " \"MB\": 2**20,\n", + " \"GB\": 2**30,\n", + " }\n", + "\n", + " # Removing unnecessary attributes\n", + " # @jmoralez decide future implementation\n", + " trim_attr:list = [\"fcst_fitted_values_\", \"cv_fitted_values_\"]\n", + " if trim:\n", + " for attr in trim_attr:\n", + " # remove unnecessary attributes here\n", + " self.__dict__.pop(attr, None)\n", + "\n", + " sf_size = len(pickle.dumps(self))\n", + "\n", + " if max_size is not None:\n", + " cap_size = self._get_cap_size(max_size, bytes_hmap)\n", + " if sf_size >= cap_size:\n", + " err_messg = \"StatsForecast is larger than the specified max_size\"\n", + " raise OSError(errno.EFBIG, err_messg) \n", + "\n", + " converted_size, sf_byte = None, None\n", + " for key in reversed(list(bytes_hmap.keys())):\n", + " x_byte = bytes_hmap[key]\n", + " if sf_size >= x_byte:\n", + " converted_size = sf_size / x_byte\n", + " sf_byte = key\n", + " break\n", + " \n", + " if converted_size is None or sf_byte is None:\n", + " err_messg = \"Internal Error, this shouldn't happen, please open an issue\"\n", + " raise RuntimeError(err_messg)\n", + " \n", + " print(f\"Saving StatsForecast object of size {converted_size:.2f}{sf_byte}.\")\n", + " \n", + " if path is None:\n", + " datetime_record = dt.datetime.utcnow().strftime(\"%Y-%m-%d_%H-%M-%S\")\n", + " path = f\"StatsForecast_{datetime_record}.pkl\"\n", + " \n", + " with open(path, \"wb\") as m_file:\n", + " pickle.dump(self, m_file)\n", + " print(\"StatsForecast object saved\")\n", + "\n", + " def _get_cap_size(self, max_size, bytes_hmap):\n", + " max_size = max_size.upper().replace(\" \", \"\")\n", + " match = re.match(r'(\\d+\\.\\d+|\\d+)(\\w+)', max_size)\n", + " if match is None or len(match.groups()) < 2 or match[2] not in bytes_hmap.keys():\n", + " parsing_error = \"Couldn't parse `max_size`, it should be `None`\", \\\n", + " \" or a number followed by one of the following units: ['B', 'KB', 'MB', 'GB']\"\n", + " raise ValueError(parsing_error)\n", + " else:\n", + " m_size = float(match[1])\n", + " key_ = match[2]\n", + " cap_size = m_size * bytes_hmap[key_]\n", + " return cap_size\n", + " \n", + " @staticmethod\n", + " def load(path:Union[Path, str]):\n", + " \"\"\"\n", + " Automatically loads the model into ready StatsForecast.\n", + "\n", + " Parameters\n", + " ----------\n", + " path : str or pathlib.Path\n", + " Path to saved StatsForecast file.\n", + " \n", + " Returns\n", + " -------\n", + " sf: StatsForecast\n", + " Previously saved StatsForecast\n", + " \"\"\"\n", + " if not Path(path).exists():\n", + " raise ValueError(\"Specified path does not exist, check again and retry.\")\n", + " with open(path, \"rb\") as f:\n", + " return pickle.load(f)\n", + " \n", " def __repr__(self):\n", " return f\"StatsForecast(models=[{','.join(map(repr, self.models))}])\"" ] @@ -2196,6 +2301,49 @@ "fcsts_df.groupby('unique_id').tail(4)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "55b00e9b", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# Testing save and load \n", + "import tempfile\n", + "from polars.testing import assert_frame_equal" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b9eaa52", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "with tempfile.TemporaryDirectory() as td:\n", + " f_path = Path(td).joinpath(\"sf_test.pickle\")\n", + " \n", + " test_df = pl.from_pandas(panel_df.astype({\"unique_id\": str}))\n", + " test_frcs = StatsForecast(\n", + " df=test_df,\n", + " models=models,\n", + " freq='D', \n", + " n_jobs=1, \n", + " verbose=True\n", + " )\n", + "\n", + " origin_df = test_frcs.forecast(h=4, fitted=True)\n", + "\n", + " test_frcs.save(f_path)\n", + "\n", + " sf_test = StatsForecast.load(f_path)\n", + " load_df = sf_test.forecast(h=4, fitted=True)\n", + " \n", + " assert_frame_equal(origin_df, load_df)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -3044,6 +3192,26 @@ " name='StatsForecast.plot')" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c69a901", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(StatsForecast.save, title_level=2, name='StatsForecast.save')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac4134ac", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(StatsForecast.load, title_level=2, name='StatsForecast.load')" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/statsforecast/_modidx.py b/statsforecast/_modidx.py index 91c4471ea..893124e2b 100644 --- a/statsforecast/_modidx.py +++ b/statsforecast/_modidx.py @@ -160,6 +160,8 @@ 'statsforecast/core.py'), 'statsforecast.core._StatsForecast._forecast_parallel': ( 'src/core/core.html#_statsforecast._forecast_parallel', 'statsforecast/core.py'), + 'statsforecast.core._StatsForecast._get_cap_size': ( 'src/core/core.html#_statsforecast._get_cap_size', + 'statsforecast/core.py'), 'statsforecast.core._StatsForecast._get_gas_Xs': ( 'src/core/core.html#_statsforecast._get_gas_xs', 'statsforecast/core.py'), 'statsforecast.core._StatsForecast._get_pool': ( 'src/core/core.html#_statsforecast._get_pool', @@ -188,10 +190,14 @@ 'statsforecast/core.py'), 'statsforecast.core._StatsForecast.forecast_fitted_values': ( 'src/core/core.html#_statsforecast.forecast_fitted_values', 'statsforecast/core.py'), + 'statsforecast.core._StatsForecast.load': ( 'src/core/core.html#_statsforecast.load', + 'statsforecast/core.py'), 'statsforecast.core._StatsForecast.plot': ( 'src/core/core.html#_statsforecast.plot', 'statsforecast/core.py'), 'statsforecast.core._StatsForecast.predict': ( 'src/core/core.html#_statsforecast.predict', 'statsforecast/core.py'), + 'statsforecast.core._StatsForecast.save': ( 'src/core/core.html#_statsforecast.save', + 'statsforecast/core.py'), 'statsforecast.core._cv_dates': ('src/core/core.html#_cv_dates', 'statsforecast/core.py'), 'statsforecast.core._get_n_jobs': ('src/core/core.html#_get_n_jobs', 'statsforecast/core.py'), 'statsforecast.core._parse_ds_type': ('src/core/core.html#_parse_ds_type', 'statsforecast/core.py'), diff --git a/statsforecast/core.py b/statsforecast/core.py index 1b9829190..c60215810 100644 --- a/statsforecast/core.py +++ b/statsforecast/core.py @@ -8,9 +8,14 @@ import logging import reprlib import warnings +import errno +from pathlib import Path from os import cpu_count from typing import Any, List, Optional, Union, Dict import pkg_resources +import pickle +import datetime as dt +import re from fugue.execution.factory import make_execution_engine import numpy as np @@ -453,7 +458,6 @@ def __init__( sort_dataframe: bool, validate: Optional[bool] = True, ): - self.dataframe = dataframe self.sort_dataframe = sort_dataframe self.validate = validate @@ -1546,6 +1550,112 @@ def plot( palette="tab20b", ) + def save( + self, + path: Optional[Union[Path, str]] = None, + max_size: Optional[str] = None, + trim: bool = False, + ): + """Function that will save StatsForecast class with certain settings to make it + reproducible. + + Parameters + ---------- + path : str or pathlib.Path, optional (default=None) + Path of the file to be saved. If `None` will create one in the current + directory using the current UTC timestamp. + max_size : str, optional (default = None) + StatsForecast object should not exceed this size. + Available byte naming: ['B', 'KB', 'MB', 'GB'] + trim : bool (default = False) + Delete any attributes not needed for inference. + """ + # Will be used to find the size of the fitted models + # Never expecting anything higher than GB (even that's a lot') + bytes_hmap = { + "B": 1, + "KB": 2**10, + "MB": 2**20, + "GB": 2**30, + } + + # Removing unnecessary attributes + # @jmoralez decide future implementation + trim_attr: list = ["fcst_fitted_values_", "cv_fitted_values_"] + if trim: + for attr in trim_attr: + # remove unnecessary attributes here + self.__dict__.pop(attr, None) + + sf_size = len(pickle.dumps(self)) + + if max_size is not None: + cap_size = self._get_cap_size(max_size, bytes_hmap) + if sf_size >= cap_size: + err_messg = "StatsForecast is larger than the specified max_size" + raise OSError(errno.EFBIG, err_messg) + + converted_size, sf_byte = None, None + for key in reversed(list(bytes_hmap.keys())): + x_byte = bytes_hmap[key] + if sf_size >= x_byte: + converted_size = sf_size / x_byte + sf_byte = key + break + + if converted_size is None or sf_byte is None: + err_messg = "Internal Error, this shouldn't happen, please open an issue" + raise RuntimeError(err_messg) + + print(f"Saving StatsForecast object of size {converted_size:.2f}{sf_byte}.") + + if path is None: + datetime_record = dt.datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S") + path = f"StatsForecast_{datetime_record}.pkl" + + with open(path, "wb") as m_file: + pickle.dump(self, m_file) + print("StatsForecast object saved") + + def _get_cap_size(self, max_size, bytes_hmap): + max_size = max_size.upper().replace(" ", "") + match = re.match(r"(\d+\.\d+|\d+)(\w+)", max_size) + if ( + match is None + or len(match.groups()) < 2 + or match[2] not in bytes_hmap.keys() + ): + parsing_error = ( + "Couldn't parse `max_size`, it should be `None`", + " or a number followed by one of the following units: ['B', 'KB', 'MB', 'GB']", + ) + raise ValueError(parsing_error) + else: + m_size = float(match[1]) + key_ = match[2] + cap_size = m_size * bytes_hmap[key_] + return cap_size + + @staticmethod + def load(path: Union[Path, str]): + """ + Automatically loads the model into ready StatsForecast. + + Parameters + ---------- + path : str or pathlib.Path + Path to saved StatsForecast file. + + Returns + ------- + sf: StatsForecast + Previously saved StatsForecast + """ + if not Path(path).exists(): + raise ValueError("Specified path does not exist, check again and retry.") + with open(path, "rb") as f: + return pickle.load(f) + def __repr__(self): return f"StatsForecast(models=[{','.join(map(repr, self.models))}])"