Skip to content

Commit

Permalink
Fix in reset of proposal kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
papamarkou committed Dec 14, 2020
1 parent 3fc6eb3 commit d65dcc6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion eeyore/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.14'
__version__ = '0.0.15'
4 changes: 4 additions & 0 deletions eeyore/samplers/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def set_current(self, theta, data=None):
self.current['target_val'], self.current['grad_val'] = \
self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y)

def reset(self, theta, data=None, reset_counter=True, reset_chain=True):
super().reset(theta, data=data, reset_counter=reset_counter, reset_chain=reset_chain)
self.set_kernel(self.current)

def kernel_mean(self, state):
return state['sample'] + 0.5 * self.step * state['grad_val']

Expand Down
4 changes: 4 additions & 0 deletions eeyore/samplers/metropolis_hastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def set_current(self, theta, data=None):
x, y = super().set_current(theta, data=data)
self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y)

def reset(self, theta, data=None, reset_counter=True, reset_chain=True):
super().reset(theta, data=data, reset_counter=reset_counter, reset_chain=reset_chain)
self.set_kernel(self.current)

def set_kernel(self, state, scale=None, scale_tril=None):
self.kernel.set_density_params(state['sample'].clone().detach())

Expand Down

0 comments on commit d65dcc6

Please sign in to comment.