diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 12553411c3..25d832c1b8 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -578,7 +578,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): return calendar -def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"): +def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None): """get trading date with shift bias wil cur_date e.g. : shift == 1, return next trading date shift == -1, return previous trading date @@ -587,14 +587,25 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq=" current date shift : int clip_shift: bool + align : Optional[str] + When align is None, this function will raise ValueError if `trading_date` is not a trading date + when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date """ from qlib.data import D cal = D.calendar(future=future, freq=freq) - if pd.to_datetime(trading_date) not in list(cal): - raise ValueError("{} is not trading day!".format(str(trading_date))) - _index = bisect.bisect_left(cal, trading_date) + trading_date = pd.to_datetime(trading_date) + if align is None: + if trading_date not in list(cal): + raise ValueError("{} is not trading day!".format(str(trading_date))) + _index = bisect.bisect_left(cal, trading_date) + elif align == "left": + _index = bisect.bisect_right(cal, trading_date) - 1 + elif align == "right": + _index = bisect.bisect_left(cal, trading_date) + else: + raise ValueError(f"align with value `{align}` is not supported") shift_index = _index + shift if shift_index < 0 or shift_index >= len(cal): if clip_shift: diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index f349a45b3e..8cdcad1f5d 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -90,14 +90,24 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta): SZ300676 -0.001321 """ - def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"): + def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"): """ Init PredUpdater. + Expected behavior in following cases: + - if `to_date` is greater than the max date in the calendar, the data will be updated to the latest date + - if there are data before `from_date` or after `to_date`, only the data between `from_date` and `to_date` are affected. + Args: record : Recorder to_date : update to prediction to the `to_date` + if to_date is None: + data will updated to the latest date. + from_date : + the update will start from `from_date` + if from_date is None: + the updating will occur on the next tick after the latest data in historical data hist_ref : int Sometimes, the dataset will have historical depends. Leave the problem to users to set the length of historical dependency @@ -127,13 +137,16 @@ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day" ) to_date = latest_date self.to_date = to_date + # FIXME: it will raise error when running routine with delay trainer # should we use another prediction updater for delay trainer? self.old_data: pd.DataFrame = record.load_object(fname) - - # dropna is for being compatible to some data with future information(e.g. label) - # The recent label data should be updated together - self.last_end = self.old_data.dropna().index.get_level_values("datetime").max() + if from_date is None: + # dropna is for being compatible to some data with future information(e.g. label) + # The recent label data should be updated together + self.last_end = self.old_data.dropna().index.get_level_values("datetime").max() + else: + self.last_end = get_date_by_shift(from_date, -1, align="left") def prepare_data(self) -> DatasetH: """ @@ -187,6 +200,15 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame: ... +def _replace_range(data, new_data): + dates = new_data.index.get_level_values("datetime") + data = data.sort_index() + data = data.drop(data.loc[dates.min() : dates.max()].index) + cb_data = pd.concat([data, new_data], axis=0) + cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index() + return cb_data + + class PredUpdater(DSBasedUpdater): """ Update the prediction in the Recorder @@ -196,11 +218,9 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame: # Load model model = self.rmdl.get_model() new_pred: pd.Series = model.predict(dataset) - - cb_pred = pd.concat([self.old_data, new_pred.to_frame("score")], axis=0) - cb_pred = cb_pred.sort_index() + data = _replace_range(self.old_data, new_pred.to_frame("score")) self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.") - return cb_pred + return data class LabelUpdater(DSBasedUpdater): @@ -216,6 +236,5 @@ def __init__(self, record: Recorder, to_date=None, **kwargs): def get_update_data(self, dataset: Dataset) -> pd.DataFrame: new_label = SignalRecord.generate_label(dataset) - cb_data = pd.concat([self.old_data, new_label], axis=0) - cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index() + cb_data = _replace_range(self.old_data.sort_index(), new_label) return cb_data diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 0fdec7b340..b1743d9329 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -158,7 +158,7 @@ def online_models(self, exp_name: str = None) -> list: exp_name = self._get_exp_name(exp_name) return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) - def update_online_pred(self, to_date=None, exp_name: str = None): + def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None): """ Update the predictions of online models to to_date. @@ -176,7 +176,7 @@ def update_online_pred(self, to_date=None, exp_name: str = None): if issubclass(cls, TSDatasetH): hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN) try: - updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref) + updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref) except LoadObjectError as e: # skip the recorder without pred self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.") diff --git a/tests/rolling_tests/test_update_pred.py b/tests/rolling_tests/test_update_pred.py index b22152fd2a..f3a295d318 100644 --- a/tests/rolling_tests/test_update_pred.py +++ b/tests/rolling_tests/test_update_pred.py @@ -21,11 +21,7 @@ def test_update_pred(self): """ task = copy.deepcopy(CSI300_GBDT_TASK) - task["record"] = { - "class": "SignalRecord", - "module_path": "qlib.workflow.record_temp", - "kwargs": {"dataset": "", "model": ""}, - } + task["record"] = ["qlib.workflow.record_temp.SignalRecord"] exp_name = "online_srv_test" @@ -57,6 +53,27 @@ def test_update_pred(self): online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10)) + good_pred = rec.load_object("pred.pkl") + + mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10)) + mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2)) + mod_pred = good_pred.copy() + + mod_pred.loc[mod_range] = -1 + mod_pred.loc[mod_range2] = -2 + + rec.save_objects(**{"pred.pkl": mod_pred}) + online_tool.update_online_pred( + to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20) + ) + + updated_pred = rec.load_object("pred.pkl") + + # this range is not fixed + self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item()) + # this range is fixed now + self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item()) + def test_update_label(self): task = copy.deepcopy(CSI300_GBDT_TASK)