Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TruncatedDistribution #121

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Implement TruncatedDistribution #121

wants to merge 2 commits into from

Conversation

alicanb
Copy link
Collaborator

@alicanb alicanb commented Feb 2, 2018

this PR adds:

  • .cdf() and icdf() methods for Distribution and tests. (Populated only for Normal for now)
  • TruncatedDistribution class
  • TruncatedNormal class

closes #78, touches #120
@tbrx I forgot you volunteered for this, want to work together? This is a very rough sketch at the moment. @fritzo it's not at the stage where I request a review, but comments welcome as always

@vishwakftw
Copy link

Will the .cdf methods for existing distributions be added in this PR, or a new PR altogether?

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 2, 2018

I just added minimal working example (+tests) to get TruncatedNormal working. With cdfs for all current distributions I think this PR would be too large. What do you think?

@vishwakftw
Copy link

I thought the same too. Maybe after this is merged, I can start working on the populating PR. Hope that is fine.

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 2, 2018

You can cherry-pick the first commit to your branch and start working in parallel if you want?

@tbrx
Copy link
Collaborator

tbrx commented Feb 2, 2018

I was actually just thinking about this today and was exploring how disasterous it would be to try to implement a "generic" TruncatedDistribution, where we use inverse transform sampling to generate from it.

Some plots in this gist: https://gist.github.com/tbrx/18e7579d9b7ff7c2a84c17c300555fc1

Basically, it's pretty bad numerically once you are more than four standard deviations away from the mean, on a Gaussian, and falls apart entirely a little past five. This doesn't give me high hopes for e.g. Gamma…

I looked at the Scipy code this morning, and it actually appears to use inverse transform sampling for truncated normals. Higher-precision floating point though means that they can get quite far away from the mean before this is an issue.

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean but we need a safer way to do .new()

set_rng_seed(0) # see Note [Randomized statistical tests]
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = pytorch_dist.sample((5,))
try:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safest to enclose as little as possible in a try-except:

try:
    pytorch_cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
    pass
self.assertEqual(pytorch_cdf, scipy_dist.cdf(samples), message=pytorch_dist)

set_rng_seed(0) # see Note [Randomized statistical tests]
for pytorch_dist, scipy_dist in self.distribution_pairs:
samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape))
try:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, enclose as little as possible

super(TruncatedDistribution, self).__init__(*args, **kwargs)
self.base_dist = base_distribution
self.lower_bound, self.upper_bound, _ = broadcast_all(lower_bound, upper_bound,
getattr(self.base_dist,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really dangerous. Why do we need to broadcast? Can we simply set

self.lower_bound = lower_bound
self.upper_bound = upper_bound

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about supporting batched bounds while writing that part, but I gave up that idea later & forgot to change it.

is a generic sampler which is not the most efficient or accurate around tails of base distribution.
"""
shape = shape = self._extended_shape(sample_shape)
u = getattr(self.base_dist, list(self.base_dist.params.keys())[0]).new(shape).uniform_()
Copy link

@fritzo fritzo Feb 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks dangerous. I wish we had a .new() method to create a correctly-placed tensor from given distribution.

@apaszke Is there an established pattern to do this? Can we define a .new_tensor() method or something? This has been coming up often. Some of our distributions define a private ._new() but we haven't exposed this as a general interface.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems simple and safe to define a method-as-property like

class Distribution(object):
    @property
    def new_tensor(self):
        raise NotImplementedError

class Normal(Distribution):
    @property
    def new_tensor(self):
        return self.loc.new

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a common pattern except for new on tensors, but we never needed anything else

@fritzo
Copy link

fritzo commented Feb 2, 2018

Re: numerical stability, one option would be to use rejection sampling to draw samples and merely use the cdf derivative to compute reparameterized gradients:

def sample(self):
    ...use rejection sampling...

def rsample(self):
    x = self.sample()  # detached
    cdf = self.cdf(x)
    pdf = self.log_prob(x).exp()
    return x + (cdf.detach() - cdf) / pdf.detach()  # or something like this...

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 2, 2018 via email

@fritzo
Copy link

fritzo commented Feb 2, 2018

Do you think rejection sampling would be fast enough?

It would be cheap if we rejection sampled when cdf(lower_bound) < 0.02 and used inverse cdf otherwise. This would require four branches: rejection-sample + inverse-cdf for each of the lower_bound and upper_bound. We would not need any low level code; see e.g. pyro.distributions.Rejector.sample().

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 2, 2018

@fritzo thanks, I'll take a look and try to come up with something.

@fritzo
Copy link

fritzo commented Feb 2, 2018

I think .cdf() methods have many uses outside of TruncatedDistribution, and are worth adding even if this TruncatedDistribution ends up being too tricky to add in PyTorch 0.4

@tbrx
Copy link
Collaborator

tbrx commented Feb 3, 2018

Oh, I completely agree that adding .cdf and its inverse are useful independently of truncated distributions! I'm just wary of releasing a "general" TruncatedDistribution without some caveats, or issuing warnings about lack of precision.

For the truncated normal, example it seems like there are only 58 distinct floating point values between 4.5 and infinity. It seems like if your bounds are within ±4 standard deviations though this would work pretty much fine! Maybe that is the more common case than sampling or evaluating tail probabilities anyway.

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 5, 2018

Here are 2 gists for TruncatedNormal, one with lower_bound=3 and one with lower_bound=4. Something doesn't add up with rejection sampling, @fritzo @tbrx rejection sampling for tails takes a lot of time, most likely due to bad proposal, do you guys have any idea for a "one size fits all" proposal distribution?

@fritzo
Copy link

fritzo commented Feb 6, 2018

Do I understand correctly that the difficult case is when you're truncating e.g. a Normal(0, 1) to its tail like [10, float('inf'))? If inverse-cdf doesn't work here, then I don't know what will.

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 8, 2018

Here is an updated gist. Sampling from 4.5 sigma looks problematic, but sampling from 4 sigma looks okish

@tbrx
Copy link
Collaborator

tbrx commented Feb 8, 2018

Those plots look good! But I agree, I don't think inverse CDF sampling will work very well for a Normal(0,1) outside of the region [-4, 4] or maybe [-4.5, 4.5] in a pinch…

There are algorithms for sampling from the tail of a gaussian (e.g. on [4, \infty) ) in chapter 9 of http://www.nrbook.com/devroye/. This doesn't help, though, with computing the .log_prob once we get > 5 or so. Sorry I've been busy with a deadline — will try to look at this closer over the weekend or on Monday.

@alicanb
Copy link
Collaborator Author

alicanb commented Feb 12, 2018

I implemented sampling from tail algorithm, it's fast and looks good! Here's a gist. Precision problem with erf is still there though, not sure how to fix that...

@tbrx
Copy link
Collaborator

tbrx commented Feb 13, 2018

This is cool!! That would work really well for .sample, at least. For .rsample maybe we can handle a loss of precision.

I was wondering if there was a way of maybe directly approximating Phi(x) - Phi(4.0), or something like that, to help compute the denominator…?

EDIT: looking at the these numeric approximations, in particular the fourth, maybe it's possible get approximations for erf(b) - erf(a) by taking the difference of two of these expansions and canceling / re-arranging terms. (All we really need here is a stable way to compute log(erf(b) - erf(a))…)

@ragulpr
Copy link

ragulpr commented Feb 14, 2018

I'm very interested in this, I'm working on a similar thing. I'm calling it ConditionalExcessDistribution, but I'm really only looking at right censored (truncated) things at the moment.
Truncation and discretization are closely related (~truncating into many intervals) so I'm currently thinking about an API where this works hand in hand.

@alicanb
Copy link
Collaborator Author

alicanb commented Mar 22, 2018

New gist time 😄 This time I also have sampling times. https://gist.github.com/alicanb/c9e6567b7c512140ed43916b4dd30106 . At this point I'm inclined towards having TruncatedDistribution only have inverse-cdf sampling implemented, and TruncatedNormal using robert's algorithm. Still working on the precision problem with erf...

@fritzo
Copy link

fritzo commented Mar 25, 2018

I'm inclined towards having TruncatedDistribution only have inverse-cdf sampling implemented, and TruncatedNormal using robert's algorithm

That sounds reasonable, implementing one new generic distribution and one specific special-case distribution. It even makes sense to send them in the same PR.

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I let this drop, it would be nice to get it in before 0.4 release.

self.base_dist = base_distribution
cdf_low, cdf_high = self.base_dist.cdf(self.lower_bound), self.base_dist.cdf(self.upper_bound)
if sample_method in ['rejection', 'inversion']:
self.sample = {'rejection': self._rejection_sample, 'inversion': self._inversion_sample}[sample_method]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a circular reference and leaks memory.

def event_shape(self):
return self.base_dist.event_shape

def _inversion_sample(self, sample_shape=torch.Size()):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to implement Inversion and Rejection as different classes because the interfaces differ: Inversion allows reparametrization hence allows an .rsample() whereas Rejection is not reparametrizable and hence only implements .sample() (It can be partially reparametrized via RSVI but that requires yet a different interface). Also, Pyro defines a different Rejector class to do rejection sampling given a more general rejection criterion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to omit rejection sampling from pytorch actually. It's hard to make it work efficiently oob for a range of distributions. What do you think?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, let's omit rejection sampling.

def __init__(self, loc, scale, lower_bound=-float('inf'), upper_bound=float('inf'), sample_method='robert', *args, **kwargs):
super(TruncatedNormal, self).__init__(Normal(loc, scale), lower_bound, upper_bound, *args, **kwargs)
if sample_method in {'exp', 'robert'}:
self.sample = {'exp':self._exp_proposal, 'robert': self._robert_sample}[sample_method]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a circular reference. It's better to simply define an if statement in an .rsample() method.

@@ -23,6 +23,7 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then
export ASAN_OPTIONS=detect_leaks=0:symbolize=1
export PYTORCH_TEST_WITH_ASAN=1
# TODO: Figure out how to avoid hard-coding these paths
export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-5.0/bin/llvm-symbolizer
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like diff was tainted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement truncated normal distribution
6 participants