Skip to content

Commit

Permalink
plot smd script
Browse files Browse the repository at this point in the history
  • Loading branch information
qklopfenstein-owkin committed Aug 20, 2024
1 parent 3b4802b commit 923af51
Showing 1 changed file with 84 additions and 16 deletions.
100 changes: 84 additions & 16 deletions experiments/smd/plot_smd.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
import pandas as pd
import seaborn as sns
import scipy
import numpy as np

from fedeca.utils.experiment_utils import load_dataframe_from_pickles
from fedeca.viz.plot import owkin_palette
import matplotlib.pyplot as plt
import matplotlib

# Load raw results
sns.set(rc={'figure.figsize': (11.69, 8.27)})
plt.style.use('default')

# plot lines for smd as a function of covariate shift
dict_colors = {
"FedECA": list(owkin_palette.values())[5],
"MAIC": list(owkin_palette.values())[4],
"Unweighted": list(owkin_palette.values())[2],
}
dict_markers = {
"FedECA": "s",
"MAIC": "P",
"Unweighted": "d"
}


fname = (
"/home/owkin/project/results_experiments/smd_cov_shift/results_smd_cov_shift.pkl"
)
df_res = load_dataframe_from_pickles(fname)

df_res["cov_shift"] = 0.5 * (df_res["overlap"] + 1)
df_res = df_res.loc[df_res['method'] != 'IPTW']
df_res = df_res.loc[df_res['method'] != 'Unweighted']

# Preprocess results
df = df_res.filter(regex=r"cov_shift|method|rep_id|smd_\w+_X_")
Expand Down Expand Up @@ -38,39 +60,85 @@
.reset_index()
.drop(columns="index")
)
df["weighted"] = df["weighted"].replace({"weighted": True, "raw": False})

df = df.loc[np.logical_or(df['method'] == "FederatedIPTW", df['weighted'] == "weighted")]
df.loc[df['weighted'] == "raw", 'method'] = "Unweighted"
df['smd'] = np.abs(df['smd'])
method_recoding = {
"FedECA": "FedECA",
"IPTW": "IPTW",
"FederatedIPTW": "FedECA",
"MAIC": "MAIC",
"Unweighted": "Unweighted",
}
df["method"] = df["method"].replace(method_recoding)


# Plot
g = sns.FacetGrid(
df[
df["cov_shift"].isin([0, 2])
& df["covariate"].isin(["X_0", "X_1", "X_2", "X_3", "X_4"])
],
col="method",
col_order=["IPTW", "FedECA", "MAIC"],
row="cov_shift",
height=3.5, # type: ignore
aspect=0.8, # type: ignore
row=None,
col="cov_shift",
margin_titles=True,
)
g.map_dataframe(
sns.boxplot,
x="smd",
y="covariate",
hue="weighted",
width=0.3,
palette=owkin_palette.values(),
hue="method",
width=0.6,
palette=list(owkin_palette.values())[4:5] + list(owkin_palette.values())[5:6] + list(owkin_palette.values())[2:3],
gap=0.3,
whis=5,
linewidth=0.7,
)
g.set_xlabels("Standardized mean difference")
g.set_xlabels("Absolute Standardized \n mean difference")
g.set_ylabels("Covariate")
g.set_titles(col_template="{col_name}", row_template="Covariate shift = {row_name}")
g.set_titles(row_template="{row_name}", col_template="Covariate shift = {col_name}")
for ax in g.axes.flat:
ax.axvline(0, color="black", linestyle="--", alpha=0.2)
g.add_legend(title="Weighted")
g.savefig("smd_cov_shift.pdf", bbox_inches="tight")
ax.axvline(0.1, color="black", linestyle="--", alpha=0.2)
# g.add_legend(title="Method")

g.savefig("smd_cov_shift.png", bbox_inches="tight")




df_smd = df_overlap = (
df.groupby(["method", "cov_shift", "weighted"])
.agg(
smd=pd.NamedAgg(column="smd", aggfunc=lambda x: np.abs(x).sum() / x.size),
lower=pd.NamedAgg(column="smd", aggfunc=lambda x: np.abs(x).sum() / x.size - scipy.stats.norm.ppf(0.975) * np.sqrt(np.var(np.abs(x)) / x.size)),
upper=pd.NamedAgg(column="smd", aggfunc=lambda x: np.abs(x).sum() / x.size + scipy.stats.norm.ppf(0.975) * np.sqrt(np.var(np.abs(x)) / x.size)),
)
.reset_index()
)

plt.style.use('default')
matplotlib.rc('xtick', labelsize=18)
matplotlib.rc('ytick', labelsize=18)
plt.rcParams['legend.title_fontsize'] = '20'


fig, axarr = plt.subplots(1, 1, figsize=(11.69, 8.27))
selected_methods = ["FedECA", "MAIC", "Unweighted"]
for i, method in enumerate(selected_methods):
axarr.plot(
df_smd[df_smd['method'] == method]['cov_shift'],
df_smd[df_smd['method'] == method]['smd'], marker=dict_markers[method],
label=method, color=dict_colors[method], markersize=10
)
axarr.fill_between(
df_smd[df_smd['method'] == method]['cov_shift'],
df_smd[df_smd['method'] == method]['lower'],
df_smd[df_smd['method'] == method]['upper'],
color=dict_colors[method], alpha=0.4
)

axarr.set_xlabel("Covariate shift", fontsize=20)
axarr.set_ylabel("Mean absolute standardized \n mean difference", fontsize=20)

axarr.grid()
fig.legend(bbox_to_anchor=(0.81, 1.03), ncol=3, title="Method", fontsize=20)
fig.savefig("smd_curves.pdf", bbox_inches="tight")

0 comments on commit 923af51

Please sign in to comment.