Skip to content

Commit

Permalink
[FEAT] Ability to save and load StatsForecast (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
akmalsoliev authored Oct 23, 2023
1 parent a4c0fe6 commit 0e25b61
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 1 deletion.
168 changes: 168 additions & 0 deletions nbs/src/core/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,14 @@
"import logging\n",
"import reprlib\n",
"import warnings\n",
"import errno\n",
"from pathlib import Path\n",
"from os import cpu_count\n",
"from typing import Any, List, Optional, Union, Dict\n",
"import pkg_resources\n",
"import pickle\n",
"import datetime as dt\n",
"import re\n",
"\n",
"from fugue.execution.factory import make_execution_engine\n",
"import numpy as np\n",
Expand Down Expand Up @@ -1965,6 +1970,106 @@
" palette='tab20b',\n",
" )\n",
" \n",
" def save(\n",
" self, \n",
" path: Optional[Union[Path, str]] = None,\n",
" max_size: Optional[str] = None,\n",
" trim: bool = False,\n",
" ):\n",
" \"\"\"Function that will save StatsForecast class with certain settings to make it \n",
" reproducible.\n",
" \n",
" Parameters\n",
" ----------\n",
" path : str or pathlib.Path, optional (default=None)\n",
" Path of the file to be saved. If `None` will create one in the current \n",
" directory using the current UTC timestamp.\n",
" max_size : str, optional (default = None)\n",
" StatsForecast object should not exceed this size.\n",
" Available byte naming: ['B', 'KB', 'MB', 'GB']\n",
" trim : bool (default = False)\n",
" Delete any attributes not needed for inference.\n",
" \"\"\"\n",
" # Will be used to find the size of the fitted models\n",
" # Never expecting anything higher than GB (even that's a lot')\n",
" bytes_hmap = {\n",
" \"B\": 1,\n",
" \"KB\": 2**10,\n",
" \"MB\": 2**20,\n",
" \"GB\": 2**30,\n",
" }\n",
"\n",
" # Removing unnecessary attributes\n",
" # @jmoralez decide future implementation\n",
" trim_attr:list = [\"fcst_fitted_values_\", \"cv_fitted_values_\"]\n",
" if trim:\n",
" for attr in trim_attr:\n",
" # remove unnecessary attributes here\n",
" self.__dict__.pop(attr, None)\n",
"\n",
" sf_size = len(pickle.dumps(self))\n",
"\n",
" if max_size is not None:\n",
" cap_size = self._get_cap_size(max_size, bytes_hmap)\n",
" if sf_size >= cap_size:\n",
" err_messg = \"StatsForecast is larger than the specified max_size\"\n",
" raise OSError(errno.EFBIG, err_messg) \n",
"\n",
" converted_size, sf_byte = None, None\n",
" for key in reversed(list(bytes_hmap.keys())):\n",
" x_byte = bytes_hmap[key]\n",
" if sf_size >= x_byte:\n",
" converted_size = sf_size / x_byte\n",
" sf_byte = key\n",
" break\n",
" \n",
" if converted_size is None or sf_byte is None:\n",
" err_messg = \"Internal Error, this shouldn't happen, please open an issue\"\n",
" raise RuntimeError(err_messg)\n",
" \n",
" print(f\"Saving StatsForecast object of size {converted_size:.2f}{sf_byte}.\")\n",
" \n",
" if path is None:\n",
" datetime_record = dt.datetime.utcnow().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
" path = f\"StatsForecast_{datetime_record}.pkl\"\n",
" \n",
" with open(path, \"wb\") as m_file:\n",
" pickle.dump(self, m_file)\n",
" print(\"StatsForecast object saved\")\n",
"\n",
" def _get_cap_size(self, max_size, bytes_hmap):\n",
" max_size = max_size.upper().replace(\" \", \"\")\n",
" match = re.match(r'(\\d+\\.\\d+|\\d+)(\\w+)', max_size)\n",
" if match is None or len(match.groups()) < 2 or match[2] not in bytes_hmap.keys():\n",
" parsing_error = \"Couldn't parse `max_size`, it should be `None`\", \\\n",
" \" or a number followed by one of the following units: ['B', 'KB', 'MB', 'GB']\"\n",
" raise ValueError(parsing_error)\n",
" else:\n",
" m_size = float(match[1])\n",
" key_ = match[2]\n",
" cap_size = m_size * bytes_hmap[key_]\n",
" return cap_size\n",
" \n",
" @staticmethod\n",
" def load(path:Union[Path, str]):\n",
" \"\"\"\n",
" Automatically loads the model into ready StatsForecast.\n",
"\n",
" Parameters\n",
" ----------\n",
" path : str or pathlib.Path\n",
" Path to saved StatsForecast file.\n",
" \n",
" Returns\n",
" -------\n",
" sf: StatsForecast\n",
" Previously saved StatsForecast\n",
" \"\"\"\n",
" if not Path(path).exists():\n",
" raise ValueError(\"Specified path does not exist, check again and retry.\")\n",
" with open(path, \"rb\") as f:\n",
" return pickle.load(f)\n",
" \n",
" def __repr__(self):\n",
" return f\"StatsForecast(models=[{','.join(map(repr, self.models))}])\""
]
Expand Down Expand Up @@ -2196,6 +2301,49 @@
"fcsts_df.groupby('unique_id').tail(4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55b00e9b",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# Testing save and load \n",
"import tempfile\n",
"from polars.testing import assert_frame_equal"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b9eaa52",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"with tempfile.TemporaryDirectory() as td:\n",
" f_path = Path(td).joinpath(\"sf_test.pickle\")\n",
" \n",
" test_df = pl.from_pandas(panel_df.astype({\"unique_id\": str}))\n",
" test_frcs = StatsForecast(\n",
" df=test_df,\n",
" models=models,\n",
" freq='D', \n",
" n_jobs=1, \n",
" verbose=True\n",
" )\n",
"\n",
" origin_df = test_frcs.forecast(h=4, fitted=True)\n",
"\n",
" test_frcs.save(f_path)\n",
"\n",
" sf_test = StatsForecast.load(f_path)\n",
" load_df = sf_test.forecast(h=4, fitted=True)\n",
" \n",
" assert_frame_equal(origin_df, load_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -3044,6 +3192,26 @@
" name='StatsForecast.plot')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c69a901",
"metadata": {},
"outputs": [],
"source": [
"show_doc(StatsForecast.save, title_level=2, name='StatsForecast.save')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac4134ac",
"metadata": {},
"outputs": [],
"source": [
"show_doc(StatsForecast.load, title_level=2, name='StatsForecast.load')"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
6 changes: 6 additions & 0 deletions statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._forecast_parallel': ( 'src/core/core.html#_statsforecast._forecast_parallel',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._get_cap_size': ( 'src/core/core.html#_statsforecast._get_cap_size',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._get_gas_Xs': ( 'src/core/core.html#_statsforecast._get_gas_xs',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast._get_pool': ( 'src/core/core.html#_statsforecast._get_pool',
Expand Down Expand Up @@ -188,10 +190,14 @@
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.forecast_fitted_values': ( 'src/core/core.html#_statsforecast.forecast_fitted_values',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.load': ( 'src/core/core.html#_statsforecast.load',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.plot': ( 'src/core/core.html#_statsforecast.plot',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.predict': ( 'src/core/core.html#_statsforecast.predict',
'statsforecast/core.py'),
'statsforecast.core._StatsForecast.save': ( 'src/core/core.html#_statsforecast.save',
'statsforecast/core.py'),
'statsforecast.core._cv_dates': ('src/core/core.html#_cv_dates', 'statsforecast/core.py'),
'statsforecast.core._get_n_jobs': ('src/core/core.html#_get_n_jobs', 'statsforecast/core.py'),
'statsforecast.core._parse_ds_type': ('src/core/core.html#_parse_ds_type', 'statsforecast/core.py'),
Expand Down
112 changes: 111 additions & 1 deletion statsforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
import logging
import reprlib
import warnings
import errno
from pathlib import Path
from os import cpu_count
from typing import Any, List, Optional, Union, Dict
import pkg_resources
import pickle
import datetime as dt
import re

from fugue.execution.factory import make_execution_engine
import numpy as np
Expand Down Expand Up @@ -453,7 +458,6 @@ def __init__(
sort_dataframe: bool,
validate: Optional[bool] = True,
):

self.dataframe = dataframe
self.sort_dataframe = sort_dataframe
self.validate = validate
Expand Down Expand Up @@ -1546,6 +1550,112 @@ def plot(
palette="tab20b",
)

def save(
self,
path: Optional[Union[Path, str]] = None,
max_size: Optional[str] = None,
trim: bool = False,
):
"""Function that will save StatsForecast class with certain settings to make it
reproducible.
Parameters
----------
path : str or pathlib.Path, optional (default=None)
Path of the file to be saved. If `None` will create one in the current
directory using the current UTC timestamp.
max_size : str, optional (default = None)
StatsForecast object should not exceed this size.
Available byte naming: ['B', 'KB', 'MB', 'GB']
trim : bool (default = False)
Delete any attributes not needed for inference.
"""
# Will be used to find the size of the fitted models
# Never expecting anything higher than GB (even that's a lot')
bytes_hmap = {
"B": 1,
"KB": 2**10,
"MB": 2**20,
"GB": 2**30,
}

# Removing unnecessary attributes
# @jmoralez decide future implementation
trim_attr: list = ["fcst_fitted_values_", "cv_fitted_values_"]
if trim:
for attr in trim_attr:
# remove unnecessary attributes here
self.__dict__.pop(attr, None)

sf_size = len(pickle.dumps(self))

if max_size is not None:
cap_size = self._get_cap_size(max_size, bytes_hmap)
if sf_size >= cap_size:
err_messg = "StatsForecast is larger than the specified max_size"
raise OSError(errno.EFBIG, err_messg)

converted_size, sf_byte = None, None
for key in reversed(list(bytes_hmap.keys())):
x_byte = bytes_hmap[key]
if sf_size >= x_byte:
converted_size = sf_size / x_byte
sf_byte = key
break

if converted_size is None or sf_byte is None:
err_messg = "Internal Error, this shouldn't happen, please open an issue"
raise RuntimeError(err_messg)

print(f"Saving StatsForecast object of size {converted_size:.2f}{sf_byte}.")

if path is None:
datetime_record = dt.datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
path = f"StatsForecast_{datetime_record}.pkl"

with open(path, "wb") as m_file:
pickle.dump(self, m_file)
print("StatsForecast object saved")

def _get_cap_size(self, max_size, bytes_hmap):
max_size = max_size.upper().replace(" ", "")
match = re.match(r"(\d+\.\d+|\d+)(\w+)", max_size)
if (
match is None
or len(match.groups()) < 2
or match[2] not in bytes_hmap.keys()
):
parsing_error = (
"Couldn't parse `max_size`, it should be `None`",
" or a number followed by one of the following units: ['B', 'KB', 'MB', 'GB']",
)
raise ValueError(parsing_error)
else:
m_size = float(match[1])
key_ = match[2]
cap_size = m_size * bytes_hmap[key_]
return cap_size

@staticmethod
def load(path: Union[Path, str]):
"""
Automatically loads the model into ready StatsForecast.
Parameters
----------
path : str or pathlib.Path
Path to saved StatsForecast file.
Returns
-------
sf: StatsForecast
Previously saved StatsForecast
"""
if not Path(path).exists():
raise ValueError("Specified path does not exist, check again and retry.")
with open(path, "rb") as f:
return pickle.load(f)

def __repr__(self):
return f"StatsForecast(models=[{','.join(map(repr, self.models))}])"

Expand Down

0 comments on commit 0e25b61

Please sign in to comment.