Skip to content

Commit

Permalink
Implement Rejection Sampling Variational Inference (#659)
Browse files Browse the repository at this point in the history
* Sketch Rejector classes and .score_function_term() method

* Implement RejectionStandardGamma using ImplicitRejector

* Clean up plots

* Integrate RSVI with Trace_ELBO

* Implement shape-augmented gradients (but not score function)

* Add kl_gamma_gamma for true gradient

* Get ShapeAugmentedStandardGamma working

* Add variance plot

* Add REINFORCE estimator to comparison

* Add check for .stateful classes in RandomPrimitive

*  add beta to rejgamma; add rejgamma to one integration test in test_inference.py

* clamp in rejector

* Move rejection_gamma, add gamma_dirichlet

* Set ShapeAugmentedGamma.reparameterized = True

* Simplify ShapeAugmentedGamma

* Integrate analytic .entropy() into Trace_ELBO

* Implement ShapeAugmentedDirichlet

* Add g_cor term to ShapeAugmentedDirichlet

* flake8

* Fix bugs in Mixture distribution

* Remove irrelevant research stuff from the branch

* Add tests for Rejector (one still fails)

* Move RejectionExponential into distributions.testing

* Fix broken stale dependency

* Fix syntax errors in shape augmented gamma test

* Fix merge conflict

* Fix Rejector, but ShapeAugmentedGamma is still broken

* Comment on reason for xfailing test

* Address review comments

* Fix beta rate-vs-scale bug in RejectionGamma

* Fix ShapeAugmentedGamma.score_parts()

* Resolve merge conflict

* Work around old pytorch wheel used in testing

* Fix attribute dependency error in rejector tests

* Unmark xfailing rsvi test
  • Loading branch information
fritzo authored and eb8680 committed Jan 25, 2018
1 parent 098f0af commit 9adc077
Show file tree
Hide file tree
Showing 12 changed files with 437 additions and 26 deletions.
1 change: 1 addition & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pyro.distributions.distribution import Distribution # noqa: F401
from pyro.distributions.log_normal import LogNormal
from pyro.distributions.random_primitive import RandomPrimitive
from pyro.distributions.rejector import Rejector # noqa: F401

# distribution classes with working torch versions in torch.distributions
from pyro.distributions.torch.bernoulli import Bernoulli
Expand Down
29 changes: 27 additions & 2 deletions pyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from abc import ABCMeta, abstractmethod

import torch
from six import add_metaclass

import torch
from pyro.distributions.score_parts import ScoreParts


@add_metaclass(ABCMeta)
class Distribution(object):
Expand Down Expand Up @@ -85,6 +87,7 @@ class Distribution(object):
Take a look at the `examples <http://pyro.ai/examples>`_ to see how they interact
with inference algorithms.
"""
stateful = False
reparameterized = False
enumerable = False

Expand Down Expand Up @@ -168,7 +171,8 @@ def log_pdf(self, x, *args, **kwargs):
Evaluates total log probability density of a batch of samples.
:param torch.autograd.Variable x: A value.
:return: total log probability density as a one-dimensional torch.autograd.Variable of size 1.
:return: total log probability density as a one-dimensional
torch.autograd.Variable of size 1.
:rtype: torch.autograd.Variable
"""
return torch.sum(self.batch_log_pdf(x, *args, **kwargs))
Expand All @@ -187,6 +191,27 @@ def batch_log_pdf(self, x, *args, **kwargs):
"""
raise NotImplementedError

def score_parts(self, x, *args, **kwargs):
"""
Computes ingredients for stochastic gradient estimators of ELBO.
The default implementation is correct both for non-reparameterized and
for fully reparameterized distributions. Partially reparameterized
distributions should override this method to compute correct
`.score_function` and `.entropy_term` parts.
:param torch.autograd.Variable x: A single value or batch of values.
:return: A `ScoreParts` object containing parts of the ELBO estimator.
:rtype: ScoreParts
"""
log_pdf = self.batch_log_pdf(x, *args, **kwargs)
if self.reparameterized:
return ScoreParts(log_pdf=log_pdf, score_function=0, entropy_term=log_pdf)
else:
# XXX should the user be able to control inclusion of the entropy term?
# See Roeder, Wu, Duvenaud (2017) "Sticking the Landing" https://arxiv.org/abs/1703.09194
return ScoreParts(log_pdf=log_pdf, score_function=log_pdf, entropy_term=0)

def enumerate_support(self, *args, **kwargs):
"""
Returns a representation of the parametrized distribution's support.
Expand Down
6 changes: 6 additions & 0 deletions pyro/distributions/random_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class RandomPrimitive(Distribution):
__slots__ = ['dist_class']

def __init__(self, dist_class):
if dist_class.stateful:
raise TypeError('Cannot wrap stateful class {} in RandomPrimitive.'.format(type(dist_class)))
self.dist_class = dist_class
super(RandomPrimitive, self).__init__()

Expand Down Expand Up @@ -53,6 +55,10 @@ def batch_log_pdf(self, x, *args, **kwargs):
kwargs.pop('sample_shape', None)
return self.dist_class(*args, **kwargs).batch_log_pdf(x)

def score_parts(self, x, *args, **kwargs):
kwargs.pop('sample_shape', None)
return self.dist_class(*args, **kwargs).score_parts(x)

def enumerate_support(self, *args, **kwargs):
kwargs.pop('sample_shape', None)
return self.dist_class(*args, **kwargs).enumerate_support()
Expand Down
64 changes: 64 additions & 0 deletions pyro/distributions/rejector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import absolute_import, division, print_function

import torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import copy_docs_from


@copy_docs_from(Distribution)
class Rejector(Distribution):
"""
Rejection sampled distribution given an acceptance rate function.
:param Distribution propose: A proposal distribution that samples batched
propsals via `propose()`.
:param callable log_prob_accept: A callable that inputs a batch of
proposals and returns a batch of log acceptance probabilities.
:param log_scale: Total log probability of acceptance.
"""
stateful = True
reparameterized = True

def __init__(self, propose, log_prob_accept, log_scale):
self.propose = propose
self.log_prob_accept = log_prob_accept
self._log_scale = log_scale

# These LRU(1) caches allow work to be shared across different method calls.
self._log_prob_accept_cache = None, None
self._propose_batch_log_pdf_cache = None, None

def _log_prob_accept(self, x):
if x is not self._log_prob_accept_cache[0]:
self._log_prob_accept_cache = x, self.log_prob_accept(x) - self._log_scale
return self._log_prob_accept_cache[1]

def _propose_batch_log_pdf(self, x):
if x is not self._propose_batch_log_pdf_cache[0]:
self._propose_batch_log_pdf_cache = x, self.propose.batch_log_pdf(x)
return self._propose_batch_log_pdf_cache[1]

def sample(self):
# Implements parallel batched accept-reject sampling.
x = self.propose()
log_prob_accept = self.log_prob_accept(x)
probs = torch.exp(log_prob_accept).clamp_(0.0, 1.0)
done = torch.bernoulli(probs).byte()
while not done.all():
proposed_x = self.propose()
log_prob_accept = self.log_prob_accept(proposed_x)
prob_accept = torch.exp(log_prob_accept).clamp_(0.0, 1.0)
accept = torch.bernoulli(prob_accept).byte() & ~done
if accept.any():
x[accept] = proposed_x[accept]
done |= accept
return x

def batch_log_pdf(self, x):
return self._propose_batch_log_pdf(x) + self._log_prob_accept(x)

def score_parts(self, x):
score_function = self._log_prob_accept(x)
log_pdf = self.batch_log_pdf(x)
return ScoreParts(log_pdf, score_function, log_pdf)
19 changes: 19 additions & 0 deletions pyro/distributions/score_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import absolute_import, division, print_function

from collections import namedtuple


class ScoreParts(namedtuple('ScoreParts', ['log_pdf', 'score_function', 'entropy_term'])):
"""
This data structure stores terms used in stochastic gradient estimators that
combine the pathwise estimator and the score function estimator.
"""
def __mul__(self, scale):
"""
Scale appropriate terms of a gradient estimator by a data multiplicity factor.
Note that the `score_function` term should not be scaled.
"""
log_pdf, score_function, entropy_term = self
return ScoreParts(log_pdf * scale, score_function, entropy_term * scale)

__rmul__ = __mul__
12 changes: 9 additions & 3 deletions pyro/distributions/testing/fakes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import absolute_import, division, print_function

from pyro.distributions import Beta
from pyro.distributions import Gamma
from pyro.distributions import Normal
from pyro.distributions.random_primitive import RandomPrimitive
from pyro.distributions.torch.beta import Beta
from pyro.distributions.torch.dirichlet import Dirichlet
from pyro.distributions.torch.gamma import Gamma
from pyro.distributions.torch.normal import Normal


class NonreparameterizedBeta(Beta):
reparameterized = False


class NonreparameterizedDirichlet(Dirichlet):
reparameterized = False


class NonreparameterizedGamma(Gamma):
reparameterized = False

Expand All @@ -19,5 +24,6 @@ class NonreparameterizedNormal(Normal):


nonreparameterized_beta = RandomPrimitive(NonreparameterizedBeta)
nonreparameterized_dirichlet = RandomPrimitive(NonreparameterizedDirichlet)
nonreparameterized_gamma = RandomPrimitive(NonreparameterizedGamma)
nonreparameterized_normal = RandomPrimitive(NonreparameterizedNormal)
21 changes: 21 additions & 0 deletions pyro/distributions/testing/rejection_exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import absolute_import, division, print_function

from pyro.distributions import Exponential
from pyro.distributions.rejector import Rejector
from pyro.distributions.util import copy_docs_from


@copy_docs_from(Exponential)
class RejectionExponential(Rejector):
def __init__(self, rate, factor):
assert (factor <= 1).all()
self.rate = rate
self.factor = factor
propose = Exponential(self.factor * self.rate)
log_scale = self.factor.log()
super(RejectionExponential, self).__init__(propose, self.log_prob_accept, log_scale)

def log_prob_accept(self, x):
result = (self.factor - 1) * self.rate * x
assert result.max() <= 0, result.max()
return result
115 changes: 115 additions & 0 deletions pyro/distributions/testing/rejection_gamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import absolute_import, division, print_function

import torch

from pyro.distributions.rejector import Rejector
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.torch.gamma import Gamma
from pyro.distributions.torch.normal import Normal
from pyro.distributions.util import copy_docs_from


@copy_docs_from(Gamma)
class RejectionStandardGamma(Rejector):
"""
Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution.
This assumes `alpha >= 1` and does not boost `alpha` or augment shape.
"""
def __init__(self, alpha):
if alpha.data.min() < 1:
raise NotImplementedError('alpha < 1 is not supported')
self.alpha = alpha
self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha))
# The following are Marsaglia & Tsang's variable names.
self._d = self.alpha - 1.0 / 3.0
self._c = 1.0 / torch.sqrt(9.0 * self._d)
# Compute log scale using Gamma.batch_log_pdf().
x = self._d.detach() # just an arbitrary x.
log_scale = self.propose_batch_log_pdf(x) + self.log_prob_accept(x) - self.batch_log_pdf(x)
super(RejectionStandardGamma, self).__init__(self.propose, self.log_prob_accept, log_scale)

def propose(self):
# Marsaglia & Tsang's x == Naesseth's epsilon
x = self.alpha.new(self.alpha.shape).normal_()
y = 1.0 + self._c * x
v = y * y * y
return (self._d * v).clamp_(1e-30, 1e30)

def propose_batch_log_pdf(self, value):
v = value / self._d
result = -self._d.log()
y = v.pow(1 / 3)
result -= torch.log(3 * y ** 2)
x = (y - 1) / self._c
result -= self._c.log()
result += Normal(torch.zeros_like(self.alpha), torch.ones_like(self.alpha)).batch_log_pdf(x)
return result

def log_prob_accept(self, value):
v = value / self._d
y = torch.pow(v, 1.0 / 3.0)
x = (y - 1.0) / self._c
log_prob_accept = 0.5 * x * x + self._d * (1.0 - v + torch.log(v))
log_prob_accept[y <= 0] = -float('inf')
return log_prob_accept

def batch_log_pdf(self, x):
return self._standard_gamma.batch_log_pdf(x)


@copy_docs_from(Gamma)
class RejectionGamma(Gamma):
stateful = True
reparameterized = True

def __init__(self, alpha, beta):
super(RejectionGamma, self).__init__(alpha, beta)
self._standard_gamma = RejectionStandardGamma(alpha)
self.beta = beta

def sample(self):
return self._standard_gamma.sample() / self.beta

def batch_log_pdf(self, x):
return self._standard_gamma.batch_log_pdf(x * self.beta) + torch.log(self.beta)

def score_parts(self, x):
log_pdf, score_function, _ = self._standard_gamma.score_parts(x * self.beta)
log_pdf = log_pdf + torch.log(self.beta)
return ScoreParts(log_pdf, score_function, log_pdf)


@copy_docs_from(Gamma)
class ShapeAugmentedGamma(Gamma):
"""
This implements the shape augmentation trick of
Naesseth, Ruiz, Linderman, Blei (2017) https://arxiv.org/abs/1610.05683
"""
stateful = True
reparameterized = True

def __init__(self, alpha, beta, boost=1):
if alpha.min() + boost < 1:
raise ValueError('Need to boost at least once for alpha < 1')
super(ShapeAugmentedGamma, self).__init__(alpha, beta)
self.alpha = alpha
self._boost = boost
self._rejection_gamma = RejectionGamma(alpha + boost, beta)
self._unboost_x_cache = None, None

def sample(self):
x = self._rejection_gamma.sample()
boosted_x = x.clone()
for i in range(self._boost):
boosted_x *= (1 - x.new(x.shape).uniform_()) ** (1 / (i + self.alpha))
self._unboost_x_cache = boosted_x, x
return boosted_x

def score_parts(self, boosted_x=None):
if boosted_x is None:
boosted_x = self._unboost_x_cache[0]
assert boosted_x is self._unboost_x_cache[0]
x = self._unboost_x_cache[1]
_, score_function, _ = self._rejection_gamma.score_parts(x)
log_pdf = self.batch_log_pdf(boosted_x)
return ScoreParts(log_pdf, score_function, log_pdf)
Loading

0 comments on commit 9adc077

Please sign in to comment.