Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement truncated normal distribution #78

Open
tbrx opened this issue Jan 11, 2018 · 7 comments · May be fixed by #121
Open

Implement truncated normal distribution #78

tbrx opened this issue Jan 11, 2018 · 7 comments · May be fixed by #121

Comments

@tbrx
Copy link
Collaborator

tbrx commented Jan 11, 2018

I recently found myself wanting a truncated normal distribution (e.g. a univariate normal which has parameters for specifying minimum and/or maximum values).

Is this a sufficiently common need to include here? If so, I'll volunteer.

@fritzo
Copy link

fritzo commented Jan 11, 2018

Ha ha @martinjankowiak and I were just talking about this last night 😄 PyTorch has a differentialbe erfinv so reparameterization should be easy. Sounds like a good idea to me.

@alicanb
Copy link
Collaborator

alicanb commented Jan 11, 2018

Can we make it more general than TruncatedNormal? I had written it in the design doc but we didn't discuss it, and honestly I don't know if there's any other way that explicitly providing truncated pdfs

@fritzo
Copy link

fritzo commented Jan 11, 2018

Reparameterization is tricky for bounded distributions. I see a few possible ways to handle this in a more general way:

  1. Support truncation only for univariate distributions that provide .cdf() and .icdf() methods.
  2. Use rejection sampling variational inference as in Implement Rejection Sampling Variational Inference pyro-ppl/pyro#659. This doesn't fit into the current torch.distributions API since the partially reparameterized gradient requires an additional log_prob term g_cor. In Pyro we're trying to break this out into a ScoreParts object. I'd like to see what we get wrong in Pyro before moving a more reasonable interface upstream to PyTorch.
  3. Support arbitrary truncation but do not implement .rsample(). Maybe we can provide .rsample() if both .cdf() and .icdf() are available, and provide .sample() otherwise.

@alicanb
Copy link
Collaborator

alicanb commented Jan 11, 2018

I think it's a good idea to implement .log_cdf() and .quantile() anyway. Shall I add it to to-do list?

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 11, 2018

Even if we have .cdf and .icdf (or .log_cdf and .quantile) I still am not sure that there is a good automatic way of handling general bounds, particularly if the bounds are used to isolate tail probabilities where numerical precision of these functions is not particularly great.

For example, I think for truncated normal we could use inverse transform sampling if we are within ±5 standard deviations (for Double anyway, it's probably less for Float…). But, if we are sampling from the tails (say, a standard normal truncated to have a lower bound of 6) we'd need to do something else. (e.g. I believe there's a rejection sampling algorithm for this in Devroye's book.)

@fritzo
Copy link

fritzo commented Jan 11, 2018

Hmm if we use rejection sampling to construct a sample, can we use .cdf to compute the reparameterized gradient as something like

dx/dt = dx_untruncated/dt
      + (cdf(x) - cdf(UB))/(cdf(UB) - cdf(LB)) pdf(LB)/pdf(x) dLB/dt
      - (cdf(LB) - cdf(x))/(cdf(UB) - cdf(LB)) pdf(UB)/pdf(x) dUB/dt 

where t is the parameter being varied?

@tbrx
Copy link
Collaborator Author

tbrx commented Jan 12, 2018

Something along those lines, yes — we just need to be able to re-normalize the part of the distribution which remains after truncation.

There is still the problem of tails. For example, suppose I would like to sample from a standard normal distribution with bounds lower=6, upper=8.

The cdf(x) of a standard normal distribution is given by 0.5*(1+torch.erf(x / math.sqrt(2))). For a FloatTensor, this is numerically 1.0 for both cdf(6) and cdf(8).

Maybe this is a bit of an edge case…?

@alicanb alicanb linked a pull request Feb 2, 2018 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants