Skip to content

Commit

Permalink
Changed Gibbs with gradients demo into an image denoising task.
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Jun 26, 2021
1 parent 67515db commit 6211884
Showing 1 changed file with 116 additions and 49 deletions.
165 changes: 116 additions & 49 deletions examples/gibbs_with_gradients.dx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
'# Gibbs with Gradients

'This is a demo of an MCMC sampler from the paper:
[Oops I Took A Gradient: Scalable Sampling for Discrete Distributions](https://arxiv.org/abs/2102.04509)
demonstrated on an Ising model.
Expand All @@ -12,13 +14,14 @@ This might sound a bit hacky, but the resulting MCMC operator
still has the correct marginal distribution, since we apply
a [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) correction step.

'This demo uses the Ising model for image denoising.


import parser
import plot

'### Helper Functions

instance Arbitrary Bool -- Todo: Move to prelude.
arb = \key. rand key < 0.5
'### Helper Functions

def grad_and_value (f:a->Float) (x:a) : (a & Float) =
(val, vjpfun) = vjp f x
Expand All @@ -35,11 +38,11 @@ def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool =

[key_sample, key_accept] = splitKey key

-- sample which dimension to change and flip it.
-- Sample which dimension to change and flip it.
flip_ix = randIdx key_sample
x' = flipEntry x flip_ix

-- accept / reject step.
-- Accept / reject step.
acceptance_rate = exp (f x' - f x)
if rand key_accept < acceptance_rate
then x'
Expand All @@ -49,6 +52,12 @@ def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool =

'## Gibbs with Gradients Sampler

'The Gibbs with Gradients sampler has a slightly different function signature
than standard Gibbs. Instead of its log probability function taking in a
discrete array, it takes in an array of floats of the same size. This is necessary
so that it can be differentiated with respect to its input, even though we will
only call it on discrete inputs.

def boolToFloat (x:Bool) : Float =
case x of
True -> 1.0
Expand All @@ -58,7 +67,6 @@ def floatToBool (x:Float) : Bool =
select (x > 0.0) True False

def gibbsWithGradients (x:n=>Bool) (f:n=>Float->Float) (key:Key) : n=>Bool =

[key_sample, key_accept] = splitKey key

-- Compute proposal distribution.
Expand Down Expand Up @@ -104,6 +112,17 @@ def ising_logprob (x:n=>m=>Float) (bias:n=>m=>Float) (theta:Float) : Float =

'## Plotting utilities

def runSampler (samplerStep: n=>Bool -> (n=>Float->Float) -> Key -> (n=>Bool))
(init:n=>Bool) (f:n=>Float -> Float)
(iters:Int) (writePeriod:Int) : List (n=>Bool) =
yieldAccum (ListMonoid (n=>Bool)) \list.
yieldState init \state.
for i:(Fin iters).
x = get state
state := samplerStep x f (newKey (ordinal i))
if mod (ordinal i) writePeriod == 0 then
append list x

def hue2rgb (p:Float) (q:Float) (t:Float) : Float =
t = t - floor t
if t < (1.0/6.0)
Expand Down Expand Up @@ -151,72 +170,120 @@ def pngsToSavedGif (delay:Int) (pngs:t=>Png) (outFileName:String) : Gif =
"gif:" <> outFileName <> ".gif"


'## Set up a particular Ising model

png = unsafeIO do
readFile "examples/peace.png"

:t png

theta = 0.5
N = Fin 60
x : (N & N)=>Bool = arb (newKey 0)
'### Image Loading Utilities
This is a nice chance to try out Dex's simpler parser combinator library.

def parseP6header : Parser (Int & Int & Int) = MkParser \h.
-- Loads a raw PPM file in P6 format.
-- The header will look something like:
--P6
--220 220 (height, width)
--255 (max color value)
-- followed by a flat block of height x width x 3 chars.
parse h $ pChar 'P'
parse h $ pChar '6'
parse h $ parseAny
rows = parse h $ parseUnsignedInt
parse h $ parseAny
cols = parse h $ parseUnsignedInt
parse h $ parseAny
colorsize = parse h $ parseUnsignedInt
(rows, cols, colorsize)

def parseP6 (rows:Int) (cols:Int) : Parser ((Fin rows)=>(Fin cols)=>(Fin 3)=>Char) = MkParser \h.
parse h $ parseP6header
parse h $ parseAny
for r:(Fin rows).
for c:(Fin cols).
for c:(Fin 3).
parse h parseAny

def pixelToBool (x:Char) : Bool = (W8ToI x) < 0


'### Load image

image_raw = unsafeIO do readFile "examples/peace.ppm"
(rows, cols, _) = fromJust $ runParserPartial image_raw parseP6header
image = fromJust $ runParserPartial image_raw (parseP6 rows cols)
image_bool = for i j.
pixelToBool image.i.j.(1@_)

-- Add noise
noisefrac = 0.1
image_noisy = for i j.
addnoise = rand (ixkey2 (newKey 0) i j) < noisefrac
case addnoise of
True -> not image_bool.i.j
False -> image_bool.i.j

imcol = for i j. for c:(Fin 3).
(IToF $ BToI $ not image_bool.i.j)

'### Set up an Ising model to denoise that image.
'The model simply encodes that nearby pixels usually have the same color.
'The bias term makes it more likely that the pixels will match the noisy image.

theta = 0.5 -- Coupling constant between neighbouring pixels.
bias = for i j. -- Bias for individual pixels.
case image_noisy.i.j of
True -> 1.0
False -> -1.0

def flattened_ising (x:n=>Float) : Float =
-- unflattens x.
xu = for i:N. for j:N.
-- unflattens x.
x_unflattend = for i j.
x.(unsafeFromOrdinal _ (ordinal (i,j)))
ising_logprob xu theta
ising_logprob x_unflattend bias theta

'Normally for image denoising, we start with the noisy image.
However, to simulate a more realistic inference problem,
we'll start far from the mode at a completely random initialization.

'## Generate images
init_field =
for (i, j):((Fin rows) & (Fin cols)).
rand (ixkey2 (newKey 0) i j) < 0.5

'### Gibbs with Gradients
'## Generate animations

def runSampler (samplerStep: n=>Bool -> (n=>Float->Float) -> Key -> (n=>Bool))
(init:n=>Bool) (f:n=>Float -> Float)
(iters:Int) (writePeriod:Int) : List (n=>Bool) =
yieldAccum (ListMonoid (n=>Bool)) \list.
yieldState init \state.
for i:(Fin iters).
x = get state
state := samplerStep x f (newKey (ordinal i))
if mod (ordinal i) writePeriod == 0 then
append list x
'### Run Gibbs with Gradients
We'll color the pixels by the probability that they'll be proposed to flip.
In an Ising model, this creates an outline around the edges of homoeneous regions.

num_iters = 1000
write_period = 100
num_iters = 175 -- Change to 17500 for full animation.
write_period = 50

xmovie' = runSampler gibbsWithGradients x flattened_ising num_iters write_period
(AsList _ xmovie) = xmovie'
xmovieflat:(Fin _)=>N=>N=>Color =
for i.
frameList = runSampler gibbsWithGradients init_field flattened_ising num_iters write_period
(AsList _ xmovie) = frameList
xmovieflat = for i.
xf = map boolToFloat xmovie.i
dfdx = grad flattened_ising xf
flip_prob = softmax (-xf * dfdx / 2.0)
flip_prob = softmax (-xf * dfdx / 2.0) -- Color pixels by probability of flipping.
for j k. probToColor xmovie.i.(j, k) flip_prob.(j, k)

:html imseqshow xmovieflat
pngsToSavedGif 1 (map imgToPng xmovieflat) "gwg"


'### Standard Gibbs
'### Run Standard Gibbs

'So that we can re-use the helper function for Gibbs with Gradients,
we need to change the signature of standard Gibbs to match.

def wrappedGibbs (x:n=>Bool) (f:n=>Float->Float) (key:Key) : n=>Bool =
boolf = \x'.
boolf:(n=>Bool->Float) = \x'.
f $ map boolToFloat x'

gibbsUpdate x boolf key

xmovie'' = runSampler wrappedGibbs x flattened_ising num_iters write_period
(AsList _ xmovie''') = xmovie''
xmovieflat':(Fin _)=>N=>N=>Color =
for i.
xf = map boolToFloat xmovie'''.i
dfdx = 1.0
flip_prob = 1.0 / ( sq $ IToF (size N))
for j k. probToColor xmovie'''.i.(j, k) flip_prob
frameList' = runSampler wrappedGibbs init_field flattened_ising num_iters write_period
(AsList _ xmovie') = frameList'
xmovieflat' = for i.
xf = map boolToFloat xmovie'.i
flip_prob = 1.0 / (IToF (rows * cols)) -- Uniform probability of flipping.
for j k. probToColor xmovie'.i.(j, k) flip_prob

:html imseqshow xmovieflat'
pngsToSavedGif 1 (map imgToPng xmovieflat') "gibbs"

'And we see Gibbs with gradients mixes faster.
The improvement will generally grow with problem size.

0 comments on commit 6211884

Please sign in to comment.