Skip to content

Commit

Permalink
Merge pull request #72 from e10v/dev
Browse files Browse the repository at this point in the history
Define Experiment.solve_power
  • Loading branch information
e10v committed Jul 7, 2024
2 parents f785f3d + 65d65df commit 5c882dc
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 7 deletions.
53 changes: 52 additions & 1 deletion src/tea_tasting/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import ibis.expr.types

from tea_tasting.metrics.base import PowerParameter


class ExperimentResult(
UserDict[str, tea_tasting.metrics.MetricResult],
Expand Down Expand Up @@ -329,10 +331,24 @@ def to_html(
"""
return tea_tasting.utils.PrettyDictsMixin.to_html(self, keys, formatter)


ExperimentResults = dict[tuple[Any, Any], ExperimentResult]


class ExperimentPowerResult(
UserDict[str, tea_tasting.metrics.MetricPowerResults[Any]],
tea_tasting.utils.PrettyDictsMixin,
):
"""Result of the analysis of power in a experiment."""
default_keys = ("metric", "power", "effect_size", "rel_effect_size", "n_obs")

def to_dicts(self) -> tuple[dict[str, Any], ...]:
"""Convert the result to a sequence of dictionaries."""
dicts = ()
for metric, results in self.items():
dicts = (*dicts, *({"metric": metric} | d for d in results.to_dicts()))
return dicts


class Experiment(tea_tasting.utils.ReprMixin): # noqa: D101
def __init__(
self,
Expand Down Expand Up @@ -575,3 +591,38 @@ def _read_variants(
.to_pandas()
.loc[:, self.variant]
)


def solve_power(
self,
data: pd.DataFrame | ibis.expr.types.Table,
parameter: PowerParameter = "rel_effect_size",
) -> ExperimentPowerResult:
"""Solve for a parameter of the power of a test.
Args:
data: Sample data.
parameter: Parameter name.
Returns:
Power analysis result.
"""
aggr_cols = tea_tasting.metrics.AggrCols()
for metric in self.metrics.values():
if isinstance(metric, tea_tasting.metrics.PowerBaseAggregated):
aggr_cols |= metric.aggr_cols

aggr_data = tea_tasting.aggr.read_aggregates(
data,
group_col=None,
**aggr_cols._asdict(),
) if len(aggr_cols) > 0 else tea_tasting.aggr.Aggregates()

result = ExperimentPowerResult()
for name, metric in self.metrics.items():
if isinstance(metric, tea_tasting.metrics.PowerBaseAggregated):
result |= {name: metric.solve_power(aggr_data, parameter=parameter)}
elif isinstance(metric, tea_tasting.metrics.PowerBase):
result |= {name: metric.solve_power(data, parameter=parameter)}

return result
2 changes: 2 additions & 0 deletions src/tea_tasting/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
MetricBaseGranular,
MetricPowerResults,
MetricResult,
PowerBase,
PowerBaseAggregated,
aggregate_by_variants,
read_dataframes,
)
Expand Down
96 changes: 90 additions & 6 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, NamedTuple, TypedDict
from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict

import ibis
import ibis.expr.types
Expand All @@ -11,11 +11,14 @@
import tea_tasting.datasets
import tea_tasting.experiment
import tea_tasting.metrics
import tea_tasting.utils


if TYPE_CHECKING:
from collections.abc import Callable

from tea_tasting.metrics.base import PowerParameter


class _MetricResultTuple(NamedTuple):
control: float
Expand All @@ -27,8 +30,17 @@ class _MetricResultDict(TypedDict):
treatment: float
effect_size: float

class _PowerResult(NamedTuple):
power: float
effect_size: float
rel_effect_size: float
n_obs: float


class _Metric(tea_tasting.metrics.MetricBase[_MetricResultTuple]):
class _Metric(
tea_tasting.metrics.MetricBase[_MetricResultTuple],
tea_tasting.metrics.PowerBase[tea_tasting.metrics.MetricPowerResults[_PowerResult]],
):
def __init__(self, value: str) -> None:
self.value = value

Expand All @@ -50,13 +62,27 @@ def analyze(
contr_mean = agg_data.loc[control, "mean"]
treat_mean = agg_data.loc[treatment, "mean"]
return _MetricResultTuple(
control=contr_mean,
treatment=treat_mean,
control=contr_mean, # type: ignore
treatment=treat_mean, # type: ignore
effect_size=treat_mean - contr_mean, # type: ignore
)


class _MetricAggregated(tea_tasting.metrics.MetricBaseAggregated[_MetricResultTuple]):
def solve_power(
self,
data: pd.DataFrame | ibis.Table, # noqa: ARG002
parameter: PowerParameter = "rel_effect_size", # noqa: ARG002
) -> tea_tasting.metrics.MetricPowerResults[_PowerResult]:
return tea_tasting.metrics.MetricPowerResults((
_PowerResult(power=0.8, effect_size=1, rel_effect_size=0.05, n_obs=10_000),
_PowerResult(power=0.9, effect_size=2, rel_effect_size=0.1, n_obs=20_000),
))


class _MetricAggregated(
tea_tasting.metrics.MetricBaseAggregated[_MetricResultTuple],
tea_tasting.metrics.PowerBaseAggregated[
tea_tasting.metrics.MetricPowerResults[dict[str, Any]]],
):
def __init__(self, value: str) -> None:
self.value = value

Expand All @@ -77,6 +103,16 @@ def analyze_aggregates(
effect_size=treat_mean - contr_mean,
)

def solve_power_from_aggregates(
self,
data: tea_tasting.aggr.Aggregates, # noqa: ARG002
parameter: PowerParameter = "rel_effect_size", # noqa: ARG002
) -> tea_tasting.metrics.MetricPowerResults[dict[str, Any]]:
return tea_tasting.metrics.MetricPowerResults((
{"power": 0.8, "effect_size": 1, "rel_effect_size": 0.05, "n_obs": 10_000},
{"power": 0.9, "effect_size": 2, "rel_effect_size": 0.1, "n_obs": 20_000},
))


class _MetricGranular(tea_tasting.metrics.MetricBaseGranular[_MetricResultDict]): # type: ignore
def __init__(self, value: str) -> None:
Expand Down Expand Up @@ -279,6 +315,36 @@ def test_experiment_result_to_html(result2: tea_tasting.experiment.ExperimentRes
)).to_html(index=False)


def test_experiment_power_result_to_dicts():
raw_results = (
{"power": 0.8, "effect_size": 1, "rel_effect_size": 0.05, "n_obs": 20_000},
{"power": 0.9, "effect_size": 1, "rel_effect_size": 0.05, "n_obs": 10_000},
{"power": 0.8, "effect_size": 2, "rel_effect_size": 0.1, "n_obs": 10_000},
{"power": 0.9, "effect_size": 2, "rel_effect_size": 0.1, "n_obs": 20_000},
)
result = tea_tasting.experiment.ExperimentPowerResult({
"metric_dict": tea_tasting.metrics.MetricPowerResults[dict[str, Any]](
raw_results[0:2]),
"metric_tuple": tea_tasting.metrics.MetricPowerResults[_PowerResult]([
_PowerResult(**raw_results[2]),
_PowerResult(**raw_results[3]),
]),
})
assert isinstance(result, tea_tasting.utils.PrettyDictsMixin)
assert result.default_keys == (
"metric", "power", "effect_size", "rel_effect_size", "n_obs")
assert result.to_dicts() == (
{"metric": "metric_dict", "power": 0.8, "effect_size": 1,
"rel_effect_size": 0.05, "n_obs": 20_000},
{"metric": "metric_dict", "power": 0.9, "effect_size": 1,
"rel_effect_size": 0.05, "n_obs": 10_000},
{"metric": "metric_tuple", "power": 0.8, "effect_size": 2,
"rel_effect_size": 0.1, "n_obs": 10_000},
{"metric": "metric_tuple", "power": 0.9, "effect_size": 2,
"rel_effect_size": 0.1, "n_obs": 20_000},
)


def test_experiment_init_default():
metrics = {
"avg_sessions": _Metric("sessions"),
Expand Down Expand Up @@ -416,3 +482,21 @@ def test_experiment_analyze_two_treatments(
(0, 1): ref_result,
(0, 2): ref_result,
})


def test_experiment_solve_power(data: ibis.expr.types.Table):
experiment = tea_tasting.experiment.Experiment(
metric=_Metric("sessions"),
metric_aggr=_MetricAggregated("orders"),
)
result = experiment.solve_power(data)
assert result == tea_tasting.experiment.ExperimentPowerResult({
"metric": tea_tasting.metrics.MetricPowerResults((
_PowerResult(power=0.8, effect_size=1, rel_effect_size=0.05, n_obs=10_000),
_PowerResult(power=0.9, effect_size=2, rel_effect_size=0.1, n_obs=20_000),
)),
"metric_aggr": tea_tasting.metrics.MetricPowerResults((
{"power": 0.8, "effect_size": 1, "rel_effect_size": 0.05, "n_obs": 10_000},
{"power": 0.9, "effect_size": 2, "rel_effect_size": 0.1, "n_obs": 20_000},
)),
})

0 comments on commit 5c882dc

Please sign in to comment.