Skip to content

Commit

Permalink
warn when num_threads is not used in MinTrace (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jul 4, 2024
1 parent b52b044 commit 3817c69
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ jobs:
- run:
name: Install dependencies
command: |
micromamba install -n base -c conda-forge -y python=3.7
micromamba install -n base -c conda-forge -y python=3.10
micromamba update -n base -f environment.yml
- run:
name: Run nbdev tests
command: |
eval "$(micromamba shell hook --shell bash)"
micromamba activate base
pip install ".[dev]"
nbdev_test --do_print --timing --n_workers 1
nbdev_test --do_print --timing --n_workers 0
test-model-performance:
resource_class: large
docker:
Expand All @@ -27,7 +27,7 @@ jobs:
- run:
name: Install dependencies
command: |
micromamba install -n base -c conda-forge -y python=3.7
micromamba install -n base -c conda-forge -y python=3.10
micromamba update -n base -f environment.yml
- run:
name: Run model performance tests
Expand Down
4 changes: 3 additions & 1 deletion hierarchicalforecast/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ class MinTrace(HReconciler):
`method`: str, one of `ols`, `wls_struct`, `wls_var`, `mint_shrink`, `mint_cov`.<br>
`nonnegative`: bool, reconciled forecasts should be nonnegative?<br>
`mint_shr_ridge`: float=2e-8, ridge numeric protection to MinTrace-shr covariance estimator.<br>
`num_threads`: int=1, number of threads to use for solving the optimization problems.
`num_threads`: int=1, number of threads to use for solving the optimization problems (when nonnegative=True).
**References:**<br>
- [Wickramasuriya, S. L., Athanasopoulos, G., & Hyndman, R. J. (2019). \"Optimal forecast reconciliation for
Expand All @@ -596,6 +596,8 @@ def __init__(self,
if method == 'mint_shrink':
self.mint_shr_ridge = mint_shr_ridge
self.num_threads = num_threads
if not self.nonnegative and self.num_threads > 1:
warnings.warn('`num_threads` is only used when `nonnegative=True`')

def _get_PW_matrices(self,
S: np.ndarray,
Expand Down
4 changes: 3 additions & 1 deletion nbs/methods.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@
" `method`: str, one of `ols`, `wls_struct`, `wls_var`, `mint_shrink`, `mint_cov`.<br>\n",
" `nonnegative`: bool, reconciled forecasts should be nonnegative?<br>\n",
" `mint_shr_ridge`: float=2e-8, ridge numeric protection to MinTrace-shr covariance estimator.<br>\n",
" `num_threads`: int=1, number of threads to use for solving the optimization problems.\n",
" `num_threads`: int=1, number of threads to use for solving the optimization problems (when nonnegative=True).\n",
"\n",
" **References:**<br>\n",
" - [Wickramasuriya, S. L., Athanasopoulos, G., & Hyndman, R. J. (2019). \\\"Optimal forecast reconciliation for\n",
Expand All @@ -1050,6 +1050,8 @@
" if method == 'mint_shrink':\n",
" self.mint_shr_ridge = mint_shr_ridge\n",
" self.num_threads = num_threads\n",
" if not self.nonnegative and self.num_threads > 1:\n",
" warnings.warn('`num_threads` is only used when `nonnegative=True`')\n",
"\n",
" def _get_PW_matrices(self, \n",
" S: np.ndarray,\n",
Expand Down

0 comments on commit 3817c69

Please sign in to comment.