Skip to content

Commit

Permalink
Callable Exp
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Nov 11, 2021
1 parent 01bdf6c commit efa7617
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
13 changes: 10 additions & 3 deletions qlib/workflow/record_temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
from pathlib import Path
from pprint import pprint
from typing import Union, List
from typing import Union, List, Optional
from collections import defaultdict

from qlib.utils.exceptions import LoadObjectError
Expand Down Expand Up @@ -270,7 +270,13 @@ def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0,
self.label_col = label_col
self.skip_existing = skip_existing

def generate(self, **kwargs):
def generate(self, label: Optional[pd.DataFrame] = None, **kwargs):
"""
Parameters
----------
label : Optional[pd.DataFrame]
Label should be a dataframe.
"""
if self.skip_existing:
try:
self.check(include_self=True, parents=False)
Expand All @@ -283,7 +289,8 @@ def generate(self, **kwargs):
self.check()

pred = self.load("pred.pkl")
label = self.load("label.pkl")
if label is None:
label = self.load("label.pkl")
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return
Expand Down
21 changes: 13 additions & 8 deletions qlib/workflow/task/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
from qlib.workflow import R
from qlib.workflow.exp import Experiment


class Collector(Serializable):
Expand Down Expand Up @@ -146,7 +147,9 @@ def __init__(
Init RecorderCollector.
Args:
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
experiment:
(Experiment or str): an instance of an Experiment or the name of an Experiment
(Callable): an callable function, which returns a list of experiments
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
Expand All @@ -157,6 +160,7 @@ def __init__(
super().__init__(process_list=process_list)
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
assert isinstance(experiment, (Experiment, Callable))
self.experiment = experiment
self.artifacts_path = artifacts_path
if rec_key_func is None:
Expand Down Expand Up @@ -192,15 +196,16 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) ->
collect_dict = {}
# filter records

with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
recs = self.experiment.list_recorders(**self.list_kwargs)
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
if isinstance(self.experiment, Experiment):
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
recs = list(self.experiment.list_recorders(**self.list_kwargs).values())
elif isinstance(self.experiment, Callable):
recs = self.experiment()

recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)]

logger = get_module_logger("RecorderCollector")
for _, rec in recs_flt.items():
for rec in recs:
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
if self.ART_KEY_RAW == key:
Expand Down

0 comments on commit efa7617

Please sign in to comment.