Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support adding from date when updating pred #703

Merged
merged 2 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar


def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day", align: Optional[str] = None):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
Expand All @@ -587,14 +587,25 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="
current date
shift : int
clip_shift: bool
align : Optional[str]
When align is None, this function will raise ValueError if `trading_date` is not a trading date
when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date

"""
from qlib.data import D

cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
trading_date = pd.to_datetime(trading_date)
if align is None:
if trading_date not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
elif align == "left":
_index = bisect.bisect_right(cal, trading_date) - 1
elif align == "right":
_index = bisect.bisect_left(cal, trading_date)
else:
raise ValueError(f"align with value `{align}` is not supported")
shift_index = _index + shift
if shift_index < 0 or shift_index >= len(cal):
if clip_shift:
Expand Down
41 changes: 30 additions & 11 deletions qlib/workflow/online/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,24 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
SZ300676 -0.001321
"""

def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
def __init__(self, record: Recorder, to_date=None, from_date=None, hist_ref: int = 0, freq="day", fname="pred.pkl"):
"""
Init PredUpdater.

Expected behavior in following cases:
- if `to_date` is greater than the max date in the calendar, the data will be updated to the latest date
- if there are data before `from_date` or after `to_date`, only the data between `from_date` and `to_date` are affected.

Args:
record : Recorder
to_date :
update to prediction to the `to_date`
if to_date is None:
data will updated to the latest date.
from_date :
the update will start from `from_date`
if from_date is None:
the updating will occur on the next tick after the latest data in historical data
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to users to set the length of historical dependency
Expand Down Expand Up @@ -127,13 +137,16 @@ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"
)
to_date = latest_date
self.to_date = to_date

# FIXME: it will raise error when running routine with delay trainer
# should we use another prediction updater for delay trainer?
self.old_data: pd.DataFrame = record.load_object(fname)

# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
if from_date is None:
# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
else:
self.last_end = get_date_by_shift(from_date, -1, align="left")

def prepare_data(self) -> DatasetH:
"""
Expand Down Expand Up @@ -187,6 +200,15 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
...


def _replace_range(data, new_data):
dates = new_data.index.get_level_values("datetime")
data = data.sort_index()
data = data.drop(data.loc[dates.min() : dates.max()].index)
cb_data = pd.concat([data, new_data], axis=0)
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
return cb_data


class PredUpdater(DSBasedUpdater):
"""
Update the prediction in the Recorder
Expand All @@ -196,11 +218,9 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
# Load model
model = self.rmdl.get_model()
new_pred: pd.Series = model.predict(dataset)

cb_pred = pd.concat([self.old_data, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
data = _replace_range(self.old_data, new_pred.to_frame("score"))
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
return cb_pred
return data


class LabelUpdater(DSBasedUpdater):
Expand All @@ -216,6 +236,5 @@ def __init__(self, record: Recorder, to_date=None, **kwargs):

def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
new_label = SignalRecord.generate_label(dataset)
cb_data = pd.concat([self.old_data, new_label], axis=0)
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
cb_data = _replace_range(self.old_data.sort_index(), new_label)
return cb_data
4 changes: 2 additions & 2 deletions qlib/workflow/online/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def online_models(self, exp_name: str = None) -> list:
exp_name = self._get_exp_name(exp_name)
return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())

def update_online_pred(self, to_date=None, exp_name: str = None):
def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None):
"""
Update the predictions of online models to to_date.

Expand All @@ -176,7 +176,7 @@ def update_online_pred(self, to_date=None, exp_name: str = None):
if issubclass(cls, TSDatasetH):
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
try:
updater = PredUpdater(rec, to_date=to_date, hist_ref=hist_ref)
updater = PredUpdater(rec, to_date=to_date, from_date=from_date, hist_ref=hist_ref)
except LoadObjectError as e:
# skip the recorder without pred
self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.")
Expand Down
27 changes: 22 additions & 5 deletions tests/rolling_tests/test_update_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ def test_update_pred(self):
"""
task = copy.deepcopy(CSI300_GBDT_TASK)

task["record"] = {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
"kwargs": {"dataset": "<DATASET>", "model": "<MODEL>"},
}
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]

exp_name = "online_srv_test"

Expand Down Expand Up @@ -57,6 +53,27 @@ def test_update_pred(self):

online_tool.update_online_pred(to_date=latest_date + pd.Timedelta(days=10))

good_pred = rec.load_object("pred.pkl")

mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10))
mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2))
mod_pred = good_pred.copy()

mod_pred.loc[mod_range] = -1
mod_pred.loc[mod_range2] = -2

rec.save_objects(**{"pred.pkl": mod_pred})
online_tool.update_online_pred(
to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20)
)

updated_pred = rec.load_object("pred.pkl")

# this range is not fixed
self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item())
# this range is fixed now
self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item())

def test_update_label(self):

task = copy.deepcopy(CSI300_GBDT_TASK)
Expand Down