Skip to content

Commit

Permalink
Make default STEP_METHODS a list that can be modified
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 5, 2024
1 parent 22e8f0b commit 034b9a4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
7 changes: 4 additions & 3 deletions pymc/step_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc.step_methods.compound import CompoundStep
from pymc.step_methods.compound import BlockedStep, CompoundStep
from pymc.step_methods.hmc import NUTS, HamiltonianMC
from pymc.step_methods.metropolis import (
BinaryGibbsMetropolis,
Expand All @@ -30,12 +30,13 @@
)
from pymc.step_methods.slicer import Slice

STEP_METHODS = (
# Other step methods can be added by appending to this list
STEP_METHODS: list[type[BlockedStep]] = [
NUTS,
HamiltonianMC,
Metropolis,
BinaryMetropolis,
BinaryGibbsMetropolis,
Slice,
CategoricalGibbsMetropolis,
)
]
16 changes: 11 additions & 5 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,12 +762,18 @@ def kill_grad(x):
steps = assign_step_methods(model, [])
assert isinstance(steps, Slice)

def test_modify_step_methods(self):
@pytest.fixture
def step_methods(self):
"""Make sure we reset the STEP_METHODS after the test is done."""
methods_copy = pm.STEP_METHODS.copy()
yield pm.STEP_METHODS
pm.STEP_METHODS.clear()
for method in methods_copy:
pm.STEP_METHODS.append(method)

def test_modify_step_methods(self, step_methods):
"""Test step methods can be changed"""
# remove nuts from step_methods
step_methods = list(pm.STEP_METHODS)
step_methods.remove(NUTS)
pm.STEP_METHODS = step_methods

with pm.Model() as model:
pm.Normal("x", 0, 1)
Expand All @@ -776,7 +782,7 @@ def test_modify_step_methods(self):
assert not isinstance(steps, NUTS)

# add back nuts
pm.STEP_METHODS = [*step_methods, NUTS]
step_methods.append(NUTS)

with pm.Model() as model:
pm.Normal("x", 0, 1)
Expand Down
6 changes: 1 addition & 5 deletions tests/step_methods/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
Slice,
)
from pymc.step_methods.compound import (
BlockedStep,
StatsBijection,
flatten_steps,
get_stats_dtypes_shapes_from_steps,
Expand All @@ -38,10 +37,7 @@


def test_all_stepmethods_emit_tune_stat():
attrs = [getattr(pm.step_methods, n) for n in dir(pm.step_methods)]
step_types = [
attr for attr in attrs if isinstance(attr, type) and issubclass(attr, BlockedStep)
]
step_types = pm.step_methods.STEP_METHODS
assert len(step_types) > 5
for cls in step_types:
assert "tune" in cls.stats_dtypes_shapes
Expand Down

0 comments on commit 034b9a4

Please sign in to comment.