Skip to content

Commit

Permalink
Cleaning up Gibbs with Gradients example.
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Jun 14, 2021
1 parent 51ba7a7 commit 2d46773
Showing 1 changed file with 135 additions and 105 deletions.
240 changes: 135 additions & 105 deletions examples/gibbs_with_gradients.dx
Original file line number Diff line number Diff line change
@@ -1,43 +1,51 @@
'Main algorithm from "Oops I Took A Gradient: Scalable Sampling for Discrete Distributions"
by Will Grathwohl, Kevin Swersky, Milad Hashemi, David Duvenaud, Chris J. Maddison
[Arxiv Link](https://arxiv.org/abs/2102.04509) demonstrated on an Ising model.
'This is a demo of an MCMC sampler from the paper:
[Oops I Took A Gradient: Scalable Sampling for Discrete Distributions](https://arxiv.org/abs/2102.04509)
demonstrated on an Ising model.

import plot
'The algorithm looks a lot like standard [Gibbs sampling](https://en.wikipedia.org/wiki/Gibbs_sampling),
but the dimension to be flipped is chosen based on the gradient
of the unnormalized density with respect to its inputs.
Although the inputs are discrete, the main idea of the
paper is to cheat and cast the function to one that
has continuous inputs, so that the gradient is well-defined.
This might sound a bit hacky, but the resulting MCMC operator
still has the correct marginal distribution, since we apply
a [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) correction step.

'### Helper Functions

def hue2rgb (p:Float) (q:Float) (t:Float) : Float =
t = t - floor t
if t < (1.0/6.0)
then p + (q - p) * 6.0 * t
else if t < (1.0/2.0)
then q
else if t < (2.0/3.0)
then p + (q - p) * (2.0/3.0 - t) * 6.0
else p
import plot

def hslToRgb (h:Float) (s:Float) (l:Float) : (Fin 3)=>Float =
if s == 0.0
then [l, l, l] -- achromatic
else
q = select (l < 0.5) (l * (1.0 + s)) (l + s - l * s)
p = 2.0 * l - q
r = hue2rgb p q (h + 1.0/3.0)
g = hue2rgb p q h
b = hue2rgb p q (h - 1.0/3.0)
[r, g, b]
'### Helper Functions

instance Arbitrary Bool
instance Arbitrary Bool -- Todo: Move to prelude.
arb = \key. rand key < 0.5

def grad_and_value (f:a->Float) (x:a) : (a & Float) =
(val, vjpfun) = vjp f x
(vjpfun 1.0, val)

-- '## Standard Gibbs
def flipEntry (x:n=>Bool) (flip_ix:n) : n=>Bool =
yieldState x \xref.
xref!flip_ix := not x.flip_ix


'## Standard Gibbs Sampler

def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool =

[key_sample, key_accept] = splitKey key

-- sample which dimension to change and flip it.
flip_ix = randIdx key_sample
x' = flipEntry x flip_ix

-- accept / reject step.
acceptance_rate = exp (f x' - f x)
if rand key_accept < acceptance_rate
then x'
else x


--def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool =
-- [key_sample, key_accept] = splitKey key

'## Gibbs with Gradients Sampler

Expand All @@ -46,17 +54,10 @@ def boolToFloat (x:Bool) : Float =
True -> 1.0
False -> -1.0

def flipEntry (x:n=>Bool) (i:n) : n=>Bool =
yieldState x \xref.
xref!i := not x.i


def floatToBool (x:Float) : Bool =
select (x > 0.0) True False

def gibbsWithGradients
(x:n=>Bool)
(f:n=>Float->Float)
(key:Key)
: n=>Bool =
def gibbsWithGradients (x:n=>Bool) (f:n=>Float->Float) (key:Key) : n=>Bool =

[key_sample, key_accept] = splitKey key

Expand All @@ -75,118 +76,147 @@ def gibbsWithGradients
diff_x' = -xFloat' * dfdx
log_qxx' = logsoftmax (diff_x' / 2.0)

-- Compute MH acceptance ratio.
log_acceptance_rate = (f xFloat' - fx) + log_qxx'.i - log_qx'x.i

-- accept / reject step.
if rand key_accept < exp log_acceptance_rate
-- MH accept/reject.
acceptance_rate = exp (f xFloat' - fx + log_qxx'.i - log_qx'x.i)
if rand key_accept < acceptance_rate
then x'
else x




def runSampler (init:n=>Bool) (f:n=>Float -> Float) (iters:Int) : n=>Bool =
snd $ yieldState (0, init) \iref.
for i:(Fin iters).
(i, x) = get iref
iref := (i + 1, gibbsWithGradients x f (newKey i))


'## Ising Model

def wrapidx (n:Type) (i:Int) : n =
asidx $ mod i $ size n -- Index wrapping around at ends.

def incwrap (i:n) : n = -- Increment index, wrapping around at ends.
asidx $ mod ((ordinal i) + 1) $ size n
-- Increment/decrement index, wrapping around at ends.
def incwrap (i:n) : n = asidx $ mod ((ordinal i) + 1) $ size n
def decwrap (i:n) : n = asidx $ mod ((ordinal i) - 1) $ size n

def decwrap (i:n) : n = -- Decrement index, wrapping around at ends.
asidx $ mod ((ordinal i) - 1) $ size n

def ising_logprob (x:n=>m=>Float) (theta:Float) : Float =
def ising_logprob (x:n=>m=>Float) (bias:n=>m=>Float) (theta:Float) : Float =
-- x is -1 or 1
theta * sum for (i, j).
sum for (i, j).
t1 = x.i.j * x.(incwrap i).j
t2 = x.i.j * x.(decwrap i).j
t3 = x.i.j * x.i.(incwrap j)
t4 = x.i.j * x.i.(decwrap j)
t1 + t2 + t3 + t4
theta * (t1 + t2 + t3 + t4) + bias.i.j * x.i.j


'## Plotting utilities

'## Generate images
def hue2rgb (p:Float) (q:Float) (t:Float) : Float =
t = t - floor t
if t < (1.0/6.0)
then p + (q - p) * 6.0 * t
else if t < (1.0/2.0)
then q
else if t < (2.0/3.0)
then p + (q - p) * (2.0/3.0 - t) * 6.0
else p

def hslToRgb (h:Float) (s:Float) (l:Float) : Color =
if s == 0.0
then [l, l, l] -- achromatic
else
q = select (l < 0.5) (l * (1.0 + s)) (l + s - l * s)
p = 2.0 * l - q
r = hue2rgb p q (h + 1.0/3.0)
g = hue2rgb p q h
b = hue2rgb p q (h - 1.0/3.0)
[r, g, b]

def probToColor (x:Bool) (grad:Float) : Color =
-- For visualizing the probability that a given bit will flip.
-- Turns pixels red if they have a high chance of flipping.

scaled_change_prob = clip (0.0, 1.0) (100.0 * grad)

hue = 0.0 -- red
saturation = 1.0 -- fully saturated
lightness = case x of
True -> scaled_change_prob
False -> 1.0 - scaled_change_prob

hslToRgb hue saturation lightness


def pngsToSavedGif (delay:Int) (pngs:t=>Png) (outFileName:String) : Gif =
unsafeIO \().
withTempFiles \pngFiles.
for i.
writeFile pngFiles.i pngs.i
shellOut $
"convert" <> " -delay " <> show delay <> " " <>
concat (for i. "png:" <> pngFiles.i <> " ") <>
"gif:" <> outFileName <> ".gif"

def boolMatToImage (x:n=>m=>Bool) : n=>m=>(Fin 3)=>Float =
for i j.
xf = boolToFloat x.i.j
[xf, xf, xf]

'## Set up a particular Ising model

png = unsafeIO do
readFile "examples/peace.png"

:t png

theta = 0.5
N = Fin 60
x : (N & N)=>Bool = arb (newKey 0)

def f (x:n=>Float) : Float =
xu:N=>N=>Float = for i j.
def flattened_ising (x:n=>Float) : Float =
-- unflattens x.
xu = for i:N. for j:N.
x.(unsafeFromOrdinal _ (ordinal (i,j)))
ising_logprob xu theta

x' = runSampler x f 10000

x'flat:N=>N=>Bool = for i j.
x'.((ordinal (i, j))@_)
'## Generate images

:html imshow $ boolMatToImage x'flat
'### Gibbs with Gradients

def runSampler2 (init:n=>Bool) (f:n=>Float -> Float)
def runSampler (samplerStep: n=>Bool -> (n=>Float->Float) -> Key -> (n=>Bool))
(init:n=>Bool) (f:n=>Float -> Float)
(iters:Int) (writePeriod:Int) : List (n=>Bool) =
yieldAccum (ListMonoid (n=>Bool)) \list.
yieldState init \state.
for i:(Fin iters).
x = get state
state := gibbsWithGradients x f (newKey (ordinal i))
state := samplerStep x f (newKey (ordinal i))
if mod (ordinal i) writePeriod == 0 then
append list x

num_iters = 10000
xmovie' = runSampler2 x f num_iters 1000
num_iters = 1000
write_period = 100

xmovie' = runSampler gibbsWithGradients x flattened_ising num_iters write_period
(AsList _ xmovie) = xmovie'

xmovieflat:(Fin _)=>N=>N=>(Fin 3)=>Float =
xmovieflat:(Fin _)=>N=>N=>Color =
for i.
xf = map boolToFloat xmovie.i
dfdx = grad f xf
gc = map exp $ logsoftmax (-xf * dfdx / 2.0)
for j k.
h = 0.0
s = 1.0
scaled_change_prob = clip (0.0, 1.0) (100.0 * gc.(j, k))
l = case xmovie.i.(j, k) of
True -> scaled_change_prob
False -> 1.0 - scaled_change_prob
--l = scaled_change_prob
hslToRgb h s l
--1.0 .* cv + 0.5 .* [xf.(j, k), xf.(j, k), xf.(j, k)]
--[xf.(j, k), (1.0 - 1010.0 * xf.(j, k)) * gc.(j, k), xf.(j, k)]
dfdx = grad flattened_ising xf
flip_prob = softmax (-xf * dfdx / 2.0)
for j k. probToColor xmovie.i.(j, k) flip_prob.(j, k)

:html imseqshow xmovieflat
pngsToSavedGif 1 (map imgToPng xmovieflat) "gwg"

def withTempFile' (action: FilePath -> {IO} a) : {IO} a =
tmpFile = newTempFile ()
result = action tmpFile
--deleteFile tmpFile
result

def pngsToSavedGif (delay:Int) (pngs:t=>Png) : Gif = unsafeIO \().
withTempFiles \pngFiles.
for i. writeFile pngFiles.i pngs.i
withTempFile' \gifFile.
shellOut $
"convert" <> " -delay " <> show delay <> " " <>
concat (for i. "png:" <> pngFiles.i <> " ") <>
"gif:" <> gifFile
readFile gifFile
'### Standard Gibbs

def wrappedGibbs (x:n=>Bool) (f:n=>Float->Float) (key:Key) : n=>Bool =
boolf = \x'.
f $ map boolToFloat x'

gibbsUpdate x boolf key

xmovie'' = runSampler wrappedGibbs x flattened_ising num_iters write_period
(AsList _ xmovie''') = xmovie''
xmovieflat':(Fin _)=>N=>N=>Color =
for i.
xf = map boolToFloat xmovie'''.i
dfdx = 1.0
flip_prob = 1.0 / ( sq $ IToF (size N))
for j k. probToColor xmovie'''.i.(j, k) flip_prob

pngsToSavedGif 5 $ map imgToPng xmovieflat
:html imseqshow xmovieflat'
pngsToSavedGif 1 (map imgToPng xmovieflat') "gibbs"

0 comments on commit 2d46773

Please sign in to comment.