Skip to content

Commit

Permalink
pk/rsa: Fix GC rooting issues and fill out API
Browse files Browse the repository at this point in the history
MbedTLS is passing raw pointers to C without rooting ownership
all over the place. This fixes these issues in the pk/rsa code
and fills out the API a bit for round-trip testing, as well
as adding GC-safe wrappers over the internal mp integers for
downstream code to use. Of course, the unsafe GC pattern is
repeated elsewhere in this package, so this is just a first
PR that fixed the API surface that I happened to need.
  • Loading branch information
Keno committed Nov 22, 2023
1 parent 41d1897 commit 868c9ea
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 31 deletions.
50 changes: 40 additions & 10 deletions src/pk.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
@enum(PKType,
PK_NONE=0,
PK_RSA,
PK_ECKEY,
PK_ECKEY_DH,
PK_ECDSA,
PK_RSA_ALT,
PK_RSASSA_PSS,
PK_OPAQUE)

mutable struct PKContext
data::Ptr{Cvoid}

Expand All @@ -15,15 +25,17 @@ mutable struct PKContext
end
end

Base.unsafe_convert(::Type{Ptr{Cvoid}}, ctx::PKContext) = ctx.data

const MBEDTLSLOCK = ReentrantLock()

function parse_keyfile!(ctx::PKContext, path, password="")
function parse_keyfile!(ctx::PKContext, path, password=C_NULL)
@err_check ccall((:mbedtls_pk_parse_keyfile, libmbedcrypto), Cint,
(Ptr{Cvoid}, Cstring, Cstring),
ctx.data, path, password)
ctx, path, password)
end

function parse_keyfile(path, password="")
function parse_keyfile(path, password=C_NULL)
ctx = PKContext()
parse_keyfile!(ctx, path, password)
ctx
Expand All @@ -32,7 +44,7 @@ end
function parse_public_keyfile!(ctx::PKContext, path)
@err_check ccall((:mbedtls_pk_parse_public_keyfile, libmbedcrypto), Cint,
(Ptr{Cvoid}, Cstring),
ctx.data, path)
ctx, path)
end

function parse_public_keyfile(path)
Expand All @@ -45,7 +57,7 @@ function parse_public_key!(ctx::PKContext, key)
key_bs = String(key)
@err_check ccall((:mbedtls_pk_parse_public_key, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{Cuchar}, Csize_t),
ctx.data, key_bs, sizeof(key_bs) + 1)
ctx, key_bs, sizeof(key_bs) + 1)
end

function parse_key!(ctx::PKContext, key, maybe_pw = nothing)
Expand All @@ -64,7 +76,7 @@ end

function bitlength(ctx::PKContext)
sz = ccall((:mbedtls_pk_get_bitlen, libmbedcrypto), Csize_t,
(Ptr{Cvoid},), ctx.data)
(Ptr{Cvoid},), ctx)
sz >= 0 || mbed_err(sz)
Int(sz)
end
Expand All @@ -74,7 +86,7 @@ function decrypt!(ctx::PKContext, input, output, rng)
Base.@lock MBEDTLSLOCK begin
@err_check ccall((:mbedtls_pk_decrypt, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{UInt8}, Csize_t, Ptr{Cvoid}, Ref{Cint}, Csize_t, Ptr{Cvoid}, Any),
ctx.data, input, sizeof(input), output, outlen_ref, sizeof(output), c_rng[], rng)
ctx, input, sizeof(input), output, outlen_ref, sizeof(output), c_rng[], rng)
end
outlen = outlen_ref[]
Int(outlen)
Expand All @@ -85,7 +97,7 @@ function encrypt!(ctx::PKContext, input, output, rng)
Base.@lock MBEDTLSLOCK begin
@err_check ccall((:mbedtls_pk_encrypt, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{UInt8}, Csize_t, Ptr{Cvoid}, Ref{Cint}, Csize_t, Ptr{Cvoid}, Any),
ctx.data, input, sizeof(input), output, outlen_ref, sizeof(output), c_rng[], rng)
ctx, input, sizeof(input), output, outlen_ref, sizeof(output), c_rng[], rng)
end
outlen = outlen_ref[]
Int(outlen)
Expand All @@ -96,7 +108,7 @@ function sign!(ctx::PKContext, hash_alg::MDKind, hash, output, rng)
Base.@lock MBEDTLSLOCK begin
@err_check ccall((:mbedtls_pk_sign, libmbedcrypto), Cint,
(Ptr{Cvoid}, Cint, Ptr{UInt8}, Csize_t, Ptr{UInt8}, Ref{Csize_t}, Ptr{Cvoid}, Any),
ctx.data, hash_alg, hash, sizeof(hash), output, outlen_ref, c_rng[], rng)
ctx, hash_alg, hash, sizeof(hash), output, outlen_ref, c_rng[], rng)
end
outlen = outlen_ref[]
Int(outlen)
Expand All @@ -116,6 +128,24 @@ function verify(ctx::PKContext, hash_alg::MDKind, hash, signature)
end

function get_name(ctx::PKContext)
ptr = ccall((:mbedtls_pk_get_name, libmbedcrypto), Ptr{Cchar}, (Ptr{Cvoid},), ctx.data)
ptr = ccall((:mbedtls_pk_get_name, libmbedcrypto), Ptr{Cchar}, (Ptr{Cvoid},), ctx)
unsafe_string(convert(Ptr{UInt8}, ptr))
end

function get_type(ctx::PKContext)
ccall((:mbedtls_pk_get_type, libmbedcrypto), PKType, (Ptr{Cvoid},), ctx)
end

# Access as RSA key
function RSA(pk::PKContext)
@assert get_type(pk) == PK_RSA
# We would like to do the following, but unfortunately, it's static_inline
# in the headers.
# ptr = ccall((:mbedtls_pk_rsa, libmbedcrypto), Ptr{mbedtls_rsa_context},
# (Ptr{Cvoid},), pk)
@GC.preserve pk begin
ptr = unsafe_load(Ptr{Ptr{mbedtls_rsa_context}}(pk.data), 2)
end
return RSA(ptr, pk)
end

93 changes: 72 additions & 21 deletions src/rsa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,47 @@ struct mbedtls_rsa_context
# are not required for this wrapper
end

struct MPI
ptr::Ptr{mbedtls_mpi}

# Used for rooting only
owner::Any
end
Base.unsafe_convert(::Type{Ptr{mbedtls_mpi}}, mpi::MPI) = mpi.ptr

mutable struct RSA
data::Ptr{mbedtls_rsa_context}

# Used for rooting only
owner::Any

function RSA(padding=MBEDTLS_RSA_PKCS_V21, hash_id=MD_MD5)
ctx = new()
ctx.data = Libc.malloc(1000)
ccall((:mbedtls_rsa_init, libmbedcrypto), Cvoid,
(Ptr{Cvoid}, Cint, Cint),
ctx.data, padding, hash_id)
(Ptr{mbedtls_rsa_context}, Cint, Cint),
ctx, padding, hash_id)
finalizer(ctx->begin
ccall((:mbedtls_rsa_free, libmbedcrypto), Cvoid, (Ptr{Cvoid},), ctx.data)
ccall((:mbedtls_rsa_free, libmbedcrypto), Cvoid, (Ptr{mbedtls_rsa_context},), ctx)
Libc.free(ctx.data)
end, ctx)
ctx
end

RSA(data::Ptr{mbedtls_rsa_context}, @nospecialize(owner)) = new(data, owner)
end

function mpi_import!(mpi::Ptr{mbedtls_mpi}, b::BigInt)
function Base.getproperty(ctx::RSA, s::Symbol)
if s in (:N, :E, :D, :P, :Q)
return MPI(Ptr{mbedtls_mpi}(getfield(ctx, :data) +
fieldoffset(mbedtls_rsa_context, Base.fieldindex(mbedtls_rsa_context, s))), ctx)
end
return getfield(ctx, s)
end

Base.unsafe_convert(::Type{Ptr{mbedtls_rsa_context}}, rsa::RSA) = rsa.data

function mpi_import!(mpi::Union{Ptr{mbedtls_mpi}, MPI}, b::BigInt)
# Export from GMP
size = ndigits(b, base=2)
nbytes = div(size+8-1,8)
Expand All @@ -50,44 +73,72 @@ function mpi_import!(mpi::Ptr{mbedtls_mpi}, b::BigInt)
mpi, data, nbytes)
end

function mpi_size(mpi::Ptr{mbedtls_mpi})
function mpi_export!(vec::Union{Vector{UInt8}, SubArray{1, UInt8, Vector{UInt8}}}, mpi::Union{Ptr{mbedtls_mpi}, MPI})
@err_check ccall((:mbedtls_mpi_write_binary, libmbedcrypto), Cint,

Check warning on line 77 in src/rsa.jl

View check run for this annotation

Codecov / codecov/patch

src/rsa.jl#L76-L77

Added lines #L76 - L77 were not covered by tests
(Ptr{mbedtls_mpi}, Ptr{UInt8}, Csize_t),
mpi, data, sizeof(vec))
return nothing

Check warning on line 80 in src/rsa.jl

View check run for this annotation

Codecov / codecov/patch

src/rsa.jl#L80

Added line #L80 was not covered by tests
end

function mpi_export!(to::IOBuffer, mpi::Union{Ptr{mbedtls_mpi}, MPI})
sz = mpi_size(mpi)
Base.ensureroom(to, sz)
ptr = (to.append ? to.size+1 : to.ptr)
@GC.preserve to begin
@err_check ccall((:mbedtls_mpi_write_binary, libmbedcrypto), Cint,
(Ptr{mbedtls_mpi}, Ptr{UInt8}, Csize_t),
mpi, pointer(to.data, ptr), sz)
ptr += sz
end
to.size = max(to.size, ptr - 1)
if !to.append
to.ptr += sz
end
return sz
end

function mpi_size(mpi::Union{Ptr{mbedtls_mpi}, MPI})
ccall((:mbedtls_mpi_size, libmbedcrypto), Csize_t, (Ptr{mbedtls_mpi},), mpi)
end

function pubkey_from_vals!(ctx::RSA, e::BigInt, n::BigInt)
Nptr = Ptr{mbedtls_mpi}(ctx.data+fieldoffset(mbedtls_rsa_context,3 #= :N =#))
mpi_import!(Nptr, n)
mpi_import!(Ptr{mbedtls_mpi}(ctx.data+fieldoffset(mbedtls_rsa_context,4 #= :E =#)), e)
nptr_size = mpi_size(Nptr)
unsafe_store!(Ptr{Csize_t}(ctx.data+fieldoffset(mbedtls_rsa_context,2 #=:len =#)),
nptr_size)
mpi_import!(ctx.N, n)
mpi_import!(ctx.E, e)
@GC.preserve ctx begin
nptr_size = mpi_size(ctx.N)
unsafe_store!(Ptr{Csize_t}(ctx.data+fieldoffset(mbedtls_rsa_context, 2 #= :len =#)), nptr_size)
end
@err_check ccall((:mbedtls_rsa_check_pubkey, libmbedcrypto), Cint,
(Ptr{Cvoid},), ctx.data)
(Ptr{mbedtls_rsa_context},), ctx)
ctx
end

function complete!(ctx::RSA)
@err_check ccall((:mbedtls_rsa_complete, libmbedcrypto), Cint,
(Ptr{mbedtls_rsa_context},), ctx)
return nothing
end

function verify(ctx::RSA, hash_alg::MDKind, hash, signature, rng = nothing; using_public=true)
(!using_public && rng == nothing) &&
error("Private key verification requires the rng")
# All errors, including validation errors throw
@err_check ccall((:mbedtls_rsa_pkcs1_verify, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Any, Cint, Cint, Csize_t, Ptr{UInt8}, Ptr{UInt8}),
ctx.data,
(Ptr{mbedtls_rsa_context}, Ptr{Cvoid}, Any, Cint, Cint, Csize_t, Ptr{UInt8}, Ptr{UInt8}),
ctx,
rng == nothing ? C_NULL : c_rng[],
rng == nothing ? Ref{Any}() : rng,
using_public ? 0 : 1,
hash_alg, sizeof(hash), hash, signature)
end


function gen_key!(ctx::RSA, f_rng, rng, nbits, exponent)
@err_check ccall((:mbedtls_rsa_gen_key, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Any, Cint, Cint),
ctx.data, f_rng, rng, nbits, exponent)
(Ptr{mbedtls_rsa_context}, Ptr{Cvoid}, Any, Cint, Cint),
ctx, f_rng, rng, nbits, exponent)
ctx
end


function gen_key(rng::AbstractRNG, nbits=2048, exponent=65537)
ctx = RSA()
gen_key!(ctx, c_rng[], rng, nbits, exponent)
Expand All @@ -96,14 +147,14 @@ end

function public(ctx::RSA, input, output)
@err_check ccall((:mbedtls_rsa_public, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Ptr{Cvoid}), ctx.data, input, output)
(Ptr{mbedtls_rsa_context}, Ptr{Cvoid}, Ptr{Cvoid}), ctx, input, output)
output
end

function private(ctx::RSA, f_rng, rng, input, output)
@err_check ccall((:mbedtls_rsa_private, libmbedcrypto), Cint,
(Ptr{Cvoid}, Ptr{Cvoid}, Any, Ptr{Cvoid}, Ptr{Cvoid}),
ctx.data, f_rng, rng, input, output)
(Ptr{mbedtls_rsa_context}, Ptr{Cvoid}, Any, Ptr{Cvoid}, Ptr{Cvoid}),
ctx, f_rng, rng, input, output)
output
end

Expand Down
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ let
@test MbedTLS.bitlength(key) == 2048
@test MbedTLS.get_name(key) == "RSA"

let rsa = MbedTLS.RSA(key)
MbedTLS.complete!(rsa)
buf = IOBuffer()
MbedTLS.mpi_export!(buf, rsa.E)
MbedTLS.mpi_export!(buf, rsa.N)
arr = take!(buf)
@test sizeof(arr) == 259
end

pubkey = MbedTLS.parse_public_keyfile(joinpath(@__DIR__, "public_key.pem"))
@test MbedTLS.bitlength(pubkey) == 2048
@test MbedTLS.get_name(pubkey) == "RSA"
Expand Down

0 comments on commit 868c9ea

Please sign in to comment.