-
-
Notifications
You must be signed in to change notification settings - Fork 983
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Rejection Sampling Variational Inference (#659)
* Sketch Rejector classes and .score_function_term() method * Implement RejectionStandardGamma using ImplicitRejector * Clean up plots * Integrate RSVI with Trace_ELBO * Implement shape-augmented gradients (but not score function) * Add kl_gamma_gamma for true gradient * Get ShapeAugmentedStandardGamma working * Add variance plot * Add REINFORCE estimator to comparison * Add check for .stateful classes in RandomPrimitive * add beta to rejgamma; add rejgamma to one integration test in test_inference.py * clamp in rejector * Move rejection_gamma, add gamma_dirichlet * Set ShapeAugmentedGamma.reparameterized = True * Simplify ShapeAugmentedGamma * Integrate analytic .entropy() into Trace_ELBO * Implement ShapeAugmentedDirichlet * Add g_cor term to ShapeAugmentedDirichlet * flake8 * Fix bugs in Mixture distribution * Remove irrelevant research stuff from the branch * Add tests for Rejector (one still fails) * Move RejectionExponential into distributions.testing * Fix broken stale dependency * Fix syntax errors in shape augmented gamma test * Fix merge conflict * Fix Rejector, but ShapeAugmentedGamma is still broken * Comment on reason for xfailing test * Address review comments * Fix beta rate-vs-scale bug in RejectionGamma * Fix ShapeAugmentedGamma.score_parts() * Resolve merge conflict * Work around old pytorch wheel used in testing * Fix attribute dependency error in rejector tests * Unmark xfailing rsvi test
- Loading branch information
Showing
12 changed files
with
437 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
from pyro.distributions.distribution import Distribution | ||
from pyro.distributions.score_parts import ScoreParts | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Distribution) | ||
class Rejector(Distribution): | ||
""" | ||
Rejection sampled distribution given an acceptance rate function. | ||
:param Distribution propose: A proposal distribution that samples batched | ||
propsals via `propose()`. | ||
:param callable log_prob_accept: A callable that inputs a batch of | ||
proposals and returns a batch of log acceptance probabilities. | ||
:param log_scale: Total log probability of acceptance. | ||
""" | ||
stateful = True | ||
reparameterized = True | ||
|
||
def __init__(self, propose, log_prob_accept, log_scale): | ||
self.propose = propose | ||
self.log_prob_accept = log_prob_accept | ||
self._log_scale = log_scale | ||
|
||
# These LRU(1) caches allow work to be shared across different method calls. | ||
self._log_prob_accept_cache = None, None | ||
self._propose_batch_log_pdf_cache = None, None | ||
|
||
def _log_prob_accept(self, x): | ||
if x is not self._log_prob_accept_cache[0]: | ||
self._log_prob_accept_cache = x, self.log_prob_accept(x) - self._log_scale | ||
return self._log_prob_accept_cache[1] | ||
|
||
def _propose_batch_log_pdf(self, x): | ||
if x is not self._propose_batch_log_pdf_cache[0]: | ||
self._propose_batch_log_pdf_cache = x, self.propose.batch_log_pdf(x) | ||
return self._propose_batch_log_pdf_cache[1] | ||
|
||
def sample(self): | ||
# Implements parallel batched accept-reject sampling. | ||
x = self.propose() | ||
log_prob_accept = self.log_prob_accept(x) | ||
probs = torch.exp(log_prob_accept).clamp_(0.0, 1.0) | ||
done = torch.bernoulli(probs).byte() | ||
while not done.all(): | ||
proposed_x = self.propose() | ||
log_prob_accept = self.log_prob_accept(proposed_x) | ||
prob_accept = torch.exp(log_prob_accept).clamp_(0.0, 1.0) | ||
accept = torch.bernoulli(prob_accept).byte() & ~done | ||
if accept.any(): | ||
x[accept] = proposed_x[accept] | ||
done |= accept | ||
return x | ||
|
||
def batch_log_pdf(self, x): | ||
return self._propose_batch_log_pdf(x) + self._log_prob_accept(x) | ||
|
||
def score_parts(self, x): | ||
score_function = self._log_prob_accept(x) | ||
log_pdf = self.batch_log_pdf(x) | ||
return ScoreParts(log_pdf, score_function, log_pdf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from collections import namedtuple | ||
|
||
|
||
class ScoreParts(namedtuple('ScoreParts', ['log_pdf', 'score_function', 'entropy_term'])): | ||
""" | ||
This data structure stores terms used in stochastic gradient estimators that | ||
combine the pathwise estimator and the score function estimator. | ||
""" | ||
def __mul__(self, scale): | ||
""" | ||
Scale appropriate terms of a gradient estimator by a data multiplicity factor. | ||
Note that the `score_function` term should not be scaled. | ||
""" | ||
log_pdf, score_function, entropy_term = self | ||
return ScoreParts(log_pdf * scale, score_function, entropy_term * scale) | ||
|
||
__rmul__ = __mul__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from pyro.distributions import Exponential | ||
from pyro.distributions.rejector import Rejector | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Exponential) | ||
class RejectionExponential(Rejector): | ||
def __init__(self, rate, factor): | ||
assert (factor <= 1).all() | ||
self.rate = rate | ||
self.factor = factor | ||
propose = Exponential(self.factor * self.rate) | ||
log_scale = self.factor.log() | ||
super(RejectionExponential, self).__init__(propose, self.log_prob_accept, log_scale) | ||
|
||
def log_prob_accept(self, x): | ||
result = (self.factor - 1) * self.rate * x | ||
assert result.max() <= 0, result.max() | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
|
||
from pyro.distributions.rejector import Rejector | ||
from pyro.distributions.score_parts import ScoreParts | ||
from pyro.distributions.torch.gamma import Gamma | ||
from pyro.distributions.torch.normal import Normal | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Gamma) | ||
class RejectionStandardGamma(Rejector): | ||
""" | ||
Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution. | ||
This assumes `alpha >= 1` and does not boost `alpha` or augment shape. | ||
""" | ||
def __init__(self, alpha): | ||
if alpha.data.min() < 1: | ||
raise NotImplementedError('alpha < 1 is not supported') | ||
self.alpha = alpha | ||
self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha)) | ||
# The following are Marsaglia & Tsang's variable names. | ||
self._d = self.alpha - 1.0 / 3.0 | ||
self._c = 1.0 / torch.sqrt(9.0 * self._d) | ||
# Compute log scale using Gamma.batch_log_pdf(). | ||
x = self._d.detach() # just an arbitrary x. | ||
log_scale = self.propose_batch_log_pdf(x) + self.log_prob_accept(x) - self.batch_log_pdf(x) | ||
super(RejectionStandardGamma, self).__init__(self.propose, self.log_prob_accept, log_scale) | ||
|
||
def propose(self): | ||
# Marsaglia & Tsang's x == Naesseth's epsilon | ||
x = self.alpha.new(self.alpha.shape).normal_() | ||
y = 1.0 + self._c * x | ||
v = y * y * y | ||
return (self._d * v).clamp_(1e-30, 1e30) | ||
|
||
def propose_batch_log_pdf(self, value): | ||
v = value / self._d | ||
result = -self._d.log() | ||
y = v.pow(1 / 3) | ||
result -= torch.log(3 * y ** 2) | ||
x = (y - 1) / self._c | ||
result -= self._c.log() | ||
result += Normal(torch.zeros_like(self.alpha), torch.ones_like(self.alpha)).batch_log_pdf(x) | ||
return result | ||
|
||
def log_prob_accept(self, value): | ||
v = value / self._d | ||
y = torch.pow(v, 1.0 / 3.0) | ||
x = (y - 1.0) / self._c | ||
log_prob_accept = 0.5 * x * x + self._d * (1.0 - v + torch.log(v)) | ||
log_prob_accept[y <= 0] = -float('inf') | ||
return log_prob_accept | ||
|
||
def batch_log_pdf(self, x): | ||
return self._standard_gamma.batch_log_pdf(x) | ||
|
||
|
||
@copy_docs_from(Gamma) | ||
class RejectionGamma(Gamma): | ||
stateful = True | ||
reparameterized = True | ||
|
||
def __init__(self, alpha, beta): | ||
super(RejectionGamma, self).__init__(alpha, beta) | ||
self._standard_gamma = RejectionStandardGamma(alpha) | ||
self.beta = beta | ||
|
||
def sample(self): | ||
return self._standard_gamma.sample() / self.beta | ||
|
||
def batch_log_pdf(self, x): | ||
return self._standard_gamma.batch_log_pdf(x * self.beta) + torch.log(self.beta) | ||
|
||
def score_parts(self, x): | ||
log_pdf, score_function, _ = self._standard_gamma.score_parts(x * self.beta) | ||
log_pdf = log_pdf + torch.log(self.beta) | ||
return ScoreParts(log_pdf, score_function, log_pdf) | ||
|
||
|
||
@copy_docs_from(Gamma) | ||
class ShapeAugmentedGamma(Gamma): | ||
""" | ||
This implements the shape augmentation trick of | ||
Naesseth, Ruiz, Linderman, Blei (2017) https://arxiv.org/abs/1610.05683 | ||
""" | ||
stateful = True | ||
reparameterized = True | ||
|
||
def __init__(self, alpha, beta, boost=1): | ||
if alpha.min() + boost < 1: | ||
raise ValueError('Need to boost at least once for alpha < 1') | ||
super(ShapeAugmentedGamma, self).__init__(alpha, beta) | ||
self.alpha = alpha | ||
self._boost = boost | ||
self._rejection_gamma = RejectionGamma(alpha + boost, beta) | ||
self._unboost_x_cache = None, None | ||
|
||
def sample(self): | ||
x = self._rejection_gamma.sample() | ||
boosted_x = x.clone() | ||
for i in range(self._boost): | ||
boosted_x *= (1 - x.new(x.shape).uniform_()) ** (1 / (i + self.alpha)) | ||
self._unboost_x_cache = boosted_x, x | ||
return boosted_x | ||
|
||
def score_parts(self, boosted_x=None): | ||
if boosted_x is None: | ||
boosted_x = self._unboost_x_cache[0] | ||
assert boosted_x is self._unboost_x_cache[0] | ||
x = self._unboost_x_cache[1] | ||
_, score_function, _ = self._rejection_gamma.score_parts(x) | ||
log_pdf = self.batch_log_pdf(boosted_x) | ||
return ScoreParts(log_pdf, score_function, log_pdf) |
Oops, something went wrong.