Skip to content

Commit

Permalink
now tests are not passing only because grid is not the same
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Aug 13, 2024
1 parent f6323fc commit 98a5bac
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 17 deletions.
3 changes: 2 additions & 1 deletion fedeca/strategies/fed_kaplan.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def compute_events_statistics(
del weights
# retrieve times and events
times = np.abs(y)
events = y > 0
events = y >= 0
assert np.allclose(events, data_from_opener[self._event_col].values)
treated = treated.astype(bool).flatten()

Expand Down Expand Up @@ -214,6 +214,7 @@ def compute_agg_km_curve(self, shared_states):
[sh["untreated"] for sh in shared_states]
),
}

return {
"treated": km_curve(*treated_untreated_tnd_agg["treated"]),
"untreated": km_curve(*treated_untreated_tnd_agg["untreated"]),
Expand Down
3 changes: 2 additions & 1 deletion fedeca/tests/strategies/test_km.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestKM(TestTempDir):
def setUp(self, backend_type="subprocess", ndim=10) -> None:
"""Set up the quantities needed for the tests."""
# Let's generate 1000 data samples with 10 covariates
data = CoxData(seed=42, n_samples=1000, ndim=ndim)
data = CoxData(seed=42, n_samples=1000, ndim=ndim, percent_ties=0.2)
self.df = data.generate_dataframe()

# We remove the true propensity score
Expand Down Expand Up @@ -135,6 +135,7 @@ def test_end_to_end(self):
]
s_gts = [kmf.survival_function_["KM_estimate"].to_numpy() for kmf in kms]
grid_gts = [kmf.survival_function_.index.to_numpy() for kmf in kms]
breakpoint()
fl_grid_treated, fl_s_treated, _ = fl_results["treated"]
fl_grid_untreated, fl_s_untreated, _ = fl_results["untreated"]

Expand Down
54 changes: 39 additions & 15 deletions fedeca/utils/survival_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,15 @@ def generate_data(
if not reached:
raise ValueError("This should not happen, lower percent_ties")
times = times.reshape((-1))
# 0-1 scale
times /= float(nbins)
# With this Kbins discretizer, times start always at 0.
# 0, 1, 2, 3, ...
# there should be no time at exactly 0. otherwise lifelines
# (rightfully) will act in a weird way. See the birth outer join in
# https://github.com/CamDavidsonPilon/lifelines/blob/4377caf5a6224941ee3ab34c413ad668d4173274/lifelines/utils/__init__.py#L567
# therefore we add a small quantity to every time
times += np.random.uniform(1.0 / nbins, 1.0, size=1)

else:
raise ValueError("Choose a larger number of ties")
Expand Down Expand Up @@ -1364,7 +1373,7 @@ def robust_sandwich_variance_pooled(
return np.sqrt(np.diag(tested_var))


def km_curve(t, n, d, tmax=5000):
def km_curve(t, n, d, tmax=None):
"""Compute Kaplan-Meier (KM) curve.
This function is typically used in conjunction with
Expand All @@ -1382,7 +1391,7 @@ def km_curve(t, n, d, tmax=5000):
Array containing the number of individuals with an event (death) at
each corresponding time `t`
tmax : int, optional
Maximal time point for the KM curve, by default 5000
Number of grid points, Default to the number of unique events + 1.
Returns
-------
Expand Down Expand Up @@ -1417,23 +1426,38 @@ def km_curve(t, n, d, tmax=5000):
https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
https://www.math.wustl.edu/~sawyer/handouts/greenwood.pdf
"""
grid = np.arange(0, tmax + 1)
# precompute for KM
if tmax is None:
# Number of unique events + 0 ("birth")
tmax = len(t)
# We compute the grid on which we will want to plot S(t)
grid = np.linspace(0, t.max(), tmax + 1)
# KM estimator but wo filtering terms out
q = 1.0 - d / n
cprod_q = np.cumprod(q)
# precompute for var

# Same for Greenwood's formula
csum_var = np.cumsum(d / (n * (n - d)))
# initialize

# Now we just need for each point of the grid to filter out terms that are
# bigger than them
# we initialize by filtering out everything
s = np.zeros(grid.shape)
var_s = np.zeros(grid.shape)
# the following sum relies on the assumption that unique_events_times is sorted!
index_closest = np.sum(grid.reshape(-1, 1) - t.reshape(1, -1) >= 0, axis=1) - 1
not_found = index_closest < 0
found = index_closest >= 0
# attribute
s[not_found] = 1.0
s[found] = cprod_q[index_closest[found]]
var_s[found] = (s[found] ** 2) * csum_var[index_closest[found]]

# we need, for each element in the grid, the index of the cumprod/cumsum
# it should go to, which would by design filter out the right terms
# to respect KM's formula
mask = grid.reshape(-1, 1) - t.reshape(1, -1) >= 0 # (grid.shape, t.shape)
index_in_cum_vec = np.sum(mask, axis=1) - 1
# We can now compute the survival function for each point in the grid
# Survival function starts at 1.
s[index_in_cum_vec < 0.0] = 1.0
s[index_in_cum_vec >= 0.0] = cprod_q[index_in_cum_vec[index_in_cum_vec >= 0]]

# And now similarly we derive Greenwood
var_s[index_in_cum_vec >= 0] = (s[index_in_cum_vec >= 0] ** 2) * csum_var[
index_in_cum_vec[index_in_cum_vec >= 0]
]
return grid, s, var_s


Expand Down Expand Up @@ -1510,7 +1534,7 @@ def aggregate_events_statistics(list_t_n_d):
with an event (death) at each corresponding `t_agg`
"""
# Step 1: get the unique times
unique_times_list = [t for (t, _, _) in list_t_n_d]
unique_times_list = [t for (t, _, _) in list_t_n_d if t.size != 0]
t_agg = np.unique(np.concatenate(unique_times_list))
# Step 2: extend to common grid
n_ext, d_ext = extend_events_to_common_grid(list_t_n_d, t_agg)
Expand Down

0 comments on commit 98a5bac

Please sign in to comment.