From cf48cc94d0422e6354a08765b4176552442987d5 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 10 Jun 2022 12:14:48 +0800 Subject: [PATCH] Refine backtest codes (#1120) * Refine backtest code * Keep working * Minor * Resolve PR comments * Fix import error * Fix import error --- qlib/backtest/__init__.py | 131 +++++++----- qlib/backtest/account.py | 90 ++++---- qlib/backtest/backtest.py | 32 ++- qlib/backtest/decision.py | 109 +++++----- qlib/backtest/exchange.py | 306 ++++++++++++++++++--------- qlib/backtest/executor.py | 180 ++++++++++------ qlib/backtest/high_performance_ds.py | 19 +- qlib/backtest/position.py | 154 +++++++------- qlib/backtest/profit_attribution.py | 16 +- qlib/backtest/report.py | 34 ++- qlib/backtest/signal.py | 13 +- qlib/backtest/utils.py | 107 ++++++---- qlib/strategy/base.py | 61 ++++-- qlib/utils/__init__.py | 3 +- tests/backtest/test_file_strategy.py | 13 +- 15 files changed, 784 insertions(+), 484 deletions(-) diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index ab62b7d559..20fbe14a43 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -2,24 +2,29 @@ # Licensed under the MIT License. from __future__ import annotations + import copy -from typing import List, Tuple, Union, TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union + +import pandas as pd from .account import Account +from .report import Indicator, PortfolioMetrics if TYPE_CHECKING: from ..strategy.base import BaseStrategy from .executor import BaseExecutor from .decision import BaseTradeDecision -from .position import Position + +from ..config import C +from ..log import get_module_logger +from ..utils import init_instance_by_config +from .backtest import backtest_loop, collect_data_loop +from .decision import Order from .exchange import Exchange -from .backtest import backtest_loop -from .backtest import collect_data_loop +from .position import Position from .utils import CommonInfrastructure -from .decision import Order -from ..utils import init_instance_by_config -from ..log import get_module_logger -from ..config import C # make import more user-friendly by adding `from qlib.backtest import STH` @@ -28,26 +33,34 @@ def get_exchange( - exchange=None, - freq="day", - start_time=None, - end_time=None, - codes="all", - subscribe_fields=[], - open_cost=0.0015, - close_cost=0.0025, - min_cost=5.0, - limit_threshold=None, + exchange: Union[str, dict, object, Path] = None, + freq: str = "day", + start_time: Union[pd.Timestamp, str] = None, + end_time: Union[pd.Timestamp, str] = None, + codes: Union[list, str] = "all", + subscribe_fields: list = [], + open_cost: float = 0.0015, + close_cost: float = 0.0025, + min_cost: float = 5.0, + limit_threshold: Union[Tuple[str, str], float, None] = None, deal_price: Union[str, Tuple[str], List[str]] = None, **kwargs, -): +) -> Exchange: """get_exchange Parameters ---------- # exchange related arguments - exchange: Exchange(). + exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`. + freq: str + frequency of data. + start_time: Union[pd.Timestamp, str] + closed start time for backtest. + end_time: Union[pd.Timestamp, str] + closed end time for backtest. + codes: list|str + list stock_id list or a string of instruments (i.e. all, csi500, sse50) subscribe_fields: list subscribe fields. open_cost : float @@ -57,8 +70,6 @@ def get_exchange( min_cost : float min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount. e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount. - trade_unit : int - Included in kwargs. Please refer to the docs of `__init__` of `Exchange` deal_price: Union[str, Tuple[str], List[str]] The `deal_price` supports following two types of input - : str @@ -101,10 +112,14 @@ def get_exchange( def create_account_instance( - start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position" + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + benchmark: str, + account: Union[float, int, dict], + pos_type: str = "Position", ) -> Account: """ - # TODO: is very strange pass benchmark_config in the account(maybe for report) + # TODO: is very strange pass benchmark_config in the account (maybe for report) # There should be a post-step to process the report. Parameters @@ -132,6 +147,8 @@ def create_account_instance( key "cash" means initial cash. key "stock1" means the information of first stock with amount and price(optional). ... + pos_type: str + Postion type. """ if isinstance(account, (int, float)): pos_kwargs = {"init_cash": account} @@ -159,15 +176,15 @@ def create_account_instance( def get_strategy_executor( - start_time, - end_time, - strategy: BaseStrategy, - executor: BaseExecutor, + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + strategy: Union[str, dict, object, Path], + executor: Union[str, dict, object, Path], benchmark: str = "SH000300", account: Union[float, int, Position] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", -): +) -> Tuple[BaseStrategy, BaseExecutor]: # NOTE: # - for avoiding recursive import @@ -176,7 +193,11 @@ def get_strategy_executor( from .executor import BaseExecutor # pylint: disable=C0415 trade_account = create_account_instance( - start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type + start_time=start_time, + end_time=end_time, + benchmark=benchmark, + account=account, + pos_type=pos_type, ) exchange_kwargs = copy.copy(exchange_kwargs) @@ -196,29 +217,31 @@ def get_strategy_executor( def backtest( - start_time, - end_time, - strategy, - executor, - benchmark="SH000300", - account=1e9, - exchange_kwargs={}, + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + strategy: Union[str, dict, object, Path], + executor: Union[str, dict, object, Path], + benchmark: str = "SH000300", + account: Union[float, int, Position] = 1e9, + exchange_kwargs: dict = {}, pos_type: str = "Position", -): - """initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution +) -> Tuple[PortfolioMetrics, Indicator]: + """initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and + executor in the nested decision execution Parameters ---------- - start_time : pd.Timestamp|str + start_time : Union[pd.Timestamp, str] closed start time for backtest **NOTE**: This will be applied to the outmost executor's calendar. - end_time : pd.Timestamp|str + end_time : Union[pd.Timestamp, str] closed end time for backtest **NOTE**: This will be applied to the outmost executor's calendar. E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301 - strategy : Union[str, dict, BaseStrategy] - for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information. - executor : Union[str, dict, BaseExecutor] + strategy : Union[str, dict, object, Path] + for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more + information. + executor : Union[str, dict, object, Path] for initializing the outermost executor. benchmark: str the benchmark for reporting. @@ -257,16 +280,16 @@ def backtest( def collect_data( - start_time, - end_time, - strategy, - executor, - benchmark="SH000300", - account=1e9, - exchange_kwargs={}, + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + strategy: Union[str, dict, object, Path], + executor: Union[str, dict, object, Path], + benchmark: str = "SH000300", + account: Union[float, int, Position] = 1e9, + exchange_kwargs: dict = {}, pos_type: str = "Position", return_value: dict = None, -): +) -> Generator[object, None, None]: """initialize the strategy and executor, then collect the trade decision data for rl training please refer to the docs of the backtest for the explanation of the parameters @@ -291,7 +314,7 @@ def collect_data( def format_decisions( decisions: List[BaseTradeDecision], -) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]: +) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]: """ format the decisions collected by `qlib.backtest.collect_data` The decisions will be organized into a tree-like structure. @@ -326,4 +349,4 @@ def format_decisions( return res -__all__ = ["Order"] +__all__ = ["Order", "backtest"] diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 4c9330e4c3..9d8adddb0b 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -1,15 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations + import copy from typing import Dict, List, Tuple -from qlib.utils import init_instance_by_config + import pandas as pd -from .position import BasePosition -from .report import PortfolioMetrics, Indicator +from qlib.utils import init_instance_by_config + from .decision import BaseTradeDecision, Order from .exchange import Exchange +from .position import BasePosition +from .report import Indicator, PortfolioMetrics """ rtn & earning in the Account @@ -34,40 +37,42 @@ class AccumulatedInfo: AccumulatedInfo should be shared across different levels """ - def __init__(self): + def __init__(self) -> None: self.reset() - def reset(self): - self.rtn = 0 # accumulated return, do not consider cost - self.cost = 0 # accumulated cost - self.to = 0 # accumulated turnover + def reset(self) -> None: + self.rtn: float = 0.0 # accumulated return, do not consider cost + self.cost: float = 0.0 # accumulated cost + self.to: float = 0.0 # accumulated turnover - def add_return_value(self, value): + def add_return_value(self, value: float) -> None: self.rtn += value - def add_cost(self, value): + def add_cost(self, value: float) -> None: self.cost += value - def add_turnover(self, value): + def add_turnover(self, value: float) -> None: self.to += value @property - def get_return(self): + def get_return(self) -> float: return self.rtn @property - def get_cost(self): + def get_cost(self) -> float: return self.cost @property - def get_turnover(self): + def get_turnover(self) -> float: return self.to class Account: """ - The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor - Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object. + The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in + qlib/backtest/executor.py:NestedExecutor + Different level of executor has different Account object when calculating metrics. But the position object is + shared cross all the Account object. """ def __init__( @@ -78,7 +83,7 @@ def __init__( benchmark_config: dict = {}, pos_type: str = "Position", port_metr_enabled: bool = True, - ): + ) -> None: """the trade account of backtest. Parameters @@ -102,7 +107,7 @@ def __init__( self.benchmark_config = None # avoid no attribute error self.init_vars(init_cash, position_dict, freq, benchmark_config) - def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict): + def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None: # 1) the following variables are shared by multiple layers # - you will see a shallow copy instead of deepcopy in the NestedExecutor; self.init_cash = init_cash @@ -114,7 +119,7 @@ def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict) "position_dict": position_dict, }, "module_path": "qlib.backtest.position", - } + }, ) self.accum_info = AccumulatedInfo() @@ -123,13 +128,13 @@ def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict) self.hist_positions = {} self.reset(freq=freq, benchmark_config=benchmark_config) - def is_port_metr_enabled(self): + def is_port_metr_enabled(self) -> bool: """ Is portfolio-based metrics enabled. """ return self._port_metr_enabled and not self.current_position.skip_update() - def reset_report(self, freq, benchmark_config): + def reset_report(self, freq: str, benchmark_config: dict) -> None: # portfolio related metrics if self.is_port_metr_enabled(): # NOTE: @@ -140,13 +145,13 @@ def reset_report(self, freq, benchmark_config): # fill stock value # The frequency of account may not align with the trading frequency. # This may result in obscure bugs when data quality is low. - if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None: + if isinstance(self.benchmark_config, dict) and "start_time" in self.benchmark_config: self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq) # trading related metrics(e.g. high-frequency trading) self.indicator = Indicator() - def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None): + def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None: """reset freq and report of account Parameters @@ -155,6 +160,7 @@ def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None frequency of account & report, by default None benchmark_config : {}, optional benchmark config of report, by default None + port_metr_enabled: bool """ if freq is not None: self.freq = freq @@ -165,13 +171,13 @@ def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None self.reset_report(self.freq, self.benchmark_config) - def get_hist_positions(self): + def get_hist_positions(self) -> dict: return self.hist_positions - def get_cash(self): + def get_cash(self) -> float: return self.current_position.get_cash() - def _update_state_from_order(self, order, trade_val, cost, trade_price): + def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: if self.is_port_metr_enabled(): # update turnover self.accum_info.add_turnover(trade_val) @@ -191,13 +197,14 @@ def _update_state_from_order(self, order, trade_val, cost, trade_price): profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val self.accum_info.add_return_value(profit) # note here do not consider cost - def update_order(self, order, trade_val, cost, trade_price): + def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: if self.current_position.skip_update(): # TODO: supporting polymorphism for account # updating order for infinite position is meaningless return - # if stock is sold out, no stock price information in Position, then we should update account first, then update current position + # if stock is sold out, no stock price information in Position, then we should update account first, + # then update current position # if stock is bought, there is no stock in current position, update current, then update account # The cost will be subtracted from the cash at last. So the trading logic can ignore the cost calculation if order.direction == Order.SELL: @@ -212,8 +219,15 @@ def update_order(self, order, trade_val, cost, trade_price): self.current_position.update_order(order, trade_val, cost, trade_price) self._update_state_from_order(order, trade_val, cost, trade_price) - def update_current_position(self, trade_start_time, trade_end_time, trade_exchange): - """update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock""" + def update_current_position( + self, + trade_start_time: pd.Timestamp, + trade_end_time: pd.Timestamp, + trade_exchange: Exchange, + ) -> None: + """ + Update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock + """ # update price for stock in the position and the profit from changed_price # NOTE: updating position does not only serve portfolio metrics, it also serve the strategy if not self.current_position.skip_update(): @@ -228,7 +242,7 @@ def update_current_position(self, trade_start_time, trade_end_time, trade_exchan # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy self.current_position.add_count_all(bar=self.freq) - def update_portfolio_metrics(self, trade_start_time, trade_end_time): + def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None: """update portfolio_metrics""" # calculate earning # account_value - last_account_value @@ -243,14 +257,16 @@ def update_portfolio_metrics(self, trade_start_time, trade_end_time): last_account_value = self.portfolio_metrics.get_latest_account_value() last_total_cost = self.portfolio_metrics.get_latest_total_cost() last_total_turnover = self.portfolio_metrics.get_latest_total_turnover() + # get now_account_value, now_stock_value, now_earning, now_cost, now_turnover now_account_value = self.current_position.calculate_value() now_stock_value = self.current_position.calculate_stock_value() now_earning = now_account_value - last_account_value now_cost = self.accum_info.get_cost - last_total_cost now_turnover = self.accum_info.get_turnover - last_total_turnover + # update portfolio_metrics for today - # judge whether the the trading is begin. + # judge whether the trading is begin. # and don't add init account state into portfolio_metrics, due to we don't have excess return in those days. self.portfolio_metrics.update_portfolio_metrics_record( trade_start_time=trade_start_time, @@ -267,7 +283,7 @@ def update_portfolio_metrics(self, trade_start_time, trade_end_time): stock_value=now_stock_value, ) - def update_hist_positions(self, trade_start_time): + def update_hist_positions(self, trade_start_time: pd.Timestamp) -> None: """update history position""" now_account_value = self.current_position.calculate_value() # set now_account_value to position @@ -287,7 +303,7 @@ def update_indicator( inner_order_indicators: List[Dict[str, pd.Series]] = None, decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, indicator_config: dict = {}, - ): + ) -> None: """update trade indicators and order indicators in each bar end""" # TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():` @@ -323,7 +339,7 @@ def update_bar_end( inner_order_indicators: List[Dict[str, pd.Series]] = None, decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None, indicator_config: dict = {}, - ): + ) -> None: """update account at each trading bar step Parameters @@ -338,6 +354,8 @@ def update_bar_end( whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it - if atomic is True, calculate the indicators with trade_info - else, aggregate indicators with inner indicators + outer_trade_decision: BaseTradeDecision + external trade decision trade_info : List[(Order, float, float, float)], optional trading information, by default None - necessary if atomic is True @@ -377,7 +395,7 @@ def update_bar_end( indicator_config=indicator_config, ) - def get_portfolio_metrics(self): + def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]: """get the history portfolio_metrics and positions instance""" if self.is_port_metr_enabled(): _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe() diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index dff15a7c28..c42d6fc9b4 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -2,17 +2,29 @@ # Licensed under the MIT License. from __future__ import annotations + +from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union + +import pandas as pd + from qlib.backtest.decision import BaseTradeDecision -from typing import TYPE_CHECKING +from qlib.backtest.report import Indicator, PortfolioMetrics if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy from qlib.backtest.executor import BaseExecutor -from ..utils.time import Freq + from tqdm.auto import tqdm +from ..utils.time import Freq + -def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor): +def backtest_loop( + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + trade_strategy: BaseStrategy, + trade_executor: BaseExecutor, +) -> Tuple[PortfolioMetrics, Indicator]: """backtest function for the interaction of the outermost strategy and executor in the nested decision execution please refer to the docs of `collect_data_loop` @@ -31,19 +43,23 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec def collect_data_loop( - start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None -): + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + trade_strategy: BaseStrategy, + trade_executor: BaseExecutor, + return_value: dict = None, +) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]: """Generator for collecting the trade decision data for rl training Parameters ---------- - start_time : pd.Timestamp|str + start_time : Union[pd.Timestamp, str] closed start time for backtest **NOTE**: This will be applied to the outmost executor's calendar. - end_time : pd.Timestamp|str + end_time : Union[pd.Timestamp, str] closed end time for backtest **NOTE**: This will be applied to the outmost executor's calendar. - E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301 + E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301 trade_strategy : BaseStrategy the outermost portfolio strategy trade_executor : BaseExecutor diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index e8f787a9f2..9a6084214e 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -2,23 +2,26 @@ # Licensed under the MIT License. from __future__ import annotations -from enum import IntEnum -from qlib.data.data import Cal -from qlib.utils.time import concat_date_time, epsilon_change -from qlib.log import get_module_logger -from typing import ClassVar, Optional, Union, List, Tuple +from abc import abstractmethod +from enum import IntEnum # try to fix circular imports when enabling type hints -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union + +from qlib.backtest.utils import TradeCalendarManager +from qlib.data.data import Cal +from qlib.log import get_module_logger +from qlib.utils.time import concat_date_time, epsilon_change if TYPE_CHECKING: from qlib.strategy.base import BaseStrategy from qlib.backtest.exchange import Exchange -from qlib.backtest.utils import TradeCalendarManager + +from dataclasses import dataclass + import numpy as np import pandas as pd -from dataclasses import dataclass class OrderDir(IntEnum): @@ -46,7 +49,7 @@ class Order: # - they are set by users and is time-invariant. stock_id: str amount: float # `amount` is a non-negative and adjusted value - direction: int + direction: OrderDir # 2) time variant values: # - Users may want to set these values when using lower level APIs @@ -61,7 +64,7 @@ class Order: # What the value should be about in all kinds of cases # - not tradable: the deal_amount == 0 , factor is None # - the stock is suspended and the entire order fails. No cost for this order - # - dealed or partially dealed: deal_amount >= 0 and factor is not None + # - dealt or partially dealt: deal_amount >= 0 and factor is not None deal_amount: Optional[float] = None # `deal_amount` is a non-negative value factor: Optional[float] = None @@ -74,10 +77,10 @@ class Order: SELL: ClassVar[OrderDir] = OrderDir.SELL BUY: ClassVar[OrderDir] = OrderDir.BUY - def __post_init__(self): + def __post_init__(self) -> None: if self.direction not in {Order.SELL, Order.BUY}: raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") - self.deal_amount = 0 + self.deal_amount = 0.0 self.factor = None @property @@ -99,7 +102,7 @@ def deal_amount_delta(self) -> float: return self.deal_amount * self.sign @property - def sign(self) -> float: + def sign(self) -> int: """ return the sign of trading - `+1` indicates buying @@ -112,15 +115,12 @@ def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> U if isinstance(direction, OrderDir): return direction elif isinstance(direction, (int, float, np.integer, np.floating)): - if direction > 0: - return Order.BUY - else: - return Order.SELL + return Order.BUY if direction > 0 else Order.SELL elif isinstance(direction, str): - dl = direction.lower() - if dl.strip() == "sell": + dl = direction.lower().strip() + if dl == "sell": return OrderDir.SELL - elif dl.strip() == "buy": + elif dl == "buy": return OrderDir.BUY else: raise NotImplementedError(f"This type of input is not supported") @@ -138,14 +138,14 @@ class OrderHelper: Motivation - Make generating order easier - User may have no knowledge about the adjust-factor information about the system. - - It involves to much interaction with the exchange when generating orders. + - It involves too much interaction with the exchange when generating orders. """ - def __init__(self, exchange: Exchange): + def __init__(self, exchange: Exchange) -> None: self.exchange = exchange + @staticmethod def create( - self, code: str, amount: float, direction: OrderDir, @@ -175,21 +175,18 @@ def create( Order: The created order """ - if start_time is not None: - start_time = pd.Timestamp(start_time) - if end_time is not None: - end_time = pd.Timestamp(end_time) # NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders return Order( stock_id=code, amount=amount, - start_time=start_time, - end_time=end_time, + start_time=start_time if start_time is not None else pd.Timestamp(start_time), + end_time=end_time if end_time is not None else pd.Timestamp(end_time), direction=direction, ) class TradeRange: + @abstractmethod def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]: """ This method will be call with following way @@ -216,6 +213,7 @@ def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]: """ raise NotImplementedError(f"Please implement the `__call__` method") + @abstractmethod def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: """ Parameters @@ -234,23 +232,26 @@ def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> T class IdxTradeRange(TradeRange): - def __init__(self, start_idx: int, end_idx: int): + def __init__(self, start_idx: int, end_idx: int) -> None: self._start_idx = start_idx self._end_idx = end_idx def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]: return self._start_idx, self._end_idx + def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: + raise NotImplementedError + class TradeRangeByTime(TradeRange): """This is a helper function for make decisions""" - def __init__(self, start_time: str, end_time: str): + def __init__(self, start_time: str, end_time: str) -> None: """ This is a callable class. **NOTE**: - - It is designed for minute-bar for intraday trading!!!!! + - It is designed for minute-bar for intra-day trading!!!!! - Both start_time and end_time are **closed** in the range Parameters @@ -264,26 +265,25 @@ def __init__(self, start_time: str, end_time: str): self.end_time = pd.Timestamp(end_time).time() assert self.start_time < self.end_time - def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]: + def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]: if trade_calendar is None: raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.") - start = trade_calendar.start_time - val_start, val_end = concat_date_time(start.date(), self.start_time), concat_date_time( - start.date(), self.end_time - ) + + start_date = trade_calendar.start_time.date() + val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time) return trade_calendar.get_range_idx(val_start, val_end) def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: start_date = start_time.date() val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time) # NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day - # Assumption: start_time and end_time is for intraday trading. So it is OK for only using start_date + # Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date return max(val_start, start_time), min(val_end, end_time) class BaseTradeDecision: """ - Trade decisions ara made by strategy and executed by exeuter + Trade decisions ara made by strategy and executed by executor Motivation: Here are several typical scenarios for `BaseTradeDecision` @@ -297,7 +297,7 @@ class BaseTradeDecision: 2. Same as `case 1.3` """ - def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None): + def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None: """ Parameters ---------- @@ -339,7 +339,7 @@ def get_decision(self) -> List[object]: """ raise NotImplementedError(f"This type of input is not supported") - def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]: + def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]: """ Be called at the **start** of each step. @@ -354,10 +354,8 @@ def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecisi Returns ------- - None: - No update, use previous decision(or unavailable) BaseTradeDecision: - New update, use new decision + New update, use new decision. If no updates, return None (use previous decision (or unavailable)) """ # purpose 1) self.total_step = trade_calendar.get_trade_len() @@ -412,12 +410,12 @@ def get_range_limit(self, **kwargs) -> Tuple[int, int]: """ try: _start_idx, _end_idx = self._get_range_limit(**kwargs) - except NotImplementedError: + except NotImplementedError as e: if "default_value" in kwargs: return kwargs["default_value"] else: # Default to get full index - raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError + raise NotImplementedError(f"The decision didn't provide an index range") from e # clip index if getattr(self, "total_step", None) is not None: @@ -426,7 +424,7 @@ def get_range_limit(self, **kwargs) -> Tuple[int, int]: if _start_idx < 0 or _end_idx >= self.total_step: logger = get_module_logger("decision") logger.warning( - f"[{_start_idx},{_end_idx}] go beyoud the total_step({self.total_step}), it will be clipped" + f"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.", ) _start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx) return _start_idx, _end_idx @@ -444,7 +442,7 @@ def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = Fals Parameters ---------- rtype: str - - "full": return the full limitation of the deicsion in the day + - "full": return the full limitation of the decision in the day - "step": return the limitation of current step raise_error: bool @@ -497,11 +495,10 @@ def empty(self) -> bool: return True return True - def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision): + def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None: """ - This method will be called on the inner_trade_decision after it is generated. - `inner_trade_decision` will be changed **inplaced**. + `inner_trade_decision` will be changed **inplace**. Motivation of the `mod_inner_decision` - Leave a hook for outer decision to affect the decision generated by the inner strategy @@ -520,6 +517,9 @@ def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision): class EmptyTradeDecision(BaseTradeDecision): + def get_decision(self) -> List[object]: + return [] + def empty(self) -> bool: return True @@ -544,4 +544,9 @@ def get_decision(self) -> List[object]: return self.order_list def __repr__(self) -> str: - return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]" + return ( + f"class: {self.__class__.__name__}; " + f"strategy: {self.strategy}; " + f"trade_range: {self.trade_range}; " + f"order_list[{len(self.order_list)}]" + ) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 4c020f8d88..ba1dd2c0b8 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -1,21 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations + from collections import defaultdict -from typing import TYPE_CHECKING -from typing import List, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union + +from ..utils.index_data import IndexData if TYPE_CHECKING: from .account import Account -from qlib.backtest.position import BasePosition, Position import random + import numpy as np import pandas as pd -from ..data.data import D +from qlib.backtest.position import BasePosition + from ..config import C from ..constant import REG_CN +from ..data.data import D from ..log import get_module_logger from .decision import Order, OrderDir, OrderHelper from .high_performance_ds import BaseQuote, NumpyQuote @@ -24,22 +28,22 @@ class Exchange: def __init__( self, - freq="day", - start_time=None, - end_time=None, - codes="all", + freq: str = "day", + start_time: Union[pd.Timestamp, str] = None, + end_time: Union[pd.Timestamp, str] = None, + codes: Union[list, str] = "all", deal_price: Union[str, Tuple[str], List[str]] = None, - subscribe_fields=[], + subscribe_fields: list = [], limit_threshold: Union[Tuple[str, str], float, None] = None, - volume_threshold=None, - open_cost=0.0015, - close_cost=0.0025, - min_cost=5, - impact_cost=0.0, - extra_quote=None, - quote_cls=NumpyQuote, + volume_threshold: Union[tuple, dict] = None, + open_cost: float = 0.0015, + close_cost: float = 0.0025, + min_cost: float = 5.0, + impact_cost: float = 0.0, + extra_quote: pd.DataFrame = None, + quote_cls: Type[BaseQuote] = NumpyQuote, **kwargs, - ): + ) -> None: """__init__ :param freq: frequency of data :param start_time: closed start time for backtest @@ -72,11 +76,12 @@ def __init__( ] 1) ("cum" or "current", limit_str) denotes a single volume limit. - limit_str is qlib data expression which is allowed to define your own Operator. - Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for high frequency, - such as DayCumsum. !!!NOTE: if you want you use the custom operator, you need to - register it in qlib_init. - - "cum" means that this is a cumulative value over time, such as cumulative market volume. - So when it is used as a volume limit, it is necessary to subtract the dealt amount. + Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for + high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom + operator, you need to register it in qlib_init. + - "cum" means that this is a cumulative value over time, such as cumulative market + volume. So when it is used as a volume limit, it is necessary to subtract the dealt + amount. - "current" means that this is a real-time value and will not accumulate over time, so it can be directly used as a capacity limit. e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1") @@ -84,7 +89,7 @@ def __init__( "buy" means the volume limits of buying. "sell" means the volume limits of selling. Different volume limits will be aggregated with min(). If volume_threshold is only ("cum" or "current", limit_str) instead of a dict, the volume limits are for - both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}. + both by default. In other words, it is same as {"all": ("cum" or "current", limit_str)}. 3) e.g. "volume_threshold": { "all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), "buy": ("current", "$askV1"), @@ -104,13 +109,14 @@ def __init__( Necessary fields: $close is for calculating the total value at end of each day. Optional fields: - $volume is only necessary when we limit the trade amount or calculate PA(vwap) indicator + $volume is only necessary when we limit the trade amount or calculate + PA(vwap) indicator $vwap is only necessary when we use the $vwap price as the deal price $factor is for rounding to the trading unit - limit_sell will be set to False by default(False indicates we can sell this - target on this day). - limit_buy will be set to False by default(False indicates we can buy this - target on this day). + limit_sell will be set to False by default (False indicates we can sell + this target on this day). + limit_buy will be set to False by default (False indicates we can buy + this target on this day). index: MultipleIndex(instrument, pd.Datetime) """ self.freq = freq @@ -163,7 +169,7 @@ def __init__( if self.limit_type == self.LT_TP_EXP: for exp in limit_threshold: necessary_fields.add(exp) - all_fields = necessary_fields | vol_lt_fields + all_fields = necessary_fields | set(vol_lt_fields) all_fields = list(all_fields | set(subscribe_fields)) self.all_fields = all_fields @@ -182,17 +188,22 @@ def __init__( self.quote_cls = quote_cls self.quote: BaseQuote = self.quote_cls(self.quote_df, freq) - def get_quote_from_qlib(self): + def get_quote_from_qlib(self) -> None: # get stock data from qlib if len(self.codes) == 0: self.codes = D.instruments() self.quote_df = D.features( - self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True + self.codes, + self.all_fields, + self.start_time, + self.end_time, + freq=self.freq, + disk_cache=True, ).dropna(subset=["$close"]) self.quote_df.columns = self.all_fields # check buy_price data and sell_price data - for attr in "buy_price", "sell_price": + for attr in ("buy_price", "sell_price"): pstr = getattr(self, attr) # price string if self.quote_df[pstr].isna().any(): self.logger.warning("{} field data contains nan.".format(pstr)) @@ -238,7 +249,7 @@ def get_quote_from_qlib(self): LT_FLT = "float" # float LT_NONE = "none" # none - def _get_limit_type(self, limit_threshold): + def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str: """get limit type""" if isinstance(limit_threshold, Tuple): return self.LT_TP_EXP @@ -249,7 +260,7 @@ def _get_limit_type(self, limit_threshold): else: raise NotImplementedError(f"This type of `limit_threshold` is not supported") - def _update_limit(self, limit_threshold): + def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None: # check limit_threshold limit_type = self._get_limit_type(limit_threshold) if limit_type == self.LT_NONE: @@ -263,9 +274,10 @@ def _update_limit(self, limit_threshold): self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130 - def _get_vol_limit(self, volume_threshold): + @staticmethod + def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]: """ - preproccess the volume limit. + preprocess the volume limit. get the fields need to get from qlib. get the volume limit list of buying and selling which is composed of all limits. Parameters @@ -295,8 +307,7 @@ def _get_vol_limit(self, volume_threshold): volume_threshold = {"all": volume_threshold} assert isinstance(volume_threshold, dict) - for key in volume_threshold: - vol_limit = volume_threshold[key] + for key, vol_limit in volume_threshold.items(): assert isinstance(vol_limit, tuple) fields.add(vol_limit[1]) @@ -307,10 +318,19 @@ def _get_vol_limit(self, volume_threshold): return buy_vol_limit, sell_vol_limit, fields - def check_stock_limit(self, stock_id, start_time, end_time, direction=None): + def check_stock_limit( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + direction: int = None, + ) -> bool: """ Parameters ---------- + stock_id : str + start_time: pd.Timestamp + end_time: pd.Timestamp direction : int, optional trade direction, by default None - if direction is None, check if tradable for buying and selling. @@ -328,39 +348,42 @@ def check_stock_limit(self, stock_id, start_time, end_time, direction=None): else: raise ValueError(f"direction {direction} is not supported!") - def check_stock_suspended(self, stock_id, start_time, end_time): + def check_stock_suspended( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + ) -> bool: # is suspended if stock_id in self.quote.get_all_stock(): return self.quote.get_data(stock_id, start_time, end_time, "$close") is None else: return True - def is_stock_tradable(self, stock_id, start_time, end_time, direction=None): + def is_stock_tradable( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + direction: int = None, + ) -> bool: # check if stock can be traded - # same as check in check_order - if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit( - stock_id, start_time, end_time, direction - ): - return False - else: - return True + return not ( + self.check_stock_suspended(stock_id, start_time, end_time) + or self.check_stock_limit(stock_id, start_time, end_time, direction) + ) - def check_order(self, order): + def check_order(self, order: Order) -> bool: # check limit and suspended - if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit( - order.stock_id, order.start_time, order.end_time, order.direction - ): - return False - else: - return True + return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction) def deal_order( self, - order, + order: Order, trade_account: Account = None, position: BasePosition = None, dealt_order_amount: defaultdict = defaultdict(float), - ): + ) -> Tuple[float, float, float]: """ Deal order when the actual transaction the results section in `Order` will be changed. @@ -371,9 +394,9 @@ def deal_order( :return: trade_val, trade_cost, trade_price """ # check order first. - if self.check_order(order) is False: + if not self.check_order(order): order.deal_amount = 0.0 - # using np.nan instead of None to make it more convenient to should the value in format string + # using np.nan instead of None to make it more convenient to show the value in format string self.logger.debug(f"Order failed due to trading limitation: {order}") return 0.0, 0.0, np.nan @@ -382,7 +405,9 @@ def deal_order( # NOTE: order will be changed in this function trade_price, trade_val, trade_cost = self._calc_trade_info_by_order( - order, trade_account.current_position if trade_account else position, dealt_order_amount + order, + trade_account.current_position if trade_account else position, + dealt_order_amount, ) if trade_val > 1e-5: # If the order can only be deal 0 value. Nothing to be updated @@ -396,23 +421,49 @@ def deal_order( return trade_val, trade_cost, trade_price - def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"): - return self.quote.get_data(stock_id, start_time, end_time, method=method) - - def get_close(self, stock_id, start_time, end_time, method="ts_data_last"): + def get_quote_info( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + method: str = "ts_data_last", + ) -> Union[None, int, float, bool, IndexData]: + return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`? + + def get_close( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + method: str = "ts_data_last", + ) -> Union[None, int, float, bool, IndexData]: return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method) - def get_volume(self, stock_id, start_time, end_time, method="sum"): + def get_volume( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + method: str = "sum", + ) -> float: """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)""" return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method) - def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"): + def get_deal_price( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + direction: OrderDir, + method: str = "ts_data_last", + ) -> float: if direction == OrderDir.SELL: pstr = self.sell_price elif direction == OrderDir.BUY: pstr = self.buy_price else: raise NotImplementedError(f"This type of input is not supported") + deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method) if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08): self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") @@ -420,11 +471,16 @@ def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, me deal_price = self.get_close(stock_id, start_time, end_time, method) return deal_price - def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]: + def get_factor( + self, + stock_id: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + ) -> Optional[float]: """ Returns ------- - Union[float, None]: + Optional[float]: `None`: if the stock is suspended `None` may be returned `float`: return factor if the factor exists """ @@ -434,11 +490,16 @@ def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]: return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last") def generate_amount_position_from_weight_position( - self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY - ): + self, + weight_position: dict, + cash: float, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + direction: OrderDir = OrderDir.BUY, + ) -> dict: """ The generate the target position according to the weight and the cash. - NOTE: All the cash will assigned to the tadable stock. + NOTE: All the cash will assigned to the tradable stock. Parameter: weight_position : dict {stock_id : weight}; allocate cash by weight_position among then, weight must be in this range: 0 < weight < 1 @@ -451,15 +512,14 @@ def generate_amount_position_from_weight_position( # calculate the total weight of tradable value tradable_weight = 0.0 - for stock_id in weight_position: + for stock_id, wp in weight_position.items(): if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): # weight_position must be greater than 0 and less than 1 - if weight_position[stock_id] < 0 or weight_position[stock_id] > 1: + if wp < 0 or wp > 1: raise ValueError( - "weight_position is {}, " - "weight_position is not in the range of (0, 1).".format(weight_position[stock_id]) + "weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp), ) - tradable_weight += weight_position[stock_id] + tradable_weight += wp if tradable_weight - 1.0 >= 1e-5: raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight)) @@ -467,19 +527,24 @@ def generate_amount_position_from_weight_position( amount_dict = {} for stock_id in weight_position: if weight_position[stock_id] > 0.0 and self.is_stock_tradable( - stock_id=stock_id, start_time=start_time, end_time=end_time + stock_id=stock_id, + start_time=start_time, + end_time=end_time, ): amount_dict[stock_id] = ( cash * weight_position[stock_id] / tradable_weight // self.get_deal_price( - stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction + stock_id=stock_id, + start_time=start_time, + end_time=end_time, + direction=direction, ) ) return amount_dict - def get_real_deal_amount(self, current_amount, target_amount, factor): + def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float: """ Calculate the real adjust deal amount when considering the trading unit :param current_amount: @@ -501,7 +566,13 @@ def get_real_deal_amount(self, current_amount, target_amount, factor): deal_amount = self.round_amount_by_trade_unit(deal_amount, factor) return -deal_amount - def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time): + def generate_order_for_target_amount_position( + self, + target_position: dict, + current_position: dict, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + ) -> list: """ Note: some future information is used in this function Parameter: @@ -517,7 +588,8 @@ def generate_order_for_target_amount_position(self, target_position, current_pos # three parts: kept stock_id, dropped stock_id, new stock_id # handle kept stock_id - # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different; + # because the order of the set is not fixed, the trading order of the stock is different, so that the backtest + # results of the same parameter are different; # so here we sort stock_id, and then randomly shuffle the order of stock_id # because the same random seed is used, the final stock_id order is fixed sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys()))) @@ -546,7 +618,7 @@ def generate_order_for_target_amount_position(self, target_position, current_pos start_time=start_time, end_time=end_time, factor=factor, - ) + ), ) else: # sell stock @@ -558,14 +630,19 @@ def generate_order_for_target_amount_position(self, target_position, current_pos start_time=start_time, end_time=end_time, factor=factor, - ) + ), ) # return order_list : buy + sell return sell_order_list + buy_order_list def calculate_amount_position_value( - self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL - ): + self, + amount_dict: dict, + start_time: pd.Timestamp, + end_time: pd.Timestamp, + only_tradable: bool = False, + direction: OrderDir = OrderDir.SELL, + ) -> float: """Parameter position : Position() amount_dict : {stock_id : amount} @@ -576,21 +653,28 @@ def calculate_amount_position_value( """ value = 0 for stock_id in amount_dict: - if ( - only_tradable is True - and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False - and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False - or only_tradable is False + if not only_tradable or ( + not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) + and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) ): value += ( self.get_deal_price( - stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction + stock_id=stock_id, + start_time=start_time, + end_time=end_time, + direction=direction, ) * amount_dict[stock_id] ) return value - def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None): + def _get_factor_or_raise_error( + self, + factor: float = None, + stock_id: str = None, + start_time: pd.Timestamp = None, + end_time: pd.Timestamp = None, + ) -> float: """Please refer to the docs of get_amount_of_trade_unit""" if factor is None: if stock_id is not None and start_time is not None and end_time is not None: @@ -599,7 +683,13 @@ def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None") return factor - def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None): + def get_amount_of_trade_unit( + self, + factor: float = None, + stock_id: str = None, + start_time: pd.Timestamp = None, + end_time: pd.Timestamp = None, + ) -> Optional[float]: """ get the trade unit of amount based on **factor** the factor can be given directly or calculated in given time range and stock id. @@ -617,14 +707,22 @@ def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, s """ if not self.trade_w_adj_price and self.trade_unit is not None: factor = self._get_factor_or_raise_error( - factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time + factor=factor, + stock_id=stock_id, + start_time=start_time, + end_time=end_time, ) return self.trade_unit / factor else: return None def round_amount_by_trade_unit( - self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None + self, + deal_amount, + factor: float = None, + stock_id: str = None, + start_time=None, + end_time=None, ): """Parameter Please refer to the docs of get_amount_of_trade_unit @@ -635,7 +733,10 @@ def round_amount_by_trade_unit( if not self.trade_w_adj_price and self.trade_unit is not None: # the minimal amount is 1. Add 0.1 for solving precision problem. factor = self._get_factor_or_raise_error( - factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time + factor=factor, + stock_id=stock_id, + start_time=start_time, + end_time=end_time, ) return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return deal_amount @@ -714,7 +815,12 @@ def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): max_trade_amount = (cash - self.min_cost) / trade_price return max_trade_amount - def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount): + def _calc_trade_info_by_order( + self, + order: Order, + position: Optional[BasePosition], + dealt_order_amount: dict, + ) -> Tuple[float, float, float]: """ Calculation of trade info **NOTE**: Order will be changed in this function @@ -753,7 +859,8 @@ def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amoun if not np.isclose(order.deal_amount, current_amount): # when not selling last stock. rounding is necessary order.deal_amount = self.round_amount_by_trade_unit( - min(current_amount, order.deal_amount), order.factor + min(current_amount, order.deal_amount), + order.factor, ) # in case of negative value of cash @@ -778,7 +885,8 @@ def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amoun # The money is not enough max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio) order.deal_amount = self.round_amount_by_trade_unit( - min(max_buy_amount, order.deal_amount), order.factor + min(max_buy_amount, order.deal_amount), + order.factor, ) self.logger.debug(f"Order clipped due to cash limitation: {order}") else: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index e1199667cd..2105471e1b 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -1,19 +1,28 @@ -from abc import abstractmethod +from __future__ import annotations + import copy -from qlib.backtest.position import BasePosition -from qlib.log import get_module_logger +from abc import abstractmethod +from collections import defaultdict from types import GeneratorType -from qlib.backtest.account import Account +from typing import Generator, List, Optional, Tuple, Union + import pandas as pd -from typing import List, Tuple, Union -from collections import defaultdict -from .decision import Order, BaseTradeDecision -from .exchange import Exchange -from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx +from qlib.backtest.account import Account +from qlib.backtest.position import BasePosition +from qlib.log import get_module_logger -from ..utils import init_instance_by_config from ..strategy.base import BaseStrategy +from ..utils import init_instance_by_config +from .decision import BaseTradeDecision, Order +from .exchange import Exchange +from .utils import ( + BaseInfrastructure, + CommonInfrastructure, + LevelInfrastructure, + TradeCalendarManager, + get_start_end_idx, +) class BaseExecutor: @@ -30,9 +39,9 @@ def __init__( track_data: bool = False, trade_exchange: Exchange = None, common_infra: CommonInfrastructure = None, - settle_type=BasePosition.ST_NO, + settle_type=BasePosition.ST_NO, # TODO: add typehint **kwargs, - ): + ) -> None: """ Parameters ---------- @@ -53,15 +62,21 @@ def __init__( - 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap' - If 'base_price' is 'twap', the based price is the time weighted average price - If 'base_price' is 'vwap', the based price is the volume weighted average price - - 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean' + - 'weight_method': weighted method when calculating total trading pa by different orders' pa in each + step, optional, default by 'mean' - If 'weight_method' is 'mean', calculating mean value of different orders' pa - - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa - - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa + - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different + orders' pa + - If 'weight_method' is 'value_weighted', calculating value weighted average value of different + orders' pa - 'ffr_config': config for calculating fulfill rate(ffr), optional - - 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean' + - 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each + step, optional, default by 'mean' - If 'weight_method' is 'mean', calculating mean value of different orders' ffr - - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr - - If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr + - If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different + orders' ffr + - If 'weight_method' is 'value_weighted', calculating value weighted average value of different + orders' ffr Example: { 'show_indicator': True, @@ -79,7 +94,8 @@ def __init__( whether to print trading info, by default False track_data : bool, optional whether to generate trade_decision, will be used when training rl agent - - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data` + - If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will + be generated by `collect_data` - Else, `trade_decision` will not be generated trade_exchange : Exchange @@ -114,7 +130,7 @@ def __init__( self.dealt_order_amount = defaultdict(float) self.deal_day = None - def reset_common_infra(self, common_infra, copy_trade_account=False): + def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None: """ reset infrastructure for trading - reset trade_account @@ -132,7 +148,7 @@ def reset_common_infra(self, common_infra, copy_trade_account=False): # 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics) self.trade_account: Account = copy.copy(common_infra.get("trade_account")) else: - self.trade_account = common_infra.get("trade_account") + self.trade_account: Account = common_infra.get("trade_account") self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics) @property @@ -148,7 +164,7 @@ def trade_calendar(self) -> TradeCalendarManager: """ return self.level_infra.get("trade_calendar") - def reset(self, common_infra: CommonInfrastructure = None, **kwargs): + def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None: """ - reset `start_time` and `end_time`, used in trade calendar - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc @@ -161,13 +177,13 @@ def reset(self, common_infra: CommonInfrastructure = None, **kwargs): if common_infra is not None: self.reset_common_infra(common_infra) - def get_level_infra(self): + def get_level_infra(self) -> LevelInfrastructure: return self.level_infra - def finished(self): + def finished(self) -> bool: return self.trade_calendar.finished() - def execute(self, trade_decision: BaseTradeDecision, level: int = 0): + def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]: """execute the trade decision and return the executed result NOTE: this function is never used directly in the framework. Should we delete it? @@ -189,9 +205,15 @@ def execute(self, trade_decision: BaseTradeDecision, level: int = 0): pass return return_value.get("execute_result") - @classmethod @abstractmethod - def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: + def _collect_data( + self, + trade_decision: BaseTradeDecision, + level: int = 0, + ) -> Union[ + Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]], + Tuple[List[object], dict], + ]: """ Please refer to the doc of collect_data The only difference between `_collect_data` and `collect_data` is that some common steps are moved into @@ -209,8 +231,11 @@ def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tup """ def collect_data( - self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0 - ) -> List[object]: + self, + trade_decision: BaseTradeDecision, + return_value: dict = None, + level: int = 0, + ) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]: """Generator for collecting the trade decision data for rl training his function will make a step forward @@ -253,7 +278,9 @@ def collect_data( obj = self._collect_data(trade_decision=trade_decision, level=level) if isinstance(obj, GeneratorType): - res, kwargs = yield from obj + yield_res = yield from obj + assert isinstance(yield_res, tuple) and len(yield_res) == 2 + res, kwargs = yield_res else: # Some concrete executor don't have inner decisions res, kwargs = obj @@ -279,7 +306,7 @@ def collect_data( return_value.update({"execute_result": res}) return res - def get_all_executors(self): + def get_all_executors(self) -> List[BaseExecutor]: """get all executors""" return [self] @@ -287,7 +314,8 @@ def get_all_executors(self): class NestedExecutor(BaseExecutor): """ Nested Executor with inner strategy and executor - - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env. + - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` + in a higher frequency env. """ def __init__( @@ -305,7 +333,7 @@ def __init__( align_range_limit: bool = True, common_infra: CommonInfrastructure = None, **kwargs, - ): + ) -> None: """ Parameters ---------- @@ -323,10 +351,14 @@ def __init__( It is only for nested executor, because range_limit is given by outer strategy """ self.inner_executor: BaseExecutor = init_instance_by_config( - inner_executor, common_infra=common_infra, accept_types=BaseExecutor + inner_executor, + common_infra=common_infra, + accept_types=BaseExecutor, ) self.inner_strategy: BaseStrategy = init_instance_by_config( - inner_strategy, common_infra=common_infra, accept_types=BaseStrategy + inner_strategy, + common_infra=common_infra, + accept_types=BaseStrategy, ) self._skip_empty_decision = skip_empty_decision @@ -344,10 +376,10 @@ def __init__( **kwargs, ) - def reset_common_infra(self, common_infra, copy_trade_account=False): + def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None: """ reset infrastructure for trading - - reset inner_strategyand inner_executor common infra + - reset inner_strategy and inner_executor common infra """ # NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account` @@ -358,7 +390,7 @@ def reset_common_infra(self, common_infra, copy_trade_account=False): self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True) self.inner_strategy.reset_common_infra(common_infra) - def _init_sub_trading(self, trade_decision): + def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None: trade_start_time, trade_end_time = self.trade_calendar.get_step_time() self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time) sub_level_infra = self.inner_executor.get_level_infra() @@ -368,14 +400,18 @@ def _init_sub_trading(self, trade_decision): def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision: # outer strategy have chance to update decision each iterator updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar) - if updated_trade_decision is not None: + if updated_trade_decision is not None: # TODO: always is None for now? trade_decision = updated_trade_decision # NEW UPDATE # create a hook for inner strategy to update outer decision self.inner_strategy.alter_outer_trade_decision(trade_decision) return trade_decision - def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): + def _collect_data( + self, + trade_decision: BaseTradeDecision, + level: int = 0, + ) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]: execute_result = [] inner_order_indicators = [] decision_list = [] @@ -390,8 +426,8 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): if trade_decision.empty() and self._skip_empty_decision: # give one chance for outer strategy to update the strategy - # - For updating some information in the sub executor(the strategy have no knowledge of the inner - # executor when generating the decision) + # - For updating some information in the sub executor (the strategy have no knowledge of the inner + # executor when generating the decision) break sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar @@ -405,15 +441,19 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): # NOTE: !!!!! # the two lines below is for a special case in RL - # To solve the confliction below - # - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop - # For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)]) + # To solve the conflicts below + # - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction + # loop For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> + # (inner Qlib Executor)]) # - However, RL-based framework has it's own script to run the loop # For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)]) - # To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed - # - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework + # To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution + # below is proposed + # - The entry script follow the example of _RL learning example_ to be compatible with all kinds of + # RL Framework # - Each step of (RL Env) will make (inner Qlib Executor) one step forward - # - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy + # - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) + # by `yield from` and wait for the action from the policy # So the two lines below is the implementation of yielding control rights if isinstance(res, GeneratorType): res = yield from res @@ -427,13 +467,15 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): # NOTE: Trade Calendar will step forward in the follow line _inner_execute_result = yield from self.inner_executor.collect_data( - trade_decision=_inner_trade_decision, level=level + 1 + trade_decision=_inner_trade_decision, + level=level + 1, ) + assert isinstance(_inner_execute_result, list) self.post_inner_exe_step(_inner_execute_result) execute_result.extend(_inner_execute_result) inner_order_indicators.append( - self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True) + self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True), ) else: # do nothing and just step forward @@ -441,7 +483,7 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list} - def post_inner_exe_step(self, inner_exe_res): + def post_inner_exe_step(self, inner_exe_res: List[object]) -> None: """ A hook for doing sth after each step of inner strategy @@ -451,11 +493,23 @@ def post_inner_exe_step(self, inner_exe_res): the execution result of inner task """ - def get_all_executors(self): + def get_all_executors(self) -> List[object]: """get all executors, including self and inner_executor.get_all_executors()""" return [self, *self.inner_executor.get_all_executors()] +def _retrieve_orders_from_decision(trade_decision: BaseTradeDecision) -> List[Order]: + """ + IDE-friendly helper function. + """ + decisions = trade_decision.get_decision() + orders: List[Order] = [] + for decision in decisions: + assert isinstance(decision, Order) + orders.append(decision) + return orders + + class SimulatorExecutor(BaseExecutor): """Executor that simulate the true market""" @@ -464,10 +518,10 @@ class SimulatorExecutor(BaseExecutor): # available trade_types TT_SERIAL = "serial" - ## The orders will be executed serially in a sequence + # The orders will be executed serially in a sequence # In each trading step, it is possible that users sell instruments first and use the money to buy new instruments TT_PARAL = "parallel" - ## The orders will be executed parallelly + # The orders will be executed in parallel # In each trading step, if users try to sell instruments first and buy new instruments with money, failure will # occur @@ -483,7 +537,7 @@ def __init__( common_infra: CommonInfrastructure = None, trade_type: str = TT_SERIAL, **kwargs, - ): + ) -> None: """ Parameters ---------- @@ -517,7 +571,7 @@ def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]: List[Order]: get a list orders according to `self.trade_type` """ - orders = trade_decision.get_decision() + orders = _retrieve_orders_from_decision(trade_decision) if self.trade_type == self.TT_SERIAL: # Orders will be traded in a parallel way @@ -525,15 +579,15 @@ def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]: elif self.trade_type == self.TT_PARAL: # NOTE: !!!!!!! # Assumption: there will not be orders in different trading direction in a single step of a strategy !!!! - # The parallel trading failure will be caused only by the confliction of money - # Therefore, make the buying go first will make sure the confliction happen. + # The parallel trading failure will be caused only by the conflicts of money + # Therefore, make the buying go first will make sure the conflicts happen. # It equals to parallel trading after sorting the order by direction order_it = sorted(orders, key=lambda order: -order.direction) else: raise NotImplementedError(f"This type of input is not supported") return order_it - def _update_dealt_order_amount(self, order): + def _update_dealt_order_amount(self, order: Order) -> None: """update date and dealt order amount in the day.""" now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D") @@ -542,8 +596,7 @@ def _update_dealt_order_amount(self, order): self.deal_day = now_deal_day self.dealt_order_amount[order.stock_id] += order.deal_amount - def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): - + def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: trade_start_time, _ = self.trade_calendar.get_step_time() execute_result = [] @@ -559,7 +612,8 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): self._update_dealt_order_amount(order) if self.verbose: print( - "[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cash {:.2f}.".format( + "[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, " + "value {:.2f}, cash {:.2f}.".format( trade_start_time, "sell" if order.direction == Order.SELL else "buy", order.stock_id, @@ -569,6 +623,6 @@ def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0): order.factor, trade_val, self.trade_account.get_cash(), - ) + ), ) return execute_result, {"trade_info": execute_result} diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 95a32022e1..8cfa9bacc8 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -1,20 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from functools import lru_cache +import inspect import logging -from typing import List, Text, Union, Callable, Iterable, Dict from collections import OrderedDict +from functools import lru_cache +from typing import Callable, Dict, Iterable, List, Text, Union -import inspect -import pandas as pd import numpy as np +import pandas as pd +import qlib.utils.index_data as idd + +from ..log import get_module_logger from ..utils.index_data import IndexData, SingleData from ..utils.resam import resam_ts_data, ts_data_last -from ..log import get_module_logger -from ..utils.time import is_single_value, Freq -import qlib.utils.index_data as idd +from ..utils.time import Freq, is_single_value class BaseQuote: @@ -627,7 +628,9 @@ def sum_all_indicators(order_indicator, indicators: list, metrics: Union[str, Li metrics = [metrics] for metric in metrics: order_indicator.data[metric] = idd.sum_by_index( - [indicator.data[metric] for indicator in indicators], stocks, fill_value + [indicator.data[metric] for indicator in indicators], + stocks, + fill_value, ) def __repr__(self): diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index a025a05a83..06218a67d3 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -2,24 +2,28 @@ # Licensed under the MIT License. +from datetime import timedelta from typing import Dict, List, Union -import pandas as pd -from datetime import timedelta import numpy as np +import pandas as pd -from .decision import Order from ..data.data import D +from .decision import Order class BasePosition: """ - The Position want to maintain the position like a dictionary + The Position wants to maintain the position like a dictionary Please refer to the `Position` class for the position """ - def __init__(self, *args, cash=0.0, **kwargs): + def __init__(self, *args, cash: float = 0.0, **kwargs) -> None: self._settle_type = self.ST_NO + self.position = {} + + def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None: + pass def skip_update(self) -> bool: """ @@ -49,7 +53,7 @@ def check_stock(self, stock_id: str) -> bool: """ raise NotImplementedError(f"Please implement the `check_stock` method") - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float): + def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: """ Parameters ---------- @@ -64,7 +68,7 @@ def update_order(self, order: Order, trade_val: float, cost: float, trade_price: """ raise NotImplementedError(f"Please implement the `update_order` method") - def update_stock_price(self, stock_id, price: float): + def update_stock_price(self, stock_id: str, price: float) -> None: """ Updating the latest price of the order The useful when clearing balance at each bar end @@ -89,6 +93,9 @@ def calculate_stock_value(self) -> float: """ raise NotImplementedError(f"Please implement the `calculate_stock_value` method") + def calculate_value(self) -> float: + raise NotImplementedError(f"Please implement the `calculate_value` method") + def get_stock_list(self) -> List: """ Get the list of stocks in the position. @@ -124,14 +131,16 @@ def get_stock_amount(self, code) -> float: def get_cash(self, include_settle: bool = False) -> float: """ + Parameters + ---------- + include_settle: + will the unsettled(delayed) cash included + Default: not include those unavailable cash Returns ------- float: the available(tradable) cash in position - include_settle: - will the unsettled(delayed) cash included - Default: not include those unavailable cash """ raise NotImplementedError(f"Please implement the `get_cash` method") @@ -165,7 +174,7 @@ def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: """ raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method") - def add_count_all(self, bar): + def add_count_all(self, bar) -> None: """ Will be called at the end of each bar on each level @@ -176,24 +185,19 @@ def add_count_all(self, bar): """ raise NotImplementedError(f"Please implement the `add_count_all` method") - def update_weight_all(self): + def update_weight_all(self) -> None: """ Updating the position weight; # TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order # and before updating weight. - - Parameters - ---------- - bar : - The level to be updated """ raise NotImplementedError(f"Please implement the `add_count_all` method") ST_CASH = "cash" ST_NO = None - def settle_start(self, settle_type: str): + def settle_start(self, settle_type: str) -> None: """ settlement start It will act like start and commit a transaction @@ -210,14 +214,9 @@ def settle_start(self, settle_type: str): """ raise NotImplementedError(f"Please implement the `settle_conf` method") - def settle_commit(self): + def settle_commit(self) -> None: """ settlement commit - - Parameters - ---------- - settle_type : str - please refer to the documents of Executor """ raise NotImplementedError(f"Please implement the `settle_commit` method") @@ -242,13 +241,11 @@ class Position(BasePosition): } """ - def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}): + def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None: """Init position by cash and position_dict. Parameters ---------- - start_time : - the start time of backtest. It's for filling the initial value of stocks. cash : float, optional initial cash in account, by default 0 position_dict : Dict[ @@ -268,9 +265,9 @@ def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = # Otherwise the initial value self.init_cash = cash self.position = position_dict.copy() - for stock in self.position: - if isinstance(self.position[stock], int): - self.position[stock] = {"amount": self.position[stock]} + for stock, value in self.position.items(): + if isinstance(value, int): + self.position[stock] = {"amount": value} self.position["cash"] = cash # If the stock price information is missing, the account value will not be calculated temporarily @@ -279,21 +276,23 @@ def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = except KeyError: pass - def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30): + def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None: """fill the stock value by the close price of latest last_days from qlib. Parameters ---------- start_time : the start time of backtest. + freq : str + Frequency last_days : int, optional the days to get the latest close price, by default 30. """ stock_list = [] - for stock in self.position: - if not isinstance(self.position[stock], dict): + for stock, value in self.position.items(): + if not isinstance(value, dict): continue - if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None): + if value.get("price", None) is None: stock_list.append(stock) if len(stock_list) == 0: @@ -304,7 +303,12 @@ def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last price_end_time = start_time price_start_time = start_time - timedelta(days=last_days) price_df = D.features( - stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True + stock_list, + ["$close"], + price_start_time, + price_end_time, + freq=freq, + disk_cache=True, ).dropna() price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict() @@ -316,7 +320,7 @@ def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last self.position[stock]["price"] = price_dict[stock] self.position["now_account_value"] = self.calculate_value() - def _init_stock(self, stock_id, amount, price=None): + def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None: """ initialization the stock in current position @@ -334,7 +338,7 @@ def _init_stock(self, stock_id, amount, price=None): self.position[stock_id]["price"] = price self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date - def _buy_stock(self, stock_id, trade_val, cost, trade_price): + def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None: trade_amount = trade_val / trade_price if stock_id not in self.position: self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price) @@ -344,15 +348,16 @@ def _buy_stock(self, stock_id, trade_val, cost, trade_price): self.position["cash"] -= trade_val + cost - def _sell_stock(self, stock_id, trade_val, cost, trade_price): + def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None: trade_amount = trade_val / trade_price if stock_id not in self.position: raise KeyError("{} not in current position".format(stock_id)) else: if np.isclose(self.position[stock_id]["amount"], trade_amount): # Selling all the stocks - # we use np.isclose instead of abs() <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount - # Using abs() <= 1e-5 will result in error when the amount is large + # we use np.isclose instead of abs() <= 1e-5 because `np.isclose` consider both + # relative amount and absolute amount + # Using abs() <= 1e-5 will result in error when the amount is large self._del_stock(stock_id) else: # decrease the amount of stock @@ -361,8 +366,10 @@ def _sell_stock(self, stock_id, trade_val, cost, trade_price): if self.position[stock_id]["amount"] < -1e-5: raise ValueError( "only have {} {}, require {}".format( - self.position[stock_id]["amount"] + trade_amount, stock_id, trade_amount - ) + self.position[stock_id]["amount"] + trade_amount, + stock_id, + trade_amount, + ), ) new_cash = trade_val - cost @@ -373,13 +380,13 @@ def _sell_stock(self, stock_id, trade_val, cost, trade_price): else: raise NotImplementedError(f"This type of input is not supported") - def _del_stock(self, stock_id): + def _del_stock(self, stock_id: str) -> None: del self.position[stock_id] - def check_stock(self, stock_id): + def check_stock(self, stock_id: str) -> bool: return stock_id in self.position - def update_order(self, order, trade_val, cost, trade_price): + def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: # handle order, order is a order class, defined in exchange.py if order.direction == Order.BUY: # BUY @@ -390,54 +397,54 @@ def update_order(self, order, trade_val, cost, trade_price): else: raise NotImplementedError("do not support order direction {}".format(order.direction)) - def update_stock_price(self, stock_id, price): + def update_stock_price(self, stock_id: str, price: float) -> None: self.position[stock_id]["price"] = price - def update_stock_count(self, stock_id, bar, count): + def update_stock_count(self, stock_id: str, bar: str, count: float) -> None: # TODO: check type of `bar` self.position[stock_id][f"count_{bar}"] = count - def update_stock_weight(self, stock_id, weight): + def update_stock_weight(self, stock_id: str, weight: float) -> None: self.position[stock_id]["weight"] = weight - def calculate_stock_value(self): + def calculate_stock_value(self) -> float: stock_list = self.get_stock_list() value = 0 for stock_id in stock_list: value += self.position[stock_id]["amount"] * self.position[stock_id]["price"] return value - def calculate_value(self): + def calculate_value(self) -> float: value = self.calculate_stock_value() value += self.position["cash"] + self.position.get("cash_delay", 0.0) return value - def get_stock_list(self): + def get_stock_list(self) -> List[str]: stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"}) return stock_list - def get_stock_price(self, code): + def get_stock_price(self, code: str) -> float: return self.position[code]["price"] - def get_stock_amount(self, code): + def get_stock_amount(self, code: str) -> float: return self.position[code]["amount"] if code in self.position else 0 - def get_stock_count(self, code, bar): + def get_stock_count(self, code: str, bar: str) -> float: """the days the account has been hold, it may be used in some special strategies""" if f"count_{bar}" in self.position[code]: return self.position[code][f"count_{bar}"] else: return 0 - def get_stock_weight(self, code): + def get_stock_weight(self, code: str) -> float: return self.position[code]["weight"] - def get_cash(self, include_settle=False): + def get_cash(self, include_settle: bool = False) -> float: cash = self.position["cash"] if include_settle: cash += self.position.get("cash_delay", 0.0) return cash - def get_stock_amount_dict(self): + def get_stock_amount_dict(self) -> dict: """generate stock amount dict {stock_id : amount of stock}""" d = {} stock_list = self.get_stock_list() @@ -445,7 +452,7 @@ def get_stock_amount_dict(self): d[stock_code] = self.get_stock_amount(code=stock_code) return d - def get_stock_weight_dict(self, only_stock=False): + def get_stock_weight_dict(self, only_stock: bool = False) -> dict: """get_stock_weight_dict generate stock weight dict {stock_id : value weight of stock in the position} it is meaningful in the beginning or the end of each trade date @@ -463,7 +470,7 @@ def get_stock_weight_dict(self, only_stock=False): d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value return d - def add_count_all(self, bar): + def add_count_all(self, bar: str) -> None: stock_list = self.get_stock_list() for code in stock_list: if f"count_{bar}" in self.position[code]: @@ -471,18 +478,18 @@ def add_count_all(self, bar): else: self.position[code][f"count_{bar}"] = 1 - def update_weight_all(self): + def update_weight_all(self) -> None: weight_dict = self.get_stock_weight_dict() for stock_code, weight in weight_dict.items(): self.update_stock_weight(stock_code, weight) - def settle_start(self, settle_type): + def settle_start(self, settle_type: str) -> None: assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!" self._settle_type = settle_type if settle_type == self.ST_CASH: self.position["cash_delay"] = 0.0 - def settle_commit(self): + def settle_commit(self) -> None: if self._settle_type != self.ST_NO: if self._settle_type == self.ST_CASH: self.position["cash"] += self.position["cash_delay"] @@ -507,10 +514,10 @@ def check_stock(self, stock_id: str) -> bool: # InfPosition always have any stocks return True - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float): + def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: pass - def update_stock_price(self, stock_id, price: float): + def update_stock_price(self, stock_id: str, price: float) -> None: pass def calculate_stock_value(self) -> float: @@ -522,17 +529,20 @@ def calculate_stock_value(self) -> float: """ return np.inf - def get_stock_list(self) -> List: + def calculate_value(self) -> float: + raise NotImplementedError(f"InfPosition doesn't support calculating value") + + def get_stock_list(self) -> list: raise NotImplementedError(f"InfPosition doesn't support stock list position") - def get_stock_price(self, code) -> float: + def get_stock_price(self, code: str) -> float: """the price of the inf position is meaningless""" return np.nan - def get_stock_amount(self, code) -> float: + def get_stock_amount(self, code: str) -> float: return np.inf - def get_cash(self, include_settle=False) -> float: + def get_cash(self, include_settle: bool = False) -> float: return np.inf def get_stock_amount_dict(self) -> Dict: @@ -541,14 +551,14 @@ def get_stock_amount_dict(self) -> Dict: def get_stock_weight_dict(self, only_stock: bool = False) -> Dict: raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict") - def add_count_all(self, bar): + def add_count_all(self, bar: str) -> None: raise NotImplementedError(f"InfPosition doesn't support add_count_all") - def update_weight_all(self): + def update_weight_all(self) -> None: raise NotImplementedError(f"InfPosition doesn't support update_weight_all") - def settle_start(self, settle_type: str): + def settle_start(self, settle_type: str) -> None: pass - def settle_commit(self): + def settle_commit(self) -> None: pass diff --git a/qlib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py index 371cb422ad..63a1d692db 100644 --- a/qlib/backtest/profit_attribution.py +++ b/qlib/backtest/profit_attribution.py @@ -4,13 +4,15 @@ This module is not well maintained. """ +import datetime +from pathlib import Path + import numpy as np import pandas as pd -from .position import Position -from ..data import D + from ..config import C -import datetime -from pathlib import Path +from ..data import D +from .position import Position def get_benchmark_weight( @@ -214,7 +216,9 @@ def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, g for idx, row in (~bench_stock_weight_df.isna()).iterrows(): bench_values = stock_group_field_df.loc[idx, row[row].index] new_stock_group_df.loc[idx] = get_daily_bin_group( - bench_values, stock_group_field_df.loc[idx], group_n=group_n + bench_values, + stock_group_field_df.loc[idx], + group_n=group_n, ) return new_stock_group_df @@ -315,7 +319,7 @@ def brinson_pa( # The excess profit from the interaction of assets allocation and stocks selection "RIN": Q4 - Q3 - Q2 + Q1, "RTotal": Q4 - Q1, # The totoal excess profit - } + }, ), { "port_group_ret": port_group_ret_df, diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 0231146233..77e43c8e73 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -2,19 +2,20 @@ # Licensed under the MIT License. -from collections import OrderedDict import pathlib +from collections import OrderedDict from typing import Dict, List, Tuple, Union import numpy as np import pandas as pd -from qlib.backtest.exchange import Exchange +import qlib.utils.index_data as idd from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir -from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric +from qlib.backtest.exchange import Exchange + from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data -import qlib.utils.index_data as idd +from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric class PortfolioMetrics: @@ -161,7 +162,8 @@ def update_portfolio_metrics_record( stock_value, ]: raise ValueError( - "None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]" + "None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, " + "total_cost, cost_rate, stock_value]", ) if trade_end_time is None and bench_value is None: @@ -335,7 +337,10 @@ def trade_amount_func(deal_amount, trade_price): # sum inner order indicators with same metric. all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"] self.order_indicator_cls.sum_all_indicators( - self.order_indicator, inner_order_indicators, all_metric, fill_value=0 + self.order_indicator, + inner_order_indicators, + all_metric, + fill_value=0, ) def func(trade_price, deal_amount): @@ -378,12 +383,17 @@ def _get_base_vol_pri( if decision.trade_range is not None: trade_start_time, trade_end_time = decision.trade_range.clip_time_range( - start_time=trade_start_time, end_time=trade_end_time + start_time=trade_start_time, + end_time=trade_end_time, ) if price == "deal_price": price_s = trade_exchange.get_deal_price( - inst, trade_start_time, trade_end_time, direction=direction, method=None + inst, + trade_start_time, + trade_end_time, + direction=direction, + method=None, ) else: raise NotImplementedError(f"This type of input is not supported") @@ -599,8 +609,12 @@ def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}): if show_indicator: print( "[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format( - freq, trade_start_time, fulfill_rate, price_advantage, positive_rate - ) + freq, + trade_start_time, + fulfill_rate, + price_advantage, + positive_rate, + ), ) def get_order_indicator(self, raw: bool = True): diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py index 5ff5fa9451..4615a89c09 100644 --- a/qlib/backtest/signal.py +++ b/qlib/backtest/signal.py @@ -1,13 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.utils import init_instance_by_config +import abc from typing import Dict, List, Text, Tuple, Union -from ..model.base import BaseModel + +import pandas as pd + +from qlib.utils import init_instance_by_config + from ..data.dataset import Dataset from ..data.dataset.utils import convert_index_format +from ..model.base import BaseModel from ..utils.resam import resam_ts_data -import pandas as pd -import abc class Signal(metaclass=abc.ABCMeta): @@ -82,7 +85,7 @@ def _update_model(self): def create_signal_from( - obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] + obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame], ) -> Signal: """ create signal from diverse information diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 5fa02420d9..2077986bca 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -2,16 +2,22 @@ # Licensed under the MIT License. from __future__ import annotations + import bisect +from abc import abstractmethod +from typing import TYPE_CHECKING, Any, Set, Tuple, Union + +import numpy as np + from qlib.utils.time import epsilon_change -from typing import TYPE_CHECKING, Tuple, Union if TYPE_CHECKING: from qlib.backtest.decision import BaseTradeDecision -import pandas as pd import warnings +import pandas as pd + from ..data.data import Cal @@ -26,8 +32,8 @@ def __init__( freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - level_infra: "LevelInfrastructure" = None, - ): + level_infra: LevelInfrastructure = None, + ) -> None: """ Parameters ---------- @@ -43,19 +49,26 @@ def __init__( self.level_infra = level_infra self.reset(freq=freq, start_time=start_time, end_time=end_time) - def reset(self, freq, start_time, end_time): + def reset( + self, + freq: str, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + ) -> None: """ Please refer to the docs of `__init__` Reset the trade calendar - self.trade_len : The total count for trading step - - self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1] + - self.trade_step : The number of trading step finished, self.trade_step can be + [0, 1, 2, ..., self.trade_len - 1] """ self.freq = freq self.start_time = pd.Timestamp(start_time) if start_time else None self.end_time = pd.Timestamp(end_time) if end_time else None _calendar = Cal.calendar(freq=freq, future=True) + assert isinstance(_calendar, np.ndarray) self._calendar = _calendar _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True) self.start_index = _start_index @@ -63,7 +76,7 @@ def reset(self, freq, start_time, end_time): self.trade_len = _end_index - _start_index + 1 self.trade_step = 0 - def finished(self): + def finished(self) -> bool: """ Check if the trading finished - Should check before calling strategy.generate_decisions and executor.execute @@ -72,29 +85,32 @@ def finished(self): """ return self.trade_step >= self.trade_len - def step(self): + def step(self) -> None: if self.finished(): raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!") - self.trade_step = self.trade_step + 1 + self.trade_step += 1 - def get_freq(self): + def get_freq(self) -> str: return self.freq - def get_trade_len(self): + def get_trade_len(self) -> int: """get the total step length""" return self.trade_len - def get_trade_step(self): + def get_trade_step(self) -> int: return self.trade_step - def get_step_time(self, trade_step=None, shift=0): + def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]: """ Get the left and right endpoints of the trade_step'th trading interval About the endpoints: - - Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc - # - The returned right endpoints should minus 1 seconds because of the closed interval representation in Qlib. - # Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval. + - Qlib uses the closed interval in time-series data selection, which has the same performance as + pandas.Series.loc + # - The returned right endpoints should minus 1 seconds because of the closed interval representation in + # Qlib. + # Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time + # interval. Parameters ---------- @@ -105,15 +121,14 @@ def get_step_time(self, trade_step=None, shift=0): Returns ------- - Tuple[pd.Timestamp, pd.Timestap] + Tuple[pd.Timestamp, pd.Timestamp] - If shift == 0, return the trading time range - If shift > 0, return the trading time range of the earlier shift bars - If shift < 0, return the trading time range of the later shift bar """ if trade_step is None: trade_step = self.get_trade_step() - trade_step = trade_step - shift - calendar_index = self.start_index + trade_step + calendar_index = self.start_index + trade_step - shift return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1]) def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: @@ -126,7 +141,7 @@ def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: Parameters ---------- rtype: str - - "full": return the full limitation of the deicsion in the day + - "full": return the full limitation of the decision in the day - "step": return the limitation of current step Returns @@ -148,7 +163,7 @@ def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: return start_idx - day_start_idx, end_index - day_start_idx - def get_all_time(self): + def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]: """Get the start_time and end_time for trading""" return self.start_time, self.end_time @@ -167,30 +182,33 @@ def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tup Tuple[int, int]: the index of the range. **the left and right are closed** """ - left, right = ( - bisect.bisect_right(self._calendar, start_time) - 1, - bisect.bisect_right(self._calendar, end_time) - 1, - ) + left = bisect.bisect_right(self._calendar, start_time) - 1 + right = bisect.bisect_right(self._calendar, end_time) - 1 left -= self.start_index right -= self.start_index - def clip(idx): + def clip(idx: int) -> int: return min(max(0, idx), self.trade_len - 1) return clip(left), clip(right) def __repr__(self) -> str: - return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]" + return ( + f"class: {self.__class__.__name__}; " + f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: " + f"[{self.trade_step}/{self.trade_len}]" + ) class BaseInfrastructure: - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: self.reset_infra(**kwargs) - def get_support_infra(self): + @abstractmethod + def get_support_infra(self) -> Set[str]: raise NotImplementedError("`get_support_infra` is not implemented!") - def reset_infra(self, **kwargs): + def reset_infra(self, **kwargs) -> None: support_infra = self.get_support_infra() for k, v in kwargs.items(): if k in support_infra: @@ -198,53 +216,58 @@ def reset_infra(self, **kwargs): else: warnings.warn(f"{k} is ignored in `reset_infra`!") - def get(self, infra_name): + def get(self, infra_name: str) -> Any: if hasattr(self, infra_name): return getattr(self, infra_name) else: warnings.warn(f"infra {infra_name} is not found!") - def has(self, infra_name): + def has(self, infra_name: str) -> bool: return infra_name in self.get_support_infra() and hasattr(self, infra_name) - def update(self, other): + def update(self, other: BaseInfrastructure) -> None: support_infra = other.get_support_infra() infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)} self.reset_infra(**infra_dict) class CommonInfrastructure(BaseInfrastructure): - def get_support_infra(self): - return ["trade_account", "trade_exchange"] + def get_support_infra(self) -> Set[str]: + return {"trade_account", "trade_exchange"} class LevelInfrastructure(BaseInfrastructure): """level infrastructure is created by executor, and then shared to strategies on the same level""" - def get_support_infra(self): + def get_support_infra(self) -> Set[str]: """ Descriptions about the infrastructure sub_level_infra: - **NOTE**: this will only work after _init_sub_trading !!! """ - return ["trade_calendar", "sub_level_infra", "common_infra"] + return {"trade_calendar", "sub_level_infra", "common_infra"} - def reset_cal(self, freq, start_time, end_time): + def reset_cal( + self, + freq: str, + start_time: Union[str, pd.Timestamp, None], + end_time: Union[str, pd.Timestamp, None], + ) -> None: """reset trade calendar manager""" if self.has("trade_calendar"): self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time) else: self.reset_infra( - trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self) + trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self), ) - def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure): - """this will make the calendar access easier when acrossing multi-levels""" + def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None: + """this will make the calendar access easier when crossing multi-levels""" self.reset_infra(sub_level_infra=sub_level_infra) -def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]: +def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]: """ A helper function for getting the decision-level index range limitation for inner strategy - NOTE: this function is not applicable to order-level diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index a2d5e198a6..3ca8a8bd0b 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,17 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations -from typing import TYPE_CHECKING + +from abc import abstractmethod +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from qlib.backtest.exchange import Exchange from qlib.backtest.position import BasePosition + from typing import Tuple, Union +from ..backtest.decision import BaseTradeDecision +from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config -from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager -from ..backtest.decision import BaseTradeDecision __all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"] @@ -25,12 +28,13 @@ def __init__( level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, trade_exchange: Exchange = None, - ): + ) -> None: """ Parameters ---------- outer_trade_decision : BaseTradeDecision, optional - the trade decision of outer strategy which this strategy relies, and it will be traded in [start_time, end_time], by default None + the trade decision of outer strategy which this strategy relies, and it will be traded in + [start_time, end_time], by default None - If the strategy is used to split trade decision, it will be used - If the strategy is used for portfolio management, it can be ignored level_infra : LevelInfrastructure, optional @@ -41,9 +45,10 @@ def __init__( trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra - - It allowes different trade_exchanges is used in different executions. + - It allows different trade_exchanges is used in different executions. - For example: - - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster. + - In daily execution, both daily exchange and minutely are usable, but the daily exchange is + recommended because it run faster. - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ @@ -63,13 +68,13 @@ def trade_exchange(self) -> Exchange: """get trade exchange in a prioritized order""" return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange") - def reset_level_infra(self, level_infra: LevelInfrastructure): + def reset_level_infra(self, level_infra: LevelInfrastructure) -> None: if not hasattr(self, "level_infra"): self.level_infra = level_infra else: self.level_infra.update(level_infra) - def reset_common_infra(self, common_infra: CommonInfrastructure): + def reset_common_infra(self, common_infra: CommonInfrastructure) -> None: if not hasattr(self, "common_infra"): self.common_infra: CommonInfrastructure = common_infra else: @@ -79,9 +84,9 @@ def reset( self, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, - outer_trade_decision=None, - **kwargs, - ): + outer_trade_decision: BaseTradeDecision = None, + **kwargs, # TODO: remove this? + ) -> None: """ - reset `level_infra`, used to reset trade calendar, .etc - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc @@ -89,18 +94,20 @@ def reset( **NOTE**: split this function into `reset` and `_reset` will make following cases more convenient - 1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called - when initialization + 1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` + called when initialization """ self._reset( - level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs + level_infra=level_infra, + common_infra=common_infra, + outer_trade_decision=outer_trade_decision, ) def _reset( self, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, - outer_trade_decision=None, + outer_trade_decision: BaseTradeDecision = None, ): """ Please refer to the docs of `reset` @@ -114,7 +121,8 @@ def _reset( if outer_trade_decision is not None: self.outer_trade_decision = outer_trade_decision - def generate_trade_decision(self, execute_result=None): + @abstractmethod + def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: """Generate trade decision in each trading bar Parameters @@ -125,9 +133,11 @@ def generate_trade_decision(self, execute_result=None): """ raise NotImplementedError("generate_trade_decision is not implemented!") + @staticmethod def update_trade_decision( - self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager - ) -> Union[BaseTradeDecision, None]: + trade_decision: BaseTradeDecision, + trade_calendar: TradeCalendarManager, + ) -> Optional[BaseTradeDecision]: """ update trade decision in each step of inner execution, this method enable all order @@ -145,7 +155,8 @@ def update_trade_decision( # default to return None, which indicates that the trade decision is not changed return None - def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision): + # FIXME: do not define this method as an abstract one since it is never implemented + def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: """ A method for updating the outer_trade_decision. The outer strategy may change its decision during updating. @@ -154,6 +165,10 @@ def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision): ---------- outer_trade_decision : BaseTradeDecision the decision updated by the outer strategy + + Returns + ------- + BaseTradeDecision """ # default to reset the decision directly # NOTE: normally, user should do something to the strategy due to the change of outer decision @@ -200,7 +215,7 @@ def __init__( level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, - ): + ) -> None: """ Parameters ---------- @@ -223,7 +238,7 @@ def __init__( level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, - ): + ) -> None: """ Parameters ---------- @@ -242,7 +257,7 @@ def __init__( self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter) self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter) - def generate_trade_decision(self, execute_result=None): + def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: _interpret_state = self.state_interpreter.interpret(execute_result=execute_result) _action = self.policy.step(_interpret_state) _trade_decision = self.action_interpreter.interpret(action=_action) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index e6b38b3891..c095acbb33 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -376,7 +376,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod def init_instance_by_config( - config: Union[str, dict, object, Path], + config: Union[str, dict, object, Path], # TODO: use a user-defined type to replace this Union. default_module=None, accept_types: Union[type, Tuple[type]] = (), try_kwargs: Dict = {}, @@ -1063,4 +1063,5 @@ def fname_to_code(fname: str): "unpack_archive_with_buffer", "get_tmp_file_with_buffer", "set_log_with_config", + "init_instance_by_config", ] diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py index c0bb87e346..f0497bc91f 100644 --- a/tests/backtest/test_file_strategy.py +++ b/tests/backtest/test_file_strategy.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import unittest -from qlib.backtest import backtest, decision +from qlib.backtest import backtest from qlib.tests import TestAutoData import pandas as pd from pathlib import Path @@ -52,13 +52,12 @@ def test_file_str(self): factor = df["$factor"].item() price_unit = price / factor * 100 dealt_num_for_1000 = (account_money // price_unit) * (100 / factor) + print(price, factor, price_unit, dealt_num_for_1000) # 2) generate orders orders = self._gen_orders(dealt_num_for_1000) - print(orders) orders.to_csv(self.EXAMPLE_FILE) - - orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"]) + print(orders) # 3) run the strategy strategy_config = { @@ -101,7 +100,11 @@ def test_file_str(self): }, }, } - report_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config) + report_dict, indicator_dict = backtest( + executor=executor_config, + strategy=strategy_config, + **backtest_config, + ) # ffr valid ffr_dict = indicator_dict["1day"]["ffr"].to_dict()