Skip to content

Commit

Permalink
fixing SMD plot (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Aug 20, 2024
1 parent c516438 commit 722e2e8
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions experiments/smd/plot_smd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


fname = (
"/home/owkin/project/results_experiments/smd_cov_shift/results_smd_cov_shift.pkl"
"/home/owkin/fedeca/results_experiments/smd_cov_shift/results_smd_cov_shift.pkl"
)
df_res = load_dataframe_from_pickles(fname)

Expand All @@ -36,6 +36,7 @@

# Preprocess results
df = df_res.filter(regex=r"cov_shift|method|rep_id|smd_\w+_X_")

df = (
pd.wide_to_long(
df.reset_index(drop=True).reset_index(),
Expand All @@ -61,15 +62,9 @@
.drop(columns="index")
)

df = df.loc[np.logical_or(df['method'] == "FederatedIPTW", df['weighted'] == "weighted")]
df = df.loc[np.logical_or(df['method'] == "FedECA", df['weighted'] == "weighted")]
df.loc[df['weighted'] == "raw", 'method'] = "Unweighted"
df['smd'] = np.abs(df['smd'])
method_recoding = {
"FederatedIPTW": "FedECA",
"MAIC": "MAIC",
"Unweighted": "Unweighted",
}
df["method"] = df["method"].replace(method_recoding)


# Plot
Expand Down Expand Up @@ -97,10 +92,10 @@
g.set_ylabels("Covariate")
g.set_titles(row_template="{row_name}", col_template="Covariate shift = {col_name}")
for ax in g.axes.flat:
ax.axvline(0.1, color="black", linestyle="--", alpha=0.2)
ax.axvline(10, color="black", linestyle="--", alpha=0.2)
# g.add_legend(title="Method")

g.savefig("smd_cov_shift.png", bbox_inches="tight")
g.savefig("smd_cov_shift.pdf", bbox_inches="tight")



Expand Down

0 comments on commit 722e2e8

Please sign in to comment.