From efa7617270885c311dd6f34f08116f7a37c9093c Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 11 Nov 2021 13:05:58 +0000 Subject: [PATCH] Callable Exp --- qlib/workflow/record_temp.py | 13 ++++++++++--- qlib/workflow/task/collect.py | 21 +++++++++++++-------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 07422243de..2c72ae4af0 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -8,7 +8,7 @@ import pandas as pd from pathlib import Path from pprint import pprint -from typing import Union, List +from typing import Union, List, Optional from collections import defaultdict from qlib.utils.exceptions import LoadObjectError @@ -270,7 +270,13 @@ def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, self.label_col = label_col self.skip_existing = skip_existing - def generate(self, **kwargs): + def generate(self, label: Optional[pd.DataFrame] = None, **kwargs): + """ + Parameters + ---------- + label : Optional[pd.DataFrame] + Label should be a dataframe. + """ if self.skip_existing: try: self.check(include_self=True, parents=False) @@ -283,7 +289,8 @@ def generate(self, **kwargs): self.check() pred = self.load("pred.pkl") - label = self.load("label.pkl") + if label is None: + label = self.load("label.pkl") if label is None or not isinstance(label, pd.DataFrame) or label.empty: logger.warn(f"Empty label.") return diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index f48fc7c0db..b5b63bba6c 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -10,6 +10,7 @@ from qlib.log import get_module_logger from qlib.utils.serial import Serializable from qlib.workflow import R +from qlib.workflow.exp import Experiment class Collector(Serializable): @@ -146,7 +147,9 @@ def __init__( Init RecorderCollector. Args: - experiment (Experiment or str): an instance of an Experiment or the name of an Experiment + experiment: + (Experiment or str): an instance of an Experiment or the name of an Experiment + (Callable): an callable function, which returns a list of experiments process_list (list or Callable): the list of processors or the instance of a processor to process dict. rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. @@ -157,6 +160,7 @@ def __init__( super().__init__(process_list=process_list) if isinstance(experiment, str): experiment = R.get_exp(experiment_name=experiment) + assert isinstance(experiment, (Experiment, Callable)) self.experiment = experiment self.artifacts_path = artifacts_path if rec_key_func is None: @@ -192,15 +196,16 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> collect_dict = {} # filter records - with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"): - recs = self.experiment.list_recorders(**self.list_kwargs) - recs_flt = {} - for rid, rec in recs.items(): - if rec_filter_func is None or rec_filter_func(rec): - recs_flt[rid] = rec + if isinstance(self.experiment, Experiment): + with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"): + recs = list(self.experiment.list_recorders(**self.list_kwargs).values()) + elif isinstance(self.experiment, Callable): + recs = self.experiment() + + recs = [rec for rec in recs if rec_filter_func is None or rec_filter_func(rec)] logger = get_module_logger("RecorderCollector") - for _, rec in recs_flt.items(): + for rec in recs: rec_key = self.rec_key_func(rec) for key in artifacts_key: if self.ART_KEY_RAW == key: