diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 7588a0718..fcecda2f1 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -105,6 +105,9 @@ end @adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) +@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Integer) where {K,N} = + (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing)) + @adjoint function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) first(xs), Δ -> ((Δ, drest...),)