Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add default protocol_version #677

Merged
merged 2 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions qlib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def set_conf_from_C(self, config_c):
REG_CN = "cn"
REG_US = "us"

# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
PROTOCOL_VERSION = 4

NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1)

DISK_DATASET_CACHE = "DiskDatasetCache"
Expand Down Expand Up @@ -107,6 +110,8 @@ def set_conf_from_C(self, config_c):
# for simple dataset cache
"local_cache_path": None,
"kernels": NUM_USABLE_CPU,
# pickle.dump protocol version
"dump_protocol_version": PROTOCOL_VERSION,
# How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data.
"maxtasksperchild": None,
# If joblib_backend is None, use loky
Expand Down
9 changes: 3 additions & 6 deletions qlib/contrib/online/manager.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import pickle
import yaml
import pathlib
import pandas as pd
import shutil
from ..backtest.account import Account
from ..backtest.exchange import Exchange
from ...backtest.account import Account
from .user import User
from .utils import load_instance
from ...utils import save_instance, init_instance_by_config
from .utils import load_instance, save_instance
from ...utils import init_instance_by_config


class UserManager:
Expand Down
6 changes: 3 additions & 3 deletions qlib/contrib/online/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import yaml
import pandas as pd
from ...data import D
from ...config import C
from ...log import get_module_logger
from ...utils import get_module_by_module_path, init_instance_by_config
from ...utils import get_next_trading_date
from ..backtest.exchange import Exchange
from ...backtest.exchange import Exchange

log = get_module_logger("utils")

Expand Down Expand Up @@ -42,7 +42,7 @@ def save_instance(instance, file_path):
"""
file_path = pathlib.Path(file_path)
with file_path.open("wb") as fr:
pickle.dump(instance, fr)
pickle.dump(instance, fr, C.dump_protocol_version)


def create_user_folder(path):
Expand Down
10 changes: 5 additions & 5 deletions qlib/data/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def visit(cache_path: Union[str, Path]):
d["meta"]["visits"] = d["meta"]["visits"] + 1
except KeyError:
raise KeyError("Unknown meta keyword")
pickle.dump(d, f)
pickle.dump(d, f, protocol=C.dump_protocol_version)
except Exception as e:
get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}")

Expand Down Expand Up @@ -573,7 +573,7 @@ def gen_expression_cache(self, expression_data, cache_path, instrument, field, f
meta_path = cache_path.with_suffix(".meta")

with meta_path.open("wb") as f:
pickle.dump(meta, f)
pickle.dump(meta, f, protocol=C.dump_protocol_version)
meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
df = expression_data.to_frame()

Expand Down Expand Up @@ -638,7 +638,7 @@ def update(self, sid, cache_uri, freq: str = "day"):
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with meta_path.open("wb") as f:
pickle.dump(d, f)
pickle.dump(d, f, protocol=C.dump_protocol_version)
return 0


Expand Down Expand Up @@ -935,7 +935,7 @@ def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, f
"meta": {"last_visit": time.time(), "visits": 1},
}
with cache_path.with_suffix(".meta").open("wb") as f:
pickle.dump(meta, f)
pickle.dump(meta, f, protocol=C.dump_protocol_version)
cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH)
# write index file
im = DiskDatasetCache.IndexManager(cache_path)
Expand Down Expand Up @@ -1057,7 +1057,7 @@ def update(self, cache_uri, freq: str = "day"):
# update meta file
d["info"]["last_update"] = str(new_calendar[-1])
with meta_path.open("wb") as f:
pickle.dump(d, f)
pickle.dump(d, f, protocol=C.dump_protocol_version)
return 0


Expand Down
205 changes: 103 additions & 102 deletions qlib/data/client.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


from __future__ import division
from __future__ import print_function

import socketio

import qlib
from ..log import get_module_logger
import pickle


class Client:
"""A client class

Provide the connection tool functions for ClientProvider.
"""

def __init__(self, host, port):
super(Client, self).__init__()
self.sio = socketio.Client()
self.server_host = host
self.server_port = port
self.logger = get_module_logger(self.__class__.__name__)
# bind connect/disconnect callbacks
self.sio.on(
"connect",
lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)),
)
self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!"))

def connect_server(self):
"""Connect to server."""
try:
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
except socketio.exceptions.ConnectionError:
self.logger.error("Cannot connect to server - check your network or server status")

def disconnect(self):
"""Disconnect from server."""
try:
self.sio.eio.disconnect(True)
except Exception as e:
self.logger.error("Cannot disconnect from server : %s" % e)

def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):
"""Send a certain request to server.

Parameters
----------
request_type : str
type of proposed request, 'calendar'/'instrument'/'feature'.
request_content : dict
records the information of the request.
msg_proc_func : func
the function to process the message when receiving response, should have arg `*args`.
msg_queue: Queue
The queue to pass the messsage after callback.
"""
head_info = {"version": qlib.__version__}

def request_callback(*args):
"""callback_wrapper

:param *args: args[0] is the response content
"""
# args[0] is the response content
self.logger.debug("receive data and enter queue")
msg = dict(args[0])
if msg["detailed_info"] is not None:
if msg["status"] != 0:
self.logger.error(msg["detailed_info"])
else:
self.logger.info(msg["detailed_info"])
if msg["status"] != 0:
ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}")
msg_queue.put(ex)
else:
if msg_proc_func is not None:
try:
ret = msg_proc_func(msg["result"])
except Exception as e:
self.logger.exception("Error when processing message.")
ret = e
else:
ret = msg["result"]
msg_queue.put(ret)
self.disconnect()
self.logger.debug("disconnected")

self.logger.debug("try connecting")
self.connect_server()
self.logger.debug("connected")
# The pickle is for passing some parameters with special type(such as
# pd.Timestamp)
request_content = {"head": head_info, "body": pickle.dumps(request_content)}
self.sio.on(request_type + "_response", request_callback)
self.logger.debug("try sending")
self.sio.emit(request_type + "_request", request_content)
self.sio.wait()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


from __future__ import division
from __future__ import print_function

import socketio

import qlib
from ..config import C
from ..log import get_module_logger
import pickle


class Client:
"""A client class

Provide the connection tool functions for ClientProvider.
"""

def __init__(self, host, port):
super(Client, self).__init__()
self.sio = socketio.Client()
self.server_host = host
self.server_port = port
self.logger = get_module_logger(self.__class__.__name__)
# bind connect/disconnect callbacks
self.sio.on(
"connect",
lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)),
)
self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!"))

def connect_server(self):
"""Connect to server."""
try:
self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port))
except socketio.exceptions.ConnectionError:
self.logger.error("Cannot connect to server - check your network or server status")

def disconnect(self):
"""Disconnect from server."""
try:
self.sio.eio.disconnect(True)
except Exception as e:
self.logger.error("Cannot disconnect from server : %s" % e)

def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None):
"""Send a certain request to server.

Parameters
----------
request_type : str
type of proposed request, 'calendar'/'instrument'/'feature'.
request_content : dict
records the information of the request.
msg_proc_func : func
the function to process the message when receiving response, should have arg `*args`.
msg_queue: Queue
The queue to pass the messsage after callback.
"""
head_info = {"version": qlib.__version__}

def request_callback(*args):
"""callback_wrapper

:param *args: args[0] is the response content
"""
# args[0] is the response content
self.logger.debug("receive data and enter queue")
msg = dict(args[0])
if msg["detailed_info"] is not None:
if msg["status"] != 0:
self.logger.error(msg["detailed_info"])
else:
self.logger.info(msg["detailed_info"])
if msg["status"] != 0:
ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}")
msg_queue.put(ex)
else:
if msg_proc_func is not None:
try:
ret = msg_proc_func(msg["result"])
except Exception as e:
self.logger.exception("Error when processing message.")
ret = e
else:
ret = msg["result"]
msg_queue.put(ret)
self.disconnect()
self.logger.debug("disconnected")

self.logger.debug("try connecting")
self.connect_server()
self.logger.debug("connected")
# The pickle is for passing some parameters with special type(such as
# pd.Timestamp)
request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)}
self.sio.on(request_type + "_response", request_callback)
self.logger.debug("try sending")
self.sio.emit(request_type + "_request", request_content)
self.sio.wait()
2 changes: 1 addition & 1 deletion qlib/utils/objm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def create_path(self) -> str:

def save_obj(self, obj, name):
with (self.path / name).open("wb") as f:
pickle.dump(obj, f)
pickle.dump(obj, f, protocol=C.dump_protocol_version)

def save_objs(self, obj_name_l):
for obj, name in obj_name_l:
Expand Down
5 changes: 3 additions & 2 deletions qlib/utils/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dill
from pathlib import Path
from typing import Union
from ..config import C


class Serializable:
Expand Down Expand Up @@ -85,7 +86,7 @@ def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list
"""
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
self.get_backend().dump(self, f)
self.get_backend().dump(self, f, protocol=C.dump_protocol_version)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend may not be a pickle.
Maybe you can add Serializable.dump_kwargs for dumping and encourage users to override it.


@classmethod
def load(cls, filepath):
Expand Down Expand Up @@ -140,4 +141,4 @@ def general_dump(obj, path: Union[Path, str]):
obj.to_pickle(path)
else:
with path.open("wb") as f:
pickle.dump(obj, f)
pickle.dump(obj, f, protocol=C.dump_protocol_version)
8 changes: 6 additions & 2 deletions qlib/workflow/task/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tqdm.cli import tqdm

from .utils import get_mongodb
from ...config import C


class TaskManager:
Expand Down Expand Up @@ -108,7 +109,7 @@ def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = Binary(pickle.dumps(task[k]))
task[k] = Binary(pickle.dumps(task[k], protocol=C.dump_protocol_version))
return task

def _decode_task(self, task):
Expand Down Expand Up @@ -359,7 +360,10 @@ def commit_task_res(self, task, res, status=STATUS_DONE):
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
self.task_pool.update_one(
{"_id": task["_id"]},
{"$set": {"status": status, "res": Binary(pickle.dumps(res, protocol=C.dump_protocol_version))}},
)

def return_task(self, task, status=STATUS_WAITING):
"""
Expand Down