diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..cc27b731f --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +custom: https://numfocus.salsalabs.org/donate-to-julia/index.html diff --git a/.travis.yml b/.travis.yml index eea787950..85e1d8593 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,9 @@ language: julia os: - linux julia: + - 1.0 - 1.1 + - 1.2 - nightly notifications: email: false @@ -27,9 +29,7 @@ jobs: julia: 1.0 os: linux script: - - julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); - Pkg.instantiate()' - - julia --project=docs/ docs/make.jl + - julia --color=yes --project=docs/ docs/make.jl after_success: skip ## uncomment and modify the following lines to manually install system packages @@ -39,7 +39,3 @@ jobs: # - gfortran #before_script: # homebrew for mac # - if [ $TRAVIS_OS_NAME = osx ]; then brew install gcc; fi - -## uncomment the following lines to override the default test script -script: - - julia --color=yes -e 'using Pkg; Pkg.activate(); Pkg.instantiate(); Pkg.test()' diff --git a/CITATION.bib b/CITATION.bib new file mode 100644 index 000000000..b7124280e --- /dev/null +++ b/CITATION.bib @@ -0,0 +1,13 @@ +@article{Zygote.jl-2018, + author = {Michael Innes}, + title = {Don't Unroll Adjoint: Differentiating SSA-Form Programs}, + journal = {CoRR}, + volume = {abs/1810.07951}, + year = {2018}, + url = {http://arxiv.org/abs/1810.07951}, + archivePrefix = {arXiv}, + eprint = {1810.07951}, + timestamp = {Tue, 30 Oct 2018 20:39:56 +0100}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1810-07951}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} diff --git a/LICENSE.md b/LICENSE.md index 7e86adaae..077c15235 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,6 +1,6 @@ The Zygote.jl package is licensed under the MIT "Expat" License: -> Copyright (c) 2018: Mike J Innes. +> Copyright (c) 2018-19: Julia Computing, Inc., Mike J Innes and contributors > > Permission is hereby granted, free of charge, to any person obtaining a copy > of this software and associated documentation files (the "Software"), to deal diff --git a/Manifest.toml b/Manifest.toml index 28b648c22..7e613d82e 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,5 +1,11 @@ # This file is machine-generated - editing it directly is not advised +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "0.4.1" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -10,10 +16,16 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" version = "0.8.10" [[BinaryProvider]] -deps = ["Libdl", "Pkg", "SHA", "Test"] -git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.3" +version = "0.5.6" + +[[CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "0ff80f68f55fcde2ed98d7b24d7abaf20727f3f8" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "0.6.1" [[CommonSubexpressions]] deps = ["Test"] @@ -23,9 +35,27 @@ version = "0.2.0" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "195a3ffcb8b0762684b6821de18f83a16455c6ea" +git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "2.0.0" +version = "2.1.0" + +[[Conda]] +deps = ["JSON", "VersionParsing"] +git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032" +uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" +version = "1.3.0" + +[[Crayons]] +deps = ["Test"] +git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.0.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.17.0" [[Dates]] deps = ["Printf"] @@ -51,6 +81,18 @@ version = "0.0.10" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[FFTW]] +deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] +git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "0.3.0" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Test"] +git-tree-sha1 = "9ab8f76758cbabba8d7f103c51dce7f73fcf8e92" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.6.3" + [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" @@ -59,14 +101,20 @@ version = "0.10.3" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53" +git-tree-sha1 = "a9b1fc7745ae4745a634bbb6d1cb7fd64e37248a" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.1.2" +version = "0.2.2" [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "b34d7cef7b337321e97d22242c3c2b91f476748e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.0" + [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -81,10 +129,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[MacroTools]] -deps = ["Compat"] -git-tree-sha1 = "3fd1a3022952128935b449c33552eb65895380c1" +deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"] +git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.4.5" +version = "0.5.1" [[Markdown]] deps = ["Base64"] @@ -94,10 +142,10 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] -deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "51330bb45927379007e089997bf548fbe232589d" +deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] +git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.4.3" +version = "0.6.0" [[NaNMath]] deps = ["Compat"] @@ -105,6 +153,18 @@ git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "0.3.2" +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.1.0" + +[[Parsers]] +deps = ["Dates", "Test"] +git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "0.3.6" + [[Pkg]] deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -121,6 +181,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + [[Requires]] deps = ["Test"] git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" @@ -151,10 +217,10 @@ uuid = "276daf66-3868-5448-9aa4-cd146d93841b" version = "0.7.2" [[StaticArrays]] -deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] -git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.10.3" +version = "0.11.0" [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -164,6 +230,17 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[TimerOutputs]] +deps = ["Crayons", "Printf", "Test", "Unicode"] +git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.0" + +[[Tokenize]] +git-tree-sha1 = "c8a8b00ae44a94950814ff77850470711a360225" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.5" + [[URIParser]] deps = ["Test", "Unicode"] git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" @@ -176,3 +253,17 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[VersionParsing]] +deps = ["Compat"] +git-tree-sha1 = "c9d5aa108588b978bd859554660c8a5c4f2f7669" +uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" +version = "1.1.3" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "24edeee7ac0e5537ce87054b0a6eda1c30cabfac" +repo-rev = "master" +repo-url = "https://github.com/FluxML/ZygoteRules.jl" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.1.0" diff --git a/Project.toml b/Project.toml index c87728fc7..37f9192ff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,11 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.3.2" [deps] DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -13,6 +16,13 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" + +[compat] +IRTools = "0.2" +NNlib = "0.6" +julia = "1" [extras] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" diff --git a/README.md b/README.md index 5fb1d09b8..4b70fa38d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![Build Status](https://travis-ci.org/FluxML/Zygote.jl.svg?branch=master)](https://travis-ci.org/FluxML/Zygote.jl) [![Dev Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://fluxml.ai/Zygote.jl/dev) -`] add Zygote#master` +`] add Zygote` Zygote is a working prototype for source-to-source automatic differentiation (AD) in Julia, and the next-gen AD system for the [Flux](https://github.com/FluxML/Flux.jl) differentiable programming framework. For more details and benchmarks of Zygote's technique, see [our paper](https://arxiv.org/abs/1810.07951). @@ -32,7 +32,7 @@ Without compromising on performance, Zygote supports the full flexibility and dy ```julia julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan); -julia> derivative(x -> fs[readline()](x), 1) +julia> gradient(x -> fs[readline()](x), 1) sin 0.5403023058681398 ``` diff --git a/REQUIRE b/REQUIRE deleted file mode 100644 index 826112a04..000000000 --- a/REQUIRE +++ /dev/null @@ -1,9 +0,0 @@ -julia 1.0 -DiffRules -ForwardDiff -MacroTools -NNlib -NaNMath -Requires -SpecialFunctions -IRTools diff --git a/docs/make.jl b/docs/make.jl index f25c86019..9217b0d96 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,3 +1,9 @@ +using Pkg; +Pkg.activate(joinpath(@__DIR__, "..")); Pkg.instantiate() +Pkg.activate(); Pkg.instantiate() + +pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) + using Documenter, Zygote makedocs( @@ -6,8 +12,13 @@ makedocs( pages = [ "Home" => "index.md", "Custom Adjoints" => "adjoints.md", + "Utilities" => "utils.md", + "Complex Differentiation" => "complex.md", + "Flux" => "flux.md", "Profiling" => "profiling.md", - "Internals" => "internals.md"]) + "Internals" => "internals.md", + "Glossary" => "glossary.md"], + format = Documenter.HTML(prettyurls = haskey(ENV, "CI"))) deploydocs( repo = "github.com/FluxML/Zygote.jl.git", diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md index afa21d488..8206fbdeb 100644 --- a/docs/src/adjoints.md +++ b/docs/src/adjoints.md @@ -1,5 +1,7 @@ # Custom Adjoints +The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are. + ## Pullbacks `gradient` is really just syntactic sugar around the more fundamental function `forward`. @@ -52,6 +54,12 @@ julia> mygradient(sin, 0.5) (0.8775825618903728,) ``` +The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user. + +If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = sin(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints". + +Zygote has many adjoints for non-mathematical operations such as for indexing and data structures. Though these can still be seen as linear functions of vectors, it's not particularly enlightening to implement them with an actual matrix multiply. In these cases it's easiest to think of the adjoint as a kind of inverse. For example, the gradient of a function that takes a tuple to a struct (e.g. `y = Complex(a, b)`) will generally take a struct to a tuple (`(ȳ.re, ȳ.im)`). The gradient of a `getindex` `y = x[i...]` is a `setindex!` `x̄[i...] = ȳ`, etc. + ## Custom Adjoints We can extend Zygote to a new function with the `@adjoint` function. diff --git a/docs/src/complex.md b/docs/src/complex.md new file mode 100644 index 000000000..edde3a3e3 --- /dev/null +++ b/docs/src/complex.md @@ -0,0 +1,58 @@ +# Complex Differentiation + +Complex numbers add some difficulty to the idea of a "gradient". To talk about `gradient(f, x)` here we need to talk a bit more about `f`. + +If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box. + +```julia +julia> gradient(c -> abs2(c), 1+2im) +(2 + 4im,) +``` + +However, while this is a very pragmatic definition that works great for gradient descent, it's not quite aligned with the mathematical notion of the derivative: i.e. ``f(c + \epsilon) \approx f(c) + \bar c \epsilon``. In general, such a ``\bar c`` is not possible for complex numbers except when `f` is *holomorphic* (or *analytic*). Roughly speaking this means that the function is defined over `c` as if it were a normal real number, without exploiting its complex structure – it can't use `real`, `imag`, `conj`, or anything that depends on these like `abs2` (`abs2(x) = x*x'`). (This constraint also means there's no overlap with the Real case above; holomorphic functions always return complex numbers for complex input.) But most "normal" numerical functions – `exp`, `log`, anything that can be represented by a Taylor series – are fine. + +Fortunately it's also possible to get these derivatives; they are the conjugate of the gradients for the real part. + +```julia +julia> gradient(x -> real(log(x)), 1+2im)[1] |> conj +0.2 - 0.4im +``` + +We can check that this function is holomorphic – and thus that the gradient we got out is sensible – by checking the Cauchy-Riemann equations. In other words this should give the same answer: + +```julia +julia> -im*gradient(x -> imag(log(x)), 1+2im)[1] |> conj +0.2 - 0.4im +``` + +Notice that this fails in a non-holomorphic case, `f(x) = log(x')`: + +```julia +julia> gradient(x -> real(log(x')), 1+2im)[1] |> conj +0.2 - 0.4im + +julia> -im*gradient(x -> imag(log(x')), 1+2im)[1] |> conj +-0.2 + 0.4im +``` + +In cases like these, all bets are off. The gradient can only be described with more information; either a 2x2 Jacobian (a generalisation of the Real case, where the second column is now non-zero), or by the two Wirtinger derivatives (a generalisation of the holomorphic case, where ``\frac{∂ f}{∂ z'}`` is now non-zero). To get these efficiently, as we would a Jacobian, we can just call the backpropagators twice. + +```julia +function jacobi(f, x) + y, back = Zygote.forward(f, x) + back(1)[1], back(im)[1] +end + +function wirtinger(f, x) + du, dv = jacobi(f, x) + (du' + im*dv')/2, (du + im*dv)/2 +end +``` + +```julia +julia> wirtinger(x -> 3x^2 + 2x + 1, 1+2im) +(8.0 + 12.0im, 0.0 + 0.0im) + +julia> wirtinger(x -> abs2(x), 1+2im) +(1.0 - 2.0im, 1.0 + 2.0im) +``` diff --git a/docs/src/flux.md b/docs/src/flux.md new file mode 100644 index 000000000..cd9e96650 --- /dev/null +++ b/docs/src/flux.md @@ -0,0 +1,27 @@ +# Flux + +It's easy to use Zygote in place of Flux's default AD, Tracker, just by changing `Tracker.gradient` to `Zygote.gradient`. The API is otherwise the same. + +```julia +julia> using Flux, Zygote + +julia> m = Chain(Dense(10, 5, relu), Dense(5, 2)) +Chain(Dense(10, 5, NNlib.relu), Dense(5, 2)) + +julia> x = rand(10); + +julia> gs = gradient(() -> sum(m(x)), params(m)) +Grads(...) + +julia> gs[m[1].W] +5×10 Array{Float32,2}: + -0.255175 -1.2295 ... +``` + +You can use optimisers and update gradients as usual. + +```julia +julia> opt = ADAM(); + +julia> Flux.Optimise.update!(opt, params(m), gs) +``` diff --git a/docs/src/glossary.md b/docs/src/glossary.md new file mode 100644 index 000000000..1f9f8e35d --- /dev/null +++ b/docs/src/glossary.md @@ -0,0 +1,37 @@ +# Glossary + +Differentiation is a minefield of conflicting and overlapping terminology, partly because the ideas have been re-discovered in many different fields (e.g. calculus and differential geometry, the traditional AD community, deep learning, finance, etc.) Many of these terms are not well-defined and others may disagree on the details. Nevertheless, we aim to at least say how *we* use these terms, which will be helpful when reading over Zygote issues, discussions and source code. + +The list is certainly not complete; if you see new terms you'd like defined, or would like to add one yourself, please do open an issue or PR. + +**Adjoint**: See *pullback*. Used when defining new pullbacks (i.e. the `@adjoint` macro) since this involves defining the adjoint of the Jacobian, in most cases. + +**Backpropagation**: Essentially equivalent to "reverse-mode AD". Used particularly in the machine learning world to refer to simple chains of functions `f(g(h(x)))`, but has generalised beyond that. + +**Derivative**: Given a scalar function ``y = f(x)``, the derivative is ``\frac{\partial y}{\partial x}``. "Partial" is taken for granted in AD; there's no interesting distinction between partial and total derivatives for our purposes. It's all in the eye of the beholder. + +**Differential**: Given a function ``f(x)``, the linearisation ``\partial f`` such that ``f(x + \epsilon) \approx f(x) + \partial f \epsilon``. This is a generalisation of the derivative since it applies to, for example, vector-to-vector functions (``\partial f`` is a Jacobian) and holomorphic complex functions (``\partial f`` is the first Wirtinger derivative). This is *not*, in general, what Zygote calculates, though differentials can usually be derived from gradients. + +**IR**: Intermediate Representation. Essentially source code, but usually lower level – e.g. control flow constructs like loops and branches have all been replaced by `goto`s. The idea is that it's harder for humans to read/write but easier to manipulate programmatically. Worth looking at SSA form as a paradigmatic example. + +**Gradient**: See *sensitivity*. There is no technical difference in Zygote's view, though "gradient" sometimes distinguishes the sensitivity we actually want from e.g. the internal ones that Zygote produces as it backpropagates. + +**Graph**: ML people tend to think of models as "computation graphs", but this is no more true than any program is a graph. In fact, pretty much anything is a graph if you squint hard enough. This also refers to the data structure that e.g. TensorFlow and PyTorch build to represent your model, but see *trace* for that. + +**Pullback**: Given ``y = f(x)`` the function ``\bar x = back(̄\bar y)``. In other words, the function `back` in `y, back = Zygote.forward(f, x)`. + +**Sensitivity**: Used to refer to the gradient ``\bar x = \frac{\partial l}{\partial x}`` with some scalar loss ``l``. In other words, you have a value ``x`` (which need not be scalar) at some point in your program, and ``\bar x`` tells you how you should change that value to decrease the loss. In the AD world, sometimes used to refer to adjoint rules. + +**Source to Source Differentiation**: Or Source Code Transformation (SCT). As opposed to *tracing* programs to simplify them, an alternative is to operate directly on a language's source code or IR, generating new source code for pullbacks. This describes Zygote, Swift for TensorFlow, Tapenade and a few other old ADs that worked on C source files. Zygote and Swift are unusual in that they work on in-memory IR rather than text source. + +To an extent, tracing ADs can be viewed as source transform of a Wengert list / trace. The key difference is that the trace is a lossy representation of the original semantics, which causes problems with e.g. control flow. Systems which can preserve some of those semantics (e.g. autograph) begin to blur the line here, though they are still not nearly as expressive as language IRs. + +**Symbolic Differentiation**: Used to refer to differentiation of "mathematical expressions", that is, things like `3x^2 + sin(x)`. Often distinguished from AD, though this is somewhat arbitrary; you can happily produce a symbolic adjoint for a Wengert list, the only difference being that you're allowed to make variable bindings. So it's really just a special case of AD on an unusually limited language. + +**Tape**: This term can refer to pretty much any part of an AD implementation. In particular confusion is caused by conflating the *trace* with the set of values sometimes closed over by a *pullback*. Autograd has a combined trace/closure data structure which is usually described as the tape. On the other hand, PyTorch described their implementation as tape-free because the trace/closure is stored as a DAG rather than a vector, so basically all bets are off here. + +**Trace**: A recording of each mathematical operation used by a program, made at runtime and usually forming a Wengert list. Traces may or may not also record actual runtime values (e.g. PyTorch vs. TensorFlow). They can often be treated as an IR and compiled, but are distinguished from true IRs in that they unroll and inline all control flow, functions and data structures. The tracing process can be thought of as a kind of partial evaluation, though tracers are typically much less worried about losing information. + +**vector-Jacobian product**: see *pullback*. So called because all pullbacks are linear functions that can be represented by (left) multiplication with the Jacobian matrix. + +**Wengert List**: A set of simple variable assignments and mathematical expressions, forming a directed graph. Can be thought of as a limited programming language with variable bindings and numerical functions but no control flow or data structures. If you *trace* a program for AD it will typically take this form. diff --git a/docs/src/index.md b/docs/src/index.md index 0de1c671e..bd67d7d37 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -61,7 +61,7 @@ pow (generic function with 1 method) julia> gradient(x -> pow(x, 3), 5) (75,) -julia> pow2(x, n) = n <= 0 ? 1 : x*pow(x, n-1) +julia> pow2(x, n) = n <= 0 ? 1 : x*pow2(x, n-1) pow2 (generic function with 1 method) julia> gradient(x -> pow2(x, 3), 5) @@ -117,6 +117,18 @@ julia> gradient(a -> dist(a + b), a)[1] Zygote's default representation of the "point adjoint" is a named tuple with gradients for both fields, but this can of course be customised too. +This means we can do something very powerful: differentiating through Julia libraries, even if they weren't designed for this. For example, `colordiff` might be a smarter loss function on colours than simple mean-squared-error: + +```julia +julia> using Colors + +julia> colordiff(RGB(1, 0, 0), RGB(0, 1, 0)) +86.60823557376344 + +julia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0)) +((r = 0.4590887719632896, g = -9.598786801605689, b = 14.181383399012862), (r = -1.7697549557037275, g = 28.88472330558805, b = -0.044793892637761346)) +``` + ## Gradients of ML models It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. diff --git a/docs/src/profiling.md b/docs/src/profiling.md index af1959abc..28a169835 100644 --- a/docs/src/profiling.md +++ b/docs/src/profiling.md @@ -1,3 +1,70 @@ -# Profiling +# Debugging in Time and Space -WIP +Because Zygote generates Julia code for the backwards pass, many of Julia's +normal profiling and performance debugging tools work well on it out of the box. + +## Performance Profiling + +Julia's [sampling profiler](https://docs.julialang.org/en/v1/manual/profile/) is +useful for understanding performance. We recommend [running the profiler in +Juno](http://docs.junolab.org/latest/man/juno_frontend/#Profiler-1), but the +terminal or [ProfileView.jl](https://github.com/timholy/ProfileView.jl) also +work well. + +![](https://i.imgur.com/saYm3Uo.png) + +The bars indicate time taken in both the forwards and backwards passes at that +line. The canopy chart on the right shows us each function call as a block, +arranged so that when `f` calls `g`, `g` gets a block just below `f`, which is +bigger the longer it took to run. If we dig down the call stack we'll eventually +find the adjoints for things like `matmul`, which we can click on to view. + +![](https://i.imgur.com/ypLQZlu.png) + +The trace inside the adjoint can be used to distinguish time taken by the forwards and backwards passes. + +## Memory Profiling + +Reverse-mode AD typically uses memory proportional to the number of operations +in the program, so long-running programs can also suffer memory usage issues. +Zygote includes a space profiler to help debug these issues. Like the time +profiler, it shows a canopy chart, but this time hovering over it displays the +number of bytes stored by each line of the program. + +![](https://i.imgur.com/pd2P4W4.png) + +Note that this currently only works inside Juno. + +## Reflection + +Julia's code and type inference reflection tools can also be useful, though +Zygote's use of closures can make the output noisy. To see the code Julia runs +you should use the low-level `_forward` method and the pullback it returns. +This will directly show either the derived adjoint code or the code for a custom +adjoint, if there is one. + +```julia +julia> using Zygote: Context, _forward + +julia> add(a, b) = a+b + +julia> @code_typed _forward(Context(), add, 1, 2) +CodeInfo( +1 ─ %1 = (Base.getfield)(args, 1)::Int64 +│ %2 = (Base.getfield)(args, 2)::Int64 +│ %3 = (Base.add_int)(%1, %2)::Int64 +│ %4 = (Base.tuple)(%3, $(QuoteNode(∂(add))))::PartialTuple(Tuple{Int64,typeof(∂(add))}, Any[Int64, Const(∂(add), false)]) +└── return %4 +) => Tuple{Int64,typeof(∂(add))} + +julia> y, back = _forward(Context(), add, 1, 2) +(3, ∂(add)) + +julia> @code_typed back(1) +CodeInfo( +1 ─ %1 = (Base.mul_int)(Δ, 1)::Int64 +│ %2 = (Base.mul_int)(Δ, 1)::Int64 +│ %3 = (Zygote.tuple)(nothing, %1, %2)::PartialTuple(Tuple{Nothing,Int64,Int64}, Any[Const(nothing, false), Int64, Int64]) +└── return %3 +) => Tuple{Nothing,Int64,Int64} +``` diff --git a/docs/src/utils.md b/docs/src/utils.md new file mode 100644 index 000000000..bceedb43b --- /dev/null +++ b/docs/src/utils.md @@ -0,0 +1,14 @@ +# Utilities + +Zygote provides a set of helpful utilities. These are all "user-level" tools – +in other words you could have written them easily yourself, but they live in +Zygote for convenience. + +```@docs +Zygote.@showgrad +Zygote.hook +Zygote.dropgrad +Zygote.hessian +Zygote.Buffer +Zygote.forwarddiff +``` diff --git a/examples/Manifest.toml b/examples/Manifest.toml new file mode 100644 index 000000000..71c78ffad --- /dev/null +++ b/examples/Manifest.toml @@ -0,0 +1,316 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractTrees]] +deps = ["Markdown", "Test"] +git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.2.1" + +[[Adapt]] +deps = ["LinearAlgebra", "Test"] +git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "0.4.2" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BinDeps]] +deps = ["Compat", "Libdl", "SHA", "URIParser"] +git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9" +uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +version = "0.8.10" + +[[BinaryProvider]] +deps = ["Libdl", "Pkg", "SHA", "Test"] +git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.3" + +[[CSTParser]] +deps = ["LibGit2", "Test", "Tokenize"] +git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "0.5.2" + +[[CodecZlib]] +deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] +git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.5.2" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random", "Test"] +git-tree-sha1 = "f73b0e10f2a5756de7019818a41654686da06b09" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.7.5" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] +git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.9.5" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "2.1.0" + +[[DataStructures]] +deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] +git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.15.0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["Compat", "StaticArrays"] +git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "0.0.4" + +[[DiffRules]] +deps = ["Random", "Test"] +git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "0.0.10" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FixedPointNumbers]] +deps = ["Test"] +git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.5.3" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Test", "Tracker", "ZipFile"] +git-tree-sha1 = "75e5a6850ad9d6129773171d9ba66be899a515ec" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.8.2" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] +git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.3" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.1.2" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile", "Test"] +git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.7.0" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["CSTParser", "Compat", "DataStructures", "Test"] +git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.0" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] +git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "0.4.0" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[NNlib]] +deps = ["Libdl", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "9ac5cd21484189339b27840818c4882d1b6df7fd" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.5.0" + +[[NaNMath]] +deps = ["Compat"] +git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.2" + +[[OrderedCollections]] +deps = ["Random", "Serialization", "Test"] +git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.1.0" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[ProgressMeter]] +deps = ["Distributed", "Printf", "Random", "Test"] +git-tree-sha1 = "48058bc11607676e5bbc0b974af79106c6200787" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "0.9.0" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["Test"] +git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "0.5.2" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures", "Random", "Test"] +git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "0.3.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] +git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.7.2" + +[[StaticArrays]] +deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] +git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.10.3" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsBase]] +deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] +git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.30.0" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[Tokenize]] +deps = ["Printf", "Test"] +git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.3" + +[[Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] +git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.1.0" + +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.4" + +[[URIParser]] +deps = ["Test", "Unicode"] +git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" +uuid = "30578b45-9adc-5946-b283-645ec420af67" +version = "0.4.0" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["BinaryProvider", "Libdl", "Printf", "Test"] +git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.8.1" + +[[Zygote]] +deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "d3c2ae55d116b5360a73b1e88d1a974b446d933a" +repo-rev = "ffc50480ff8f7662110bfb82b0b6d4f9cef6e59d" +repo-url = "https://github.com/FluxML/Zygote.jl.git" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.2.0+" diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 000000000..541d5a4f5 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,5 @@ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/linear_regression.jl b/examples/linear_regression.jl new file mode 100644 index 000000000..8b1e2cfe8 --- /dev/null +++ b/examples/linear_regression.jl @@ -0,0 +1,99 @@ +# Initialize environment in current directory +@info("Ensuring example environment instantiated...") +import Pkg +Pkg.activate(@__DIR__) +Pkg.instantiate() + +@info("Loading Zygote...") +using Zygote, LinearAlgebra + +# This example will showcase how we do a simple linear fit with Zygote, making +# use of complex datastructures, a home-grown stochastic gradient descent +# optimizer, and some good old-fashioned math. We start with the problem +# statement: We wish to learn the mapping `f(X) -> Y`, where `X` is a matrix +# of vector observations, `f()` is a linear mapping function and `Y` is a +# vector of scalar observations. + +# Because we like complex objects, we will define our linear regression as the +# following object: +mutable struct LinearRegression + # These values will be implicitly learned + weights::Matrix + bias::Float64 + + # These values will not be learned + name::String +end +LinearRegression(nparams, name) = LinearRegression(randn(1, nparams), 0.0, name) + +# Our linear prediction looks very familiar; w*X + b +function predict(model::LinearRegression, X) + return model.weights * X .+ model.bias +end + +# Our "loss" that must be minimized is the l2 norm between our current +# prediction and our ground-truth Y +function loss(model::LinearRegression, X, Y) + return norm(predict(model, X) .- Y, 2) +end + + +# Our "ground truth" values (that we will learn, to prove that this works) +weights_gt = [1.0, 2.7, 0.3, 1.2]' +bias_gt = 0.4 + +# Generate a dataset of many observations +X = randn(length(weights_gt), 10000) +Y = weights_gt * X .+ bias_gt + +# Add a little bit of noise to `X` so that we do not have an exact solution, +# but must instead do a least-squares fit: +X .+= 0.001.*randn(size(X)) + + +# Now we begin our "training loop", where we take examples from `X`, +# calculate loss with respect to the corresponding entry in `Y`, find the +# gradient upon our model, update the model, and continue. Before we jump +# in, let's look at what `Zygote.gradient()` gives us: +@info("Building model...") +model = LinearRegression(size(X, 1), "Example") + +# Calculate gradient upon `model` for the first example in our training set +@info("Calculating gradient (the first time can take a while to compile...)") +grads = Zygote.gradient(model) do m + return loss(m, X[:,1], Y[1]) +end + +# The `grads` object is a Tuple containing one element per argument to +# `gradient()`, so we take the first one to get the gradient upon `model`: +grads = grads[1] + +# Because our LinearRegression object is mutable, the gradient holds a +# reference to it, which we peel via `grads[]`: +grads = grads[] + +# We now get a `NamedTuple` so we can now do things like `grads.weight`. Let's +# print it out, just to see what it looks like. Note that while `weights` and +# `bias` have gradients, `name` just naturally has a gradient of `nothing`, +# because it was not involved in the calculation of the output loss. +@info grads + +# Let's define an update rule that will allow us to modify the weights +# of our model a tad bit according to the gradients +function sgd_update!(model::LinearRegression, grads, η = 0.001) + model.weights .-= η .* grads.weights + model.bias -= η * grads.bias +end + +# Now let's do that for each example in our training set: +@info("Running train loop for $(size(X,2)) iterations") +for idx in 1:size(X, 2) + grads = Zygote.gradient(m -> loss(m, X[:, idx], Y[idx]), model)[1][] + sgd_update!(model, grads) +end + +# Now let's look at how well we've approximated the ground truth weights/bias: +@info("Ground truth weights: $(weights_gt)") +@info("Learned weights: $(round.(model.weights; digits=3))") +@info("Ground truth bias: $(bias_gt)") +@info("Learned bias: $(round(model.bias; digits=3))") diff --git a/examples/mnist_mlp.jl b/examples/mnist_mlp.jl new file mode 100644 index 000000000..60c3e7e02 --- /dev/null +++ b/examples/mnist_mlp.jl @@ -0,0 +1,107 @@ +# Initialize environment in current directory +@info("Ensuring example environment instantiated...") +import Pkg +Pkg.activate(@__DIR__) +Pkg.instantiate() + +@info("Loading Zygote and Flux...") +using Zygote, Flux, Random, Statistics +using Flux.Data.MNIST + +# We're going to showcase how to use Zygote with Flux; we'll create a simple +# Multi-Layer Perceptron network to do digit classification upon the MNIST +# dataset. We start with some setup that is ripped straight from the Flux +# model zoo: + +# First, we load the MNIST images and flatten them into a giant matrix: +@info("Loading dataset...") +X = hcat(float.(reshape.(MNIST.images(), :))...) + +# Load labels as well, one-hot encoding them +Y = float.(Flux.onehotbatch(MNIST.labels(), 0:9)) + +# Do the same for the test data/labels: +X_test = hcat(float.(reshape.(MNIST.images(:test), :))...) +Y_test = float.(Flux.onehotbatch(MNIST.labels(:test), 0:9)) + +@info("Constructing MLP model...") +model = Chain( + Dense(28^2, 32, relu), + Dense(32, 10), + softmax, +) + +# Until Flux drops Tracker as its default Automatic Differentiation library, +# strip it out with this line: +model = Flux.mapleaves(Flux.data, model) + +# Our loss is the classical multiclass crossentropy loss +loss(model, X, Y) = Flux.crossentropy(model(X), Y) + +# Helper function to calculate accuracy of our model +accuracy(model, X, Y) = mean(Flux.onecold(model(X)) .== Flux.onecold(Y)) + + +# Recursive zygote update method, this is the general recursion case: +function zyg_update!(opt, model, updates) + # If this `model` node has no fields, then just return it + if nfields(model) == 0 + return model + end + + # If it does have fields, recurse into them: + for field_idx in 1:nfields(model) + zyg_update!(opt, getfield(model, field_idx), getfield(updates, field_idx)) + end + + # In the end, return the `model` + return model +end +# If the `updates` is set to `Nothing`, then just return `model`; this means +# that there were no changes to be applied to this piece of the model. +zyg_update!(opt, model, updates::Nothing) = model + +# If `model` is an `AbstractArray` and `updates` is too, then apply our Flux +# optimizer to the incoming gradients and apply them to the model! +function zyg_update!(opt, model::AbstractArray, updates::AbstractArray) + # Sub off to Flux's ADAM optimizer + Flux.Optimise.apply!(opt, model, updates) + return model .-= updates +end + + +# We will train for a number of epochs, with minibatches, using the `ADAM` +# optimizer to nudge our weights toward perfection. +opt = ADAM(0.001) +num_epochs = 10 +@info("Training for $(num_epochs) epochs...") +for epoch_idx in 1:num_epochs + # "global" here to dodgescoping issues with for loops at top-level + global X, Y, model + + # Shuffle the data each epoch: + perm = shuffle(1:size(X,2)) + X = X[:, perm] + Y = Y[:, perm] + + # Iterate over batches + batch_size = 512 + batch_idxs = 1:batch_size:(size(X,2) - batch_size) + for bidx in batch_idxs + # Calculate gradients upon the model for this batch + grads = Zygote.gradient(model) do model + return loss(model, X[:, bidx:bidx+batch_size], + Y[:, bidx:bidx+batch_size]) + end + + # Peel outer Tuple to access gradient of first parameter + grads = grads[1] + + # Apply recursive update to our model: + zyg_update!(opt, model, grads) + end + + # After each epoch, report our accuracy on the test set: + acc = accuracy(model, X_test, Y_test) + @info("[$(epoch_idx)] Accuracy: $(round(100*acc; digits=1))%") +end diff --git a/examples/profiler.jl b/examples/profiler.jl index 340479331..513fd6933 100644 --- a/examples/profiler.jl +++ b/examples/profiler.jl @@ -1,3 +1,10 @@ +# Initialize environment in current directory +@info("Ensuring example environment instantiated...") +import Pkg +Pkg.activate(@__DIR__) +Pkg.instantiate() + +@info("Loading Zygote...") using Zygote function f(x) diff --git a/src/Zygote.jl b/src/Zygote.jl index e0d78e2e2..ac9684017 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -1,7 +1,9 @@ module Zygote -using LinearAlgebra -using LinearAlgebra: copytri! +using LinearAlgebra, Statistics +using LinearAlgebra: copytri!, AbstractTriangular + +import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _forward # This flag enables Zygote to grab extra type inference information during # compiles. When control flow is present, this can give gradient code a @@ -16,12 +18,9 @@ using IRTools using MacroTools, Requires using MacroTools: @forward -export Params, gradient, derivative, forward, @code_grad +export Params, gradient, forward, @code_grad include("tools/idset.jl") -include("tools/ir.jl") -include("tools/reflection.jl") -include("tools/fillarray.jl") include("compiler/reverse.jl") include("compiler/emit.jl") @@ -30,10 +29,10 @@ include("compiler/show.jl") include("lib/grad.jl") include("lib/lib.jl") -include("lib/real.jl") -include("lib/complex.jl") +include("lib/number.jl") include("lib/base.jl") include("lib/array.jl") +include("lib/buffer.jl") include("lib/nnlib.jl") include("lib/broadcast.jl") include("lib/forward.jl") @@ -47,8 +46,8 @@ usetyped || include("precompile.jl") include("profiler/Profile.jl") -@init @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin - isdefined(Flux, :Tracker) && include("flux.jl") +@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin + include("flux.jl") end # helps to work around 265-y issues diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index 6d03c82e2..84cf94990 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -14,8 +14,6 @@ function Base.pop!(stk::Stack) @inbounds return stk.data[i] end -xstack(T) = (Vector{T}, Expr(:call, Vector{T})) - function _push!(a::Vector{T}, x::T) where T Base._growend!(a, 1) @inbounds a[end] = x @@ -24,12 +22,12 @@ end # Emit -function alphauses(ir, bi) - us = [] - for i = range(ir.cfg.blocks[bi]), u in userefs(ir.stmts[i]) - u[] isa Alpha && push!(us, SSAValue(u[].id)) - end - return unique(us) +xstack(T) = stmt(Expr(:call, Vector{T}), type = Vector{T}) + +function alphauses(b) + us = Set{Alpha}() + postwalk(x -> x isa Alpha && push!(us, x), b) + return us end xtuple(xs...) = xcall(:tuple, xs...) @@ -38,95 +36,93 @@ concrete(T::DataType) = T concrete(::Type{Type{T}}) where T = typeof(T) concrete(T) = Any -function stacklines(adj::Adjoint) - recs = [] - for fb in adj.perm, α in alphauses(adj.back, invperm(adj.perm)[fb]) - pushfirst!(recs, adj.forw.linetable[adj.forw.lines[α.id]]) - end - return recs -end +runonce(b) = b.id in (1, length(b.ir.blocks)) function forward_stacks!(adj, F) stks, recs = [], [] - for fb in adj.perm, α in alphauses(adj.back, invperm(adj.perm)[fb]) - if fb == 1 - pushfirst!(recs, α) + pr = adj.primal + for b in blocks(pr), α in alphauses(block(adj.adjoint, b.id)) + if runonce(b) + push!(recs, Variable(α)) else - T = exprtype(adj.forw, α) - stk = insert_node!(adj.forw, 1, xstack(T)...) - pushfirst!(recs, stk) - insert_blockend!(adj.forw, blockidx(adj.forw, α.id), Any, xcall(Zygote, :_push!, stk, α)) + T = exprtype(pr, Variable(α)) + stk = pushfirst!(pr, xstack(T)) + push!(recs, stk) + push!(b, xcall(Zygote, :_push!, stk, Variable(α))) end - pushfirst!(stks, (invperm(adj.perm)[fb], alpha(α))) + push!(stks, (b.id, alpha(α))) end - args = [Argument(i) for i = 3:length(adj.forw.argtypes)] - T = Tuple{concrete.(exprtype.((adj.forw,), recs))...} + args = arguments(pr)[3:end] + T = Tuple{concrete.(exprtype.((pr,), recs))...} isconcretetype(T) || (T = Any) - rec = insert_node!(adj.forw, length(adj.forw.stmts), T, - xtuple(recs...)) - if usetyped - rec = insert_node!(adj.forw, length(adj.forw.stmts), Pullback{F,T}, - Expr(:call, Pullback{F,T}, rec)) + rec = push!(pr, xtuple(recs...)) + if usetyped && length(pr.blocks) > 1 + rec = push!(pr, Expr(:call, Pullback{F,T}, rec)) else - P = length(adj.perm) == 1 ? Pullback{F} : Pullback{F,Any} - rec = insert_node!(adj.forw, length(adj.forw.stmts), Any, - Expr(:call, P, rec)) + P = length(pr.blocks) == 1 ? Pullback{F} : Pullback{F,Any} + rec = push!(pr, Expr(:call, P, rec)) end - ret = xtuple(adj.forw.stmts[end].val, rec) - R = exprtype(adj.forw, adj.forw.stmts[end].val) - ret = insert_node!(adj.forw, length(adj.forw.stmts), Tuple{R,Pullback{F,T}}, ret) - adj.forw.stmts[end] = ReturnNode(ret) - forw = compact!(adj.forw) - return forw, stks + ret = xtuple(pr.blocks[end].branches[end].args[1], rec) + ret = push!(pr, ret) + pr.blocks[end].branches[end].args[1] = ret + return pr, stks end function reverse_stacks!(adj, stks) - ir = adj.back - t = insert_node!(ir, 1, Any, xcall(Base, :getfield, Argument(1), QuoteNode(:t))) - for b = 1:length(ir.cfg.blocks) - repl = Dict() + ir = adj.adjoint + entry = blocks(ir)[end] + self = argument!(entry, at = 1) + t = pushfirst!(blocks(ir)[end], xcall(:getfield, self, QuoteNode(:t))) + repl = Dict() + runonce(b) = b.id in (1, length(ir.blocks)) + for b in blocks(ir) for (i, (b′, α)) in enumerate(stks) - b == b′ || continue - loc, attach_after = afterphi(ir, range(ir.cfg.blocks[b])[1]) - loc = max(2, loc) - if adj.perm[b′] == 1 - val = insert_node!(ir, loc, Any, xcall(:getindex, t, i), attach_after) + b.id == b′ || continue + if runonce(b) + val = insertafter!(ir, t, xcall(:getindex, t, i)) else - stk = insert_node!(ir, 1, Any, xcall(:getindex, t, i)) - stk = insert_node!(ir, 1, Any, xcall(Zygote, :Stack, stk)) - val = insert_node!(ir, loc, Any, xcall(:pop!, stk), attach_after) + stk = push!(entry, xcall(:getindex, t, i)) + stk = push!(entry, xcall(Zygote, :Stack, stk)) + val = pushfirst!(b, xcall(:pop!, stk)) end repl[α] = val end - for i in range(ir.cfg.blocks[b]), u in userefs(ir.stmts[i]) - if u.stmt == Expr(:call, :Δ) - u.stmt = Argument(2) - elseif haskey(repl, u[]) - u[] = repl[u[]] - else continue - end - ir.stmts[i] = u.stmt - end end - return compact!(ir) + return IRTools.prewalk!(x -> get(repl, x, x), ir) end function stacks!(adj, T) forw, stks = forward_stacks!(adj, T) back = reverse_stacks!(adj, stks) + permute!(back, length(back.blocks):-1:1) + IRTools.domorder!(back) return forw, back end varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing +meta(T) = (usetyped ? IRTools.typed_meta : IRTools.meta)(T) + +function getmeta(T) + m = meta(T) + (usetyped && m != nothing) || return m + any(x -> isexpr(x, :goto, :gotoifnot), m.code.code) || return IRTools.meta(T) + return m +end + function _lookup_grad(T) - (m = meta(T)) == nothing && return - usetyped && m.ret == Union{} && return + (m = getmeta(T)) == nothing && return + m isa IRTools.TypedMeta && m.ret == Union{} && return va = varargs(m.method, length(T.parameters)) - forw, back = stacks!(Adjoint(IRCode(m), varargs = va), T) - # verify_ir(forw) - # verify_ir(back) + forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T) m, forw, back end -stacklines(T::Type) = stacklines(Adjoint(IRCode(meta(T)))) +function stacklines(T::Type) + adj = Adjoint(IR(meta(T)), normalise = false) + recs = [] + for b in blocks(adj.adjoint), α in alphauses(b) + push!(recs, IRTools.exprline(adj.primal, Variable(α))) + end + return recs +end diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 3707b256d..0a964f1c8 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -1,4 +1,4 @@ -mutable struct Context +mutable struct Context <: AContext cache::Union{IdDict{Any,Any},Nothing} globals::Union{Dict{GlobalRef,Any},Nothing} end @@ -38,15 +38,16 @@ function forward(f, args...) y, Δ -> tailmemaybe(back(Δ)) end +sensitivity(y::Number) = one(y) +sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.") +sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") + function gradient(f, args...) y, back = forward(f, args...) - y isa Real || error("Function output is not scalar") - return back(Int8(1)) + return back(sensitivity(y)) end -derivative(f::F, x) where F = gradient(f, x)[1] - -Base.adjoint(f::Function) = x -> derivative(f, x) +Base.adjoint(f::Function) = x -> gradient(f, x)[1] # Param-style wrappers @@ -83,7 +84,12 @@ end Base.show(io::IO, ps::Grads) = print(io, "Grads(...)") -@forward Grads.grads Base.setindex!, Base.getindex, Base.haskey +@forward Grads.grads Base.getindex, Base.haskey + +function Base.getindex(gs::Grads, x) + isbits(x) && error("Only reference types can be differentiated with `Params`.") + return gs.grads[x] +end function forward(f, ps::Params) cx = Context() @@ -96,3 +102,28 @@ function forward(f, ps::Params) Grads(cx.cache) # TODO make a copy end end + +# Code Reflection + +using InteractiveUtils +using InteractiveUtils: typesof +using Core: Typeof + +function code_ir(f, T) + m = meta(Tuple{Typeof(f),T.parameters...}) + return IR(m) +end + +function code_irm(ex) + isexpr(ex, :call) || error("@code_ir f(args...)") + f, args = ex.args[1], ex.args[2:end] + :(code_ir($(esc(f)), typesof($(esc.(args)...)))) +end + +macro code_ir(ex) + code_irm(ex) +end + +macro code_adjoint(ex) + :(Adjoint($(code_irm(ex)), varargs = varargs($(esc(:($InteractiveUtils.@which $ex))), length(($(esc.(ex.args)...),))))) +end diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 33d50e3c2..936367718 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -1,6 +1,8 @@ +using IRTools: argnames!, varargs!, inlineable!, pis!, slots! + ignore(T) = all(T -> T <: Type, T.parameters) -@generated function _forward(ctx::Context, f, args...) +@generated function _forward(ctx::AContext, f, args...) T = Tuple{f,args...} ignore(T) && return :(f(args...), Pullback{$T}(())) g = try _lookup_grad(T) catch e e end @@ -8,6 +10,7 @@ ignore(T) = all(T -> T <: Type, T.parameters) meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) + # IRTools.verify(forw) forw = slots!(pis!(inlineable!(forw))) return IRTools.update!(meta, forw) end @@ -20,11 +23,11 @@ end end if g == nothing Δ == Nothing && return :nothing - return :(error("Non-differentiable function $(j.t[1])")) + return :(error("Non-differentiable function $(repr(j.t[1]))")) end meta, _, back = g - resize!(back.argtypes, 2) argnames!(meta, Symbol("#self#"), :Δ) + # IRTools.verify(back) back = slots!(inlineable!(back)) return IRTools.update!(meta, back) end diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 38e0c25b2..e3bb16b11 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -1,3 +1,7 @@ +using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk, + blocks, predecessors, successors, argument!, arguments, branches, + exprtype, insertafter!, finish, expand!, prune!, substitute!, substitute, + block, block!, branch!, return!, stmt using Base: @get! @inline tuple_va(N, xs) = xs @@ -6,102 +10,56 @@ using Base: @get! iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n) -function isassert(ir, i) - ex = ir.stmts[i+3] - iscall(ex, Zygote, :typeassert) -end +gradindex(x, i) = x[i] +gradindex(::Nothing, i) = nothing +xgetindex(x, i...) = xcall(Base, :getindex, x, i...) +xgradindex(x, i) = xcall(Zygote, :gradindex, x, i) -# TODO: Move this to Base -function append_node!(ir, @nospecialize(typ), @nospecialize(node), line) - @assert isempty(ir.new_nodes) - push!(ir.stmts, node) - push!(ir.types, typ) - push!(ir.lines, line) - push!(ir.flags, 0) - last_bb = ir.cfg.blocks[end] - ir.cfg.blocks[end] = BasicBlock(StmtRange(first(last_bb.stmts):length(ir.stmts)), - last_bb.preds, - last_bb.succs) - return SSAValue(length(ir.stmts)) -end +normalise!(ir) = ir |> IRTools.merge_returns! -function merge_returns(ir) - any(x -> x == unreachable, ir.stmts) && error("Unsupported control flow") - rs = findall(x -> x isa ReturnNode && isdefined(x, :val), ir.stmts) - length(rs) <= 1 && return ir - bs = blockidx.(Ref(ir), rs) - xs = Any[] - bb = length(ir.cfg.blocks)+1 - @assert length(unique(bs)) == length(bs) - push!(ir.cfg.blocks, BasicBlock(StmtRange(length(ir.stmts)+1, length(ir.stmts)), bs, [])) - push!(ir.cfg.index, length(ir.stmts) + 1) - r = append_node!(ir, Any, nothing, ir.lines[end]) - append_node!(ir, Any, ReturnNode(r), ir.lines[end]) - for r in rs - x = ir.stmts[r].val - x = insert_node!(ir, r, Any, x) - push!(xs, x) - ir.stmts[r] = GotoNode(bb) - end - for b in bs - push!(ir.cfg.blocks[b].succs, bb) - end - ir.stmts[r.id] = PhiNode(bs, xs) - return compact!(ir) +function instrument_new!(ir, v, ex) + isexpr(ex, :new) ? (ir[v] = xcall(Zygote, :__new__, ex.args...)) : + isexpr(ex, :splatnew) ? (ir[v] = xcall(Zygote, :__splatnew__, ex.args...)) : + ex end -function merge_entry(ir) - isempty(ir.cfg.blocks[1].preds) && return ir - ir = IncrementalCompact(ir) - insert_node_here!(ir, nothing, Nothing, Int32(0)) - foreach(_ -> nothing, ir) - ir = finish(ir) - for i in 1:length(ir.cfg.blocks) - old = ir.cfg.blocks[i] - ir.cfg.blocks[i] = BasicBlock(old.stmts, old.preds .+ 1, old.succs .+ 1) - end - for i = 1:length(ir.stmts) - ex = ir.stmts[i] - ir.stmts[i] = - ex isa GotoNode ? GotoNode(ex.label+1) : - ex isa GotoIfNot ? GotoIfNot(ex.cond, ex.dest+1) : - ex - end - pushfirst!(ir.cfg.blocks, BasicBlock(StmtRange(1:1), [], [2])) - old = ir.cfg.blocks[2] - ir.cfg.blocks[2] = BasicBlock(StmtRange(2:last(old.stmts)), old.preds, old.succs) - pushfirst!(ir.cfg.blocks[2].preds, 1) - return ir -end +# Hack to work around fragile constant prop through overloaded functions +unwrapquote(x) = x +unwrapquote(x::QuoteNode) = x.value -normalise(ir) = ir |> merge_entry |> merge_returns +is_literal_getproperty(ex) = + (iscall(ex, Base, :getproperty) || iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) && + ex.args[3] isa Union{QuoteNode,Integer} -struct Alpha - id::Int +function instrument_getproperty!(ir, v, ex) + is_literal_getproperty(ex) ? + (ir[v] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(unwrapquote(ex.args[3])))) : + ex end -Base.show(io::IO, x::Alpha) = print(io, "@", x.id) +is_literal_getindex(ex) = + iscall(ex, Base, :getindex) && length(ex.args) == 3 && ex.args[3] isa Union{Integer,QuoteNode} -alpha(x) = x -alpha(x::SSAValue) = Alpha(x.id) +function instrument_getindex!(ir, v, ex) + is_literal_getindex(ex) ? + (ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3])))) : + ex +end -gradindex(x, i) = x[i] -gradindex(::Nothing, i) = nothing -xgetindex(x, i...) = Expr(:call, GlobalRef(Base, :getindex), x, i...) -xgradindex(x, i) = xcall(Zygote, :gradindex, x, i) +is_literal_iterate(ex) = + iscall(ex, Base, :indexed_iterate) && length(ex.args) >= 3 && ex.args[3] isa Union{Integer,QuoteNode} -function record_branches!(ir::IRCode) - ir = IncrementalCompact(ir) - offset = 0 - for (i, x) in ir - bi = findfirst(x -> x == i+1-offset, ir.ir.cfg.index) - bi == nothing && continue - preds = ir.ir.cfg.blocks[bi+1].preds - length(preds) > 1 || continue - insert_node_here!(ir, PhiNode(sort(preds), Int8.(1:length(preds))), Int8, ir.result_lines[i]) - offset += 1 - end - return finish_dc(ir) +function instrument_iterate!(ir, v, ex) + is_literal_iterate(ex) ? + (ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2], + Val(unwrapquote(ex.args[3])), ex.args[4:end]...)) : + ex +end + +function instrument_literals!(ir, v, ex) + ex = instrument_getproperty!(ir, v, ex) + ex = instrument_getindex!(ir, v, ex) + ex = instrument_iterate!(ir, v, ex) end function istrackable(x) @@ -111,18 +69,45 @@ function istrackable(x) !(x isa Type || sizeof(x) == 0) end -function record_globals!(ir::IRCode) - for i = 1:length(ir.stmts) - ex = ir[SSAValue(i)] - # TODO general globalrefs - if isexpr(ex, :call) - for j = 1:length(ex.args) - istrackable(ex.args[j]) || continue - ex.args[j] = insert_node!(ir, i, Any, xcall(Zygote, :unwrap, QuoteNode(ex.args[j]), ex.args[j])) - end +function instrument_global!(ir, v, ex) + if istrackable(ex) + ir[v] = xcall(Zygote, :unwrap, QuoteNode(ex), ex) + else + ir[v] = prewalk(ex) do x + istrackable(x) || return x + insert!(ir, v, stmt(xcall(Zygote, :unwrap, QuoteNode(x), x), type = exprtype(x))) + end + end +end + +function instrument(ir::IR) + pr = Pipe(ir) + for (v, st) in pr + ex = st.expr + isexpr(ex, :foreigncall) && continue + isexpr(ex, :enter, :leave) && error("try/catch is not supported.") + ex = instrument_new!(pr, v, ex) + ex = instrument_literals!(pr, v, ex) + ex = instrument_global!(pr, v, ex) + end + return finish(pr) +end + +const BranchNumber = UInt8 + +function record_branches!(ir::IR) + brs = Dict{Int,Variable}() + for bb in blocks(ir) + preds = predecessors(bb) + length(preds) > 1 || continue + brs[bb.id] = argument!(bb, BranchNumber(0), BranchNumber) + i = length(arguments(bb)) + n = 0 + for aa in blocks(ir), br in branches(aa) + br.block == bb.id && (arguments(br)[i] = BranchNumber(n += 1)) end end - return compact!(ir) + return ir, brs end ignored_f(f) = f in (GlobalRef(Base, :not_int), @@ -132,29 +117,13 @@ ignored_f(f) = f in (GlobalRef(Base, :not_int), GlobalRef(Core, :typeof), GlobalRef(Core, :throw), GlobalRef(Base, :kwerr), - GlobalRef(Core, :kwfunc)) + GlobalRef(Core, :kwfunc), + GlobalRef(Core, :isdefined)) ignored_f(ir, f) = ignored_f(f) -ignored_f(ir, f::SSAValue) = ignored_f(ir[f]) +ignored_f(ir, f::Variable) = ignored_f(get(ir, f, nothing)) ignored(ir, ex) = isexpr(ex, :call) && ignored_f(ir, ex.args[1]) -ignored(ir, ex::SSAValue) = ignored(ir, ir[ex]) - -function valid_usages(ir) - r = Dict() - for (x, us) in usages(ir) - x isa Union{SSAValue,Argument} || continue - us′ = filter(i -> !ignored(ir, i), us) - isempty(us′) || (r[x] = us′) - end - return r -end - -reachable(ir) = keys(valid_usages(ir)) - -# Hack to work around fragile constant prop through overloaded functions -is_literal_getproperty(ex) = - (iscall(ex, Base, :getproperty) || iscall(ex, Core, :getfield)) && - ex.args[3] isa QuoteNode +ignored(ir, ex::Variable) = ignored(ir, ir[ex]) # TODO: remove this once we don't mess with type inference function _forward_type(Ts) @@ -166,274 +135,165 @@ end isvalidtype(jT, yT) = jT <: Tuple && length(jT.parameters) == 2 && jT.parameters[1] <: yT -function record!(ir::IRCode) - pushfirst!(ir.argtypes, typeof(_forward), Context) - xs = reachable(ir) - for i = 1:length(ir.stmts) - ex = argmap(x -> Argument(x.n+2), ir[SSAValue(i)]) - isexpr(ex, :new) && (ex = ir[SSAValue(i)] = xcall(Zygote, :__new__, ex.args...)) - isexpr(ex, :splatnew) && (ex = ir[SSAValue(i)] = xcall(Zygote, :__splatnew__, ex.args...)) - is_literal_getproperty(ex) && - (ex = ir[SSAValue(i)] = xcall(Zygote, :literal_getproperty, ex.args[2], Val(ex.args[3].value))) +function primal(ir::IR) + pr = Pipe(ir) + pbs = Dict{Variable,Variable}() + argument!(pr, at = 1) + cx = argument!(pr, Context, at = 2) + for (v, st) in pr + ex = st.expr if isexpr(ex, :call) && !ignored(ir, ex) - yT = widenconst(types(ir)[i]) - T = _forward_type(exprtype.(Ref(ir), ex.args)) + yT = exprtype(ir, v) + T = _forward_type(exprtype.((ir,), ex.args)) if yT == Any || isvalidtype(T, yT) - yJ = insert_node!(ir, i, T, xcall(Zygote, :_forward, Argument(2), ex.args...)) - ir[SSAValue(i)] = xgetindex(yJ, 1) - insert_node!(ir, i, T == Any ? Any : T.parameters[2], xgetindex(yJ, 2), true) + yJ = insert!(pr, v, stmt(xcall(Zygote, :_forward, cx, ex.args...), + line = ir[v].line)) + pr[v] = xgetindex(yJ, 1) + J = insertafter!(pr, v, stmt(xgetindex(yJ, 2), + type = T == Any ? Any : T.parameters[2], + line = ir[v].line)) + pbs[v] = substitute(pr, J) else - yJ = insert_node!(ir, i, Any, xcall(Zygote, :_forward, Argument(2), ex.args...)) - y = insert_node!(ir, i, Any, xgetindex(yJ, 1)) - J = insert_node!(ir, i, Any, xgetindex(yJ, 2)) - ir[SSAValue(i)] = xcall(Zygote, :typeassert, y, yT) + yJ = insert!(pr, v, xcall(Zygote, :_forward, cx, ex.args...)) + y = insert!(pr, v, xgetindex(yJ, 1)) + J = insert!(pr, v, stmt(xgetindex(yJ, 2), line = ir[v].line)) + pr[v] = xcall(Zygote, :typeassert, y, yT) + pbs[v] = substitute(pr, J) end - else - ir[SSAValue(i)] = ex end end - ir, m = _compact!(ir) - return ir, Set(x isa Argument ? Argument(x.n+2) : x for x in rename(xs, m)) -end - -# Backwards Pass - -function reverse_cfg(cfg, perm) - newidx(i) = invperm(perm)[i] - CFG([BasicBlock(StmtRange(1,0),newidx.(b.succs),newidx.(b.preds)) for b in cfg.blocks[perm]]) -end - -function reverse_order(cfg) - n = length(cfg.blocks) - perm = n:-1:1 - guess = reverse_cfg(cfg, perm) - dt = construct_domtree(guess) - perm[sortperm(1:n, by = x -> dt.nodes[x].level)] + pr = finish(pr) + pr, brs = record_branches!(pr) + return pr, brs, pbs end struct Primal - forw::IRCode - perm::Vector{Int} - wrt::Set{Any} + ir::IR + pr::IR varargs::Union{Int,Nothing} + branches::Dict{Int,Variable} + pullbacks::Dict{Variable,Variable} end -Primal(ir::IRCode, xs, vs) = Primal(ir, reverse_order(ir.cfg), xs, vs) +function Primal(ir::IR; varargs = nothing) + ir = instrument(normalise!(ir)) + pr, brs, pbs = primal(ir) + Primal(expand!(ir), pr, varargs, brs, pbs) +end + +# Backwards Pass -function Primal(ir::IRCode; varargs = nothing) - ir = normalise(ir) - forw, xs = record!(record_branches!(record_globals!(ir))) - Primal(forw, xs, varargs) +struct Alpha + id::Int end -newblock(pr::Primal, b) = invperm(pr.perm)[b] -oldblock(pr::Primal, b) = pr.perm[b] - -function blockinfo(pr::Primal) - preds(b) = pr.forw.cfg.blocks[b].preds - info = Dict(b => (phis=Dict(),partials=[],grads=[]) for b in 1:length(pr.forw.cfg.blocks)) - append!(info[1].grads, filter(x -> x isa Argument, pr.wrt)) - for b in 1:length(pr.forw.cfg.blocks), i in pr.forw.cfg.blocks[b].stmts - ex = pr.forw[SSAValue(i)] - if ex isa ReturnNode - ex.val in pr.wrt && push!(info[b].partials, ex.val) - elseif ex isa PiNode - (SSAValue(i) in pr.wrt && ex.val in pr.wrt) || continue - push!(info[b].grads, SSAValue(i)) - push!(info[b].partials, ex.val) - elseif ex isa PhiNode - any(x -> x in pr.wrt, ex.values) && push!(info[b].grads, SSAValue(i)) - for (c, x) in zip(ex.edges, ex.values) - x in pr.wrt && push!(@get!(info[b].phis, c, []), x) - end - elseif iscall(ex, Zygote, :_forward) - y = isassert(pr.forw, i) ? SSAValue(i+3) : SSAValue(i+1) - push!(info[b].grads, y) - for x in ex.args[3:end] - x in pr.wrt && push!(info[b].partials, x) - end +Base.show(io::IO, x::Alpha) = print(io, "@", x.id) + +alpha(x) = x +alpha(x::Variable) = Alpha(x.id) +Variable(a::Alpha) = Variable(a.id) + +sig(b::IRTools.Block) = unique([arg for br in branches(b) for arg in br.args if arg isa Variable]) +sig(pr::Primal) = Dict(b.id => sig(b) for b in blocks(pr.ir)) + +# TODO unreachables? +function adjointcfg(pr::Primal) + ir = empty(pr.ir) + return!(ir, nothing) + for b in blocks(pr.ir)[2:end] + block!(ir) + preds = predecessors(b) + rb = block(ir, b.id) + for i = 1:length(preds) + cond = i == length(preds) ? nothing : + push!(rb, xcall(Base, :(!==), alpha(pr.branches[b.id]), BranchNumber(i))) + branch!(rb, preds[i].id, unless = cond) end - end - worklist = collect(1:length(pr.forw.cfg.blocks)) - while !isempty(worklist) - b = pop!(worklist) - for c in preds(b) - in = union(get(info[b].phis, c, []), setdiff(info[b].partials, info[b].grads)) - out = union(info[c].partials, info[c].grads) - new = setdiff(in, out) - if !isempty(new) - append!(info[c].partials, new) - c ∉ worklist && push!(worklist, c) - end + if !isempty(branches(b)) && branches(b)[end] == IRTools.unreachable + branch!(rb, 0) end end - return info -end - -function IRCode(ir::Primal) - stmts = [] - blocks = [] - newidx(i) = invperm(ir.perm)[i] - for block in ir.perm - old = ir.forw.cfg.blocks[block] - start = length(stmts)+1 - block == length(ir.perm) && push!(stmts, :(Δ())) - preds, succs = newidx.(old.succs), newidx.(sort(old.preds)) - if isempty(succs) - push!(stmts, nothing) - else - for (i, b) in enumerate(succs[1:end-1]) - push!(stmts, xcall(Base, :(!==), Alpha(range(old)[1]), Int8(i))) - push!(stmts, GotoIfNot(SSAValue(length(stmts)), b)) - end - push!(stmts, GotoNode(succs[end])) - end - push!(blocks, BasicBlock(StmtRange(start,length(stmts)), preds, succs)) + sigs = sig(pr) + for b in blocks(ir)[1:end-1], i = 1:length(sigs[b.id]) + argument!(b) end - ir = IRCode(ir.forw, stmts, Any[Any for _ in stmts], Int32[0 for _ in stmts], - [0x00 for _ in stmts], CFG(blocks), NewNode[]) + argument!(blocks(ir)[end]) + return ir, sigs end -function reverse_ir(pr::Primal) - ir = IRCode(pr) - grads = Dict() - partials = Dict(x => [] for x in pr.wrt) - phis = Dict() - for b in pr.perm - j = ir.cfg.blocks[newblock(pr, b)].stmts[1] - j = max(j, 2) - for i in reverse(pr.forw.cfg.blocks[b].stmts) - ex = pr.forw[SSAValue(i)] - if ex isa ReturnNode - ex.val in pr.wrt && push!(partials[ex.val], SSAValue(1)) - elseif ex isa PiNode - (SSAValue(i) in pr.wrt && ex.val in pr.wrt) || continue - Δ = insert_node!(ir, j, Any, xcall(Zygote, :accum)) - ir.lines[j] = pr.forw.lines[i] - grads[SSAValue(i)] = Δ - push!(partials[ex.val], Δ) - elseif ex isa PhiNode - any(x -> x in pr.wrt, ex.values) || continue - Δ = insert_node!(ir, j, Any, xcall(Zygote, :accum)) - ir.lines[j] = pr.forw.lines[i] - grads[SSAValue(i)] = Δ - for (c, x) in zip(ex.edges, ex.values) - x in pr.wrt || continue - @assert !haskey(phis, (newblock(pr, b), newblock(pr, c), x)) "not implemented" - phis[(newblock(pr, b), newblock(pr, c), x)] = Δ - end - elseif iscall(ex, Zygote, :_forward) - # TODO remove with type hacks above - y = isassert(pr.forw, i) ? SSAValue(i+3) : SSAValue(i+1) - J = Alpha(i+2) - dy = insert_node!(ir, j, Any, xcall(Zygote, :accum)) - ir.lines[j] = pr.forw.lines[i] - dxs = insert_node!(ir, j, Any, Expr(:call, J, dy)) - ir.lines[j] = pr.forw.lines[i] - grads[y] = dy - for (a, x) in enumerate(ex.args[3:end]) - x in pr.wrt || continue - dx = insert_node!(ir, j, Any, xgradindex(dxs, a)) - ir.lines[j] = pr.forw.lines[i] - push!(partials[x], dx) - end - elseif isexpr(ex, :call, :isdefined, GotoIfNot, GotoNode, Nothing, GlobalRef) - # ignore it - else - desc = isexpr(ex) ? "$(ex.head) expression" : ex - insert_node!(ir, j, Any, xcall(Base, :error, "Can't differentiate $desc")) - ir.lines[j] = pr.forw.lines[i] - end +branchfor(ir, (from,to)) = + get(filter(br -> br.block == to, branches(block(ir, from))), 1, nothing) + +xaccum(ir) = nothing +xaccum(ir, x) = x +xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...)) + +function adjoint(pr::Primal) + ir, sigs = adjointcfg(pr) + for b in reverse(blocks(pr.ir)) + rb = block(ir, b.id) + grads = Dict() + grad(x, x̄) = push!(get!(grads, x, []), x̄) + grad(x) = xaccum(rb, get(grads, x, [])...) + # Backprop through (successor) branch arguments + for i = 1:length(sigs[b.id]) + grad(sigs[b.id][i], arguments(rb)[i]) end - if b == 1 - gs = [] - for i = 3:length(pr.forw.argtypes) - Argument(i) in pr.wrt || (push!(gs, nothing); continue) - dx = insert_node!(ir, j, Any, xcall(Zygote, :accum)) - grads[Argument(i)] = dx - push!(gs, dx) - end - if pr.varargs == nothing - Δ = insert_node!(ir, j, Any, xcall(Zygote, :tuple, gs...)) - else - Δ = insert_node!(ir, j, Any, xcall(Zygote, :tuple_va, Val(pr.varargs), gs...)) + # Backprop through statements + for v in reverse(keys(b)) + ex = b[v].expr + if haskey(pr.pullbacks, v) + g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), + line = b[v].line)) + for (i, x) in enumerate(ex.args) + x isa Variable || continue + grad(x, push!(rb, stmt(xgradindex(g, i), + line = b[v].line))) + end + elseif ex isa Core.PiNode + grads[ex.val] = grads[v] + elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta) + elseif isexpr(ex) + push!(rb, stmt(xcall(Base, :error, "Can't differentiate $(ex.head) expression"), + line = b[v].line)) + else # A literal value + continue end - insert_node!(ir, j, Any, ReturnNode(Δ)) end - end - ir, m = _compact!(ir) - return ir, rename(grads, m), rename(partials, m), rename(phis, m) -end - -function simplify!(ir) - ir = IncrementalCompact(ir) - for (i, x) in ir - iscall(x, Zygote, :accum) || continue - filter!(x -> x != nothing, x.args) - nargs = length(x.args)-1 - ir[i] = nargs == 0 ? nothing : nargs == 1 ? x.args[end] : x - end - return finish(ir) -end - -function accumulators!(pr::Primal, ir::IRCode, grads, partials, phis) - blockpartials(b, x) = filter(x -> x.id in ir.cfg.blocks[b].stmts, get(partials, x, [])) - accums = Dict() - info = blockinfo(pr) - for b = 1:length(ir.cfg.blocks), x in setdiff(info[b].partials, info[b].grads) - ps = blockpartials(newblock(pr, b), x) - p = insert_blockend!(ir, newblock(pr, b), Any, xcall(Zygote, :accum, ps...)) - setdiff!(partials[x], ps) - push!(partials[x], p) - accums[(newblock(pr, b),x)] = p - end - - # Work around ordering issues with `accum` stmts and phis - ir, m = _compact!(ir) - accums, phis, grads, partials = rename((accums, phis, grads, partials), m) - - function predpartial(b, x) - function blockpartial(b, c, x) - if haskey(accums, (b, x)) - @assert !haskey(phis, (b, c, x)) "not implemented" - return accums[(b, x)] - elseif haskey(phis, (b, c, x)) - return phis[(b, c, x)] + if b.id > 1 # Backprop through (predecessor) branch arguments + gs = grad.(arguments(b)) + for br in branches(rb) + br.block == 0 && continue + br′ = branchfor(pr.ir, br.block=>b.id) + br′ == nothing && continue + ins = br′.args + for i = 1:length(br.args) + ā = [gs[j] for j = 1:length(ins) if ins[j] == sigs[br.block][i]] + br.args[i] = xaccum(rb, ā...) + end end + else # Backprop function arguments + gs = [grad(arg) for arg = arguments(pr.ir)] + Δ = push!(rb, pr.varargs == nothing ? + xcall(Zygote, :tuple, gs...) : + xcall(Zygote, :tuple_va, Val(pr.varargs), gs...)) + branches(rb)[1].args[1] = Δ end - preds = ir.cfg.blocks[b].preds - isempty(preds) && return - ps = map(c -> blockpartial(c, b, x), preds) - all(==(nothing), ps) && return - length(ps) == 1 ? ps[1] : insert_blockstart!(ir, b, Any, PhiNode(preds, ps)) - end - - for ((b, x), p) in accums - push!(ir[p].args, predpartial(b, x)) - end - for (x, dx) in grads - b = blockidx(ir, dx) - append!(ir[dx].args, blockpartials(b, x)) - push!(ir[dx].args, predpartial(b, x)) end - return simplify!(ir) + return ir end struct Adjoint - forw::IRCode - back::IRCode - perm::Vector{Int} + primal::IR + adjoint::IR end -function Adjoint(pr::Primal) - back = accumulators!(pr, reverse_ir(pr)...) - Adjoint(pr.forw, compact!(compact!(back)), pr.perm) -end - -Adjoint(ir::IRCode; varargs = nothing) = Adjoint(Primal(ir, varargs = varargs)) - -using InteractiveUtils - -macro code_adjoint(ex) - :(Adjoint($(code_irm(ex)), varargs = varargs($(esc(:($InteractiveUtils.@which $ex))), length(($(esc.(ex.args)...),))))) +function Adjoint(ir::IR; varargs = nothing, normalise = true) + pr = Primal(ir, varargs = varargs) + adj = adjoint(pr) |> prune! + if normalise + permute!(adj, length(adj.blocks):-1:1) + adj = IRTools.domorder!(adj) |> IRTools.renumber + end + Adjoint(pr.pr, adj) end diff --git a/src/compiler/show.jl b/src/compiler/show.jl index ea4f5453c..d318e44c8 100644 --- a/src/compiler/show.jl +++ b/src/compiler/show.jl @@ -9,4 +9,5 @@ function funcname(T) end Base.show(io::IO, j::Pullback{S}) where S = print(io, "∂($(funcname(S.parameters[1])))") -Base.show(io::IO, P::Type{<:Pullback{S}}) where S = print(io, "typeof(∂($(funcname(S.parameters[1]))))") + +Base.show(io::IO, P::Type{<:Pullback{S}}) where S<:Tuple = print(io, "typeof(∂($(funcname(S.parameters[1]))))") diff --git a/src/flux.jl b/src/flux.jl index 85048fafc..8f85ab9a8 100644 --- a/src/flux.jl +++ b/src/flux.jl @@ -1,8 +1,7 @@ -using .Flux -using .Flux.Tracker: TrackedArray, TrackedReal +using .Tracker: TrackedArray, TrackedReal if !usetyped - unwrap(x::Union{TrackedArray,TrackedReal}) = Flux.data(x) + unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x) end -forward(f, ps::Flux.Tracker.Params) = forward(f, Params(ps)) +forward(f, ps::Tracker.Params) = forward(f, Params(ps)) diff --git a/src/lib/array.jl b/src/lib/array.jl index 1c792ec70..aeb305848 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -1,13 +1,26 @@ -@adjoint (::Type{T})(args...) where T<:Array = T(args...), Δ -> nothing +using FillArrays, FFTW +using FillArrays: AbstractFill, getindex_value +using Base.Broadcast: broadcasted, broadcast_shape + +@adjoint (::Type{T})(::UndefInitializer, args...) where T<:Array = T(undef, args...), Δ -> nothing + +@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,) @nograd size, length, eachindex, Colon(), findfirst, randn, ones, zeros, one, zero, - print, println + print, println, any, all @adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,) -@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,) +@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,) +# Array Constructors +@adjoint (::Type{T})(x::T) where T<:Array = T(x), ȳ -> (ȳ,) +@adjoint (::Type{T})(x::Number, sz) where {T <: Fill} = Fill(x, sz), Δ -> (sum(Δ), nothing) +@adjoint (::Type{T})(sz) where {T<:Zeros} = Zeros(sz), Δ->(nothing,) +@adjoint (::Type{T})(sz) where {T<:Ones} = Ones(sz), Δ->(nothing,) + +_zero(xs::AbstractArray{<:Integer}) = fill!(similar(xs, float(eltype(xs))), false) _zero(xs::AbstractArray{<:Number}) = zero(xs) _zero(xs::AbstractArray) = Any[nothing for x in xs] @@ -22,18 +35,35 @@ end @adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...), _ -> error("Mutating arrays is not supported") +@adjoint function view(x::AbstractArray, inds...; kw...) + view(x, inds...; kw...), dy -> begin + dx = _zero(x) + copyto!(view(dx, inds...; kw...), dy) + (dx, map(_->nothing, inds)...) + end +end + # General -@adjoint collect(x) = collect(x), Δ -> (Δ,) +@adjoint collect(x::Array) = collect(x), Δ -> (Δ,) + +@adjoint fill(x::Real, dims...) = fill(x, dims...), Δ->(sum(Δ), map(_->nothing, dims)...) + +@adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),) + +@adjoint permutedims(xs::AbstractVector) = permutedims(xs), Δ -> (vec(permutedims(Δ)),) @adjoint permutedims(xs, dims) = permutedims(xs, dims), Δ -> (permutedims(Δ, invperm(dims)), nothing) +@adjoint PermutedDimsArray(xs, dims) = PermutedDimsArray(xs, dims), + Δ -> (PermutedDimsArray(Δ, invperm(dims)), nothing) + @adjoint reshape(xs, dims...) = reshape(xs, dims...), Δ -> (reshape(Δ, size(xs)),map(_->nothing,dims)...) @adjoint function hvcat(rows::Tuple{Vararg{Int}}, xs::T...) where T<:Number - hvcat(rows, xs...), ȳ -> (nothing, ȳ...) + hvcat(rows, xs...), ȳ -> (nothing, permutedims(ȳ)...) end pull_block_vert(sz, Δ, A::AbstractVector) = Δ[sz-length(A)+1:sz] @@ -67,19 +97,59 @@ end end end +@adjoint getindex(i::Int, j::Int) = i[j], _ -> nothing + +function unzip(tuples) + map(1:length(first(tuples))) do i + map(tuple -> tuple[i], tuples) + end +end +function ∇map(cx, f, args...) + ys_and_backs = map((args...) -> _forward(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> nothing + else + ys, backs = unzip(ys_and_backs) + ys, function (Δ) + Δf_and_args_zipped = map((f, δ) -> f(δ), backs, Δ) + Δf_and_args = unzip(Δf_and_args_zipped) + Δf = reduce(accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + end +end + +@adjoint function map(f, args::Union{AbstractArray,Tuple}...) + ∇map(__context__, f, args...) +end + +function _forward(cx::AContext, ::typeof(collect), g::Base.Generator) + y, back = ∇map(cx, g.f, g.iter) + y, function (ȳ) + f̄, x̄ = back(ȳ) + (nothing, (f = f̄, iter = x̄),) + end +end + +@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing + # Reductions @adjoint function sum(xs::AbstractArray; dims = :) if dims === (:) - sum(xs), Δ -> (FillArray(Δ, size(xs)),) + sum(xs), Δ -> (Fill(Δ, size(xs)),) else sum(xs, dims = dims), Δ -> (similar(xs) .= Δ,) end end -function _forward(cx::Context, ::typeof(sum), f, xs::AbstractArray) +function _forward(cx::AContext, ::typeof(sum), f, xs::AbstractArray) y, back = forward(cx, (xs -> sum(f.(xs))), xs) - y, ȳ -> (nothing, nothing, back(ȳ)...) + y, ȳ -> (nothing, nothing, back(ȳ)...) +end + +@adjoint function sum(::typeof(abs2), X::AbstractArray; dims = :) + return sum(abs2, X; dims=dims), Δ::Union{Number, AbstractArray}->(nothing, ((2Δ) .* X)) end @adjoint function prod(xs; dims = :) @@ -87,10 +157,15 @@ end p, Δ -> (p ./ xs .* Δ,) end +function _forward(cx::AContext, ::typeof(prod), f, xs::AbstractArray) + y, back = forward(cx, (xs -> prod(f.(xs))), xs) + y, ȳ -> (nothing, nothing, back(ȳ)...) +end + @adjoint function maximum(xs; dims = :) max, i = findmax(xs, dims = dims) max, function (Δ) - Δ isa Real && Δ <= sqrt(eps(float(Δ))) && return nothing + Δ isa Real && abs(Δ) <= sqrt(eps(float(Δ))) && return nothing Δ′ = zero(xs) Δ′[i] = Δ return (Δ′,) @@ -106,17 +181,46 @@ end end end +@adjoint function mean(xs::AbstractArray; dims = :) + return mean(xs, dims=dims), Δ -> (_backmean(xs,Δ,dims),) +end +_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs) +_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(xs,i),*,dims) + # LinAlg +# ====== -@adjoint function(a::AbstractVecOrMat * b::AbstractVecOrMat) - return a * b, function(Δ) - return (reshape(Δ * transpose(b), size(a)), reshape(transpose(a) * Δ, size(b))) - end +@adjoint function(A::AbstractMatrix * B::AbstractMatrix) + return A * B, Δ::AbstractMatrix->(Δ * B', A' * Δ) +end + +@adjoint function(A::AbstractMatrix * x::AbstractVector) + return A * x, Δ::AbstractVector->(Δ * x', A' * Δ) +end + +@adjoint function *(x::Union{Transpose{<:Any, <:AbstractVector}, + LinearAlgebra.Adjoint{<:Any, <:AbstractVector}}, + y::AbstractVector) + return x * y, Δ->(Δ * y', x' * Δ) +end + +@adjoint function(a::AbstractVector * x::AbstractMatrix) + return a * x, Δ::AbstractMatrix->(vec(Δ * x'), a' * Δ) +end + +@adjoint function transpose(x) + back(Δ) = (transpose(Δ),) + back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,) + return transpose(x), back end -@adjoint transpose(x) = transpose(x), Δ -> (transpose(Δ),) -@adjoint Base.adjoint(x) = x', Δ -> (Δ',) -@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),) +@adjoint function Base.adjoint(x) + back(Δ) = (Δ',) + back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,) + return x', back +end + +@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),) @adjoint dot(x::AbstractArray, y::AbstractArray) = dot(x, y), Δ->(Δ .* y, Δ .* x) @@ -132,49 +236,117 @@ end @adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = forward(_kron, a, b) -@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing +@adjoint function Diagonal(d::AbstractVector) + back(Δ::NamedTuple) = (Δ.diag,) + back(Δ::AbstractMatrix) = (diag(Δ),) + return Diagonal(d), back +end @adjoint diag(A::AbstractMatrix) = diag(A), Δ->(Diagonal(Δ),) -@adjoint function \(A::AbstractMatrix, B::AbstractVecOrMat) - Y = A \ B - return Y, function(Ȳ) - B̄ = A' \ Ȳ - return (-B̄ * Y', B̄) +@adjoint det(xs) = det(xs), Δ -> (Δ * det(xs) * transpose(inv(xs)),) + +@adjoint logdet(xs) = logdet(xs), Δ -> (Δ * transpose(inv(xs)),) + +@adjoint logabsdet(xs) = logabsdet(xs), Δ -> (Δ[1] * transpose(inv(xs)),) + +@adjoint function inv(A) + return inv(A), function (Δ) + Ainv = inv(A) + ∇A = - Ainv' * Δ * Ainv' + return (∇A, ) end end -# This is basically a hack while we don't have a working `ldiv!`. -@adjoint function \(A::Cholesky, B::AbstractVecOrMat) - Y, back = Zygote.forward((U, B)->U \ (U' \ B), A.U, B) - return Y, function(Ȳ) - Ā_factors, B̄ = back(Ȳ) - return ((uplo=nothing, status=nothing, factors=Ā_factors), B̄) +# Defaults for atol and rtol copied directly from LinearAlgebra. See the following for +# derivation: +# Golub, Gene H., and Victor Pereyra. "The differentiation of pseudo-inverses and nonlinear +# least squares problems whose variables separate." SIAM Journal on numerical analysis 10.2 +# (1973): 413-432. +@adjoint function pinv( + A::AbstractMatrix{T}; + atol::Real = 0.0, + rtol::Real = (eps(real(float(one(T))))*min(size(A)...))*iszero(atol), +) where {T} + Y = pinv(A) + return Y, Δ->(-Y' * Δ * Y' + (I - A * Y) * Δ' * Y * Y' + Y' * Y * Δ' * (I - Y * A),) +end + +@adjoint function \(A::Union{Diagonal, AbstractTriangular}, B::AbstractVecOrMat) + Y = A \ B + return Y, function(Ȳ) + B̄ = A' \ Ȳ + return (-B̄ * Y', B̄) end end -@adjoint function /(A::AbstractMatrix, B::AbstractMatrix) +@adjoint function /(A::AbstractMatrix, B::Union{Diagonal, AbstractTriangular}) Y = A / B - return Y, function(Ȳ) - Ā = Ȳ / B' - return (Ā, -Y' * Ā) + return Y, function(Ȳ) + Ā = Ȳ / B' + return (Ā, -Y' * Ā) + end +end + +@adjoint function \(A::AbstractMatrix, B::AbstractVecOrMat) + Z = A \ B + return Z, function(Z̄) + B̄ = A' \ Z̄ + if size(A, 1) == size(A, 2) + return (-B̄ * Z', B̄) + else + a = -B̄ * Z' + b = (B - A * Z) * B̄' / A' + c = A' \ Z * (Z̄' - B̄' * A) + return (a + b + c, B̄) + end + end +end + +function _forward(cx::AContext, ::typeof(norm), x::AbstractArray, p::Real = 2) + fallback = (x, p) -> sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 + _forward(cx, fallback, x, p) +end + +# LinAlg Matrix Types +# =================== + +@adjoint LinearAlgebra.LowerTriangular(A) = LowerTriangular(A), Δ->(LowerTriangular(Δ),) +@adjoint LinearAlgebra.UpperTriangular(A) = UpperTriangular(A), Δ->(UpperTriangular(Δ),) + +# This is basically a hack while we don't have a working `ldiv!`. +@adjoint function \(A::Cholesky, B::AbstractVecOrMat) + Y, back = Zygote.forward((U, B)->U \ (U' \ B), A.U, B) + return Y, function(Ȳ) + Ā_factors, B̄ = back(Ȳ) + return ((uplo=nothing, status=nothing, factors=Ā_factors), B̄) end end _symmetric_back(Δ) = UpperTriangular(Δ) + LowerTriangular(Δ)' - Diagonal(Δ) -_symmetric_back(Δ::UpperTriangular) = Δ +_symmetric_back(Δ::Union{Diagonal, UpperTriangular}) = Δ @adjoint function Symmetric(A::AbstractMatrix) back(Δ::AbstractMatrix) = (_symmetric_back(Δ),) back(Δ::NamedTuple) = (_symmetric_back(Δ.data),) return Symmetric(A), back end +@adjoint function cholesky(Σ::Real) + C = cholesky(Σ) + return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),) +end + +@adjoint function cholesky(Σ::Diagonal) + C = cholesky(Σ) + return C, Δ::NamedTuple->(Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)),) +end + # Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra." @adjoint function cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}) C = cholesky(Σ) - return C, function(Δ) - U, Ū = C.U, Δ.factors - Σ̄ = Ū * U' + return C, function(Δ::NamedTuple) + U, Ū = C.U, Δ.factors + Σ̄ = Ū * U' Σ̄ = copytri!(Σ̄, 'U') Σ̄ = ldiv!(U, Σ̄) BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) @@ -185,9 +357,36 @@ end end end -@adjoint function cholesky(Σ::Real) - C = cholesky(Σ) - return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),) +@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix) + X = lyap(A, C) + return X, function (X̄) + C̄ = lyap(collect(A'), X̄) + Ā = C̄*X' + C̄'*X + return (Ā, C̄) + end +end + +# Adjoint based on the Theano implementation, which uses the differential as described +# in Brančík, "Matlab programs for matrix exponential function derivative evaluation" +@adjoint exp(A::AbstractMatrix) = exp(A), function(F̄) + n = size(A, 1) + E = eigen(A) + w = E.values + ew = exp.(w) + X = [i==j ? ew[i] : (ew[i]-ew[j])/(w[i]-w[j]) for i in 1:n,j=1:n] + VT = transpose(E.vectors) + VTF = factorize(collect(VT)) + Ā = real.(VTF\(VT*F̄/VTF.*X)*VT) + (Ā, ) +end + +Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) + # x is a squre matrix checked by tr, + # so we could just use Eye(size(x, 1)) + # to create a Diagonal + tr(x), function (Δ::Number) + (Diagonal(Fill(Δ, (size(x, 1), ))), ) + end end # Various sensitivities for `literal_getproperty`, depending on the 2nd argument. @@ -223,3 +422,65 @@ end @adjoint function +(A::AbstractMatrix, S::UniformScaling) return A + S, Δ->(Δ, (λ=sum(view(Δ, diagind(Δ))),)) end + +@adjoint +(A::AbstractArray, B::AbstractArray) = A + B, Δ->(Δ, Δ) +@adjoint -(A::AbstractArray, B::AbstractArray) = A - B, Δ->(Δ, -Δ) +@adjoint -(A::AbstractArray) = -A, Δ->(-Δ,) + +# FFTW +# =================== + +# FFTW functions do not work with FillArrays, which are needed +# for some functionality of Zygote. To make it work with FillArrays +# as well, overload the relevant functions +FFTW.fft(x::Fill, dims...) = FFTW.fft(collect(x), dims...) +FFTW.ifft(x::Fill, dims...) = FFTW.ifft(collect(x), dims...) + + +# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the +# gradient of its inputs, but with different normalization factor +@adjoint function FFTW.fft(xs) + return FFTW.fft(xs), function(Δ) + N = length(xs) + return (N * FFTW.ifft(Δ),) + end +end + +@adjoint function FFTW.ifft(xs) + return FFTW.ifft(xs), function(Δ) + N = length(xs) + return (1/N* FFTW.fft(Δ),) + end +end + +@adjoint function FFTW.fft(xs, dims) + return FFTW.fft(xs, dims), function(Δ) + # dims can be int, array or tuple, + # convert to collection for use as index + dims = collect(dims) + # we need to multiply by all dimensions that we FFT over + N = prod(collect(size(xs))[dims]) + return (N * FFTW.ifft(Δ, dims), nothing) + end +end + +@adjoint function FFTW.ifft(xs,dims) + return FFTW.ifft(xs, dims), function(Δ) + # dims can be int, array or tuple, + # convert to collection for use as index + dims = collect(dims) + # we need to divide by all dimensions that we FFT over + N = prod(collect(size(xs))[dims]) + return (1/N * FFTW.fft(Δ, dims),nothing) + end +end + +# FillArray functionality +# ======================= + +@adjoint function broadcasted(op, r::AbstractFill{<:Real}) + y, _back = Zygote.forward(op, getindex_value(r)) + back(Δ::AbstractFill) = (nothing, Fill(_back(getindex_value(Δ))[1], size(r))) + back(Δ::AbstractArray) = (nothing, getindex.(_back.(Δ), 1)) + return Fill(y, size(r)), back +end diff --git a/src/lib/base.jl b/src/lib/base.jl index 61e82a477..91c2fca5d 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -40,7 +40,9 @@ end end @adjoint! function setindex!(d::AbstractDict, v, k) - setindex!(d, v, k), function (Δ) - (nothing, get(grad_mut(__context__, d), k, nothing), nothing) + setindex!(d, v, k), function (_) + Δ = get(grad_mut(__context__, d), k, nothing) + delete!(grad_mut(__context__, d), k) + (nothing, Δ, nothing) end end diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 601f7b5a6..7acb0a720 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -13,103 +13,153 @@ # `--' `" `--' `" `'-' using Base.Broadcast -using Base.Broadcast: Broadcasted, DefaultArrayStyle, instantiate +using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize +using NNlib -# Structural utilities +# There's a saying that debugging code is about twice as hard as writing it in +# the first place. So if you're as clever as you can be when writing code, how +# will you ever debug it? -using Base: tail +# AD faces a similar dilemma: if you write code that's as clever as the compiler +# can handle, how will you ever differentiate it? Differentiating makes clever +# code that bit more complex and the compiler gives up, usually resulting in +# 100x worse performance. -tcat(x) = x -tcat(x, y, z...) = tcat((x..., y...), z...) +# Base's broadcasting is very cleverly written, and this makes differentiating +# it... somewhat tricky. -broadcast_args(x) = (x,) -broadcast_args(bc::Broadcasted) = tcat(map(broadcast_args, bc.args)...) +# Utilities +# ========= -accum_sum(xs, dims = :) = reduce(accum, xs, dims = dims) +accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims) + +# Work around reducedim_init issue +accum_sum(xs::AbstractArray{Nothing}; dims = :) = nothing +accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims) +accum_sum(xs::Number; dims = :) = xs trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) +unbroadcast(x::AbstractArray, x̄) = + size(x) == size(x̄) ? x̄ : + length(x) == length(x̄) ? trim(x, x̄) : + trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) -unbroadcast(x::Union{Number,Ref}, Δ) = sum(Δ) +unbroadcast(x::Union{Number,Ref}, x̄) = accum_sum(x̄) -# Trivial Special Cases -# TODO fix this up and use it +# Split Reverse Mode +# ================== -Jtrivial(f, a...) = nothing -Jtrivial(::typeof(+), a...) = a -Jtrivial(::typeof(-), a, b) = (a..., .-b...) +# TODO: use DiffRules here. It's complicated a little by the fact that we need +# to do CSE, then broadcast-ify the expression so that the closure captures the +# right arrays. -trivia(_) = (1,) -function trivia(bc::Broadcasted) - t = map(trivia, bc.args) - any(t -> t === nothing, t) && return - Jtrivial(bc.f, t...) -end +Numeric{T<:Number} = Union{T,AbstractArray{<:T}} -Joutput(f, a...) = nothing -Joutput(::typeof(exp), x) = map(t -> y -> y*t, x) +@adjoint broadcasted(::typeof(+), xs::Numeric...) = + broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) -function Jbroadcast(bc::Broadcasted) - t = map(trivia, bc.args) - any(t -> t === nothing, t) && return - Joutput(bc.f, t...) -end +@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, + z̄ -> (nothing, unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))) -@inline function unbroadcast_t(x, y, ȳ, j::J) where J - trim(x, j.(y).*ȳ) +@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric) + res = x ./ y + res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y)) end -@inline function unbroadcast_t(x::Number, y, ȳ, j::J) where J - x̄ = zero(float(x)) - @simd for I in eachindex(y) - @inbounds x̄ += j(y[I])*ȳ[I] - end - return x̄ +@adjoint function broadcasted(::typeof(σ), x::Numeric) + y = σ.(x) + y, ȳ -> (nothing, ȳ .* conj.(y .* (1 .- y))) end -function ∇broadcast_t(bc::Broadcasted, J) - y = copy(instantiate(bc)) - back(ȳ) = map(unbroadcast_t, broadcast_args(bc), map(_ -> y, J), map(_ -> ȳ, J), J) - return y, back +@adjoint function broadcasted(::typeof(tanh), x::Numeric) + y = tanh.(x) + y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2)) end -# Reverse Mode +@adjoint broadcasted(::typeof(conj), x::Numeric) = + conj.(x), z̄ -> (nothing, conj.(z̄)) + +@adjoint broadcasted(::typeof(real), x::Numeric) = + real.(x), z̄ -> (nothing, real.(z̄)) + +@adjoint broadcasted(::typeof(imag), x::Numeric) = + imag.(x), z̄ -> (nothing, im .* real.(z̄)) + +# General Fallback +# ================ + +# The fused reverse mode implementation is the most general but currently has +# poor performance. It works by flattening the broadcast and mapping the call to +# `_forward` over the input. + +# However, the core call +# broadcast(_forward, (cx,), f, args...) +# is already 10x slower than a simple broadcast (presumably due to inlining +# issues, or something similar) and the other operations needed take it to about +# 100x overhead. + +@generated inclen(::NTuple{N,Any}) where N = Val(N+1) + +# Avoid hitting special cases for `Adjoint` etc. +_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...)) -# TODO: forward context appropriately -# multi-output map for better performance -function ∇broadcast_r(bc::Broadcasted) - bc′, unflatten = _forward(Broadcast.flatten, bc) - len = Val(length(bc′.args)+1) - y∂b = broadcast(_forward, bc′.f, bc′.args...) +@adjoint function broadcasted(::AbstractArrayStyle, f, args...) + len = inclen(args) + y∂b = _broadcast((x...) -> _forward(__context__, f, x...), args...) y = map(x -> x[1], y∂b) ∂b = map(x -> x[2], y∂b) y, function (ȳ) dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ) dxs = ntuple(i -> map(x -> x[i], dxs_zip), len) - (f = accum_sum(dxs[1]), - args = map(unbroadcast, bc′.args, Base.tail(dxs)), - axes = nothing) |> unflatten |> Base.tail + (nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...) end end -function ∇broadcast_r(bc::Broadcasted{<:DefaultArrayStyle{0}}) - bc′, unflatten = _forward(Broadcast.flatten, bc) - len = Val(length(bc′.args)+1) - y, ∂b = broadcast(_forward, bc′.f, bc′.args...) +@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...) + len = inclen(args) + y, ∂b = _broadcast((x...) -> _forward(__context__, f, x...), args...) y, function (ȳ) dxs = ∂b(ȳ) - (f = dxs[1], - args = Base.tail(dxs), - axes = nothing) |> unflatten |> Base.tail + (nothing, dxs...) end end -∇broadcast(bc::Broadcasted, ::Nothing) = ∇broadcast_r(bc) -∇broadcast(bc::Broadcasted, J) = ∇broadcast_t(bc, J) -∇broadcast(bc::Broadcasted) = ∇broadcast(bc, Jbroadcast(bc)) +@adjoint! (b::typeof(broadcast))(f, args...) = _forward(__context__, broadcasted, f, args...) -@adjoint Broadcast.materialize(bc::Broadcasted{<:DefaultArrayStyle}) = ∇broadcast_r(bc) +# Forward Mode (mainly necessary for CUDA) + +import ForwardDiff +using ForwardDiff: Dual + +dual(x, p) = x +dual(x::Real, p) = Dual(x, p) + +dualtype(::Type{Dual{G,T,P}}) where {G,T,P} = T +dualtype(T) = T + +function dual_function(f::F) where F + function (args::Vararg{Any,N}) where N + ds = map(args, ntuple(identity,Val(N))) do x, i + dual(x, ntuple(j -> i==j, Val(N))) + end + return f(ds...) + end +end + +@inline function broadcast_forward(f, args::Vararg{Any,N}) where N + T = Broadcast.combine_eltypes(f, args) + out = dual_function(f).(args...) + eltype(out) <: Dual || return (out, _ -> nothing) + y = map(x -> x.value, out) + _back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out)) + back(ȳ) = ntuple(i -> _back(ȳ, i), N) + return y, back +end + +@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin + @adjoint function broadcasted(::Broadcast.ArrayStyle{CuArrays.CuArray}, f, args...) + y, back = broadcast_forward(f, args...) + y, ȳ -> (nothing, nothing, back(ȳ)...) + end +end diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl new file mode 100644 index 000000000..58f3c323b --- /dev/null +++ b/src/lib/buffer.jl @@ -0,0 +1,91 @@ +""" + Buffer(xs, ...) + +`Buffer` is an array-like type which is mutable when taking gradients. You can +construct a `Buffer` with the same syntax as `similar` (e.g. `Buffer(xs, 5)`) +and then use normal indexing. Finally, use `copy` to get back a normal array. + +For example: + +```julia +julia> function vstack(xs) + buf = Buffer(xs, length(xs), 5) + for i = 1:5 + buf[:, i] = xs + end + return copy(buf) + end +vstack (generic function with 1 method) + +julia> vstack([1, 2, 3]) +3×5 Array{Int64,2}: + 1 1 1 1 1 + 2 2 2 2 2 + 3 3 3 3 3 + +julia> gradient(x -> sum(vstack(x)), [1, 2, 3]) +([5.0, 5.0, 5.0],) +``` + +`Buffer` is not an `AbstractArray` and can't be used for linear algebra +operations like matrix multiplication. This prevents it from being captured by +pullbacks. + +`copy` is a semantic copy, but does not allocate memory. Instead the `Buffer` +is made immutable after copying. +""" +mutable struct Buffer{T,A<:AbstractArray{T}} + data::A + freeze::Bool +end + +Buffer(xs::AbstractArray, args...) = + Buffer(similar(xs, args...), false) + +Base.getindex(b::Buffer, i...) = b.data[i...] + +function Base.setindex!(b::Buffer, v, i...) + b.freeze && error("Buffer is frozen") + b.data[i...] = v +end + +function Base.copy(b::Buffer) + b.freeze = true + return b.data +end + +@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes, Base.eachindex, Base.stride, Base.strides + +grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing) +grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0) + +@nograd Buffer + +@adjoint function getindex(b::Buffer, i...) + d[k], function (Δ) + grad = grad_mut(__context__, d) + grad[i...] = accum(grad[i...], Δ) + return + end +end + +@adjoint! function setindex!(b::Buffer, v, i...) + setindex!(b, v, i...), function (_) + grad = grad_mut(__context__, b) + v̄ = grad[i...] + zero = eltype(grad) <: Number ? 0 : nothing + if i isa NTuple{N,Integer} where N + grad[i...] = zero + else + grad[i...] .= zero + end + (nothing, v̄, map(_->nothing, i)...) + end +end + +@adjoint function copy(b::Buffer) + copy(b), function (b̄) + grad_mut(__context__, b)[:] = b̄ + return + end +end diff --git a/src/lib/complex.jl b/src/lib/complex.jl deleted file mode 100644 index 02aa43277..000000000 --- a/src/lib/complex.jl +++ /dev/null @@ -1,10 +0,0 @@ -@adjoint real(x::Complex) = real(x), r̄ -> (r̄ + zero(r̄)*im,) -@adjoint imag(x::Complex) = imag(x), ī -> (zero(ī) + ī*im,) - -# The adjoint of the map z -> g*z is given by y -> g' * y. -# Therefore, for holomorphic functions (for which the differential is given by a complex multiplication), -# the gradient map is given by a multiplication with the conjugate of the derivative (in the holomorphic sense) -@adjoint log(x::Complex) = log(x), ȳ -> (ȳ/conj(x),) -@adjoint exp(x::Complex) = exp(x), ȳ -> (ȳ*conj(exp(x)),) -@adjoint sin(x::Complex) = sin(x), ȳ -> (ȳ*conj(cos(x)),) -@adjoint cos(x::Complex) = cos(x), ȳ -> (-ȳ*conj(sin(x)),) diff --git a/src/lib/distances.jl b/src/lib/distances.jl index dcee2f9d4..c4e3825f7 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -15,18 +15,32 @@ end end end -@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix) - return pairwise(s, x, y), function(Δ) - x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * Δ') - ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ) - return nothing, x̄, ȳ +@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix, y::AbstractMatrix; dims::Int=2) + if dims==1 + return pairwise(s, x, y; dims=1), ∇pairwise(s, transpose(x), transpose(y), transpose) + else + return pairwise(s, x, y; dims=dims), ∇pairwise(s, x, y, identity) end end + +∇pairwise(s, x, y, f) = + function(Δ) + x̄ = 2 .* (x * Diagonal(vec(sum(Δ; dims=2))) .- y * transpose(Δ)) + ȳ = 2 .* (y * Diagonal(vec(sum(Δ; dims=1))) .- x * Δ) + return (nothing, f(x̄), f(ȳ)) + end -@adjoint function pairwise(s::SqEuclidean, X::AbstractMatrix) - D = pairwise(s, X) - return D, function(Δ) - d1, d2 = Diagonal(vec(sum(Δ; dims=1))), Diagonal(vec(sum(Δ; dims=2))) - return (nothing, X * (2 .* (d1 .+ d2 .- Δ .- Δ'))) +@adjoint function pairwise(s::SqEuclidean, x::AbstractMatrix; dims::Int=2) + if dims==1 + return pairwise(s, x; dims=1), ∇pairwise(s, transpose(x), transpose) + else + return pairwise(s, x; dims=dims), ∇pairwise(s, x, identity) end end + +∇pairwise(s, x, f) = + function(Δ) + d1 = Diagonal(vec(sum(Δ; dims=1))) + d2 = Diagonal(vec(sum(Δ; dims=2))) + return (nothing, x * (2 .* (d1 .+ d2 .- Δ .- transpose(Δ))) |> f) + end diff --git a/src/lib/forward.jl b/src/lib/forward.jl index 0d67c06ea..7a6e125ff 100644 --- a/src/lib/forward.jl +++ b/src/lib/forward.jl @@ -46,6 +46,55 @@ vec_scalar(x::Real) = [x] reshape_scalar(x, y) = reshape(y, size(x)) reshape_scalar(x::Real, y) = y[] +""" + forwarddiff(f, x) -> f(x) + +Runs `f(x)` as usual, but instructs Zygote to differentiate `f` using forward +mode, rather than the usual reverse mode. + +Forward mode takes time linear in `length(x)` but only has constant memory +overhead, and is very efficient for scalars, so in some cases this can be a +useful optimisation. + +```julia +julia> function pow(x, n) + r = one(x) + for i = 1:n + r *= x + end + return r + end +pow (generic function with 1 method) + +julia> gradient(5) do x + forwarddiff(x) do x + pow(x, 2) + end + end +(10,) +``` + +Note that the function `f` will *drop gradients* for any closed-over values. + +```julia +julia> gradient(2, 3) do a, b + forwarddiff(a) do a + a*b + end + end +(3, nothing) +``` + +This can be rewritten by explicitly passing through `b`, i.e. + +```julia +gradient(2, 3) do a, b + forwarddiff([a, b]) do (a, b) + a*b + end +end +``` +""" forwarddiff(f, x) = f(x) @adjoint function forwarddiff(f, x) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 6cadd4b34..796a91c41 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -1,64 +1,3 @@ -using MacroTools -using MacroTools: combinedef - -named(arg) = isexpr(arg, :(::)) && length(arg.args) == 1 ? :($(gensym())::$(arg.args[1])) : arg - -typeless(x) = MacroTools.prewalk(x -> isexpr(x, :(::)) ? x.args[1] : x, x) - -for n = 0:3 - gradtuple = Symbol(:gradtuple, n) - @eval begin - $gradtuple(x::Tuple) = ($(ntuple(_->:nothing,n)...), x...) - $gradtuple(x::Nothing) = nothing - $gradtuple(x) = error("Gradient $x should be a tuple") - end -end - -function adjoint end - -function gradm(ex, mut = false) - @capture(shortdef(ex), (name_(args__) = body_) | - (name_(args__) where {Ts__} = body_)) || error("Need a function definition") - kw = length(args) > 1 && isexpr(args[1], :parameters) ? esc(popfirst!(args)) : nothing - isclosure = isexpr(name, :(::)) && length(name.args) > 1 - f, T = isexpr(name, :(::)) ? - (length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) : - (esc(gensym()), :(Core.Typeof($(esc(name))))) - kT = :(Core.kwftype($T)) - Ts == nothing && (Ts = []) - args = esc.(named.(args)) - argnames = typeless.(args) - Ts = esc.(Ts) - cx = :($(esc(:__context__))::Context) - fargs = kw == nothing ? [cx, :($f::$T), args...] : [kw, cx, :($f::$T), args...] - gradtuple = isclosure ? gradtuple0 : gradtuple1 - gradtuplekw = isclosure ? gradtuple2 : gradtuple3 - quote - @inline Zygote.adjoint($(fargs...)) where $(Ts...) = $(esc(body)) - @inline function Zygote._forward($cx, $f::$T, $(args...)) where $(Ts...) - y, _back = adjoint(__context__, $f, $(argnames...)) - $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuple(_back(Δ)) - return y, back - end - @inline function Zygote._forward($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) - y, _back = adjoint(__context__, $f, $(argnames...); kw...) - $(mut ? nothing : :(back(::Nothing) = nothing)) - back(Δ) = $gradtuplekw(_back(Δ)) - return y, back - end - nothing - end -end - -macro adjoint(ex) - gradm(ex) -end - -macro adjoint!(ex) - gradm(ex, true) -end - macro nograd(ex) isexpr(ex, :tuple) || (ex = Expr(:tuple, ex)) blk = :(;) @@ -68,3 +7,8 @@ macro nograd(ex) end return blk end + +macro which(ex) + @capture(ex, f_(args__)) || error("Zygote.@which f(args...)") + :(InteractiveUtils.@which adjoint(Context(), $(esc(f)), $(esc.(args)...))) +end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index b727716e3..57285814a 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -1,3 +1,5 @@ +using Base: RefValue + # Interfaces accum() = nothing @@ -18,16 +20,24 @@ accum(x::AbstractArray, y::AbstractArray) = accum.(x, y) Expr(:tuple, [:($f=accum(x.$f, $(grad(f)))) for f in fieldnames(x)]...) end +function accum(x::RefValue, y::RefValue) + @assert x === y + return x +end + # Core functions @nograd Core.apply_type, Core.typeof, nfields, fieldtype, - (==), (===), (>=), (<), (>), isempty, supertype, Base.typename, Base.parameter_upper_bound + (==), (===), (>=), (<), (>), isempty, supertype, Base.typename, + Base.parameter_upper_bound, eps, Meta.parse, Base.eval + +@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,) @adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing @adjoint ifelse(cond::Bool, t, f) = ifelse(cond, t, f), - Δ -> cond ? (Δ, zero(Δ)) : (zero(Δ), Δ) + Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ) @adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing) @@ -60,11 +70,34 @@ end # Tuples +using Base: tail + @adjoint tuple(xs...) = xs, identity +literal_getindex(x, ::Val{i}) where i = getindex(x, i) +literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i) +literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state) + +@adjoint literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} = + (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) + @adjoint getindex(xs::NTuple{N,Any}, i::Integer) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) +function _forward(cx::Context, ::typeof(literal_indexed_iterate), xs::Tuple, ::Val{i}) where i + y, b = _forward(cx, literal_getindex, xs, Val(i)) + back(::Nothing) = nothing + back(ȳ) = b(ȳ[1]) + (y, i+1), back +end + +function _forward(cx::Context, ::typeof(literal_indexed_iterate), xs::Tuple, ::Val{i}, st) where i + y, b = _forward(cx, literal_getindex, xs, Val(i)) + back(::Nothing) = nothing + back(ȳ) = (b(ȳ[1])..., nothing) + (y, i+1), back +end + # Needed for iteration lowering @adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) @@ -94,7 +127,8 @@ unapply(t, xs) = _unapply(t, xs)[1] st = map(_empty, args) y, function (Δ) Δ = back(Δ) - (first(Δ), unapply(st, Base.tail(Δ))...) + Δ === nothing ? nothing : + (first(Δ), unapply(st, Base.tail(Δ))...) end end @@ -112,23 +146,6 @@ end @generated pair(::Val{k}, v) where k = :($k = v,) -# TODO make this inferrable -# Right now constant prop is too fragile ... -@adjoint function getfield(x, f::Symbol) - val = getfield(x, f) - unwrap(val), function (Δ) - accum_param(__context__, val, Δ) - if isimmutable(x) - ((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing) - else - dx = getfield(grad_mut(__context__, x), f) - dx[] = accum(dx[], Δ) - return - end - end -end - -# ... so we have Zygote call this version where we can. literal_getproperty(x, ::Val{f}) where f = getproperty(x, f) @adjoint function literal_getproperty(x, ::Val{f}) where f @@ -138,23 +155,32 @@ literal_getproperty(x, ::Val{f}) where f = getproperty(x, f) if isimmutable(x) ((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing) else - dx = getfield(grad_mut(__context__, x), f) - dx[] = accum(dx[], Δ) - return + dx = grad_mut(__context__, x) + dx[] = (;dx[]...,pair(Val(f),accum(getfield(dx[], f), Δ))...) + return (dx,nothing) end end unwrap(val), back end -@generated function grad_mut(x) - Expr(:tuple, [:($f = Ref{Any}(nothing)) for f in fieldnames(x)]...) -end +_forward(cx::Context, ::typeof(getproperty), x, f::Symbol) = + _forward(cx, literal_getproperty, x, Val(f)) + +_forward(cx::Context, ::typeof(getfield), x, f::Symbol) = + _forward(cx, literal_getproperty, x, Val(f)) + +_forward(cx::Context, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) where f = + _forward(cx, literal_getproperty, x, Val(f)) + +_forward(cx::Context, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = + _forward(cx, literal_getindex, x, Val(f)) + +grad_mut(x) = Ref{Any}(nt_nothing(x)) function grad_mut(cx::Context, x) - T = Core.Compiler.return_type(grad_mut, Tuple{typeof(x)}) ch = cache(cx) if haskey(ch, x) - ch[x]::T + ch[x] else ch[x] = grad_mut(x) end @@ -164,8 +190,8 @@ end y = setfield!(x, f, val) g = grad_mut(__context__, x) y, function (_) - r = getfield(g, f) - Δ = deref!(r) + Δ = getfield(g[], f) + g[] = (;g[]...,pair(Val(f),nothing)...) (nothing, nothing, Δ) end end @@ -203,14 +229,24 @@ end end # TODO captured mutables + multiple calls to `back` -@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing}) where {T,G} +@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G} !T.mutable && Δ == Nothing && return :nothing - Δ = G == Nothing ? :Δ : :(back.g) - :(nothing, $(map(f -> :(deref!($Δ.$f)), fieldnames(T))...)) + Δ = G == Nothing ? :Δ : :(back.g[]) + quote + x̄ = $Δ + $(G == Nothing || :($Δ = nt_nothing($Δ))) + (nothing, $(map(f -> :(x̄.$f), fieldnames(T))...)) + end end -@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing}) where {T,G} +@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G} !T.mutable && Δ == Nothing && return :nothing Δ = G == Nothing ? :Δ : :(back.g) - :(nothing, ($(map(f -> :(deref!($Δ.$f)), fieldnames(T))...),)) + quote + x̄ = $Δ + $(G == Nothing || :($Δ = nt_nothing($Δ))) + (nothing, ($(map(f -> :(x̄.$f), fieldnames(T))...),)) + end end + +(back::Jnew{T})(Δ) where T = error("Need an adjoint for constructor $T. Gradient is of type $(typeof(Δ))") diff --git a/src/lib/nnlib.jl b/src/lib/nnlib.jl index b46c3f53a..d4298ced9 100644 --- a/src/lib/nnlib.jl +++ b/src/lib/nnlib.jl @@ -1,22 +1,65 @@ using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, meanpool +import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, depthwiseconv, ∇conv_data, ∇depthwiseconv_data, maxpool, meanpool, σ + +@adjoint function σ(x::Real) + y = σ(x) + return y, Δ -> (Δ * y * (1 - y),) +end @adjoint softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),) @adjoint logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),) -@adjoint conv(x, w; kw...) = - conv(x, w; kw...), - Δ -> - (NNlib.∇conv_data(Δ, x, w; kw...), - NNlib.∇conv_filter(Δ, x, w; kw...)) +@adjoint NNlib.DenseConvDims(args...; kwargs...) = NNlib.DenseConvDims(args...; kwargs...), _ -> nothing +@adjoint NNlib.DepthwiseConvDims(args...; kwargs...) = NNlib.DepthwiseConvDims(args...; kwargs...), _ -> nothing +@adjoint NNlib.PoolDims(args...; kwargs...) = NNlib.PoolDims(args...; kwargs...), _ -> nothing + +@adjoint conv(x, w, cdims; kw...) = + conv(x, w, cdims; kw...), + Δ -> begin + return ( + NNlib.∇conv_data(Δ, w, cdims; kw...), + NNlib.∇conv_filter(x, Δ, cdims; kw...), + nothing, + ) + end + +@adjoint ∇conv_data(x, w, cdims; kw...) = + ∇conv_data(x, w, cdims; kw...), + Δ -> begin + return ( + NNlib.conv(Δ, w, cdims; kw...), + NNlib.∇conv_filter(Δ, x, cdims; kw...), + nothing, + ) + end + +@adjoint depthwiseconv(x, w, cdims; kw...) = + depthwiseconv(x, w, cdims; kw...), + Δ -> begin + return ( + NNlib.∇depthwiseconv_data(Δ, w, cdims; kw...), + NNlib.∇depthwiseconv_filter(x, Δ, cdims; kw...), + nothing, + ) + end + +@adjoint ∇depthwiseconv_data(x, w, cdims; kw...) = + ∇depthwiseconv_data(x, w, cdims; kw...), + Δ -> begin + return ( + NNlib.depthwiseconv(Δ, w, cdims; kw...), + NNlib.∇depthwiseconv_filter(Δ, x, cdims; kw...), + nothing, + ) + end -@adjoint function maxpool(x, k; kw...) - y = maxpool(x, k; kw...) - y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing) +@adjoint function maxpool(x, pdims; kw...) + y = maxpool(x, pdims; kw...) + y, Δ -> (NNlib.∇maxpool(Δ, y, x, pdims; kw...), nothing) end -@adjoint function meanpool(x, k; kw...) - y = meanpool(x, k; kw...) - y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing) +@adjoint function meanpool(x, pdims; kw...) + y = meanpool(x, pdims; kw...) + y, Δ -> (NNlib.∇meanpool(Δ, y, x, pdims; kw...), nothing) end diff --git a/src/lib/number.jl b/src/lib/number.jl new file mode 100644 index 000000000..be83e9bd4 --- /dev/null +++ b/src/lib/number.jl @@ -0,0 +1,60 @@ +using DiffRules, SpecialFunctions, NaNMath + +@nograd isinf, isnan, isfinite, div + +# TODO use CSE here + +for (M, f, arity) in DiffRules.diffrules() + arity == 1 || continue + Δ = :Δ + dx = DiffRules.diffrule(M, f, :x) + if f in [:abs, :abs2] + Δ = :(real($Δ)) + else + dx = :(conj($dx)) + end + @eval begin + @adjoint $M.$f(x::Number) = $M.$f(x), + Δ -> ($Δ * $dx,) + end +end + +for (M, f, arity) in DiffRules.diffrules() + arity == 2 || continue + da, db = DiffRules.diffrule(M, f, :a, :b) + @eval begin + @adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b), + Δ -> (Δ * conj($da), Δ * conj($db)) + end +end + +@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ) +@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ) + +for T in Base.uniontypes(Core.BuiltinInts) + @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) +end + +@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs) + +@adjoint Base.muladd(x::Number, y::Number, z::Number) = + Base.muladd(x, y, z), ō -> (y'ō, x'ō, ō) + +@adjoint function sincos(x) + s, c = sincos(x) + (s, c), ((s̄, c̄),) -> (s̄*c - c̄*s,) +end + +@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b)) + +@nograd floor, ceil, trunc, round, hash + +# Complex Numbers + +@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄)) + +@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),) +@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) +@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) + +DiffRules._abs_deriv(x::Complex) = x/abs(x) diff --git a/src/lib/real.jl b/src/lib/real.jl deleted file mode 100644 index ef3c9c453..000000000 --- a/src/lib/real.jl +++ /dev/null @@ -1,35 +0,0 @@ -using DiffRules, SpecialFunctions, NaNMath - -for (M, f, arity) in DiffRules.diffrules() - arity == 1 || continue - @eval begin - @adjoint $M.$f(x::Real) = $M.$f(x), - Δ -> (Δ * $(DiffRules.diffrule(M, f, :x)),) - end -end - -for (M, f, arity) in DiffRules.diffrules() - arity == 2 || continue - da, db = DiffRules.diffrule(M, f, :a, :b) - @eval begin - @adjoint $M.$f(a::Real, b::Real) = $M.$f(a, b), - Δ -> (Δ * $da, Δ * $db) - end -end - -@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), Δ -> (nothing, Δ) - -for T in Base.uniontypes(Core.BuiltinInts) - @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) -end - -@adjoint Base.:+(xs...) = +(xs...), Δ -> map(_ -> Δ, xs) - -@adjoint function sincos(x) - s, c = sincos(x) - (s, c), ((s̄, c̄),) -> (s̄*c - c̄*s,) -end - -@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b)) - -@nograd floor, ceil, trunc, round, hash diff --git a/src/lib/statsfuns.jl b/src/lib/statsfuns.jl index 929a24846..40c7fddf4 100644 --- a/src/lib/statsfuns.jl +++ b/src/lib/statsfuns.jl @@ -1,11 +1,11 @@ import .StatsFuns -using StatsFuns: xlogx, logistic, logit, log1psq, log1pexp, logsumexp +using .StatsFuns: xlogx, logistic, logit, log1psq, log1pexp, logsumexp @adjoint function xlogx(x::Real) y = xlogx(x) return y, function(Δ::Real) return (x > zero(x) ? Δ * (log(x) + one(y)) : zero(y),) - end + end end @adjoint function logistic(x::Real) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index add8c7a4e..aa42be3ad 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -1,10 +1,81 @@ +""" + dropgrad(x) -> x + +Drop the gradient of `x`. + + julia> gradient(2, 3) do a, b + dropgrad(a)*b + end + (nothing, 2) +""" +dropgrad(x) = x +@adjoint dropgrad(x) = dropgrad(x), _ -> nothing + +""" + hook(x̄ -> ..., x) -> x + +Gradient hooks. Allows you to apply an arbitrary function to the gradient for +`x`. + + julia> gradient(2, 3) do a, b + hook(ā -> @show(ā), a)*b + end + ā = 3 + (3, 2) + + julia> gradient(2, 3) do a, b + hook(-, a)*b + end + (-3, 2) +""" hook(f, x) = x @adjoint! hook(f, x) = x, x̄ -> (nothing, f(x̄),) +""" + @showgrad(x) -> x + +Much like `@show`, but shows the gradient about to accumulate to `x`. Useful for +debugging gradients. + + julia> gradient(2, 3) do a, b + @showgrad(a)*b + end + ∂(a) = 3 + (3, 2) + +Note that the gradient depends on how the output of `@showgrad` is *used*, and is +not the *overall* gradient of the variable `a`. For example: + + julia> gradient(2) do a + @showgrad(a)*a + end + ∂(a) = 2 + (4,) + + julia> gradient(2, 3) do a, b + @showgrad(a) # not used, so no gradient + a*b + end + ∂(a) = nothing + (3, 2) +""" macro showgrad(x) :(hook($(esc(x))) do x̄ - println($"D($x) = ", repr(x̄)) + println($"∂($x) = ", repr(x̄)) x̄ end) end + +""" + hessian(f, x) + +Construct the Hessian of `f`, where `x` is a real or real array and `f(x)` is +a real. + + julia> hessian(((a, b),) -> a*b, [2, 3]) + 2×2 Array{Int64,2}: + 0 1 + 1 0 +""" +hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] diff --git a/src/profiler/Profile.jl b/src/profiler/Profile.jl index e968e16d7..6fa6b14dd 100644 --- a/src/profiler/Profile.jl +++ b/src/profiler/Profile.jl @@ -69,7 +69,7 @@ end function profile(x::Pullback{T}, seen) where T ls = [] for (c, l) in zip(x.t, stacklines(T)) - c isa Vector{<:Integer} && continue + c isa Union{Integer,Vector{<:Integer}} && continue cs = c isa Vector ? merge(vcat(map(x -> profile(x, seen),c)...)) : profile(c, seen) push!(ls, Node(loc(x)[1],String(l.file),l.line,cs)) end diff --git a/src/tools/fillarray.jl b/src/tools/fillarray.jl deleted file mode 100644 index c1d6e2072..000000000 --- a/src/tools/fillarray.jl +++ /dev/null @@ -1,8 +0,0 @@ -struct FillArray{T,N} <: AbstractArray{T,N} - value::T - size::NTuple{N,Int} -end - -Base.size(xs::FillArray) = xs.size - -Base.getindex(xs::FillArray, ::Int...) = xs.value diff --git a/src/tools/idset.jl b/src/tools/idset.jl index 2e6b38b3e..d9f0ceb04 100644 --- a/src/tools/idset.jl +++ b/src/tools/idset.jl @@ -11,7 +11,7 @@ Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s) Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s) Base.in(x, s::IdSet) = haskey(s.dict, x) -(::Type{IdSet{T}})(xs) where T = push!(IdSet{T}(), xs...) +IdSet{T}(xs) where T = push!(IdSet{T}(), xs...) IdSet(xs) = IdSet{eltype(xs)}(xs) diff --git a/src/tools/ir.jl b/src/tools/ir.jl deleted file mode 100644 index 7afe68ef8..000000000 --- a/src/tools/ir.jl +++ /dev/null @@ -1,115 +0,0 @@ -import Core: SSAValue, GotoNode, Compiler -import Core: Typeof -import Core.Compiler: CodeInfo, IRCode, CFG, BasicBlock, Argument, ReturnNode, - just_construct_ssa, compact!, NewNode, InferenceState, OptimizationState, - GotoIfNot, PhiNode, PiNode, StmtRange, IncrementalCompact, insert_node!, insert_node_here!, - compact!, finish, DomTree, construct_domtree, dominates, userefs, widenconst, types, verify_ir -using InteractiveUtils: typesof - -function afterphi(ir, loc) - if isa(ir.stmts[loc], PhiNode) - if isa(ir.stmts[loc+1], PhiNode) - return afterphi(ir, loc+1) - end - return (loc, true) - end - return (loc, false) -end - -function insert_blockstart!(ir::IRCode, pos, typ, val) - (loc, attach_after) = afterphi(ir, ir.cfg.blocks[pos].stmts[1]) - insert_node!(ir, loc, typ, val, attach_after) -end - -function insert_blockend!(ir::IRCode, pos, typ, val) - i = first(ir.cfg.blocks[pos].stmts) - j = last(ir.cfg.blocks[pos].stmts) - if !(ir.stmts[j] isa Union{GotoNode,GotoIfNot,ReturnNode}) - return insert_node!(ir, j, typ, val, true) - end - while i < j && !(ir.stmts[i] isa Union{GotoNode,GotoIfNot,ReturnNode}) - i += 1 - end - insert_node!(ir, i, typ, val) -end - -function finish_dc(ic::IncrementalCompact) - Compiler.non_dce_finish!(ic) - return Compiler.complete(ic) -end - -function _compact!(code::IRCode) - compact = IncrementalCompact(code) - foreach(x -> nothing, compact) - return finish_dc(compact), compact.ssa_rename -end - -function argmap(f, @nospecialize(stmt)) - urs = userefs(stmt) - for op in urs - val = op[] - if isa(val, Argument) - op[] = f(val) - end - end - return urs[] -end - -exprtype(ir::IRCode, x::Argument) = widenconst(ir.argtypes[x.n]) -exprtype(ir::IRCode, x::SSAValue) = widenconst(types(ir)[x]) -exprtype(ir::IRCode, x::GlobalRef) = isconst(x.mod, x.name) ? Typeof(getfield(x.mod, x.name)) : Any -exprtype(ir::IRCode, x::QuoteNode) = Typeof(x.value) -# probably can fall back to any here -exprtype(ir::IRCode, x::Union{Type,Number,Nothing,Tuple,Function,Val,String,Char,Module}) = Typeof(x) -exprtype(ir::IRCode, x::Expr) = error(x) - -rename(x, m) = x -rename(x::SSAValue, m) = m[x.id] -rename(xs::AbstractVector, m) = map(x -> rename(x, m), xs) -rename(xs::Tuple, m) = map(x -> rename(x, m), xs) -rename(xs::AbstractSet, m) = Set(rename(x, m) for x in xs) -rename(d::AbstractDict, m) = Dict(k => rename(v, m) for (k, v) in d) - -function usages(ir) - us = Dict() - for i = 1:length(ir.stmts), u in userefs(ir.stmts[i]) - push!(get!(us, u[], []), SSAValue(i)) - end - return us -end - -function blockidx(ir, i::Integer) - i = findlast(x -> x <= i, ir.cfg.index) - i == nothing ? 1 : i+1 -end - -blockidx(ir, i::SSAValue) = blockidx(ir, i.id) - -if VERSION > v"1.1.0-DEV.560" - Base.range(b::BasicBlock) = b.stmts.start:b.stmts.stop -else - Base.range(b::BasicBlock) = b.stmts.first:b.stmts.last -end - -xcall(mod::Module, f::Symbol, args...) = Expr(:call, GlobalRef(mod, f), args...) -xcall(f::Symbol, args...) = xcall(Base, f, args...) - -const unreachable = ReturnNode() - -# Dominance frontiers - -function domfront(cfg, dt = construct_domtree(cfg)) - fronts = [Set{Int}() for _ in cfg.blocks] - for b = 1:length(cfg.blocks) - length(cfg.blocks[b].preds) >= 2 || continue - for p in cfg.blocks[b].preds - runner = p - while runner != dt.idoms[b] - runner == b && break - push!(fronts[runner], b) - runner = dt.idoms[runner] - end - end - end - return fronts -end diff --git a/src/tools/reflection.jl b/src/tools/reflection.jl deleted file mode 100644 index c7e684278..000000000 --- a/src/tools/reflection.jl +++ /dev/null @@ -1,105 +0,0 @@ -meta(T) = (usetyped ? IRTools.typed_meta : IRTools.meta)(T) - -function code_ir(f, T) - m = meta(Tuple{Typeof(f),T.parameters...}) - return IRCode(m) -end - -function code_irm(ex) - isexpr(ex, :call) || error("@code_ir f(args...)") - f, args = ex.args[1], ex.args[2:end] - :(code_ir($(esc(f)), typesof($(esc.(args)...)))) -end - -macro code_ir(ex) - code_irm(ex) -end - -function argnames!(meta, names...) - meta.code.slotnames = [names...] - meta.code.slotflags = [0x00 for name in names] -end - -function spliceargs!(meta, ir::IRCode, args...) - for i = 1:length(ir.stmts) - ir[SSAValue(i)] = argmap(x -> Argument(x.n+length(args)), ir[SSAValue(i)]) - end - for (name, T) in reverse(args) - pushfirst!(ir.argtypes, T) - pushfirst!(meta.code.slotnames, name) - end - return ir -end - -# Behave as if the function signature is f(args...) -function varargs!(meta, ir::IRCode, n = 1) - isva = meta.method.isva - Ts = widenconst.(ir.argtypes[n+1:end]) - argtypes = !isva ? - Any[ir.argtypes[1:n]..., Tuple{Ts...}] : - Any[ir.argtypes[1:n]..., Tuple{Ts[1:end-1]...,Ts[end].parameters...}] - empty!(ir.argtypes); append!(ir.argtypes, argtypes) - ir = IncrementalCompact(ir) - map = Dict{Argument,Any}() - for i = 1:(length(Ts)-isva) - map[Argument(i+n)] = insert_node_here!(ir, xcall(Base, :getfield, Argument(n+1), i), Ts[i], Int32(0)) - end - if isva - i = length(Ts) - xs, T = Argument(n+1), argtypes[end] - for _ = 1:i-1 - T = Tuple{T.parameters[2:end]...} - xs = insert_node_here!(ir, xcall(Base, :tail, xs), T, Int32(0)) - end - map[Argument(i+n)] = xs - end - for (i, x) in ir - ir[i] = argmap(a -> get(map, a, a), x) - end - return finish_dc(ir) -end - -function pis!(ir::IRCode) - for i = 1:length(ir.stmts) - ex = ir.stmts[i] - ex isa PiNode || continue - ir.stmts[i] = xcall(Core, :typeassert, ex.val, ex.typ) - end - return ir -end - -function slots!(ir::IRCode) - n = 0 - for b = 1:length(ir.cfg.blocks) - i = first(ir.cfg.blocks[b].stmts) - while (phi = ir[SSAValue(i)]) isa PhiNode - slot = IRTools.Slot(Symbol(:phi, n+=1)) - ir[SSAValue(i)] = slot - for (pred, val) in zip(phi.edges, phi.values) - insert_blockend!(ir, pred, Any, :($slot = $val)) - end - i += 1 - end - end - return compact!(ir) -end - -@generated function roundtrip(f, args...) - m = meta(Tuple{f,args...}) - ir = IRCode(m) - ir = varargs!(m, ir) - argnames!(m, :f, :args) - ir = spliceargs!(m, ir, (Symbol("#self#"), typeof(roundtrip))) - ir = slots!(pis!(ir)) - return IRTools.update!(m, ir) -end - -function inlineable!(ir) - insert_node!(ir, 1, Any, Expr(:meta, :inline)) - compact!(ir) -end - -function log!(ir, msg) - insert_node!(ir, 1, Any, xcall(Core, :println, msg)) - compact!(ir) -end diff --git a/test/compiler.jl b/test/compiler.jl index 8604c761c..00ea2bbeb 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -71,11 +71,11 @@ getx(x) = x.x y, back = @test_inferred forward(getx, (x=1,y=2.0)) @test_inferred back(1) -# TODO -# MRP: -# foo(f) = Ref((f,)) -# @code_typed foo(Complex) -# @test_inferred forward(Complex, 1, 2) +y, back = @test_inferred forward(x->x[1], (5,:a)) +@test_inferred back(1) + +y, back = @test_inferred forward(((a,b),) -> a, (5, 10)) +@test_inferred back(1) # Checks that use control flow if Zygote.usetyped diff --git a/test/complex.jl b/test/complex.jl new file mode 100644 index 000000000..40f7930b9 --- /dev/null +++ b/test/complex.jl @@ -0,0 +1,73 @@ +using Zygote, Test + +@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1 +@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0 +@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ -1im +@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im + +@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im +@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im +@test gradient(a -> real(([a].*conj([a])))[], 0.3im)[1] == 0.6im +@test gradient(a -> real(([a].*conj.([a])))[], 0.3im)[1] == 0.6im +@test gradient(a -> real.(([a].*conj.([a])))[], 0.3im)[1] == 0.6im + +fs_C_to_R = (real, + imag, + abs, + abs2, + z -> abs(z)*cos(im*angle(z)), + z->abs(cos(exp(z))), + z->3*real(z)^3-2*imag(z)^5 + ) +@testset "C->R" begin + for f in fs_C_to_R + for z in (1.0+2.0im, -2.0+pi*im) + grad_zygote = gradient(real∘f, z)[1] + ε = 1e-8 + grad_fd = (f(z+ε)-f(z))/ε + im*(f(z+ε*im)-f(z))/ε + @test abs(grad_zygote - grad_fd) < sqrt(ε) + end + end +end + +fs_C_to_C_holomorphic = (cos, + exp, + log, + z->z^2, + z->(real(z)+im*imag(z))^2, + z->real(z)^2 - imag(z)^2 +2im*(real(z)*imag(z)), + z->exp(cos(log(z))), + z->abs(z)*exp(im*angle(z)), + ) +@testset "C->C holomorphic" begin + for f in fs_C_to_C_holomorphic + for z in (1.0+2.0im, -2.0+pi*im) + grad_zygote = gradient(real∘f, z)[1] + ε = 1e-8 + grad_fd_r = (f(z+ε)-f(z))/ε + grad_fd_i = (f(z+ε*im)-f(z))/(ε*im) + @assert abs(grad_fd_r - grad_fd_i) < sqrt(ε) # check the function is indeed holomorphic + @test abs(grad_zygote - conj(grad_fd_r)) < sqrt(ε) + end + end +end + + +fs_C_to_C_non_holomorphic = (conj, + z->abs(z)+0im, + z->im*abs(z), + z->abs2(z)+0im, + z->im*abs2(z), + z->z'z, + z->conj(z)*z^2, + ) +@testset "C->C non-holomorphic" begin + for f in (fs_C_to_C_holomorphic...,fs_C_to_C_holomorphic...) + for z in (1.0+2.0im, -2.0+pi*im) + grad_zygote = gradient(real∘f, z)[1] + ε = 1e-8 + grad_fd = real(f(z+ε)-f(z))/ε + im*real(f(z+ε*im)-f(z))/ε + @test abs(grad_zygote - grad_fd) < sqrt(ε) + end + end +end diff --git a/test/features.jl b/test/features.jl index 5e931ed79..76c23100c 100644 --- a/test/features.jl +++ b/test/features.jl @@ -1,15 +1,10 @@ using Zygote, Test -using Zygote: Params, gradient, derivative, roundtrip, forwarddiff +using Zygote: Params, gradient, forwarddiff add(a, b) = a+b _relu(x) = x > 0 ? x : 0 f(a, b...) = +(a, b...) -@test roundtrip(add, 1, 2) == 3 -@test roundtrip(_relu, 1) == 1 -@test roundtrip(Complex, 1, 2) == 1+2im -@test roundtrip(f, 1, 2, 3) == 6 - y, back = forward(identity, 1) dx = back(2) @test y == 1 @@ -155,11 +150,11 @@ end D(f, x) = grad(f, x)[1] @test D(x -> D(sin, x), 0.5) == -sin(0.5) +@test D(x -> x*D(y -> x+y, 1), 1) == 1 +@test D(x -> x*D(y -> x*y, 1), 4) == 8 -if VERSION > v"1.2-" - @test D(x -> x*D(y -> x+y, 1), 1) == 1 - @test D(x -> x*D(y -> x*y, 1), 4) == 8 - @test_broken sin'''(1.0) == -cos(1.0) +if VERSION >= v"1.1" + @test sin'''(1.0) == -cos(1.0) end f(x) = throw(DimensionMismatch("fubar")) @@ -187,32 +182,32 @@ y, back = forward(() -> layer(x), Params([W])) @test gradient(() -> sum(W * x), Params([W]))[W] == [1 2; 1 2] -@test derivative(2) do x +@test gradient(2) do x H = [1 x; 3 4] sum(H) -end == 1 +end[1] == 1 # FIXME if !Zygote.usetyped - @test derivative(2) do x + @test gradient(2) do x if x < 0 throw("foo") end return x*5 - end == 5 + end[1] == 5 - @test derivative(x -> one(eltype(x)), rand(10)) == nothing + @test gradient(x -> one(eltype(x)), rand(10))[1] == nothing end # Thre-way control flow merge -@test derivative(1) do x +@test gradient(1) do x if x > 0 x *= 2 elseif x < 0 x *= 3 end x -end == 2 +end[1] == 2 # Gradient of closure grad_closure(x) = 2x @@ -240,23 +235,19 @@ function f(x) end end -if VERSION >= v"1.1" - @test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint -end +@test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint @test_throws ErrorException Zygote.gradient(1) do x push!([], x) return x end -if VERSION >= v"1.1" - @test gradient(1) do x - stk = [] - Zygote._push!(stk, x) - stk = Zygote.Stack(stk) - pop!(stk) - end == (1,) -end +@test gradient(1) do x + stk = [] + Zygote._push!(stk, x) + stk = Zygote.Stack(stk) + pop!(stk) +end == (1,) @test gradient(x -> [x][1].a, Foo(1, 1)) == ((a=1, b=nothing),) @@ -283,3 +274,23 @@ global_param = 3 @test back(1) == (nothing, 3) Zygote.globals(cx)[GlobalRef(Main, :global_param)] == 2 end + +function pow_try(x) + try + 2x + catch e + println("error") + end +end + +@test_broken gradient(pow_try, 1) == (2,) + +function pow_simd(x, n) + r = 1 + @simd for i = 1:n + r *= x + end + return r +end + +@test_broken gradient(pow_simd, 2, 3) == (12,nothing) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 1605f0226..e5a35e825 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,7 +1,7 @@ -using Zygote, NNlib, Test, Random, LinearAlgebra +using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays, FFTW using Zygote: gradient -using NNlib: conv -import Random +using NNlib: conv, ∇conv_data, depthwiseconv +using Base.Broadcast: broadcast_shape function ngradient(f, xs::AbstractArray...) grads = zero.(xs) @@ -38,6 +38,8 @@ Random.seed!(0) @test gradtest((w, x) -> transpose(w)*x, randn(5,5), randn(5,5)) @test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5)) +@test gradtest(x -> sum(abs2, x), randn(4, 3, 2)) +@test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2)) @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) @test gradtest(x -> prod(x), (3,4,5)) @@ -46,22 +48,59 @@ Random.seed!(0) @test gradtest(x -> logsoftmax(x).*(1:3), 3) @test gradtest(x -> logsoftmax(x).*(1:3), (3,5)) -@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2)) -@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2)) -@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2)) - -@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2)) -@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2)) +@test gradtest(x -> x', rand(5)) + +@test gradtest(det, (4, 4)) +@test gradtest(logdet, map(x -> x*x', (rand(4, 4),))[1]) +@test gradtest(x -> logabsdet(x)[1], (4, 4)) + +@test gradtest(x -> view(x,:,2,:), (3,4,5)) +@test gradtest(x -> view(x,1:2,3:4), (3,4)) + +@testset "conv" begin + for spatial_rank in (1, 2, 3) + x = rand(repeat([10], spatial_rank)..., 3, 2) + w = rand(repeat([3], spatial_rank)..., 3, 3) + cdims = DenseConvDims(x, w) + @test gradtest((x, w) -> conv(x, w, cdims), x, w) + y = conv(x, w, cdims) + @test gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w) + dcdims = DepthwiseConvDims(x, w) + @test gradtest((x, w) -> depthwiseconv(x, w, dcdims), x, w) + end +end -@test gradtest(x -> meanpool(x, (2,2)), rand(10, 10, 3, 2)) -@test gradtest(x -> meanpool(x, (2,2,2)), rand(5, 5, 5, 3, 2)) +@testset "pooling" begin + for spatial_rank in (1, 2) + x = rand(repeat([10], spatial_rank)..., 3, 2) + pdims = PoolDims(x, 2) + @test gradtest(x -> maxpool(x, pdims), x) + @test gradtest(x -> meanpool(x, pdims), x) + end +end +@test gradtest(x -> permutedims(x), rand(2)) +@test gradtest(x -> permutedims(x), rand(2,3)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) +@test gradtest(x -> PermutedDimsArray(x, (3,1,2)), rand(4,5,6)) +let + y, back = Zygote.forward(permutedims, randn(3)) + @test first(back(randn(1, 3))) isa Vector +end @test gradtest(x -> repeat(x; inner=2), rand(5)) @test gradtest(x -> repeat(x; inner=2, outer=3), rand(5)) @test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3)) +@test gradtest(tr, rand(4, 4)) + +@testset "fill" begin + rng, N, M, P = MersenneTwister(123456), 11, 6, 5 + @test gradtest(x->fill(first(x), N), randn(rng, 1)) + @test gradtest(x->fill(first(x), N, M), randn(rng, 1)) + @test gradtest(x->fill(first(x), N, M, P), randn(rng, 1)) +end + @testset "dot" begin rng = MersenneTwister(123456) @test gradtest((x, y)->dot(x[1], y[1]), [randn(rng)], [randn(rng)]) @@ -75,6 +114,27 @@ end @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) +@testset "map" begin + @test gradtest(xs -> sum(map(x -> x^2, xs)), rand(2,3)) + @test gradtest((xss...) -> sum(map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) + function foo(y) + bar = (x) -> x*y + sum(map(bar, 1:5)) + end + @test gradtest(foo, 3) + @test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) +end + +@testset "mean" begin + @test gradtest(mean, rand(2, 3)) + + @test gradtest(x -> mean(x, dims=1), rand(2, 3)) + @test gradtest(x -> mean(x, dims=2), rand(2, 3)) + @test gradtest(x -> mean(x, dims=3), rand(2, 3, 4)) + + @test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4)) +end + @testset "maximum" begin @test gradtest(maximum, rand(2, 3)) @@ -83,6 +143,8 @@ end @test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4)) @test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) + + @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] end @testset "minimum" begin @@ -95,21 +157,99 @@ end @test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4)) end +@testset "(p)inv" begin + rng, P, Q = MersenneTwister(123456), 13, 11 + A, B, C = randn(rng, P, Q), randn(rng, P, P), randn(Q, P) + @test gradtest(pinv, A) + @test gradtest(inv, B) + @test gradtest(pinv, C) +end + +@testset "multiplication" begin + @testset "matrix-matrix" begin + rng, M, P, Q = MersenneTwister(123456), 13, 7, 11 + @test gradtest(*, randn(rng, M, P), randn(rng, P, Q)) + @test gradtest(*, randn(rng, M, P), randn(rng, P)) + @test gradtest(*, randn(rng, M, 1), randn(rng, 1, Q)) + @test gradtest(*, randn(rng, M), randn(rng, 1, Q)) + @test gradtest(*, randn(rng, 10)', randn(rng, 10)) + + let + y, back = Zygote.forward(*, randn(rng, M, P), randn(rng, P)) + @test last(back(randn(rng, M))) isa Vector + end + let + y, back = Zygote.forward(*, randn(rng, M), randn(rng, 1, P)) + @test first(back(randn(rng, M, P))) isa Vector + end + end +end + @testset "backsolve" begin - rng, P, Q = MersenneTwister(123456), 10, 9 + rng, M, P, Q = MersenneTwister(123456), 13, 10, 9 X, Y, y = randn(rng, P, P), randn(rng, P, Q), randn(rng, P) - - # \ - @test gradtest(X -> X \ Y, X) - @test gradtest(Y -> X \ Y, Y) - @test gradtest(X -> X \ y, X) - @test gradtest(y -> X \ y, y) + A, B = randn(rng, P, M), randn(P, Q) + D = collect(Diagonal(randn(rng, P))) + L = collect(LowerTriangular(randn(rng, P, P))) + L[diagind(L)] .= 1 .+ 0.01 .* randn(rng, P) + U = collect(UpperTriangular(randn(rng, P, P))) + U[diagind(U)] .= 1 .+ 0.01 .* randn(rng, P) + + # \ (Dense square) + @test gradtest(\, X, Y) + @test gradtest(\, X, y) + + # \ (Dense rectangular) + @test gradtest(\, A, Y) + @test gradtest(\, A, y) + @test gradtest(\, B, Y) + @test gradtest(\, B, y) + + # \ (Diagonal) + @test gradtest(\, D, Y) + @test gradtest(\, D, y) + @test gradtest((D, Y)-> Diagonal(D) \ Y, D, Y) + @test gradtest((D, Y)-> Diagonal(D) \ Y, D, y) + + # \ (LowerTriangular) + @test gradtest(\, L, Y) + @test gradtest(\, L, y) + @test gradtest((L, Y) -> LowerTriangular(L) \ Y, L, Y) + @test gradtest((L, Y) -> LowerTriangular(L) \ Y, L, y) + + # \ (UpperTriangular) + @test gradtest(\, U, Y) + @test gradtest(\, U, y) + @test gradtest((U, Y) -> UpperTriangular(U) \ Y, U, Y) + @test gradtest((U, Y) -> UpperTriangular(U) \ Y, U, y) # / - @test gradtest(X -> Y' / X, X) - @test gradtest(Y -> Y' / X, Y) - @test gradtest(X -> y' / X, X) - @test gradtest(y -> y' / X, y) + @test gradtest(/, Y', X) + @test gradtest((y, X)->y' / X, y, X) + + # / (rectangular) + @test gradtest(/, Y', A') + @test gradtest((y, A)->y' / A', y, A) + @test gradtest(/, Y', B') + @test gradtest((y, A)->y' / A', y, B) + + # / (Diagonal) + @test gradtest((D, Y) -> Y' / D, D, Y) + @test gradtest((D, Y) -> Y' / D, D, y) + @test gradtest((D, Y)-> Y' / Diagonal(D), D, Y) + @test gradtest((D, Y)-> Y' / Diagonal(D), D, y) + + # / (LowerTriangular) + @test gradtest((L, Y) -> Y' / L, L, Y) + @test gradtest((L, Y) -> Y' / L, L, y) + @test gradtest((L, Y) -> Y' / LowerTriangular(L), L, Y) + @test gradtest((L, Y) -> Y' / LowerTriangular(L), L, y) + + # / (UpperTriangular) + @test gradtest((U, Y) -> Y' / U, U, Y) + @test gradtest((U, Y) -> Y' / U, U, y) + @test gradtest((U, Y) -> Y' / UpperTriangular(U), U, Y) + @test gradtest((U, Y) -> Y' / UpperTriangular(U), U, y) @testset "Cholesky" begin @@ -126,6 +266,13 @@ end rng, P = MersenneTwister(123456), 7 A = randn(rng, P, P) @test gradtest(Symmetric, A) + y, back = Zygote.forward(Symmetric, A) + + @testset "back(::Diagonal)" begin + D̄ = Diagonal(randn(rng, P)) + @test back(Diagonal(D̄))[1] isa Diagonal + @test back(Diagonal(D̄))[1] ≈ back(Matrix(D̄))[1] + end end @testset "diag" begin @@ -134,6 +281,16 @@ end @test gradtest(diag, A) end +@testset "Diagonal" begin + rng, P = MersenneTwister(123456), 10 + d = randn(rng, P) + @test gradtest(Diagonal, d) + y, back = Zygote.forward(Diagonal, d) + D̄ = randn(rng, P, P) + @test back(D̄) == back(Diagonal(D̄)) + @test back(D̄) == back((diag=diag(D̄),)) +end + @testset "dense + UniformScaling" begin rng = MersenneTwister(123456) A, λ = randn(rng, 10, 10), randn(rng) @@ -142,16 +299,51 @@ end end @testset "cholesky" begin - rng, N = MersenneTwister(123456), 5 - A = randn(rng, N, N) - @test gradtest(A->logdet(cholesky(A' * A + 1e-6I)), A) + @testset "cholesky - dense" begin + rng, N = MersenneTwister(123456), 5 + A = randn(rng, N, N) + @test cholesky(A' * A + I) == first(Zygote.forward(A->cholesky(A' * A + I), A)) + @test gradtest(A->cholesky(A' * A + I).U, A) + @test gradtest(A->logdet(cholesky(A' * A + I)), A) + @test gradtest(B->cholesky(Symmetric(B)).U, A * A' + I) + @test gradtest(B->logdet(cholesky(Symmetric(B))), A * A' + I) + end @testset "cholesky - scalar" begin + rng = MersenneTwister(123456) y, back = Zygote.forward(cholesky, 5.0 * ones(1, 1)) y′, back′ = Zygote.forward(cholesky, 5.0) C̄ = randn(rng, 1, 1) @test back′((factors=C̄,))[1] isa Real @test back′((factors=C̄,))[1] ≈ back((factors=C̄,))[1][1, 1] end + @testset "cholesky - Diagonal" begin + rng, N = MersenneTwister(123456), 3 + D = Diagonal(exp.(randn(rng, N))) + Dmat = Matrix(D) + y, back = Zygote.forward(cholesky, Dmat) + y′, back′ = Zygote.forward(cholesky, D) + C̄ = (factors=randn(rng, N, N),) + @test back′(C̄)[1] isa Diagonal + @test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1]) + end +end + +@testset "lyap" begin + rng, N = MersenneTwister(6865943), 5 + for i = 1:5 + A = randn(rng, N, N) + C = randn(rng, N, N) + @test gradtest(lyap, A, C) + end + @test gradcheck(x->lyap(x[1],x[2]),[3.1,4.6]) +end + +@testset "matrix exponential" begin + rng, N = MersenneTwister(6865931), 8 + for i = 1:5 + A = randn(rng, N, N) + @test gradtest(exp, A) + end end using Distances @@ -178,19 +370,25 @@ Zygote.refresh() # Check binary pairwise. let X, Y = randn(rng, D, P), randn(rng, D, Q) - @test gradtest(X->pairwise(SqEuclidean(), X, Y), X) - @test gradtest(Y->pairwise(SqEuclidean(), X, Y), Y) + @test gradtest(X->pairwise(SqEuclidean(), X, Y; dims=2), X) + @test gradtest(Y->pairwise(SqEuclidean(), X, Y; dims=2), Y) + end + let + Xt, Yt = randn(rng, P, D), randn(rng, Q, D) + @test gradtest(Xt->pairwise(SqEuclidean(), Xt, Yt; dims=1), Xt) + @test gradtest(Yt->pairwise(SqEuclidean(), Xt, Yt; dims=1), Yt) end # Check unary pairwise. - @test gradtest(X->pairwise(SqEuclidean(), X), randn(rng, D, P)) + @test gradtest(X->pairwise(SqEuclidean(), X; dims=2), randn(rng, D, P)) + @test gradtest(Xt->pairwise(SqEuclidean(), Xt; dims=1), randn(rng, P, D)) end function cat_test(f, A::Union{AbstractVector, AbstractMatrix}...) @test gradtest(f, A...) Z, back = Zygote.forward(f, A...) - Ā = back(randn(size(Z))) - @test all(map((a, ā)->ā isa typeof(a), A, Ā)) + Ā = back(randn(size(Z))) + @test all(map((a, ā)->ā isa typeof(a), A, Ā)) end @testset "vcat" begin @@ -239,6 +437,13 @@ end end end +@testset "hvcat" begin + @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == (1,0,0,0) + @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == (0,0,1,0) + @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == (0,1,0,0) + @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == (0,0,0,1) +end + @testset "one(s) and zero(s)" begin @test Zygote.gradient(x->sum(ones(size(x))), randn(5))[1] isa Nothing @test Zygote.gradient(x->sum(one(x)), randn(3, 3))[1] isa Nothing @@ -317,3 +522,132 @@ end @test size(Zygote.gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[1]) == (1, 1) @test size(Zygote.gradient((x, y)->sum(x * y), randn(1, 1), randn(1, 10))[2]) == (1, 10) end + +@testset "broadcast" begin + if !Zygote.usetyped + @test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1 + end +end + +using Zygote: Buffer + +@testset "Buffer" begin + @test gradient([1, 2, 3]) do x + b = Buffer(x) + b[:] = x + return sum(copy(b)) + end == ([1,1,1],) + + function vstack(xs) + buf = Buffer(xs, length(xs), 5) + for i = 1:5 + buf[:, i] = xs + end + return copy(buf) + end + + @test gradient(x -> sum(vstack(x)), [1, 2, 3]) == ([5, 5, 5],) + + buf = Buffer([1, 2, 3]) + buf[1] = 1 + copy(buf) + @test_throws ErrorException buf[1] = 1 + @test eltype(buf) === Int + @test length(buf) === 3 + @test ndims(buf) === 1 + @test size(buf) === (3, ) + @test size(buf, 2) === 1 + @test axes(buf) == (1:3, ) + @test axes(buf, 2) == 1:1 + @test eachindex(buf) == 1:3 + @test stride(buf, 2) === 3 + @test strides(buf) === (1, ) +end + +@testset "FillArrays" begin + @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1]) + @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing + @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing +end + +@testset "AbstractArray Addition / Subtraction / Negation" begin + rng, M, N, P = MersenneTwister(123567), 3, 7, 11 + A, B = randn(rng, M, N, P), randn(rng, M, N, P) + @test gradtest(+, A, B) + @test gradtest(-, A, B) + @test gradtest(-, A) +end + +@testset "FFTW" begin + x=[-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982] + # gradient of ifft(rfft) must be 1 + @test gradient((x)->real(ifft(fft(x))[1]),x)[1][1] == 1.0+0.0im + @test gradient((x)->real(fft(ifft(x))[1]),x)[1][1] == 1.0+0.0im + + # check ffts for individual dimensions + @test gradient((x)->sum(abs.(FFTW.fft(x))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.fft(FFTW.fft(x,1),2))),x)[1] + @test gradient((x)->abs(sum((FFTW.fft(x)))),x)[1] ≈ gradient((x)->abs(sum(FFTW.fft(FFTW.fft(x,1),2))),x)[1] + @test gradient((x, dims)->sum(abs.(FFTW.fft(x,dims))),x,(1,2))[1] ≈ gradient((x)->sum(abs.(FFTW.fft(x))),x)[1] + @test gradient((x)->sum(abs.(FFTW.fft(x,(1,2)))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.fft(FFTW.fft(x,1),2))),x)[1] + @test gradient((x, dims)->sum(abs.(FFTW.ifft(x,dims))),x,(1,2))[1] ≈ gradient((x)->sum(abs.(FFTW.ifft(x))),x)[1] + @test gradient((x)->sum(abs.(FFTW.ifft(x,(1,2)))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.ifft(FFTW.ifft(x,1),2))),x)[1] + + @test gradcheck(x->sum(abs.(FFTW.fft(x))), x) + @test gradcheck(x->sum(abs.(FFTW.ifft(x))), x) + @test gradcheck(x->sum(abs.(FFTW.fft(x, 1))), x) + @test gradcheck(x->sum(abs.(FFTW.ifft(x, 1))), x) + +end + +@testset "FillArrays" begin + rng, M, N = MersenneTwister(123456), 7, 11 + x, y = randn(rng), randn(rng) + @test Zygote.gradient(x->sum(Fill(x, N)), x)[1] == N + @test Zygote.gradient(x->sum(Fill(x, N, 3, 4)), x)[1] == N * 3 * 4 + @test Zygote.gradient((x, y)->sum(Fill(x, N)), x, y) == (N, nothing) + + let + out, back = Zygote.forward(sum, Fill(x, N)) + @test back(nothing) isa Nothing + end + + z = randn(rng, N) + @test gradtest(x->Fill(first(x), N), [x]) + let + out, back = Zygote.forward(x->Fill(x, N), x) + @test out == Fill(x, N) + @test first(back(Fill(y, N))) ≈ y * N + end + + # Test unary broadcasting gradients. + out, back = Zygote.forward(x->exp.(x), Fill(x, N)) + @test out isa Fill + @test out == Fill(exp(x), N) + @test back(Ones(N))[1] isa Fill + @test back(Ones(N))[1] == Ones(N) .* exp(x) + @test back(ones(N))[1] isa Vector + @test back(ones(N))[1] == ones(N) .* exp(x) + @test gradtest(x->exp.(Fill(3 * first(x), N)), [x]) + + @testset "broadcast + and *" begin + for sx in [(M, N), (M, 1), (1, N), (1, 1)] + for sy in [(M, N), (M, 1), (1, N), (1, 1)] + z = randn(rng, broadcast_shape(sx, sy)) + + # Addition + @test gradtest((x, y)->Fill(first(x), sx...) .+ Fill(first(y), sy...), [x], [y]) + @test gradtest(x->Fill(first(x), sx...) .+ Ones(sy...), [x]) + @test gradtest(x->Fill(first(x), sx...) .+ Zeros(sy...), [x]) + @test gradtest(y->Ones(sx...) .+ Fill(first(y), sy...), [y]) + @test gradtest(y->Zeros(sx...) .+ Fill(first(y), sy...), [y]) + + # Multiplication + @test gradtest((x, y)->Fill(first(x), sx...) .* Fill(first(y), sy...), [x], [y]) + @test gradtest(x->Fill(first(x), sx...) .* Ones(sy...), [x]) + @test gradtest(x->Fill(first(x), sx...) .* Zeros(sy...), [x]) + @test gradtest(y->Ones(sx...) .* Fill(first(y), sy...), [y]) + @test gradtest(y->Zeros(sx...) .* Fill(first(y), sy...), [y]) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 152de4361..0142a6467 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,12 @@ using Zygote, Test using Zygote: gradient +if Zygote.usetyped + @info "Testing Zygote in type-hacks mode." +else + @info "Testing Zygote in normal mode." +end + @testset "Zygote" begin @testset "Features" begin @@ -11,6 +17,10 @@ end include("gradcheck.jl") end +@testset "Complex" begin + include("complex.jl") +end + @testset "Compiler" begin include("compiler.jl") end diff --git a/test/typed.jl b/test/typed.jl index d9f2ea6c4..f97f8c9a8 100644 --- a/test/typed.jl +++ b/test/typed.jl @@ -1,5 +1,5 @@ using Zygote, Test -using Zygote: gradient, derivative, forward +using Zygote: gradient, forward dpow(n, p) = something(gradient(pow, n, p)[1], zero(n)) @@ -7,6 +7,6 @@ dpow(n, p) = something(gradient(pow, n, p)[1], zero(n)) @test_inferred dpow(2, 3) cube(x) = pow(x, 3) -dcube(x) = something(derivative(cube, x), zero(x)) +dcube(x) = something(gradient(cube, x)[1], zero(x)) y, back = @test_inferred forward(cube, 2) @test_inferred dcube(2)