Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support optimization based strategy #754

Merged
merged 9 commits into from
Dec 28, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/portfolio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Portfolio Optimization Strategy
evanzd marked this conversation as resolved.
Show resolved Hide resolved

## Introduction

In `qlib/examples/benchmarks` we have various **alpha** models that predict
the stock returns. We also use a simple rule based `TopkDropoutStrategy` to
evaluate the investing performance of these models. However, such a strategy
is too simple to control the portfolio risk like correlation and volatility.

To this end, an optimization based strategy should be used to for the
trade-off between return and risk. In this doc, we will show how to use
`EnhancedIndexingStrategy` to maximize portfolio return while minimizing
tracking error relative to a benchmark.


## Preparation

We use China stock market data for our example.

1. Prepare CSI300 weight:

```bash
wget http://fintech.msra.cn/stock_data/downloads/csi300_weight.zip
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
rm -f csi300_weight.zip
```

2. Prepare risk model data:

```bash
python prepare_riskdata.py
```

Here we use a **Statistical Risk Model** implemented in `qlib.model.riskmodel`.
However users are strongly recommended to use other risk models for better quality:
* **Fundamental Risk Model** like MSCI BARRA
* [Deep Risk Model](https://arxiv.org/abs/2107.05201)


## End-to-End Workflow

You can finish workflow with `EnhancedIndexingStrategy` by running
`qrun config_enhanced_indexing.yaml`.

In this config, we mainly changed the strategy section compared to
`qlib/examples/benchmarks/workflow_config_lightgbm_Alpha158.yaml`.
71 changes: 71 additions & 0 deletions examples/portfolio/config_enhanced_indexing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
fit_start_time: 2008-01-01
fit_end_time: 2014-12-31
instruments: *market
port_analysis_config: &port_analysis_config
strategy:
class: EnhancedIndexingStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
riskmodel_root: ./riskdata
backtest:
start_time: 2017-01-01
end_time: 2020-08-01
account: 100000000
benchmark: *benchmark
exchange_kwargs:
limit_threshold: 0.095
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: Alpha158
module_path: qlib.contrib.data.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs:
model: <MODEL>
dataset: <DATASET>
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
53 changes: 53 additions & 0 deletions examples/portfolio/prepare_riskdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
evanzd marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import pandas as pd

from qlib.data import D
from qlib.model.riskmodel import StructuredCovEstimator


def prepare_data(riskdata_root="./riskdata", T=240, start_time="2016-01-01"):

universe = D.features(D.instruments("csi300"), ["$close"], start_time=start_time).swaplevel().sort_index()

price_all = (
D.features(D.instruments("all"), ["$close"], start_time=start_time).squeeze().unstack(level="instrument")
)

# StructuredCovEstimator is a statistical risk model
riskmodel = StructuredCovEstimator()

for i in range(T - 1, len(price_all)):

date = price_all.index[i]
ref_date = price_all.index[i - T + 1]

print(date)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussions about the data preparation


codes = universe.loc[date].index
price = price_all.loc[ref_date:date, codes]

# calculate return and remove extreme return
ret = price.pct_change()
you-n-g marked this conversation as resolved.
Show resolved Hide resolved
ret.clip(ret.quantile(0.025), ret.quantile(0.975), axis=1, inplace=True)

# run risk model
F, cov_b, var_u = riskmodel.predict(ret, is_price=False, return_decomposed_components=True)

# save risk data
root = riskdata_root + "/" + date.strftime("%Y%m%d")
os.makedirs(root, exist_ok=True)

pd.DataFrame(F, index=codes).to_pickle(root + "/factor_exp.pkl")
pd.DataFrame(cov_b).to_pickle(root + "/factor_cov.pkl")
# for specific_risk we follow the convention to save volatility
pd.Series(np.sqrt(var_u), index=codes).to_pickle(root + "/specific_risk.pkl")


if __name__ == "__main__":

import qlib

qlib.init(provider_uri="~/.qlib/qlib_data/cn_data")

prepare_data()
1 change: 1 addition & 0 deletions qlib/contrib/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .signal_strategy import (
TopkDropoutStrategy,
WeightStrategyBase,
EnhancedIndexingStrategy,
)

from .rule_strategy import (
Expand Down
203 changes: 203 additions & 0 deletions qlib/contrib/strategy/optimizer/enhanced_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import numpy as np
import cvxpy as cp
import pandas as pd

from typing import Union, Optional, Dict, Any, List

from qlib.log import get_module_logger
from .base import BaseOptimizer


logger = get_module_logger("EnhancedIndexingOptimizer")


class EnhancedIndexingOptimizer(BaseOptimizer):
"""
Portfolio Optimizer for Enhanced Indexing

Notations:
w0: current holding weights
wb: benchmark weight
r: expected return
F: factor exposure
cov_b: factor covariance
var_u: residual variance (diagonal)
lamb: risk aversion parameter
delta: total turnover limit
b_dev: benchmark deviation limit
f_dev: factor deviation limit

Also denote:
d = w - wb: benchmark deviation
v = d @ F: factor deviation

The optimization problem for enhanced indexing:
max_w d @ r - lamb * (v @ cov_b @ v + var_u @ d**2)
s.t. w >= 0
sum(w) == 1
sum(|w - w0|) <= delta
d >= -b_dev
d <= b_dev
v >= -f_dev
v <= f_dev
"""

def __init__(
self,
lamb: float = 1,
delta: Optional[float] = 0.2,
b_dev: Optional[float] = 0.01,
f_dev: Optional[Union[List[float], np.ndarray]] = None,
scale_return: bool = True,
epsilon: float = 5e-5,
solver_kwargs: Optional[Dict[str, Any]] = {},
):
"""
Args:
lamb (float): risk aversion parameter (larger `lamb` means more focus on risk)
delta (float): total turnover limit
b_dev (float): benchmark deviation limit
f_dev (list): factor deviation limit
scale_return (bool): whether scale return to match estimated volatility
epsilon (float): minumum weight
evanzd marked this conversation as resolved.
Show resolved Hide resolved
solver_kwargs (dict): kwargs for cvxpy solver
"""

assert lamb >= 0, "risk aversion parameter `lamb` should be positive"
self.lamb = lamb

assert delta >= 0, "turnover limit `delta` should be positive"
self.delta = delta

assert b_dev is None or b_dev >= 0, "benchmark deviation limit `b_dev` should be positive"
self.b_dev = b_dev

if isinstance(f_dev, float):
assert f_dev >= 0, "factor deviation limit `f_dev` should be positive"
elif f_dev is not None:
f_dev = np.array(f_dev)
assert all(f_dev >= 0), "factor deviation limit `f_dev` should be positive"
self.f_dev = f_dev

self.scale_return = scale_return
self.epsilon = epsilon
self.solver_kwargs = solver_kwargs

def __call__(
self,
r: np.ndarray,
F: np.ndarray,
cov_b: np.ndarray,
var_u: np.ndarray,
w0: np.ndarray,
wb: np.ndarray,
mfh: Optional[np.ndarray] = None,
mfs: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Args:
r (np.ndarray): expected returns
F (np.ndarray): factor exposure
cov_b (np.ndarray): factor covariance
var_u (np.ndarray): residual variance
w0 (np.ndarray): current holding weights
wb (np.ndarray): benchmark weights
mfh (np.ndarray): mask force holding
mfs (np.ndarray): mask force selling

Returns:
np.ndarray: optimized portfolio allocation
"""
# scale return to match volatility
if self.scale_return:
r = r / r.std()
r *= np.sqrt(np.mean(np.diag(F @ cov_b @ F.T) + var_u))

# target weight
w = cp.Variable(len(r), nonneg=True)
w.value = wb # for warm start

# precompute exposure
d = w - wb # benchmark exposure
v = d @ F # factor exposure

# objective
ret = d @ r # excess return
risk = cp.quad_form(v, cov_b) + var_u @ (d ** 2) # tracking error
obj = cp.Maximize(ret - self.lamb * risk)

# weight bounds
lb = np.zeros_like(wb)
ub = np.ones_like(wb)

# bench bounds
if self.b_dev is not None:
lb = np.maximum(lb, wb - self.b_dev)
ub = np.minimum(ub, wb + self.b_dev)

# force holding
if mfh is not None:
lb[mfh] = w0[mfh]
ub[mfh] = w0[mfh]

# force selling
# NOTE: this will override mfh
if mfs is not None:
lb[mfs] = 0
ub[mfs] = 0

# constraints
# TODO: currently we assume fullly invest in the stocks,
# in the future we should support holding cash as an asset
cons = [cp.sum(w) == 1, w >= lb, w <= ub]

# factor deviation
if self.f_dev is not None:
cons.extend([v >= -self.f_dev, v <= self.f_dev])

# total turnover constraint
t_cons = []
if self.delta is not None:
if w0 is not None and w0.sum() > 0:
t_cons.extend([cp.norm(w - w0, 1) <= self.delta])

# optimize
# trial 1: use all constraints
success = False
try:
prob = cp.Problem(obj, cons + t_cons)
prob.solve(solver=cp.ECOS, warm_start=True, **self.solver_kwargs)
assert prob.status == "optimal"
success = True
except Exception as e:
logger.warning(f"trial 1 failed {e} (status: {prob.status})")

# trial 2: remove turnover constraint
if not success and len(t_cons):
logger.info("try removing turnvoer constraint as last optimization failed")
evanzd marked this conversation as resolved.
Show resolved Hide resolved
try:
w.value = wb
prob = cp.Problem(obj, cons)
prob.solve(solver=cp.ECOS, warm_start=True, **self.solver_kwargs)
assert prob.status in ["optimal", "optimal_inaccurate"]
success = True
except Exception as e:
logger.warning(f"trial 2 failed {e} (status: {prob.status})")

# return current weight if not success
if not success:
logger.warning("optimization failed, will return current holding weight")
return w0

if prob.status == "optimal_inaccurate":
logger.warning(f"the optimization is inaccurate")

# remove small weight
w = np.asarray(w.value)
w[w < self.epsilon] = 0
w /= w.sum()

return w
Loading