Skip to content

Commit

Permalink
Allow to use ktls for send and receive
Browse files Browse the repository at this point in the history
  • Loading branch information
craff committed Jun 20, 2023
1 parent 9e12c50 commit b2a4e45
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 10 deletions.
17 changes: 13 additions & 4 deletions src/ssl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,20 @@ type context_type =
| Both_context

external create_context :
protocol
-> context_type
protocol
-> context_type -> bool
-> context
= "ocaml_ssl_create_context"

external ktls_send_available : socket -> bool = "caml_ssl_ktls_send_available"
[@@noalloc]

external ktls_recv_available : socket -> bool = "caml_ssl_ktls_recv_available"
[@@noalloc]

let create_context ?(ktls=false) protocol typ =
create_context protocol typ ktls

external set_min_protocol_version :
context
-> protocol
Expand Down Expand Up @@ -716,9 +725,9 @@ module Make (Ssl_base : Ssl_base) = struct
Unix.close sock;
raise exn

let open_connection ssl_method sockaddr =
let open_connection ?(ktls=false) ssl_method sockaddr =
open_connection_with_context
(create_context ssl_method Client_context)
(create_context ~ktls ssl_method Client_context)
sockaddr

let close_notify = ssl_shutdown
Expand Down
24 changes: 20 additions & 4 deletions src/ssl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,24 @@ type context_type =
| Server_context (** Server connections. *)
| Both_context (** Client and server connections. *)

val create_context : protocol -> context_type -> context
(** Create a context. *)
val create_context : ?ktls:bool -> protocol -> context_type -> context
(** Create a context.
The ktls optional parameter ([false] by default) allows for using kernel
TLS function. You must use [ktls_send_available] and [ktls_recv_available]
to check that ktls is enabled before you can use [Unix.read] and
[Unix.single_write] directly to write to the SSL buffer *)

val ktls_send_available : socket -> bool
(** Checks if ktls is available to write. This allows for using
[Unix.single_write] on the file descriptor underlying the ssl socket. *)

val ktls_recv_available : socket -> bool
(** Checks if ktls is available to read. This allows for using [Unix.read] on
the file descriptor underlying the ssl socket.
NOTE: recv is supported for TLS1_3 only with recent version of openssl, requiring the recent commit: https://github.com/openssl/openssl/commit/7c78932b9a4330fb7c8db72b3fb37cbff1401f8b
*)

val set_min_protocol_version : context -> protocol -> unit
(** [set_min_protocol_version ctx proto] sets the minimum supported protocol
Expand Down Expand Up @@ -547,7 +563,7 @@ val connect : socket -> unit
val accept : socket -> unit
(** Accept an SSL connection. *)

val open_connection : protocol -> Unix.sockaddr -> socket
val open_connection : ?ktls:bool -> protocol -> Unix.sockaddr -> socket
(** Open an SSL connection. *)

val open_connection_with_context : context -> Unix.sockaddr -> socket
Expand Down Expand Up @@ -620,7 +636,7 @@ module Runtime_lock : sig
val accept : socket -> unit
(** Accept an SSL connection. *)

val open_connection : protocol -> Unix.sockaddr -> socket
val open_connection : ?ktls:bool -> protocol -> Unix.sockaddr -> socket
(** Open an SSL connection. *)

val open_connection_with_context : context -> Unix.sockaddr -> socket
Expand Down
17 changes: 15 additions & 2 deletions src/ssl_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,20 @@ static void set_protocol(SSL_CTX *ssl_context, int protocol) {
}
}

CAMLprim value ocaml_ssl_create_context(value protocol, value type) {
CAMLparam2(protocol, type);
CAMLprim value caml_ssl_ktls_send_available(value out_fd) {
CAMLparam1(out_fd);
int r = BIO_get_ktls_send(SSL_get_wbio(SSL_val(out_fd)));
CAMLreturn(Val_int(r));
}

CAMLprim value caml_ssl_ktls_recv_available(value out_fd) {
CAMLparam1(out_fd);
int r = BIO_get_ktls_recv(SSL_get_wbio(SSL_val(out_fd)));
CAMLreturn(Val_int(r));
}

CAMLprim value ocaml_ssl_create_context(value protocol, value type, value ktls) {
CAMLparam3(protocol, type, ktls);
CAMLlocal1(block);
SSL_CTX *ctx;
const SSL_METHOD *method = get_method(Int_val(type));
Expand All @@ -559,6 +571,7 @@ CAMLprim value ocaml_ssl_create_context(value protocol, value type) {
mode, hide SSL_ERROR_WANT_(READ|WRITE) from us. */
SSL_CTX_set_mode(ctx,
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY);
if Int_val(ktls) SSL_CTX_set_options(ctx, SSL_OP_ENABLE_KTLS);
caml_acquire_runtime_system();

block = caml_alloc_custom(&ctx_ops, sizeof(SSL_CTX *), 0, 1);
Expand Down
11 changes: 11 additions & 0 deletions tests/dune
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
(modules util)
(libraries ssl threads str alcotest))

(library
(name util_ktls)
(modules util_ktls)
(libraries ssl threads str alcotest))

(test
(name ssl_test)
(modules ssl_test)
Expand Down Expand Up @@ -43,3 +48,9 @@
(modules ssl_io)
(libraries ssl alcotest util)
(deps ca.pem ca.key server.key server.pem))

(test
(name ssl_io_ktls)
(modules ssl_io_ktls)
(libraries ssl alcotest util_ktls)
(deps ca.pem ca.key server.key server.pem))
86 changes: 86 additions & 0 deletions tests/ssl_io_ktls.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
open Alcotest

module Util = Util_ktls

let test_verify () =
let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 11342) in
Util.server_thread addr None |> ignore;

let context = Ssl.create_context ~ktls:true TLSv1_2 Client_context in
let ssl = Ssl.open_connection_with_context context addr in
assert(Ssl.ktls_send_available ssl);
assert(Ssl.ktls_recv_available ssl);
let verify_result =
try
Ssl.verify ssl;
""
with
| e -> Printexc.to_string e
in
Ssl.shutdown_connection ssl;
check
bool
"no verify errors"
true
(Str.search_forward
(Str.regexp_string "error:00:000000:lib(0)")
verify_result
0
> 0)

let test_set_host () =
let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 11343) in
Util.server_thread addr None |> ignore;

let context = Ssl.create_context ~ktls:true TLSv1_2 Client_context in
let domain = Unix.domain_of_sockaddr addr in
let sock = Unix.socket domain Unix.SOCK_STREAM 0 in
let ssl = Ssl.embed_socket sock context in
Ssl.set_host ssl "localhost";
Unix.connect sock addr;
Ssl.connect ssl;
let verify_result =
try
Ssl.verify ssl;
assert(Ssl.ktls_send_available ssl);
assert(Ssl.ktls_recv_available ssl);
""
with
| e -> Printexc.to_string e
in
Ssl.shutdown_connection ssl;
check
bool
"no verify errors"
true
(Str.search_forward
(Str.regexp_string "error:00:000000:lib(0)")
verify_result
0
> 0)

let test_read_write () =
let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 11344) in
Util.server_thread addr (Some (fun _ -> "received")) |> ignore;

let context = Ssl.create_context ~ktls:true TLSv1_2 Client_context in
let ssl = Ssl.open_connection_with_context context addr in
assert(Ssl.ktls_send_available ssl);
assert(Ssl.ktls_recv_available ssl);
let send_msg = "send" in
let write_buf = Bytes.create (String.length send_msg) in
Unix.single_write (Ssl.file_descr_of_socket ssl) write_buf 0 4 |> ignore;
let read_buf = Bytes.create 8 in
Unix.read (Ssl.file_descr_of_socket ssl) read_buf 0 8 |> ignore;
Ssl.shutdown_connection ssl;
check string "received message" "received" (Bytes.to_string read_buf)

let () =
run
"Ssl io functions"
[ ( "IO"
, [ test_case "Verify" `Quick test_verify
; test_case "Set host" `Quick test_set_host
; test_case "Read write" `Quick test_read_write
] )
]
96 changes: 96 additions & 0 deletions tests/util_ktls.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
module Ssl = struct
include Ssl

let[@ocaml.alert "-deprecated"] get_error_string = get_error_string
end

open Ssl

type server_args =
{ address : Unix.sockaddr
; condition : Condition.t
; mutex : Mutex.t
; parser : (string -> string) option
}

let server_rw_loop ssl parser_func =
let fd = Ssl.file_descr_of_socket ssl in
let rw_loop = ref true in
while !rw_loop do
try
let read_buf = Bytes.create 256 in
let read_bytes = Unix.read fd read_buf 0 256 in
if read_bytes > 0
then (
let input = Bytes.to_string read_buf in
let response = parser_func input in
Unix.write_substring fd response 0 (String.length response) |> ignore;
Ssl.close_notify ssl |> ignore;
rw_loop := false)
with
| Read_error read_error ->
(match read_error with Error_ssl -> rw_loop := false | _ -> ())
done

let server_init args =
try
(* Server initialization *)
Mutex.lock args.mutex;
let socket = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in
Unix.setsockopt socket Unix.SO_REUSEADDR true;
Unix.bind socket args.address;
let context = create_context ~ktls:true TLSv1_2 Server_context in
use_certificate context "server.pem" "server.key";
Ssl.set_context_alpn_select_callback context (fun client_protos ->
List.find_opt (fun opt -> opt = "http/1.1") client_protos);
(* Signal ready and listen for connection *)
Unix.listen socket 1;
Some (socket, context)
with
| exn ->
Printexc.to_string exn |> print_endline;
None

let server_listen args =
match server_init args with
| None ->
Mutex.unlock args.mutex;
Condition.signal args.condition;
Thread.exit () [@warning "-3"]
| Some (socket, context) ->
Mutex.unlock args.mutex;
Condition.signal args.condition;
let listen = Unix.accept socket in
let ssl = embed_socket (fst listen) context in
accept ssl;
assert(ktls_send_available ssl);
assert(ktls_recv_available ssl);
(* Exit right away unless we need to rw *)
(match args.parser with
| Some parser_func -> server_rw_loop ssl parser_func
| None ->
();
shutdown ssl;
Thread.exit () [@warning "-3"])

let server_thread addr parser =
let mutex = Mutex.create () in
Mutex.lock mutex;
let condition = Condition.create () in
let args = { address = addr; condition; mutex; parser } in
let thread = Thread.create server_listen args in
Condition.wait condition mutex;
thread

let check_ssl_no_error err =
Str.string_partial_match (Str.regexp_string "error:00000000:lib(0)") err 0

let[@ocaml.alert "-deprecated"] pp_protocol ppf = function
| SSLv23 -> Format.fprintf ppf "SSLv23"
| SSLv3 -> Format.fprintf ppf "SSLv3"
| TLSv1 -> Format.fprintf ppf "TLSv1"
| TLSv1_1 -> Format.fprintf ppf "TLSv1_1"
| TLSv1_2 -> Format.fprintf ppf "TLSv1_2"
| TLSv1_3 -> Format.fprintf ppf "TLSv1_3"

let protocol_testable = Alcotest.testable pp_protocol (fun r1 r2 -> r1 == r2)

0 comments on commit b2a4e45

Please sign in to comment.