Skip to content

Commit

Permalink
Added WIP Gibbs with Gradients demo
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Apr 13, 2021
1 parent c8a5296 commit 51ba7a7
Showing 1 changed file with 192 additions and 0 deletions.
192 changes: 192 additions & 0 deletions examples/gibbs_with_gradients.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
'Main algorithm from "Oops I Took A Gradient: Scalable Sampling for Discrete Distributions"
by Will Grathwohl, Kevin Swersky, Milad Hashemi, David Duvenaud, Chris J. Maddison
[Arxiv Link](https://arxiv.org/abs/2102.04509) demonstrated on an Ising model.

import plot

'### Helper Functions

def hue2rgb (p:Float) (q:Float) (t:Float) : Float =
t = t - floor t
if t < (1.0/6.0)
then p + (q - p) * 6.0 * t
else if t < (1.0/2.0)
then q
else if t < (2.0/3.0)
then p + (q - p) * (2.0/3.0 - t) * 6.0
else p

def hslToRgb (h:Float) (s:Float) (l:Float) : (Fin 3)=>Float =
if s == 0.0
then [l, l, l] -- achromatic
else
q = select (l < 0.5) (l * (1.0 + s)) (l + s - l * s)
p = 2.0 * l - q
r = hue2rgb p q (h + 1.0/3.0)
g = hue2rgb p q h
b = hue2rgb p q (h - 1.0/3.0)
[r, g, b]

instance Arbitrary Bool
arb = \key. rand key < 0.5

def grad_and_value (f:a->Float) (x:a) : (a & Float) =
(val, vjpfun) = vjp f x
(vjpfun 1.0, val)

-- '## Standard Gibbs

--def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool =
-- [key_sample, key_accept] = splitKey key

'## Gibbs with Gradients Sampler

def boolToFloat (x:Bool) : Float =
case x of
True -> 1.0
False -> -1.0

def flipEntry (x:n=>Bool) (i:n) : n=>Bool =
yieldState x \xref.
xref!i := not x.i



def gibbsWithGradients
(x:n=>Bool)
(f:n=>Float->Float)
(key:Key)
: n=>Bool =

[key_sample, key_accept] = splitKey key

-- Compute proposal distribution.
xFloat = map boolToFloat x
(dfdx, fx) = grad_and_value f xFloat
diff_x = -xFloat * dfdx
log_qx'x = logsoftmax (diff_x / 2.0)

-- sample which dimension to change and flip it.
i = categorical log_qx'x key_sample
x' = flipEntry x i

-- Compute reverse transition distribution.
xFloat' = map boolToFloat x'
diff_x' = -xFloat' * dfdx
log_qxx' = logsoftmax (diff_x' / 2.0)

-- Compute MH acceptance ratio.
log_acceptance_rate = (f xFloat' - fx) + log_qxx'.i - log_qx'x.i

-- accept / reject step.
if rand key_accept < exp log_acceptance_rate
then x'
else x




def runSampler (init:n=>Bool) (f:n=>Float -> Float) (iters:Int) : n=>Bool =
snd $ yieldState (0, init) \iref.
for i:(Fin iters).
(i, x) = get iref
iref := (i + 1, gibbsWithGradients x f (newKey i))


'## Ising Model

def wrapidx (n:Type) (i:Int) : n =
asidx $ mod i $ size n -- Index wrapping around at ends.

def incwrap (i:n) : n = -- Increment index, wrapping around at ends.
asidx $ mod ((ordinal i) + 1) $ size n

def decwrap (i:n) : n = -- Decrement index, wrapping around at ends.
asidx $ mod ((ordinal i) - 1) $ size n

def ising_logprob (x:n=>m=>Float) (theta:Float) : Float =
-- x is -1 or 1
theta * sum for (i, j).
t1 = x.i.j * x.(incwrap i).j
t2 = x.i.j * x.(decwrap i).j
t3 = x.i.j * x.i.(incwrap j)
t4 = x.i.j * x.i.(decwrap j)
t1 + t2 + t3 + t4



'## Generate images

def boolMatToImage (x:n=>m=>Bool) : n=>m=>(Fin 3)=>Float =
for i j.
xf = boolToFloat x.i.j
[xf, xf, xf]

theta = 0.5
N = Fin 60
x : (N & N)=>Bool = arb (newKey 0)

def f (x:n=>Float) : Float =
xu:N=>N=>Float = for i j.
x.(unsafeFromOrdinal _ (ordinal (i,j)))
ising_logprob xu theta

x' = runSampler x f 10000

x'flat:N=>N=>Bool = for i j.
x'.((ordinal (i, j))@_)

:html imshow $ boolMatToImage x'flat

def runSampler2 (init:n=>Bool) (f:n=>Float -> Float)
(iters:Int) (writePeriod:Int) : List (n=>Bool) =
yieldAccum (ListMonoid (n=>Bool)) \list.
yieldState init \state.
for i:(Fin iters).
x = get state
state := gibbsWithGradients x f (newKey (ordinal i))
if mod (ordinal i) writePeriod == 0 then
append list x

num_iters = 10000
xmovie' = runSampler2 x f num_iters 1000

(AsList _ xmovie) = xmovie'

xmovieflat:(Fin _)=>N=>N=>(Fin 3)=>Float =
for i.
xf = map boolToFloat xmovie.i
dfdx = grad f xf
gc = map exp $ logsoftmax (-xf * dfdx / 2.0)
for j k.
h = 0.0
s = 1.0
scaled_change_prob = clip (0.0, 1.0) (100.0 * gc.(j, k))
l = case xmovie.i.(j, k) of
True -> scaled_change_prob
False -> 1.0 - scaled_change_prob
--l = scaled_change_prob
hslToRgb h s l
--1.0 .* cv + 0.5 .* [xf.(j, k), xf.(j, k), xf.(j, k)]
--[xf.(j, k), (1.0 - 1010.0 * xf.(j, k)) * gc.(j, k), xf.(j, k)]

:html imseqshow xmovieflat

def withTempFile' (action: FilePath -> {IO} a) : {IO} a =
tmpFile = newTempFile ()
result = action tmpFile
--deleteFile tmpFile
result

def pngsToSavedGif (delay:Int) (pngs:t=>Png) : Gif = unsafeIO \().
withTempFiles \pngFiles.
for i. writeFile pngFiles.i pngs.i
withTempFile' \gifFile.
shellOut $
"convert" <> " -delay " <> show delay <> " " <>
concat (for i. "png:" <> pngFiles.i <> " ") <>
"gif:" <> gifFile
readFile gifFile

pngsToSavedGif 5 $ map imgToPng xmovieflat

0 comments on commit 51ba7a7

Please sign in to comment.