Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Rejection Sampling Variational Inference #659

Merged
merged 39 commits into from
Jan 25, 2018
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
21fd1c3
Sketch Rejector classes and .score_function_term() method
fritzo Jan 3, 2018
48ca4df
Implement RejectionStandardGamma using ImplicitRejector
fritzo Jan 4, 2018
d486c03
Clean up plots
fritzo Jan 5, 2018
304fffa
Integrate RSVI with Trace_ELBO
fritzo Jan 6, 2018
2b3c234
Implement shape-augmented gradients (but not score function)
fritzo Jan 6, 2018
51df863
Add kl_gamma_gamma for true gradient
fritzo Jan 7, 2018
56e5311
Get ShapeAugmentedStandardGamma working
fritzo Jan 7, 2018
bfa6db3
Add variance plot
fritzo Jan 8, 2018
4889777
Add REINFORCE estimator to comparison
fritzo Jan 8, 2018
66b94c1
Add check for .stateful classes in RandomPrimitive
fritzo Jan 10, 2018
24e73af
add beta to rejgamma; add rejgamma to one integration test in test_i…
martinjankowiak Jan 10, 2018
2c3c3b9
clamp in rejector
martinjankowiak Jan 11, 2018
23cf6b8
Merge branch 'dev' into rejector
fritzo Jan 14, 2018
531833f
Move rejection_gamma, add gamma_dirichlet
fritzo Jan 14, 2018
12ae69b
Set ShapeAugmentedGamma.reparameterized = True
fritzo Jan 14, 2018
137111d
Simplify ShapeAugmentedGamma
fritzo Jan 14, 2018
b40c1ea
Integrate analytic .entropy() into Trace_ELBO
fritzo Jan 14, 2018
17a80a9
Implement ShapeAugmentedDirichlet
fritzo Jan 14, 2018
2e3e040
Add g_cor term to ShapeAugmentedDirichlet
fritzo Jan 15, 2018
2c4aed8
flake8
fritzo Jan 15, 2018
43c18b5
Fix bugs in Mixture distribution
fritzo Jan 17, 2018
5d1610c
Merge branch 'dev' into rejector
fritzo Jan 23, 2018
f2d9502
Remove irrelevant research stuff from the branch
fritzo Jan 23, 2018
0bc09f6
Add tests for Rejector (one still fails)
fritzo Jan 23, 2018
b6455c0
Move RejectionExponential into distributions.testing
fritzo Jan 23, 2018
20ae559
Fix broken stale dependency
fritzo Jan 23, 2018
3b48953
Fix syntax errors in shape augmented gamma test
fritzo Jan 23, 2018
21c157a
Fix merge conflict
fritzo Jan 23, 2018
1c9397f
Fix Rejector, but ShapeAugmentedGamma is still broken
fritzo Jan 24, 2018
2e34667
Merge branch 'dev' into rejector
fritzo Jan 24, 2018
a20d9b3
Comment on reason for xfailing test
fritzo Jan 24, 2018
d1afe4a
Address review comments
fritzo Jan 24, 2018
9c7fe3c
Merge branch 'dev' into rejector
fritzo Jan 24, 2018
ac20484
Fix beta rate-vs-scale bug in RejectionGamma
fritzo Jan 24, 2018
c6a9abb
Fix ShapeAugmentedGamma.score_parts()
fritzo Jan 24, 2018
f01d2d1
Resolve merge conflict
fritzo Jan 24, 2018
bcfc39d
Work around old pytorch wheel used in testing
fritzo Jan 24, 2018
1a4d88f
Fix attribute dependency error in rejector tests
fritzo Jan 25, 2018
fce5612
Unmark xfailing rsvi test
fritzo Jan 25, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyro.distributions.multivariate_normal import MultivariateNormal
from pyro.distributions.poisson import Poisson
from pyro.distributions.random_primitive import RandomPrimitive
from pyro.distributions.rejector import Rejector # noqa: F401

# distribution classes with working torch versions in torch.distributions
from pyro.distributions.torch.bernoulli import Bernoulli
Expand Down
29 changes: 27 additions & 2 deletions pyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from abc import ABCMeta, abstractmethod

import torch
from six import add_metaclass

import torch
from pyro.distributions.score_parts import ScoreParts


@add_metaclass(ABCMeta)
class Distribution(object):
Expand Down Expand Up @@ -85,6 +87,7 @@ class Distribution(object):
Take a look at the `examples <http://pyro.ai/examples>`_ to see how they interact
with inference algorithms.
"""
stateful = False
reparameterized = False
enumerable = False

Expand Down Expand Up @@ -168,7 +171,8 @@ def log_pdf(self, x, *args, **kwargs):
Evaluates total log probability density of a batch of samples.

:param torch.autograd.Variable x: A value.
:return: total log probability density as a one-dimensional torch.autograd.Variable of size 1.
:return: total log probability density as a one-dimensional
torch.autograd.Variable of size 1.
:rtype: torch.autograd.Variable
"""
return torch.sum(self.batch_log_pdf(x, *args, **kwargs))
Expand All @@ -187,6 +191,27 @@ def batch_log_pdf(self, x, *args, **kwargs):
"""
raise NotImplementedError

def score_parts(self, x, *args, **kwargs):
"""
Computes ingredients for stochastic gradient estimators of ELBO.

The default implementation is correct both for non-reparameterized and
for fully reparameterized distributions. Partially reparameterized
distributions should override this method to compute correct
`.score_function` and `.entropy_term` parts.

:param torch.autograd.Variable x: A single value or batch of values.
:return: A `ScoreParts` object containing parts of the ELBO estimator.
:rtype: ScoreParts
"""
log_pdf = self.batch_log_pdf(x, *args, **kwargs)
if self.reparameterized:
return ScoreParts(log_pdf, 0, log_pdf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe explicit names in namedtuple here?

else:
# XXX should the user be able to control inclusion of the entropy term?
# See Roeder, Wu, Duvenaud (2017) "Sticking the Landing" https://arxiv.org/abs/1703.09194
return ScoreParts(log_pdf, log_pdf, 0)

def enumerate_support(self, *args, **kwargs):
"""
Returns a representation of the parametrized distribution's support.
Expand Down
5 changes: 5 additions & 0 deletions pyro/distributions/random_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class RandomPrimitive(Distribution):
__slots__ = ['dist_class']

def __init__(self, dist_class):
if dist_class.stateful:
raise TypeError('Cannot wrap stateful class {} in RandomPrimitive.'.format(type(dist_class)))
self.dist_class = dist_class
super(RandomPrimitive, self).__init__()

Expand Down Expand Up @@ -44,6 +46,9 @@ def log_pdf(self, x, *args, **kwargs):
def batch_log_pdf(self, x, *args, **kwargs):
return self.dist_class(*args, **kwargs).batch_log_pdf(x)

def score_parts(self, x, *args, **kwargs):
return self.dist_class(*args, **kwargs).score_parts(x)

def enumerate_support(self, *args, **kwargs):
return self.dist_class(*args, **kwargs).enumerate_support()

Expand Down
63 changes: 63 additions & 0 deletions pyro/distributions/rejector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import absolute_import, division, print_function

import torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import copy_docs_from


@copy_docs_from(Distribution)
class Rejector(Distribution):
"""
Rejection sampled distribution given an acceptance rate function.

:param Distribution propose: A proposal distribution that samples batched
propsals via `propose()`.
:param callable log_prob_accept: A callable that inputs a batch of
proposals and returns a batch of log acceptance probabilities.
:param log_scale: Total log probability of acceptance. This is needed
only if :meth:`batch_log_pdf` needs to be normalized.
"""
stateful = True
reparameterized = True

def __init__(self, propose, log_prob_accept, log_scale=None):
self.propose = propose
self.log_prob_accept = log_prob_accept
self._log_prob_accept_cache = None, None
self._log_scale = 0 if log_scale is None else log_scale
self._propose_batch_log_pdf_cache = None, None

def _log_prob_accept(self, x):
if x is not self._log_prob_accept_cache[0]:
self._log_prob_accept_cache = x, self.log_prob_accept(x)
return self._log_prob_accept_cache[1]

def _propose_batch_log_pdf(self, x):
if x is not self._propose_batch_log_pdf_cache[0]:
self._propose_batch_log_pdf_cache = x, self.propose.batch_log_pdf(x)
return self._propose_batch_log_pdf_cache[1]

def sample(self):
# Implements parallel batched accept-reject sampling.
x = self.propose()
log_prob_accept = self.log_prob_accept(x)
probs = torch.exp(log_prob_accept).clamp_(0.0, 1.0)
done = torch.bernoulli(probs).byte()
while not done.all():
proposed_x = self.propose()
log_prob_accept = self.log_prob_accept(proposed_x)
prob_accept = torch.exp(log_prob_accept).clamp_(0.0, 1.0)
accept = torch.bernoulli(prob_accept).byte() & ~done
if accept.any():
x[accept] = proposed_x[accept]
done |= accept
return x

def batch_log_pdf(self, x):
return self._propose_batch_log_pdf(x) + self._log_prob_accept(x) - self._log_scale

def score_parts(self, x):
score_function = self._log_prob_accept(x)
log_pdf = self.batch_log_pdf(x)
return ScoreParts(log_pdf, score_function, log_pdf)
19 changes: 19 additions & 0 deletions pyro/distributions/score_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import absolute_import, division, print_function

from collections import namedtuple


class ScoreParts(namedtuple('ScoreParts', ['log_pdf', 'score_function', 'entropy'])):
"""
This data structure stores terms used in stochastic gradient estimators that
combine the pathwise estimator and the score function estimator.
"""
def __mul__(self, scale):
"""
Scale appropriate terms of a gradient estimator by a data multiplicity factor.
Note that the `score_function` term should not be scaled.
"""
log_pdf, score_function, entropy = self
return ScoreParts(log_pdf * scale, score_function, entropy * scale)

__rmul__ = __mul__
12 changes: 9 additions & 3 deletions pyro/distributions/testing/fakes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import absolute_import, division, print_function

from pyro.distributions import Beta
from pyro.distributions import Gamma
from pyro.distributions import Normal
from pyro.distributions.random_primitive import RandomPrimitive
from pyro.distributions.torch.beta import Beta
from pyro.distributions.torch.dirichlet import Dirichlet
from pyro.distributions.torch.gamma import Gamma
from pyro.distributions.torch.normal import Normal


class NonreparameterizedBeta(Beta):
reparameterized = False


class NonreparameterizedDirichlet(Dirichlet):
reparameterized = False


class NonreparameterizedGamma(Gamma):
reparameterized = False

Expand All @@ -19,5 +24,6 @@ class NonreparameterizedNormal(Normal):


nonreparameterized_beta = RandomPrimitive(NonreparameterizedBeta)
nonreparameterized_dirichlet = RandomPrimitive(NonreparameterizedDirichlet)
nonreparameterized_gamma = RandomPrimitive(NonreparameterizedGamma)
nonreparameterized_normal = RandomPrimitive(NonreparameterizedNormal)
21 changes: 21 additions & 0 deletions pyro/distributions/testing/rejection_exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import absolute_import, division, print_function

from pyro.distributions import Exponential
from pyro.distributions.rejector import Rejector
from pyro.distributions.util import copy_docs_from


@copy_docs_from(Exponential)
class RejectionExponential(Rejector):
def __init__(self, rate, factor):
assert (factor <= 1).all()
self.rate = rate
self.factor = factor
propose = Exponential(self.factor * self.rate)
log_scale = self.factor.log()
super(RejectionExponential, self).__init__(propose, self.log_prob_accept, log_scale)

def log_prob_accept(self, x):
result = (self.factor - 1) * self.rate * x
assert result.max() <= 0, result.max()
return result
93 changes: 93 additions & 0 deletions pyro/distributions/testing/rejection_gamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import absolute_import, division, print_function

import torch
from pyro.distributions.distribution import Distribution
from pyro.distributions.torch.gamma import Gamma
from pyro.distributions.rejector import Rejector
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.util import copy_docs_from


@copy_docs_from(Gamma)
class RejectionStandardGamma(Rejector):
"""
Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution.
This assumes `alpha >= 1` and does not boost `alpha` boosting or
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring typo

augment shape.
"""
def __init__(self, alpha):
super(RejectionStandardGamma, self).__init__(self.propose, self.log_prob_accept)
if alpha.data.min() < 1:
raise NotImplementedError('alpha < 1 is not supported')
self.alpha = alpha
self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha))
# The following are Marsaglia & Tsang's variable names.
self._d = self.alpha - 1.0 / 3.0
self._c = 1.0 / torch.sqrt(9.0 * self._d)

def propose(self):
# Marsaglia & Tsang's x == Naesseth's epsilon
x = self.alpha.new(self.alpha.shape).normal_()
y = 1.0 + self._c * x
v = y * y * y
return (self._d * v).clamp_(1e-30, 1e30)

def log_prob_accept(self, value):
v = value / self._d
y = torch.pow(v, 1.0 / 3.0)
x = (y - 1.0) / self._c
log_prob_accept = 0.5 * x * x + self._d * (1.0 - v + torch.log(v))
log_prob_accept[y <= 0] = -float('inf')
return log_prob_accept

def batch_log_pdf(self, x):
return self._standard_gamma.batch_log_pdf(x)


@copy_docs_from(Gamma)
class RejectionGamma(Distribution):
reparameterized = True

def __init__(self, alpha, beta):
self._standard_gamma = RejectionStandardGamma(alpha)
self.beta = beta

def sample(self):
return self._standard_gamma.sample() * self.beta

def batch_log_pdf(self, x):
return self._standard_gamma.batch_log_pdf(x / self.beta) - torch.log(self.beta)

def score_parts(self, x):
log_pdf, score_function, _ = self._standard_gamma.score_parts(x / self.beta)
log_pdf = log_pdf - torch.log(self.beta)
return ScoreParts(log_pdf, score_function, log_pdf)


@copy_docs_from(Gamma)
class ShapeAugmentedGamma(Gamma):
stateful = True
reparameterized = True

def __init__(self, alpha, beta, boost=1):
if alpha.min() + boost < 1:
raise ValueError('Need to boost at least once for alpha < 1')
super(ShapeAugmentedGamma, self).__init__(alpha, beta)
self._boost = boost
self._rejection_gamma = RejectionGamma(alpha + boost, beta)
self._unboost_x_cache = None, None

def sample(self):
x = self._rejection_gamma.sample()
boosted_x = x.clone()
for i in range(self._boost):
boosted_x *= (1 - x.new(x.shape).uniform_()) ** (1 / (i + self.alpha))
self._unboost_x_cache = boosted_x, x
return boosted_x

def score_parts(self, boosted_x=None):
if boosted_x is None:
boosted_x = self._unboost_x_cache[0]
assert boosted_x is self._unboost_x_cache[0]
x = self._unboost_x_cache[1]
return self._rejection_gamma.score_parts(x)
36 changes: 26 additions & 10 deletions pyro/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from pyro.util import check_model_guide_match


def is_identically_zero(x):
return isinstance(x, numbers.Number) and x == 0


def check_enum_discrete_can_run(model_trace, guide_trace):
"""
Checks whether `enum_discrete` is supported for the given (model, guide) pair.
Expand Down Expand Up @@ -74,6 +78,7 @@ def _get_traces(self, model, guide, *args, **kwargs):
model_trace = prune_subsample_sites(model_trace)
check_enum_discrete_can_run(model_trace, guide_trace)

guide_trace.compute_score_parts()
log_r = model_trace.batch_log_pdf() - guide_trace.batch_log_pdf()
weight = scale / self.num_particles
yield weight, model_trace, guide_trace, log_r
Expand All @@ -86,6 +91,7 @@ def _get_traces(self, model, guide, *args, **kwargs):
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)

guide_trace.compute_score_parts()
log_r = model_trace.log_pdf() - guide_trace.log_pdf()
weight = 1.0 / self.num_particles
yield weight, model_trace, guide_trace, log_r
Expand Down Expand Up @@ -140,26 +146,36 @@ def loss_and_grads(self, model, guide, *args, **kwargs):
for weight, model_trace, guide_trace, log_r in self._get_traces(model, guide, *args, **kwargs):
elbo_particle = weight * 0
surrogate_elbo_particle = weight * 0
batched = (self.enum_discrete and weight.size(0) > 1)
# compute elbo and surrogate elbo
if (self.enum_discrete and isinstance(weight, Variable) and weight.size(0) > 1):
log_pdf = "batch_log_pdf"
else:
log_pdf = "log_pdf"
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
model_log_pdf = model_site[log_pdf]
if model_site["is_observed"]:
elbo_particle += model_site[log_pdf]
surrogate_elbo_particle += model_site[log_pdf]
elbo_particle += model_log_pdf
surrogate_elbo_particle += model_log_pdf
else:
guide_site = guide_trace.nodes[name]
lp_lq = model_site[log_pdf] - guide_site[log_pdf]
elbo_particle += lp_lq
if guide_site["fn"].reparameterized:
surrogate_elbo_particle += lp_lq
else:
# XXX should the user be able to control inclusion of the -logq term below?
guide_log_pdf = guide_site[log_pdf] / guide_site["scale"] # not scaled by subsampling
surrogate_elbo_particle += model_site[log_pdf] + log_r.detach() * guide_log_pdf
guide_log_pdf, score_function_term, entropy_term = guide_site["score_parts"]

if not batched:
guide_log_pdf = guide_log_pdf.sum()
elbo_particle += model_log_pdf - guide_log_pdf
surrogate_elbo_particle += model_log_pdf

if not is_identically_zero(entropy_term):
if not batched:
entropy_term = entropy_term.sum()
surrogate_elbo_particle -= entropy_term

if not is_identically_zero(score_function_term):
if not batched:
score_function_term = score_function_term.sum()
surrogate_elbo_particle += log_r.detach() * score_function_term

# drop terms of weight zero to avoid nans
if isinstance(weight, numbers.Number):
Expand Down
Loading