Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jean/fed kaplan #44

Merged
merged 34 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1d8a9e0
general architecture of fedkaplan
jeandut Jul 12, 2024
302e386
respecting naming conventions
jeandut Jul 12, 2024
c495edf
refactoring preprocessing
jeandut Jul 12, 2024
2d7b04b
passing in my head but not tested
jeandut Jul 12, 2024
70773f6
adding test for KM utils
jeandut Jul 12, 2024
6094a78
add credit
jeandut Jul 12, 2024
b5ba356
everything works in my head
jeandut Aug 2, 2024
fe28879
hacking
jeandut Aug 2, 2024
2e2f2fe
some refactoring
jeandut Aug 2, 2024
1615830
fixing bug
jeandut Aug 2, 2024
3c51103
fixing various stuff
jeandut Aug 2, 2024
a242ebc
fixing stuff
jeandut Aug 2, 2024
85cdf39
everything passing
jeandut Aug 12, 2024
31fa318
test passing
jeandut Aug 12, 2024
244b894
trying fixing tests
jeandut Aug 12, 2024
74c1e01
linting
jeandut Aug 12, 2024
cc03227
linting
jeandut Aug 12, 2024
b74aff6
linting
jeandut Aug 12, 2024
09d4d13
linting
jeandut Aug 12, 2024
2142efc
linting
jeandut Aug 12, 2024
67f0d1e
linting
jeandut Aug 12, 2024
6a64e58
linting fedkaplan
jeandut Aug 12, 2024
5c654cc
trying to finally fix linting
jeandut Aug 12, 2024
544fcf8
linting
jeandut Aug 12, 2024
19d55f2
fixing substra stuff in FedKM
jeandut Aug 12, 2024
24ba6d9
test almost working wo weights
jeandut Aug 12, 2024
3bee459
linting
jeandut Aug 12, 2024
f6323fc
linting
jeandut Aug 12, 2024
98a5bac
now tests are not passing only because grid is not the same
jeandut Aug 13, 2024
369d0b8
tests passing
jeandut Aug 13, 2024
5dd6363
weights working
jeandut Aug 13, 2024
9c85be9
removing useless comments
jeandut Aug 13, 2024
4804492
tests should be passing
jeandut Aug 13, 2024
5369fc9
removing forgoteen brakpoint
jeandut Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 26 additions & 142 deletions fedeca/algorithms/torch_webdisco_algo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Implement webdisco algorithm with Torch."""
import copy
from copy import deepcopy
from math import sqrt
from pathlib import Path
from typing import Any, List, Optional, Union

Expand All @@ -11,7 +10,6 @@
from autograd import elementwise_grad
from autograd import numpy as anp
from lifelines.utils import StepSizer
from pandas.api.types import is_numeric_dtype
from scipy.linalg import norm
from scipy.linalg import solve as spsolve
from substrafl.algorithms.pytorch import weight_manager
Expand All @@ -21,7 +19,11 @@

from fedeca.schemas import WebDiscoAveragedStates, WebDiscoSharedState
from fedeca.utils.moments_utils import compute_uncentered_moment
from fedeca.utils.survival_utils import MockStepSizer
from fedeca.utils.survival_utils import (
MockStepSizer,
build_X_y_function,
compute_X_y_and_propensity_weights_function,
)


class TorchWebDiscoAlgo(TorchAlgo):
Expand Down Expand Up @@ -597,124 +599,6 @@ def summary(self):
summary = super().summary()
return summary

def build_X_y(self, data_from_opener, shared_state={}):
"""Build appropriate X and y times from output of opener.

This function 1. uses the event column to inject the censorship
information present in the duration column (given in absolute values)
in the form of a negative sign.
2. Drop every covariate except treatment if self.strategy == "iptw".
3. Standardize the data if self.standardize_data AND if it receives
an outmodel.
4. Return the (unstandardized) input to the propensity model Xprop if
necessary as well as the treated column to be able to compute the
propensity weights.

Parameters
----------
data_from_opener : pd.DataFrame
The output of the opener
shared_state : dict, optional
Outmodel containing global means and stds.
by default {}

Returns
-------
tuple
standardized X, signed times, treatment column and unstandardized
propensity model input
"""
# We need y to be in the format (2*event-1)*duration
data_from_opener["time_multiplier"] = [
2.0 * e - 1.0 for e in data_from_opener[self._event_col].tolist()
]
# No funny business irrespective of the convention used
y = (
np.abs(data_from_opener[self._duration_col])
* data_from_opener["time_multiplier"]
)
y = y.to_numpy().astype("float64")
data_from_opener = data_from_opener.drop(columns=["time_multiplier"])
# dangerous but we need to do it
string_columns = [
col
for col in data_from_opener.columns
if not (is_numeric_dtype(data_from_opener[col]))
]
data_from_opener = data_from_opener.drop(columns=string_columns)

# We drop the targets from X
columns_to_drop = self._target_cols
X = data_from_opener.drop(columns=columns_to_drop)
if self._propensity_model is not None:
assert self._treated_col is not None
if self._training_strategy == "iptw":
X = X.loc[:, [self._treated_col]]
elif self._training_strategy == "aiptw":
if len(self._cox_fit_cols) > 0:
X = X.loc[:, [self._treated_col] + self._cox_fit_cols]
else:
pass
else:
assert self._training_strategy == "webdisco"
if len(self._cox_fit_cols) > 0:
X = X.loc[:, self._cox_fit_cols]
else:
pass

# If X is to be standardized we do it
if self._standardize_data:
if shared_state:
# Careful this shouldn't happen apart from the predict
means = shared_state["global_uncentered_moment_1"]
vars = shared_state["global_centered_moment_2"]
# Careful we need to match pandas and use unbiased estimator
bias_correction = (shared_state["total_n_samples"]) / float(
shared_state["total_n_samples"] - 1
)
self.global_moments = {
"means": means,
"vars": vars,
"bias_correction": bias_correction,
}
stds = vars.transform(lambda x: sqrt(x * bias_correction + self._tol))
X = X.sub(means)
X = X.div(stds)
else:
X = X.sub(self.global_moments["means"])
stds = self.global_moments["vars"].transform(
lambda x: sqrt(
x * self.global_moments["bias_correction"] + self._tol
)
)
X = X.div(stds)

X = X.to_numpy().astype("float64")

# If we have a propensity model we need to build X without the targets AND the
# treated column
if self._propensity_model is not None:
# We do not normalize the data for the propensity model !!!
Xprop = data_from_opener.drop(columns=columns_to_drop + [self._treated_col])
if self._propensity_fit_cols is not None:
Xprop = Xprop[self._propensity_fit_cols]
Xprop = Xprop.to_numpy().astype("float64")
else:
Xprop = None

# If WebDisco is used without propensity treated column does not exist
if self._treated_col is not None:
treated = (
data_from_opener[self._treated_col]
.to_numpy()
.astype("float64")
.reshape((-1, 1))
)
else:
treated = None

return (X, y, treated, Xprop)

def compute_X_y_and_propensity_weights(self, data_from_opener, shared_state):
"""Build appropriate X, y and weights from raw output of opener.

Expand All @@ -731,26 +615,26 @@ def compute_X_y_and_propensity_weights(self, data_from_opener, shared_state):
Returns
-------
tuple
_description_
X input to the Cox model, y target of Cox model, weights propensity weights
"""
X, y, treated, Xprop = self.build_X_y(data_from_opener, shared_state)
if self._propensity_model is not None:
assert (
treated is not None
), f"""If you are using a propensity model the {self._treated_col} (Treated)
column should be available"""
assert np.all(
np.in1d(np.unique(treated.astype("uint8"))[0], [0, 1])
), "The treated column should have all its values in set([0, 1])"
Xprop = torch.from_numpy(Xprop)
with torch.no_grad():
propensity_scores = self._propensity_model(Xprop)

propensity_scores = propensity_scores.detach().numpy()
# We robustify the division
weights = treated * 1.0 / np.maximum(propensity_scores, self._tol) + (
1 - treated
) * 1.0 / (np.maximum(1.0 - propensity_scores, self._tol))
else:
weights = np.ones((X.shape[0], 1))
X, y, treated, Xprop, self.global_moments = build_X_y_function(
data_from_opener,
self._event_col,
self._duration_col,
self._treated_col,
self._target_cols,
self._standardize_data,
self._propensity_model,
self._cox_fit_cols,
self._propensity_fit_cols,
self._tol,
self._training_strategy,
shared_state=shared_state,
global_moments={}
if not hasattr(self, "global_moments")
else self.global_moments,
)
X, y, weights = compute_X_y_and_propensity_weights_function(
X, y, treated, Xprop, self._propensity_model, self._tol
)
return X, y, weights
Loading
Loading