Skip to content

Commit

Permalink
Merge pull request #230 from JuliaDiff/ox/relax
Browse files Browse the repository at this point in the history
Relax type constraints to permit AbstractVectors
  • Loading branch information
oxinabox committed Jan 23, 2024
2 parents 99ad77f + 252d17b commit a370b9b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Approximate the Jacobian of `f` at `x` using `fdm`. Results will be returned as
version of `x`, and `y_vec` the flattened version of `f(x...)`. Flattening performed by
[`to_vec`](@ref).
"""
function jacobian(fdm, f, x::Vector{<:Real}; len=nothing)
function jacobian(fdm, f, x::AbstractVector{<:Real}; len=nothing)
len !== nothing && Base.depwarn(
"`len` keyword argument to `jacobian` is no longer required " *
"and will not be permitted in the future.",
Expand Down Expand Up @@ -40,11 +40,11 @@ end
replace_arg(x, xs::Tuple, k::Int) = ntuple(p -> p == k ? x : xs[p], length(xs))

"""
_jvp(fdm, f, x::Vector{<:Real}, ẋ::AbstractVector{<:Real})
_jvp(fdm, f, x::AbstractVector{<:Real}, ẋ::AbstractVector{<:Real})
Convenience function to compute `jacobian(f, x) * ẋ`.
"""
function _jvp(fdm, f, x::Vector{<:Real}, ẋ::Vector{<:Real})
function _jvp(fdm, f, x::AbstractVector{<:Real}, ẋ::AbstractVector{<:Real})
return fdm-> f(x .+ ε .* ẋ), zero(eltype(x)))
end

Expand Down Expand Up @@ -79,7 +79,7 @@ end

j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1]

function _j′vp(fdm, f, ȳ::Vector{<:Real}, x::Vector{<:Real})
function _j′vp(fdm, f, ȳ::AbstractVector{<:Real}, x::AbstractVector{<:Real})
isempty(x) && return eltype(ȳ)[] # if x is empty, then so is the jacobian and x̄
return transpose(first(jacobian(fdm, f, x))) *
end
Expand Down

0 comments on commit a370b9b

Please sign in to comment.