-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Rejection Sampling Variational Inference #659
Merged
Changes from 26 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
21fd1c3
Sketch Rejector classes and .score_function_term() method
fritzo 48ca4df
Implement RejectionStandardGamma using ImplicitRejector
fritzo d486c03
Clean up plots
fritzo 304fffa
Integrate RSVI with Trace_ELBO
fritzo 2b3c234
Implement shape-augmented gradients (but not score function)
fritzo 51df863
Add kl_gamma_gamma for true gradient
fritzo 56e5311
Get ShapeAugmentedStandardGamma working
fritzo bfa6db3
Add variance plot
fritzo 4889777
Add REINFORCE estimator to comparison
fritzo 66b94c1
Add check for .stateful classes in RandomPrimitive
fritzo 24e73af
add beta to rejgamma; add rejgamma to one integration test in test_i…
martinjankowiak 2c3c3b9
clamp in rejector
martinjankowiak 23cf6b8
Merge branch 'dev' into rejector
fritzo 531833f
Move rejection_gamma, add gamma_dirichlet
fritzo 12ae69b
Set ShapeAugmentedGamma.reparameterized = True
fritzo 137111d
Simplify ShapeAugmentedGamma
fritzo b40c1ea
Integrate analytic .entropy() into Trace_ELBO
fritzo 17a80a9
Implement ShapeAugmentedDirichlet
fritzo 2e3e040
Add g_cor term to ShapeAugmentedDirichlet
fritzo 2c4aed8
flake8
fritzo 43c18b5
Fix bugs in Mixture distribution
fritzo 5d1610c
Merge branch 'dev' into rejector
fritzo f2d9502
Remove irrelevant research stuff from the branch
fritzo 0bc09f6
Add tests for Rejector (one still fails)
fritzo b6455c0
Move RejectionExponential into distributions.testing
fritzo 20ae559
Fix broken stale dependency
fritzo 3b48953
Fix syntax errors in shape augmented gamma test
fritzo 21c157a
Fix merge conflict
fritzo 1c9397f
Fix Rejector, but ShapeAugmentedGamma is still broken
fritzo 2e34667
Merge branch 'dev' into rejector
fritzo a20d9b3
Comment on reason for xfailing test
fritzo d1afe4a
Address review comments
fritzo 9c7fe3c
Merge branch 'dev' into rejector
fritzo ac20484
Fix beta rate-vs-scale bug in RejectionGamma
fritzo c6a9abb
Fix ShapeAugmentedGamma.score_parts()
fritzo f01d2d1
Resolve merge conflict
fritzo bcfc39d
Work around old pytorch wheel used in testing
fritzo 1a4d88f
Fix attribute dependency error in rejector tests
fritzo fce5612
Unmark xfailing rsvi test
fritzo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
from pyro.distributions.distribution import Distribution | ||
from pyro.distributions.score_parts import ScoreParts | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Distribution) | ||
class Rejector(Distribution): | ||
""" | ||
Rejection sampled distribution given an acceptance rate function. | ||
|
||
:param Distribution propose: A proposal distribution that samples batched | ||
propsals via `propose()`. | ||
:param callable log_prob_accept: A callable that inputs a batch of | ||
proposals and returns a batch of log acceptance probabilities. | ||
:param log_scale: Total log probability of acceptance. This is needed | ||
only if :meth:`batch_log_pdf` needs to be normalized. | ||
""" | ||
stateful = True | ||
reparameterized = True | ||
|
||
def __init__(self, propose, log_prob_accept, log_scale=None): | ||
self.propose = propose | ||
self.log_prob_accept = log_prob_accept | ||
self._log_prob_accept_cache = None, None | ||
self._log_scale = 0 if log_scale is None else log_scale | ||
self._propose_batch_log_pdf_cache = None, None | ||
|
||
def _log_prob_accept(self, x): | ||
if x is not self._log_prob_accept_cache[0]: | ||
self._log_prob_accept_cache = x, self.log_prob_accept(x) | ||
return self._log_prob_accept_cache[1] | ||
|
||
def _propose_batch_log_pdf(self, x): | ||
if x is not self._propose_batch_log_pdf_cache[0]: | ||
self._propose_batch_log_pdf_cache = x, self.propose.batch_log_pdf(x) | ||
return self._propose_batch_log_pdf_cache[1] | ||
|
||
def sample(self): | ||
# Implements parallel batched accept-reject sampling. | ||
x = self.propose() | ||
log_prob_accept = self.log_prob_accept(x) | ||
probs = torch.exp(log_prob_accept).clamp_(0.0, 1.0) | ||
done = torch.bernoulli(probs).byte() | ||
while not done.all(): | ||
proposed_x = self.propose() | ||
log_prob_accept = self.log_prob_accept(proposed_x) | ||
prob_accept = torch.exp(log_prob_accept).clamp_(0.0, 1.0) | ||
accept = torch.bernoulli(prob_accept).byte() & ~done | ||
if accept.any(): | ||
x[accept] = proposed_x[accept] | ||
done |= accept | ||
return x | ||
|
||
def batch_log_pdf(self, x): | ||
return self._propose_batch_log_pdf(x) + self._log_prob_accept(x) - self._log_scale | ||
|
||
def score_parts(self, x): | ||
score_function = self._log_prob_accept(x) | ||
log_pdf = self.batch_log_pdf(x) | ||
return ScoreParts(log_pdf, score_function, log_pdf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from collections import namedtuple | ||
|
||
|
||
class ScoreParts(namedtuple('ScoreParts', ['log_pdf', 'score_function', 'entropy'])): | ||
""" | ||
This data structure stores terms used in stochastic gradient estimators that | ||
combine the pathwise estimator and the score function estimator. | ||
""" | ||
def __mul__(self, scale): | ||
""" | ||
Scale appropriate terms of a gradient estimator by a data multiplicity factor. | ||
Note that the `score_function` term should not be scaled. | ||
""" | ||
log_pdf, score_function, entropy = self | ||
return ScoreParts(log_pdf * scale, score_function, entropy * scale) | ||
|
||
__rmul__ = __mul__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
from pyro.distributions import Exponential | ||
from pyro.distributions.rejector import Rejector | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Exponential) | ||
class RejectionExponential(Rejector): | ||
def __init__(self, rate, factor): | ||
assert (factor <= 1).all() | ||
self.rate = rate | ||
self.factor = factor | ||
propose = Exponential(self.factor * self.rate) | ||
log_scale = self.factor.log() | ||
super(RejectionExponential, self).__init__(propose, self.log_prob_accept, log_scale) | ||
|
||
def log_prob_accept(self, x): | ||
result = (self.factor - 1) * self.rate * x | ||
assert result.max() <= 0, result.max() | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import torch | ||
from pyro.distributions.distribution import Distribution | ||
from pyro.distributions.torch.gamma import Gamma | ||
from pyro.distributions.rejector import Rejector | ||
from pyro.distributions.score_parts import ScoreParts | ||
from pyro.distributions.util import copy_docs_from | ||
|
||
|
||
@copy_docs_from(Gamma) | ||
class RejectionStandardGamma(Rejector): | ||
""" | ||
Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution. | ||
This assumes `alpha >= 1` and does not boost `alpha` boosting or | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring typo |
||
augment shape. | ||
""" | ||
def __init__(self, alpha): | ||
super(RejectionStandardGamma, self).__init__(self.propose, self.log_prob_accept) | ||
if alpha.data.min() < 1: | ||
raise NotImplementedError('alpha < 1 is not supported') | ||
self.alpha = alpha | ||
self._standard_gamma = Gamma(alpha, alpha.new([1]).expand_as(alpha)) | ||
# The following are Marsaglia & Tsang's variable names. | ||
self._d = self.alpha - 1.0 / 3.0 | ||
self._c = 1.0 / torch.sqrt(9.0 * self._d) | ||
|
||
def propose(self): | ||
# Marsaglia & Tsang's x == Naesseth's epsilon | ||
x = self.alpha.new(self.alpha.shape).normal_() | ||
y = 1.0 + self._c * x | ||
v = y * y * y | ||
return (self._d * v).clamp_(1e-30, 1e30) | ||
|
||
def log_prob_accept(self, value): | ||
v = value / self._d | ||
y = torch.pow(v, 1.0 / 3.0) | ||
x = (y - 1.0) / self._c | ||
log_prob_accept = 0.5 * x * x + self._d * (1.0 - v + torch.log(v)) | ||
log_prob_accept[y <= 0] = -float('inf') | ||
return log_prob_accept | ||
|
||
def batch_log_pdf(self, x): | ||
return self._standard_gamma.batch_log_pdf(x) | ||
|
||
|
||
@copy_docs_from(Gamma) | ||
class RejectionGamma(Distribution): | ||
reparameterized = True | ||
|
||
def __init__(self, alpha, beta): | ||
self._standard_gamma = RejectionStandardGamma(alpha) | ||
self.beta = beta | ||
|
||
def sample(self): | ||
return self._standard_gamma.sample() * self.beta | ||
|
||
def batch_log_pdf(self, x): | ||
return self._standard_gamma.batch_log_pdf(x / self.beta) - torch.log(self.beta) | ||
|
||
def score_parts(self, x): | ||
log_pdf, score_function, _ = self._standard_gamma.score_parts(x / self.beta) | ||
log_pdf = log_pdf - torch.log(self.beta) | ||
return ScoreParts(log_pdf, score_function, log_pdf) | ||
|
||
|
||
@copy_docs_from(Gamma) | ||
class ShapeAugmentedGamma(Gamma): | ||
stateful = True | ||
reparameterized = True | ||
|
||
def __init__(self, alpha, beta, boost=1): | ||
if alpha.min() + boost < 1: | ||
raise ValueError('Need to boost at least once for alpha < 1') | ||
super(ShapeAugmentedGamma, self).__init__(alpha, beta) | ||
self._boost = boost | ||
self._rejection_gamma = RejectionGamma(alpha + boost, beta) | ||
self._unboost_x_cache = None, None | ||
|
||
def sample(self): | ||
x = self._rejection_gamma.sample() | ||
boosted_x = x.clone() | ||
for i in range(self._boost): | ||
boosted_x *= (1 - x.new(x.shape).uniform_()) ** (1 / (i + self.alpha)) | ||
self._unboost_x_cache = boosted_x, x | ||
return boosted_x | ||
|
||
def score_parts(self, boosted_x=None): | ||
if boosted_x is None: | ||
boosted_x = self._unboost_x_cache[0] | ||
assert boosted_x is self._unboost_x_cache[0] | ||
x = self._unboost_x_cache[1] | ||
return self._rejection_gamma.score_parts(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe explicit names in namedtuple here?