Skip to content

Commit

Permalink
RL Training pipeline on 5-min data (#1415)
Browse files Browse the repository at this point in the history
* Workflow runnable

* CI

* Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging.

* Train experiment successful

* Refine handler & provider

* CI issues

* Resolve PR comments

* Resolve PR comments

* CI issues

* Fix test issue

* Black
  • Loading branch information
lihuoran authored Jan 18, 2023
1 parent d876466 commit d8fc9ae
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 59 deletions.
61 changes: 32 additions & 29 deletions qlib/contrib/data/highfreq_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ def __init__(
fit_end_time=None,
drop_raw=True,
day_length=240,
freq="1min",
columns=["$open", "$high", "$low", "$close", "$vwap"],
):
self.day_length = day_length
self.columns = columns

infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
Expand All @@ -124,7 +127,7 @@ def __init__(
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"freq": freq,
},
}
super().__init__(
Expand Down Expand Up @@ -160,19 +163,13 @@ def get_normalized_price_feature(price_field, shift=0):
)
return feature_ops

fields += [get_normalized_price_feature("$open", 0)]
fields += [get_normalized_price_feature("$high", 0)]
fields += [get_normalized_price_feature("$low", 0)]
fields += [get_normalized_price_feature("$close", 0)]
fields += [get_normalized_price_feature("$vwap", 0)]
names += ["$open", "$high", "$low", "$close", "$vwap"]
for column_name in self.columns:
fields.append(get_normalized_price_feature(column_name, 0))
names.append(column_name)

fields += [get_normalized_price_feature("$open", self.day_length)]
fields += [get_normalized_price_feature("$high", self.day_length)]
fields += [get_normalized_price_feature("$low", self.day_length)]
fields += [get_normalized_price_feature("$close", self.day_length)]
fields += [get_normalized_price_feature("$vwap", self.day_length)]
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
for column_name in self.columns:
fields.append(get_normalized_price_feature(column_name, self.day_length))
names.append(column_name + "_1")

# calculate and fill nan with 0
fields += [
Expand Down Expand Up @@ -258,14 +255,17 @@ def __init__(
start_time=None,
end_time=None,
day_length=240,
freq="1min",
columns=["$close", "$vwap", "$volume"],
):
self.day_length = day_length
self.columns = set(columns)
data_loader = {
"class": "QlibDataLoader",
"kwargs": {
"config": self.get_feature_config(),
"swap_level": False,
"freq": "1min",
"freq": freq,
},
}
super().__init__(
Expand All @@ -279,21 +279,24 @@ def get_feature_config(self):
fields = []
names = []

template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
template_fillnan = "FFillNan({0})"
template_if = "If(IsNull({1}), {0}, {1})"
fields += [
template_paused.format(template_fillnan.format("$close")),
]
names += ["$close0"]

fields += [
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
]
names += ["$vwap0"]

fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
names += ["$volume0"]
if "$close" in self.columns:
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
template_fillnan = "FFillNan({0})"
template_if = "If(IsNull({1}), {0}, {1})"
fields += [
template_paused.format(template_fillnan.format("$close")),
]
names += ["$close0"]

if "$vwap" in self.columns:
fields += [
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
]
names += ["$vwap0"]

if "$volume" in self.columns:
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
names += ["$volume0"]

return fields, names

Expand Down
10 changes: 6 additions & 4 deletions qlib/contrib/data/highfreq_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
feature_conf: dict,
label_conf: Optional[dict] = None,
backtest_conf: dict = None,
freq: str = "1min",
**kwargs,
) -> None:
self.start_time = start_time
Expand All @@ -42,6 +43,7 @@ def __init__(
self.backtest_conf = backtest_conf
self.qlib_conf = qlib_conf
self.logger = get_module_logger("HighFreqProvider")
self.freq = freq

def get_pre_datasets(self):
"""Generate the training, validation and test datasets for prediction
Expand Down Expand Up @@ -116,8 +118,8 @@ def _prepare_calender_cache(self):
# This code used the copy-on-write feature of Linux
# to avoid calculating the calendar multiple times in the subprocess.
# This code may accelerate, but may be not useful on Windows and Mac Os
Cal.calendar(freq="1min")
get_calendar_day(freq="1min")
Cal.calendar(freq=self.freq)
get_calendar_day(freq=self.freq)

def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
try:
Expand Down Expand Up @@ -240,7 +242,7 @@ def _gen_day_dataset(self, config, conf_type):
with open(path + "tmp_dataset.pkl", "rb") as f:
new_dataset = pkl.load(f)

time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]

def generate_dataset(times):
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
Expand Down Expand Up @@ -283,7 +285,7 @@ def _gen_stock_dataset(self, config, conf_type):

instruments = D.instruments(market="all")
stock_list = D.list_instruments(
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True
)

def generate_dataset(stock):
Expand Down
68 changes: 50 additions & 18 deletions qlib/rl/contrib/train_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import pandas as pd
import qlib
import torch
import yaml
from qlib.backtest import Order
Expand All @@ -17,7 +18,9 @@
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
from qlib.rl.reward import Reward
from qlib.rl.trainer import Checkpoint, train
from qlib.rl.trainer import Checkpoint, backtest, train
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
from qlib.rl.utils.log import CsvWriter
from qlib.utils import init_instance_by_config
from tianshou.policy import BasePolicy
from torch import nn
Expand Down Expand Up @@ -98,40 +101,54 @@ def train_and_test(
action_interpreter: ActionInterpreter,
policy: BasePolicy,
reward: Reward,
run_backtest: bool,
) -> None:
qlib.init()

order_root_path = Path(data_config["source"]["order_dir"])

data_granularity = simulator_config.get("data_granularity", 1)

def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
return SingleAssetOrderExecutionSimple(
order=order,
data_dir=Path(data_config["source"]["data_dir"]),
ticks_per_step=simulator_config["time_per_step"],
data_granularity=data_granularity,
deal_price_type=data_config["source"].get("deal_price_column", "close"),
vol_threshold=simulator_config["vol_limit"],
)

train_dataset = LazyLoadDataset(
order_file_path=order_root_path / "train",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time"],
default_end_time_index=data_config["source"]["default_end_time"],
)
valid_dataset = LazyLoadDataset(
order_file_path=order_root_path / "valid",
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time"],
default_end_time_index=data_config["source"]["default_end_time"],
)
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
assert data_config["source"]["default_end_time_index"] % data_granularity == 0

train_dataset, valid_dataset, test_dataset = [
LazyLoadDataset(
order_file_path=order_root_path / tag,
data_dir=Path(data_config["source"]["data_dir"]),
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
)
for tag in ("train", "valid", "test")
]

callbacks = []
if "checkpoint_path" in trainer_config:
callbacks: List[Callback] = []
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
callbacks.append(
Checkpoint(
dirpath=Path(trainer_config["checkpoint_path"]),
every_n_iters=trainer_config["checkpoint_every_n_iters"],
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
save_latest="copy",
),
)
if "earlystop_patience" in trainer_config:
callbacks.append(
EarlyStopping(
patience=trainer_config["earlystop_patience"],
monitor="val/pa",
)
)

trainer_kwargs = {
"max_iters": trainer_config["max_epoch"],
Expand Down Expand Up @@ -160,8 +177,21 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
vessel_kwargs=vessel_kwargs,
)

if run_backtest:
backtest(
simulator_fn=_simulator_factory_simple,
state_interpreter=state_interpreter,
action_interpreter=action_interpreter,
initial_states=test_dataset,
policy=policy,
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
reward=reward,
finite_env_type=trainer_kwargs["finite_env_type"],
concurrency=trainer_kwargs["concurrency"],
)


def main(config: dict) -> None:
def main(config: dict, run_backtest: bool) -> None:
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])

Expand Down Expand Up @@ -200,6 +230,7 @@ def main(config: dict) -> None:
state_interpreter=state_interpreter,
policy=policy,
reward=reward,
run_backtest=run_backtest,
)


Expand All @@ -211,9 +242,10 @@ def main(config: dict) -> None:

parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
args = parser.parse_args()

with open(args.config_path, "r") as input_stream:
config = yaml.safe_load(input_stream)

main(config)
main(config, run_backtest=args.run_backtest)
12 changes: 11 additions & 1 deletion qlib/rl/data/pickle_styled.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path:

@lru_cache(maxsize=10) # 10 * 40M = 400MB
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
return pd.read_pickle(_find_pickle(filename_without_suffix))
df = pd.read_pickle(_find_pickle(filename_without_suffix))
index_cols = df.index.names

df = df.reset_index()
for date_col_name in ["date", "datetime"]:
if date_col_name in df:
df[date_col_name] = pd.to_datetime(df[date_col_name])
df = df.set_index(index_cols)

return df


class SimpleIntradayBacktestData(BaseIntradayBacktestData):
Expand Down Expand Up @@ -161,6 +170,7 @@ def __init__(
time_index: pd.Index,
) -> None:
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)

# We have to infer the names here because,
# unfortunately they are not included in the original data.
cnames = _infer_processed_data_column_names(feature_dim)
Expand Down
7 changes: 5 additions & 2 deletions qlib/rl/order_execution/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ class PAPenaltyReward(Reward[SAOEState]):
----------
penalty
The penalty for large volume in a short time.
scale
The weight used to scale up or down the reward.
"""

def __init__(self, penalty: float = 100.0):
def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None:
self.penalty = penalty
self.scale = scale

def reward(self, simulator_state: SAOEState) -> float:
whole_order = simulator_state.order.amount
Expand All @@ -43,4 +46,4 @@ def reward(self, simulator_state: SAOEState) -> float:

self.log("reward/pa", pa)
self.log("reward/penalty", penalty)
return reward
return reward * self.scale
9 changes: 8 additions & 1 deletion qlib/rl/order_execution/simulator_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
----------
order
The seed to start an SAOE simulator is an order.
data_granularity
Number of ticks between consecutive data entries.
ticks_per_step
How many ticks per step.
data_dir
Expand Down Expand Up @@ -71,14 +73,17 @@ def __init__(
self,
order: Order,
data_dir: Path,
data_granularity: int = 1,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: Optional[float] = None,
) -> None:
super().__init__(initial=order)

assert ticks_per_step % data_granularity == 0

self.order = order
self.ticks_per_step: int = ticks_per_step
self.ticks_per_step: int = ticks_per_step // data_granularity
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
Expand Down Expand Up @@ -132,6 +137,8 @@ def step(self, amount: float) -> None:
ticks_position = self.position - np.cumsum(exec_vol)

self.position -= exec_vol.sum()
if abs(self.position) < 1e-6:
self.position = 0.0
if self.position < -EPS or (exec_vol < -EPS).any():
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")

Expand Down
13 changes: 11 additions & 2 deletions qlib/rl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
"""Train, test, inference utilities."""

from .api import backtest, train
from .callbacks import Checkpoint, EarlyStopping
from .callbacks import Checkpoint, EarlyStopping, MetricsWriter
from .trainer import Trainer
from .vessel import TrainingVessel, TrainingVesselBase

__all__ = ["Trainer", "TrainingVessel", "TrainingVesselBase", "Checkpoint", "EarlyStopping", "train", "backtest"]
__all__ = [
"Trainer",
"TrainingVessel",
"TrainingVesselBase",
"Checkpoint",
"EarlyStopping",
"MetricsWriter",
"train",
"backtest",
]
Loading

0 comments on commit d8fc9ae

Please sign in to comment.