Skip to content

Commit

Permalink
Fix brittle check in test_marginal_likelihood
Browse files Browse the repository at this point in the history
This would fail when the chains do not have the same length, resulting in the marginal_likelihood being wrapped in a numpy array of `object` dtype. This led the nanmean call to fail with an AttributeError

Closes pymc-devs#5324
  • Loading branch information
ricardoV94 committed Jan 13, 2022
1 parent 1272012 commit cd1b313
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions pymc/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,16 @@ def test_marginal_likelihood(self):
marginals.append(trace.report.log_marginal_likelihood)

# compare to the analytical result
assert abs(np.exp(np.nanmean(marginals[1]) - np.nanmean(marginals[0])) - 4.0) <= 1
assert (
np.abs(
np.exp(
np.nanmean(np.array(marginals[1], dtype=float))
- np.nanmean(np.array(marginals[0], dtype=float))
- 4.0
)
)
<= 1
)

def test_start(self):
with pm.Model() as model:
Expand Down Expand Up @@ -299,7 +308,7 @@ def setup_class(self):
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)

def test_one_gaussian(self):
assert self.count_rvs(self.SMABC_test.logpt) == 1
assert self.count_rvs(self.SMABC_test.logpt()) == 1

with self.SMABC_test:
trace = pm.sample_smc(draws=1000, chains=1, return_inferencedata=False)
Expand Down Expand Up @@ -333,7 +342,7 @@ def test_custom_dist_sum_stat(self):
observed=self.data,
)

assert self.count_rvs(m.logpt) == 1
assert self.count_rvs(m.logpt()) == 1

with m:
pm.sample_smc(draws=100)
Expand All @@ -354,7 +363,7 @@ def test_custom_dist_sum_stat_scalar(self):
sum_stat=self.quantiles,
observed=scalar_data,
)
assert self.count_rvs(m.logpt) == 1
assert self.count_rvs(m.logpt()) == 1

with pm.Model() as m:
s = pm.Simulator(
Expand All @@ -366,10 +375,10 @@ def test_custom_dist_sum_stat_scalar(self):
sum_stat="mean",
observed=scalar_data,
)
assert self.count_rvs(m.logpt) == 1
assert self.count_rvs(m.logpt()) == 1

def test_model_with_potential(self):
assert self.count_rvs(self.SMABC_potential.logpt) == 1
assert self.count_rvs(self.SMABC_potential.logpt()) == 1

with self.SMABC_potential:
trace = pm.sample_smc(draws=100, chains=1, return_inferencedata=False)
Expand Down Expand Up @@ -413,7 +422,7 @@ def test_multiple_simulators(self):
observed=data2,
)

assert self.count_rvs(m.logpt) == 2
assert self.count_rvs(m.logpt()) == 2

# Check that the logps use the correct methods
a_val = m.rvs_to_values[a]
Expand Down Expand Up @@ -463,7 +472,7 @@ def test_nested_simulators(self):
observed=data,
)

assert self.count_rvs(m.logpt) == 2
assert self.count_rvs(m.logpt()) == 2

with m:
trace = pm.sample_smc(return_inferencedata=False)
Expand Down

0 comments on commit cd1b313

Please sign in to comment.