Skip to content

Commit

Permalink
fix plkl param prediction with multi_models=False
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Sep 27, 2024
1 parent e9daf67 commit 7026ab0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 11 deletions.
14 changes: 12 additions & 2 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,10 +1143,18 @@ def predict(
# same for covariate matrices
for cov_type, data in covariate_matrices.items():
covariate_matrices[cov_type] = np.repeat(data, num_samples, axis=0)

# for concatenating target with predictions (or quantile parameters)
if predict_likelihood_parameters and self.likelihood is not None:
# with `multi_models=False`, the predictions are concatenated with the past target, even if `n<=ocl`
# to make things work, we just append the first predicted parameter (it will never be accessed)
sample_slice = slice(0, None, self.num_parameters)
else:
sample_slice = slice(None)

# prediction
predictions = []
last_step_shift = 0

# t_pred indicates the number of time steps after the first prediction
for t_pred in range(0, n, step):
# in case of autoregressive forecast `(t_pred > 0)` and if `n` is not a round multiple of `step`,
Expand All @@ -1157,7 +1165,9 @@ def predict(

# concatenate previous iteration forecasts
if "target" in self.lags and predictions:
series_matrix = np.concatenate([series_matrix, predictions[-1]], axis=1)
series_matrix = np.concatenate(
[series_matrix, predictions[-1][:, :, sample_slice]], axis=1
)

# extract and concatenate lags from target and covariates series
X = _create_lagged_data_autoregression(
Expand Down
50 changes: 41 additions & 9 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@
),
]

xgb_test_params = {
"n_estimators": 1,
"max_depth": 1,
"max_leaves": 1,
}
lgbm_test_params = {
"n_estimators": 1,
"max_depth": 1,
"num_leaves": 2,
"verbosity": -1,
}
cb_test_params = {
"iterations": 1,
"depth": 1,
"verbose": -1,
}

if TORCH_AVAILABLE:
models_cls_kwargs_errs += [
(
Expand Down Expand Up @@ -294,15 +311,17 @@ def helper_test_probabilistic_forecast_accuracy(self, model, err, ts, noisy_ts):
@pytest.mark.parametrize(
"config",
itertools.product(
[(LinearRegressionModel, False), (XGBModel, False)]
+ ([(LightGBMModel, False)] if lgbm_available else [])
+ ([(CatBoostModel, True)] if cb_available else []),
[1, 3],
[(LinearRegressionModel, False, {}), (XGBModel, False, xgb_test_params)]
+ ([(LightGBMModel, False, lgbm_test_params)] if lgbm_available else [])
+ ([(CatBoostModel, True, cb_test_params)] if cb_available else []),
[1, 3], # n components
[
"quantile",
"poisson",
"gaussian",
],
], # likelihood
[True, False], # multi models
[1, 2], # horizon
),
)
def test_predict_likelihood_parameters_regression_models(self, config):
Expand All @@ -312,7 +331,13 @@ def test_predict_likelihood_parameters_regression_models(self, config):
Note: values are not tested as it would be too time consuming
"""
(model_cls, supports_gaussian), n_comp, likelihood = config
(
(model_cls, supports_gaussian, model_kwargs),
n_comp,
likelihood,
multi_models,
horizon,
) = config

seed = 142857
n_times, n_samples = 100, 1
Expand Down Expand Up @@ -340,10 +365,17 @@ def test_predict_likelihood_parameters_regression_models(self, config):
else:
assert False, f"unknown likelihood {likelihood}"

model = model_cls(lags=3, random_state=seed, **lkl["kwargs"])
model = model_cls(
lags=3,
output_chunk_length=horizon,
random_state=seed,
**lkl["kwargs"],
multi_models=multi_models,
**model_kwargs,
)
model.fit(lkl["ts"])
pred_lkl_params = model.predict(
n=1, num_samples=1, predict_likelihood_parameters=True
n=horizon, num_samples=1, predict_likelihood_parameters=True
)
if n_comp == 1:
assert lkl["expected"].shape == pred_lkl_params.values()[0].shape, (
Expand All @@ -352,7 +384,7 @@ def test_predict_likelihood_parameters_regression_models(self, config):
)
else:
assert (
1,
horizon,
len(lkl["expected"]) * n_comp,
1,
) == pred_lkl_params.all_values().shape, (
Expand Down

0 comments on commit 7026ab0

Please sign in to comment.