From fb5888be9ee976c3e3879e7298e3a0a1030db37e Mon Sep 17 00:00:00 2001 From: Chia-hung Tai Date: Sun, 30 Oct 2022 16:27:59 +0800 Subject: [PATCH] Use mock data for element operator tests. (#1330) --- qlib/tests/__init__.py | 216 +++++++++++++++++++++++++++++++- tests/ops/test_elem_operator.py | 40 +++++- 2 files changed, 253 insertions(+), 3 deletions(-) diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index a23092a2ef..52c924918a 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -1,10 +1,16 @@ +from typing import Union, List, Dict, Tuple import unittest +import pandas as pd +import numpy as np +import io + from .data import GetData from .. import init -from ..constant import REG_CN +from ..constant import REG_CN, REG_TW from qlib.data.filter import NameDFilter from qlib.data import D from qlib.data.data import Cal, DatasetD +from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT class TestAutoData(unittest.TestCase): @@ -75,3 +81,211 @@ def setUpClass(cls, enable_1d_type="simple", enable_1min=False) -> None: cls.end_time = cal[-1] cls.inst = list(instruments_d.keys())[0] cls.spans = list(instruments_d.values())[0] + + +MOCK_DATA = """ +id,symbol,datetime,interval,volume,open,high,low,close +20275,0050,2022-01-03 00:00:00,day,6761.0,146.0,147.35,146.0,146.4 +20276,0050,2022-01-04 00:00:00,day,9608.0,147.7,149.6,147.7,149.6 +20277,0050,2022-01-05 00:00:00,day,11387.0,150.1,150.55,149.1,149.3 +20278,0050,2022-01-06 00:00:00,day,8611.0,148.3,148.75,147.0,147.9 +20279,0050,2022-01-07 00:00:00,day,6954.0,148.3,149.0,146.5,146.6 +20280,0050,2022-01-10 00:00:00,day,15684.0,146.0,147.8,145.4,147.55 +20281,0050,2022-01-11 00:00:00,day,17741.0,147.6,148.5,146.7,148.3 +20282,0050,2022-01-12 00:00:00,day,10134.0,149.35,149.6,148.7,149.55 +20283,0050,2022-01-13 00:00:00,day,7431.0,149.55,150.45,149.55,150.3 +20284,0050,2022-01-14 00:00:00,day,10091.0,150.8,151.2,149.05,150.3 +20285,0050,2022-01-17 00:00:00,day,6899.0,151.1,152.4,151.1,152.0 +20286,0050,2022-01-18 00:00:00,day,14360.0,152.2,152.25,150.15,150.3 +20287,0050,2022-01-19 00:00:00,day,14654.0,149.0,149.65,148.25,148.5 +20288,0050,2022-01-20 00:00:00,day,16201.0,148.5,149.2,147.6,149.1 +20289,0050,2022-01-21 00:00:00,day,29848.0,143.9,143.95,142.3,142.65 +20290,0050,2022-01-24 00:00:00,day,13143.0,142.1,144.0,141.7,144.0 +20291,0050,2022-01-25 00:00:00,day,23982.0,142.55,142.55,141.25,141.65 +20292,0050,2022-01-26 00:00:00,day,17729.0,141.15,142.2,141.05,141.55 +8547,1101,2021-12-01 00:00:00,day,16119.0,46.0,46.85,46.0,46.6 +8548,1101,2021-12-02 00:00:00,day,14521.0,46.6,46.7,46.3,46.3 +8549,1101,2021-12-03 00:00:00,day,14357.0,46.55,46.85,46.4,46.4 +8550,1101,2021-12-06 00:00:00,day,15115.0,46.45,47.35,46.4,47.3 +8551,1101,2021-12-07 00:00:00,day,13117.0,47.35,47.55,46.9,47.55 +8552,1101,2021-12-08 00:00:00,day,10329.0,47.75,47.8,47.5,47.7 +8553,1101,2021-12-09 00:00:00,day,9300.0,47.8,47.85,47.1,47.4 +8554,1101,2021-12-10 00:00:00,day,9919.0,47.4,47.6,47.1,47.3 +8555,1101,2021-12-13 00:00:00,day,7784.0,47.3,47.75,47.1,47.1 +8556,1101,2021-12-14 00:00:00,day,9373.0,47.05,47.2,46.95,47.0 +8557,1101,2021-12-15 00:00:00,day,11189.0,47.0,47.3,46.8,46.95 +8558,1101,2021-12-16 00:00:00,day,7516.0,47.0,47.15,46.8,46.9 +8559,1101,2021-12-17 00:00:00,day,18502.0,46.95,47.6,46.9,47.45 +8560,1101,2021-12-20 00:00:00,day,11309.0,47.45,47.5,47.1,47.4 +8561,1101,2021-12-21 00:00:00,day,5666.0,47.4,47.45,47.1,47.25 +8562,1101,2021-12-22 00:00:00,day,5460.0,47.4,47.45,47.2,47.4 +8563,1101,2021-12-23 00:00:00,day,9371.0,47.3,47.7,47.3,47.7 +8564,1101,2021-12-24 00:00:00,day,5980.0,47.75,47.95,47.75,47.9 +8565,1101,2021-12-27 00:00:00,day,5709.0,47.9,48.1,47.9,48.1 +8566,1101,2021-12-28 00:00:00,day,7777.0,48.1,48.15,47.95,48.15 +8567,1101,2021-12-29 00:00:00,day,5309.0,48.15,48.25,48.05,48.15 +8568,1101,2021-12-30 00:00:00,day,4616.0,48.15,48.2,48.0,48.0 +8569,1101,2022-01-03 00:00:00,day,12350.0,48.05,48.15,47.35,47.45 +8570,1101,2022-01-04 00:00:00,day,11439.0,47.5,47.6,47.0,47.3 +8571,1101,2022-01-05 00:00:00,day,9692.0,47.1,47.3,47.0,47.15 +8572,1101,2022-01-06 00:00:00,day,12361.0,47.3,47.6,47.15,47.6 +8573,1101,2022-01-07 00:00:00,day,10921.0,47.6,47.65,47.2,47.45 +8574,1101,2022-01-10 00:00:00,day,11925.0,47.45,47.5,47.0,47.3 +8575,1101,2022-01-11 00:00:00,day,11047.0,47.1,47.5,47.1,47.5 +8576,1101,2022-01-12 00:00:00,day,10817.0,47.5,47.5,47.1,47.5 +8577,1101,2022-01-13 00:00:00,day,13849.0,47.5,47.95,47.4,47.95 +8578,1101,2022-01-14 00:00:00,day,9460.0,47.85,47.85,47.45,47.6 +8579,1101,2022-01-17 00:00:00,day,9057.0,47.55,47.7,47.35,47.6 +8580,1101,2022-01-18 00:00:00,day,8089.0,47.6,47.75,47.45,47.75 +8581,1101,2022-01-19 00:00:00,day,5110.0,47.6,47.7,47.5,47.6 +8582,1101,2022-01-20 00:00:00,day,6327.0,47.55,47.7,47.45,47.5 +8583,1101,2022-01-21 00:00:00,day,9470.0,47.5,47.65,47.15,47.4 +8584,1101,2022-01-24 00:00:00,day,5475.0,47.1,47.3,47.0,47.15 +8585,1101,2022-01-25 00:00:00,day,16153.0,47.0,47.05,46.6,46.8 +8586,1101,2022-01-26 00:00:00,day,7772.0,46.7,47.0,46.55,46.85 +8587,1101,2022-02-07 00:00:00,day,17031.0,46.55,47.1,46.0,47.1 +8588,1101,2022-02-08 00:00:00,day,9741.0,47.1,47.25,46.9,46.95 +8589,1101,2022-02-09 00:00:00,day,7968.0,46.95,47.3,46.9,47.3 +8590,1101,2022-02-10 00:00:00,day,7479.0,47.15,47.55,47.05,47.55 +8591,1101,2022-02-11 00:00:00,day,6841.0,47.3,47.55,47.15,47.55 +8592,1101,2022-02-14 00:00:00,day,9136.0,47.2,47.3,46.95,47.15 +8593,1101,2022-02-15 00:00:00,day,5444.0,47.05,47.1,46.8,47.0 +8594,1101,2022-02-16 00:00:00,day,8751.0,47.0,47.15,47.0,47.0 +8595,1101,2022-02-17 00:00:00,day,10662.0,47.15,47.55,47.1,47.45 +8596,1101,2022-02-18 00:00:00,day,8781.0,47.25,47.55,47.2,47.45 +8597,1101,2022-02-21 00:00:00,day,8201.0,47.35,47.75,47.15,47.6 +8598,1101,2022-02-22 00:00:00,day,10655.0,47.4,47.7,47.1,47.7 +8599,1101,2022-02-23 00:00:00,day,8040.0,47.7,47.85,47.45,47.65 +8600,1101,2022-02-24 00:00:00,day,13124.0,47.5,47.5,47.1,47.3 +8601,1101,2022-02-25 00:00:00,day,14556.0,47.2,47.5,46.9,47.35 +""" + +MOCK_DF = pd.read_csv(io.StringIO(MOCK_DATA), header=0, dtype={"symbol": str}) + + +class MockStorageBase: + def __init__(self, **kwargs): + self.df = MOCK_DF + + +class MockCalendarStorage(MockStorageBase, CalendarStorage): + def __init__(self, **kwargs): + super().__init__() + self._data = sorted(self.df["datetime"].unique()) + + @property + def data(self) -> List[CalVT]: + return self._data + + def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]: + return self.data[i] + + def __len__(self) -> int: + return len(self.data) + + +class MockInstrumentStorage(MockStorageBase, InstrumentStorage): + def __init__(self, **kwargs): + super().__init__() + instruments = {} + for symbol, group in self.df.groupby(by="symbol"): + start = group["datetime"].iloc[0] + end = group["datetime"].iloc[-1] + instruments[symbol] = [(start, end)] + self._data = instruments + + @property + def data(self) -> Dict[InstKT, InstVT]: + return self._data + + def __getitem__(self, k: InstKT) -> InstVT: + return self.data[k] + + def __len__(self) -> int: + return len(self.data) + + +class MockFeatureStorage(MockStorageBase, FeatureStorage): + def __init__(self, instrument: str, field: str, freq: str, db_region: str = None, **kwargs): # type: ignore + super().__init__(instrument=instrument, field=field, freq=freq, db_region=db_region, **kwargs) + self.field = field + calendar = sorted(self.df["datetime"].unique()) + df_calendar = pd.DataFrame(calendar, columns=["datetime"]).set_index("datetime") + df = self.df[self.df["symbol"] == instrument] + data_dt_field = "datetime" + cal_df = df_calendar[ + (df_calendar.index >= df[data_dt_field].min()) & (df_calendar.index <= df[data_dt_field].max()) + ] + df = df.set_index(data_dt_field) + df_data = df.reindex(cal_df.index) + date_index = df_calendar.index.get_loc(df_data.index.min()) # type: ignore + df_data.reset_index(inplace=True) + df_data.index += date_index + self._data = df_data + + @property + def data(self) -> pd.Series: + return self._data[self.field] + + @property + def start_index(self) -> Union[int, None]: + if self._data.empty: + return None + return self._data.index[0] + + @property + def end_index(self) -> Union[int, None]: + if self._data.empty: + return None + # The next data appending index point will be `end_index + 1` + return self._data.index[-1] + + def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]: + df = self._data + storage_start_index = df.index[0] + storage_end_index = df.index[-1] + if isinstance(i, int): + if storage_start_index > i or i > storage_end_index: + raise IndexError(f"{i}: start index is {storage_start_index}") + data = self.data[i] + return i, data + elif isinstance(i, slice): + start_index = storage_start_index if i.start is None else i.start + end_index = storage_end_index if i.stop is None else i.stop + si = max(start_index, storage_start_index) + if si > end_index or self.field not in df.columns: + return pd.Series(dtype=np.float32) # type: ignore + data = df[self.field].tolist() + result = data[si - storage_start_index : end_index - storage_start_index] + return pd.Series(result, index=pd.RangeIndex(si, si + len(result))) # type: ignore + else: + raise TypeError(f"type(i) = {type(i)}") + + def __len__(self) -> int: + return len(self.data) + + +class TestMockData(unittest.TestCase): + _setup_kwargs = { + "calendar_provider": { + "class": "LocalCalendarProvider", + "module_path": "qlib.data.data", + "kwargs": {"backend": {"class": "MockCalendarStorage", "module_path": "qlib.tests"}}, + }, + "instrument_provider": { + "class": "LocalInstrumentProvider", + "module_path": "qlib.data.data", + "kwargs": {"backend": {"class": "MockInstrumentStorage", "module_path": "qlib.tests"}}, + }, + "feature_provider": { + "class": "LocalFeatureProvider", + "module_path": "qlib.data.data", + "kwargs": {"backend": {"class": "MockFeatureStorage", "module_path": "qlib.tests"}}, + }, + } + + @classmethod + def setUpClass(cls) -> None: + + provider_uri = "Not necessary." + init(region=REG_TW, provider_uri=provider_uri, expression_cache=None, dataset_cache=None, **cls._setup_kwargs) diff --git a/tests/ops/test_elem_operator.py b/tests/ops/test_elem_operator.py index e641b1ac2e..8349157ff4 100644 --- a/tests/ops/test_elem_operator.py +++ b/tests/ops/test_elem_operator.py @@ -1,17 +1,52 @@ import unittest +import numpy as np +import pytest from qlib.data import DatasetProvider -from qlib.tests import TestOperatorData +from qlib.data.data import ExpressionD +from qlib.tests import TestOperatorData, TestMockData, MOCK_DF from qlib.config import C +class TestElementOperator(TestMockData): + def setUp(self) -> None: + self.instrument = "0050" + self.start_time = "2022-01-01" + self.end_time = "2022-02-01" + self.freq = "day" + self.mock_df = MOCK_DF[MOCK_DF["symbol"] == self.instrument] + + def test_Abs(self): + field = "Abs($close-Ref($close, 1))" + result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq) + self.assertGreaterEqual(result.min(), 0) + result = result.to_numpy() + prev_close = self.mock_df["close"].shift(1) + close = self.mock_df["close"] + change = prev_close - close + golden = change.abs().to_numpy() + self.assertIsNone(np.testing.assert_allclose(result, golden)) + + def test_Sign(self): + field = "Sign($close-Ref($close, 1))" + result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq) + result = result.to_numpy() + prev_close = self.mock_df["close"].shift(1) + close = self.mock_df["close"] + change = close - prev_close + change[change > 0] = 1.0 + change[change < 0] = -1.0 + golden = change.to_numpy() + self.assertIsNone(np.testing.assert_allclose(result, golden)) + + class TestOperatorDataSetting(TestOperatorData): def test_setting(self): self.assertEqual(len(self.instruments_d), 1) self.assertGreater(len(self.cal), 0) -class TestElementOperator(TestOperatorData): +class TestInstElementOperator(TestOperatorData): def setUp(self) -> None: freq = "day" expressions = [ @@ -24,6 +59,7 @@ def setUp(self) -> None: ) self.data.columns = columns + @pytest.mark.slow def test_abs(self): abs_values = self.data["abs"] self.assertGreater(abs_values[2], 0)