Skip to content

Commit

Permalink
Add REG_US and REG_TW into test case: test_utils.py. (#1310)
Browse files Browse the repository at this point in the history
* Add REG_US and REG_TW into test case: test_utils.py.

* Fix black.

* Trigger checks.

* Add REG_US and REG_TW into test case: test_utils.py.

* Fix black.

* Trigger checks.
  • Loading branch information
ChiahungTai authored Oct 14, 2022
1 parent 216a8ec commit 84e9df2
Showing 1 changed file with 48 additions and 21 deletions.
69 changes: 48 additions & 21 deletions tests/misc/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from unittest.case import TestCase
import unittest
import pandas as pd
Expand All @@ -6,10 +7,13 @@
from qlib import init
from qlib.config import C
from qlib.log import TimeInspector
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal
from qlib.constant import REG_CN, REG_US, REG_TW
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME

REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME}

def cal_sam_minute(x, sam_minutes):

def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str):
"""
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
Expand All @@ -20,21 +24,43 @@ def cal_sam_minute(x, sam_minutes):
# TODO: actually, this version is much faster when no cache or optimization
day_time = pd.Timestamp(x.date())
shift = C.min_data_shift

open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1)
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
region_time = REG_MAP[region]

open_time = (
day_time
+ pd.Timedelta(hours=region_time[0].hour, minutes=region_time[0].minute)
- shift * pd.Timedelta(minutes=1)
)
close_time = (
day_time
+ pd.Timedelta(hours=region_time[-1].hour, minutes=region_time[-1].minute)
- shift * pd.Timedelta(minutes=1)
)
if region_time == CN_TIME:
mid_close_time = (
day_time
+ pd.Timedelta(hours=region_time[1].hour, minutes=region_time[1].minute - 1)
- shift * pd.Timedelta(minutes=1)
)
mid_open_time = (
day_time
+ pd.Timedelta(hours=region_time[2].hour, minutes=region_time[2].minute)
- shift * pd.Timedelta(minutes=1)
)
else:
mid_close_time = close_time
mid_open_time = open_time

if open_time <= x <= mid_close_time:
minute_index = (x - open_time).seconds // 60
elif mid_open_time <= x <= close_time:
minute_index = (x - mid_open_time).seconds // 60 + 120
else:
raise ValueError("datetime of calendar is out of range")

minute_index = minute_index // sam_minutes * sam_minutes

if 0 <= minute_index < 120:
if 0 <= minute_index < 120 or region_time != CN_TIME:
return open_time + minute_index * pd.Timedelta(minutes=1)
elif 120 <= minute_index < 240:
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
Expand All @@ -50,9 +76,9 @@ def setUpClass(cls):
def test_cal_sam_minute(self):
# test the correctness of the code
random_n = 1000
cal = get_min_cal()
regions = [REG_CN, REG_US, REG_TW]

def gen_args():
def gen_args(cal: List):
for time in np.random.choice(cal, size=random_n, replace=True):
sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])
dt = pd.Timestamp(
Expand All @@ -69,20 +95,21 @@ def gen_args():
args = dt, sam_minutes
yield args

for args in gen_args():
assert cal_sam_minute(*args) == cal_sam_minute_new(*args)

# test the performance of the code
for region in regions:
cal_time = get_min_cal(region=region)
for args in gen_args(cal_time):
assert cal_sam_minute(*args, region) == cal_sam_minute_new(*args, region=region)

args_l = list(gen_args())
# test the performance of the code
args_l = list(gen_args(cal_time))

with TimeInspector.logt():
for args in args_l:
cal_sam_minute(*args)
with TimeInspector.logt():
for args in args_l:
cal_sam_minute(*args, region=region)

with TimeInspector.logt():
for args in args_l:
cal_sam_minute_new(*args)
with TimeInspector.logt():
for args in args_l:
cal_sam_minute_new(*args, region=region)


if __name__ == "__main__":
Expand Down

0 comments on commit 84e9df2

Please sign in to comment.