diff --git a/examples/gibbs_with_gradients.dx b/examples/gibbs_with_gradients.dx index 4dd370d54..55ec08360 100644 --- a/examples/gibbs_with_gradients.dx +++ b/examples/gibbs_with_gradients.dx @@ -1,3 +1,5 @@ +'# Gibbs with Gradients + 'This is a demo of an MCMC sampler from the paper: [Oops I Took A Gradient: Scalable Sampling for Discrete Distributions](https://arxiv.org/abs/2102.04509) demonstrated on an Ising model. @@ -12,13 +14,14 @@ This might sound a bit hacky, but the resulting MCMC operator still has the correct marginal distribution, since we apply a [Metropolis-Hastings](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) correction step. +'This demo uses the Ising model for image denoising. + +import parser import plot -'### Helper Functions -instance Arbitrary Bool -- Todo: Move to prelude. - arb = \key. rand key < 0.5 +'### Helper Functions def grad_and_value (f:a->Float) (x:a) : (a & Float) = (val, vjpfun) = vjp f x @@ -35,11 +38,11 @@ def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool = [key_sample, key_accept] = splitKey key - -- sample which dimension to change and flip it. + -- Sample which dimension to change and flip it. flip_ix = randIdx key_sample x' = flipEntry x flip_ix - -- accept / reject step. + -- Accept / reject step. acceptance_rate = exp (f x' - f x) if rand key_accept < acceptance_rate then x' @@ -49,6 +52,12 @@ def gibbsUpdate (x:n=>Bool) (f:n=>Bool->Float) (key:Key) : n=>Bool = '## Gibbs with Gradients Sampler +'The Gibbs with Gradients sampler has a slightly different function signature +than standard Gibbs. Instead of its log probability function taking in a +discrete array, it takes in an array of floats of the same size. This is necessary +so that it can be differentiated with respect to its input, even though we will +only call it on discrete inputs. + def boolToFloat (x:Bool) : Float = case x of True -> 1.0 @@ -58,7 +67,6 @@ def floatToBool (x:Float) : Bool = select (x > 0.0) True False def gibbsWithGradients (x:n=>Bool) (f:n=>Float->Float) (key:Key) : n=>Bool = - [key_sample, key_accept] = splitKey key -- Compute proposal distribution. @@ -104,6 +112,17 @@ def ising_logprob (x:n=>m=>Float) (bias:n=>m=>Float) (theta:Float) : Float = '## Plotting utilities +def runSampler (samplerStep: n=>Bool -> (n=>Float->Float) -> Key -> (n=>Bool)) + (init:n=>Bool) (f:n=>Float -> Float) + (iters:Int) (writePeriod:Int) : List (n=>Bool) = + yieldAccum (ListMonoid (n=>Bool)) \list. + yieldState init \state. + for i:(Fin iters). + x = get state + state := samplerStep x f (newKey (ordinal i)) + if mod (ordinal i) writePeriod == 0 then + append list x + def hue2rgb (p:Float) (q:Float) (t:Float) : Float = t = t - floor t if t < (1.0/6.0) @@ -151,72 +170,120 @@ def pngsToSavedGif (delay:Int) (pngs:t=>Png) (outFileName:String) : Gif = "gif:" <> outFileName <> ".gif" -'## Set up a particular Ising model - -png = unsafeIO do - readFile "examples/peace.png" - -:t png - -theta = 0.5 -N = Fin 60 -x : (N & N)=>Bool = arb (newKey 0) +'### Image Loading Utilities +This is a nice chance to try out Dex's simpler parser combinator library. + +def parseP6header : Parser (Int & Int & Int) = MkParser \h. + -- Loads a raw PPM file in P6 format. + -- The header will look something like: + --P6 + --220 220 (height, width) + --255 (max color value) + -- followed by a flat block of height x width x 3 chars. + parse h $ pChar 'P' + parse h $ pChar '6' + parse h $ parseAny + rows = parse h $ parseUnsignedInt + parse h $ parseAny + cols = parse h $ parseUnsignedInt + parse h $ parseAny + colorsize = parse h $ parseUnsignedInt + (rows, cols, colorsize) + +def parseP6 (rows:Int) (cols:Int) : Parser ((Fin rows)=>(Fin cols)=>(Fin 3)=>Char) = MkParser \h. + parse h $ parseP6header + parse h $ parseAny + for r:(Fin rows). + for c:(Fin cols). + for c:(Fin 3). + parse h parseAny + +def pixelToBool (x:Char) : Bool = (W8ToI x) < 0 + + +'### Load image + +image_raw = unsafeIO do readFile "examples/peace.ppm" +(rows, cols, _) = fromJust $ runParserPartial image_raw parseP6header +image = fromJust $ runParserPartial image_raw (parseP6 rows cols) +image_bool = for i j. + pixelToBool image.i.j.(1@_) + +-- Add noise +noisefrac = 0.1 +image_noisy = for i j. + addnoise = rand (ixkey2 (newKey 0) i j) < noisefrac + case addnoise of + True -> not image_bool.i.j + False -> image_bool.i.j + +imcol = for i j. for c:(Fin 3). + (IToF $ BToI $ not image_bool.i.j) + +'### Set up an Ising model to denoise that image. +'The model simply encodes that nearby pixels usually have the same color. +'The bias term makes it more likely that the pixels will match the noisy image. + +theta = 0.5 -- Coupling constant between neighbouring pixels. +bias = for i j. -- Bias for individual pixels. + case image_noisy.i.j of + True -> 1.0 + False -> -1.0 def flattened_ising (x:n=>Float) : Float = - -- unflattens x. - xu = for i:N. for j:N. + -- unflattens x. + x_unflattend = for i j. x.(unsafeFromOrdinal _ (ordinal (i,j))) - ising_logprob xu theta + ising_logprob x_unflattend bias theta +'Normally for image denoising, we start with the noisy image. +However, to simulate a more realistic inference problem, +we'll start far from the mode at a completely random initialization. -'## Generate images +init_field = + for (i, j):((Fin rows) & (Fin cols)). + rand (ixkey2 (newKey 0) i j) < 0.5 -'### Gibbs with Gradients +'## Generate animations -def runSampler (samplerStep: n=>Bool -> (n=>Float->Float) -> Key -> (n=>Bool)) - (init:n=>Bool) (f:n=>Float -> Float) - (iters:Int) (writePeriod:Int) : List (n=>Bool) = - yieldAccum (ListMonoid (n=>Bool)) \list. - yieldState init \state. - for i:(Fin iters). - x = get state - state := samplerStep x f (newKey (ordinal i)) - if mod (ordinal i) writePeriod == 0 then - append list x +'### Run Gibbs with Gradients +We'll color the pixels by the probability that they'll be proposed to flip. +In an Ising model, this creates an outline around the edges of homoeneous regions. -num_iters = 1000 -write_period = 100 +num_iters = 175 -- Change to 17500 for full animation. +write_period = 50 -xmovie' = runSampler gibbsWithGradients x flattened_ising num_iters write_period -(AsList _ xmovie) = xmovie' -xmovieflat:(Fin _)=>N=>N=>Color = - for i. +frameList = runSampler gibbsWithGradients init_field flattened_ising num_iters write_period +(AsList _ xmovie) = frameList +xmovieflat = for i. xf = map boolToFloat xmovie.i dfdx = grad flattened_ising xf - flip_prob = softmax (-xf * dfdx / 2.0) + flip_prob = softmax (-xf * dfdx / 2.0) -- Color pixels by probability of flipping. for j k. probToColor xmovie.i.(j, k) flip_prob.(j, k) :html imseqshow xmovieflat pngsToSavedGif 1 (map imgToPng xmovieflat) "gwg" -'### Standard Gibbs +'### Run Standard Gibbs + +'So that we can re-use the helper function for Gibbs with Gradients, +we need to change the signature of standard Gibbs to match. def wrappedGibbs (x:n=>Bool) (f:n=>Float->Float) (key:Key) : n=>Bool = - boolf = \x'. + boolf:(n=>Bool->Float) = \x'. f $ map boolToFloat x' - gibbsUpdate x boolf key -xmovie'' = runSampler wrappedGibbs x flattened_ising num_iters write_period -(AsList _ xmovie''') = xmovie'' -xmovieflat':(Fin _)=>N=>N=>Color = - for i. - xf = map boolToFloat xmovie'''.i - dfdx = 1.0 - flip_prob = 1.0 / ( sq $ IToF (size N)) - for j k. probToColor xmovie'''.i.(j, k) flip_prob +frameList' = runSampler wrappedGibbs init_field flattened_ising num_iters write_period +(AsList _ xmovie') = frameList' +xmovieflat' = for i. + xf = map boolToFloat xmovie'.i + flip_prob = 1.0 / (IToF (rows * cols)) -- Uniform probability of flipping. + for j k. probToColor xmovie'.i.(j, k) flip_prob :html imseqshow xmovieflat' pngsToSavedGif 1 (map imgToPng xmovieflat') "gibbs" +'And we see Gibbs with gradients mixes faster. +The improvement will generally grow with problem size.