Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Rejection Sampling Variational Inference #659

Merged
merged 39 commits into from
Jan 25, 2018
Merged

Implement Rejection Sampling Variational Inference #659

merged 39 commits into from
Jan 25, 2018

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jan 3, 2018

Addresses #63

This implements Rejection Sampling Variational Inference (RSVI) of Naesseth et al. (2017).

This implements a new class Rejector as an abstract base class derived from Distribution. The crux to integrating these with Pyro's SVI in TraceGraph_ELBO is a new method Distribution.score_function_term() that can be overridden to compute g_cor in the paper. This method computes a partial score function for a partially reparameterized distribution such that: for an unreparameterized distribution it computes the entire score function, and for a fully reparameterized distribution it computes zero.

Tasks

  • integrate with Trace_ELBO
  • implement RejectionGamma example, e.g. see @slinderman's notebook
  • implement @naesseth's shape augmentation, e.g. see code
  • write convergence tests
  • write smoke tests
  • clean up PR to remove irrelevant testing stuff
  • fix unit tests

This was joint work by @martinjankowiak and @fritzo

@fritzo
Copy link
Member Author

fritzo commented Jan 3, 2018

cc @rachtsingh @naesseth

@fritzo fritzo changed the title Implement Rejector classes and .score_function_term() method Implement Rejection Sampling Variational Inference Jan 3, 2018
@slinderman
Copy link

Glad to see this incorporated into pyro, @fritzo and @martinjankowiak! One of the points discussed in the paper is that the score function term tends to be higher variance, and in practice it can often be safely ignored at the cost of a small bias.

You can also ameliorate it to some extent by "shape augmentation," i.e. introducing auxiliary uniform r.v.'s in order to increase the shape parameter and thereby increase the acceptance probability. I did not implement this in the demo notebook, but we used it in our paper. The code for that is in this repo: https://github.com/blei-lab/ars-reparameterization and @naesseth can probably provide more details on its implementation.

@fritzo
Copy link
Member Author

fritzo commented Jan 23, 2018

@martinjankowiak This is almost done, there's just one more test to fix.

Could you please review the changes to Rejector? I've simplified and added a log_scale argument that is used to correctly compute the .batch_log_pdf() scale.

@martinjankowiak
Copy link
Collaborator

cool awesome. i'll do so in the morning when i have a fresh brain

"""
log_pdf = self.batch_log_pdf(x, *args, **kwargs)
if self.reparameterized:
return ScoreParts(log_pdf, 0, log_pdf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe explicit names in namedtuple here?

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's investigate the xfail more

class RejectionStandardGamma(Rejector):
"""
Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution.
This assumes `alpha >= 1` and does not boost `alpha` boosting or
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring typo

@fritzo
Copy link
Member Author

fritzo commented Jan 24, 2018

@martinjankowiak I believe I've fixed Rejector (sorry for the earlier bug!), but ShapeAugmentedGamma gradients still appear to be wrong. Do you think it's worth merging this as-is and debugging ShapeAugmentedGamma in a follow-up PR that also adds ShapeAugmentedBeta etc?

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. @fritzo merge?

@fritzo
Copy link
Member Author

fritzo commented Jan 25, 2018

Ready to merge, finally.

@eb8680 eb8680 merged commit 9adc077 into dev Jan 25, 2018
@eb8680 eb8680 deleted the rejector branch January 25, 2018 10:41
@rmehta1987
Copy link

@fritzo are there any examples of using the accept-reject dirichlet or gamma?

@martinjankowiak
Copy link
Collaborator

@rmehta1987 i don't believe there are. but if you changed the Gammas in the guide in examples/sparse_gamma_def.py to ShapeAugmentedGammas (found here: pyro/distributions/testing/rejection_gamma.py) the example should work more or less out of the box (modulo optimization issues)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants