Skip to content

Commit

Permalink
support adding from date when updating pred (#703)
Browse files Browse the repository at this point in the history
* support adding from date when updating pred

* fix updating data error
  • Loading branch information
you-n-g authored Nov 22, 2021
1 parent 103f857 commit 45ebb1d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 22 deletions.
19 changes: 15 additions & 4 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar


def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
Expand All @@ -587,14 +587,25 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="
current date
shift : int
clip_shift: bool
align : Optional[str]
When align is None, this function will raise ValueError if `trading_date` is not a trading date
when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date
"""
from qlib.data import D

cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
trading_date = pd.to_datetime(trading_date)
if align is None:
if trading_date not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
elif align == "left":
_index = bisect.bisect_right(cal, trading_date) - 1
elif align == "right":
_index = bisect.bisect_left(cal, trading_date)
else:
raise ValueError(f"align with value `{align}` is not supported")
shift_index = _index + shift
if shift_index < 0 or shift_index >= len(cal):
if clip_shift:
Expand Down
41 changes: 30 additions & 11 deletions qlib/workflow/online/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,24 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
SZ300676 -0.001321
"""

def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
"""
Init PredUpdater.
Expected behavior in following cases:
- if `to_date` is greater than the max date in the calendar, the data will be updated to the latest date
- if there are data before `from_date` or after `to_date`, only the data between `from_date` and `to_date` are affected.
Args:
record : Recorder
to_date :
update to prediction to the `to_date`
if to_date is None:
data will updated to the latest date.
from_date :
the update will start from `from_date`
if from_date is None:
the updating will occur on the next tick after the latest data in historical data
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to users to set the length of historical dependency
Expand Down Expand Up @@ -127,13 +137,16 @@ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"
)
to_date = latest_date
self.to_date = to_date

# FIXME: it will raise error when running routine with delay trainer
# should we use another prediction updater for delay trainer?
self.old_data: pd.DataFrame = record.load_object(fname)

# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
if from_date is None:
# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
else:
self.last_end = get_date_by_shift(from_date, -1, align="left")

def prepare_data(self) -> DatasetH:
"""
Expand Down Expand Up @@ -187,6 +200,15 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
...


def _replace_range(data, new_data):
dates = new_data.index.get_level_values("datetime")
data = data.sort_index()
data = data.drop(data.loc[dates.min() : dates.max()].index)
cb_data = pd.concat([data, new_data], axis=0)
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
return cb_data


class PredUpdater(DSBasedUpdater):
"""
Update the prediction in the Recorder
Expand All @@ -196,11 +218,9 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
# Load model
model = self.rmdl.get_model()
new_pred: pd.Series = model.predict(dataset)

cb_pred = pd.concat([self.old_data, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
data = _replace_range(self.old_data, new_pred.to_frame("score"))
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
return cb_pred
return data


class LabelUpdater(DSBasedUpdater):
Expand All @@ -216,6 +236,5 @@ def __init__(self, record: Recorder, to_date=None, **kwargs):

def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
new_label = SignalRecord.generate_label(dataset)
cb_data = pd.concat([self.old_data, new_label], axis=0)
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
cb_data = _replace_range(self.old_data.sort_index(), new_label)
return cb_data
4 changes: 2 additions & 2 deletions qlib/workflow/online/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def online_models(self, exp_name: str = None) -> list:
exp_name = self._get_exp_name(exp_name)
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())

def update_online_pred(self, to_date=None, exp_name: str = None):
def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.
Expand All @@ -176,7 +176,7 @@ def update_online_pred(self, to_date=None, exp_name: str = None):
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try:
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref)
except LoadObjectError as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
Expand Down
27 changes: 22 additions & 5 deletions tests/rolling_tests/test_update_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ def test_update_pred(self):
"""
task = copy.deepcopy(CSI300_GBDT_TASK)

task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
}
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]

exp_name = "online_srv_test"

Expand Down Expand Up @@ -57,6 +53,27 @@ def test_update_pred(self):

online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))

good_pred = rec.load_object("pred.pkl")

mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
mod_pred = good_pred.copy()

mod_pred.loc[mod_range] = -1
mod_pred.loc[mod_range2] = -2

rec.save_objects(**{"pred.pkl": mod_pred})
online_tool.update_online_pred(
to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
)

updated_pred = rec.load_object("pred.pkl")

# this range is not fixed
self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
# this range is fixed now
self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())

def test_update_label(self):

task = copy.deepcopy(CSI300_GBDT_TASK)
Expand Down

0 comments on commit 45ebb1d

Please sign in to comment.