From b2a4e45ca7e13072ebb28b7fb3717eebfeeb02c4 Mon Sep 17 00:00:00 2001 From: Christophe Raffalli Date: Tue, 20 Jun 2023 13:27:51 -1000 Subject: [PATCH] Allow to use ktls for send and receive --- src/ssl.ml | 17 ++++++-- src/ssl.mli | 24 +++++++++-- src/ssl_stubs.c | 17 +++++++- tests/dune | 11 +++++ tests/ssl_io_ktls.ml | 86 +++++++++++++++++++++++++++++++++++++++ tests/util_ktls.ml | 96 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 241 insertions(+), 10 deletions(-) create mode 100644 tests/ssl_io_ktls.ml create mode 100644 tests/util_ktls.ml diff --git a/src/ssl.ml b/src/ssl.ml index 56d1f26..69d2cd5 100644 --- a/src/ssl.ml +++ b/src/ssl.ml @@ -212,11 +212,20 @@ type context_type = | Both_context external create_context : - protocol - -> context_type + protocol + -> context_type -> bool -> context = "ocaml_ssl_create_context" +external ktls_send_available : socket -> bool = "caml_ssl_ktls_send_available" + [@@noalloc] + +external ktls_recv_available : socket -> bool = "caml_ssl_ktls_recv_available" + [@@noalloc] + +let create_context ?(ktls=false) protocol typ = + create_context protocol typ ktls + external set_min_protocol_version : context -> protocol @@ -716,9 +725,9 @@ module Make (Ssl_base : Ssl_base) = struct Unix.close sock; raise exn - let open_connection ssl_method sockaddr = + let open_connection ?(ktls=false) ssl_method sockaddr = open_connection_with_context - (create_context ssl_method Client_context) + (create_context ~ktls ssl_method Client_context) sockaddr let close_notify = ssl_shutdown diff --git a/src/ssl.mli b/src/ssl.mli index f8c24f1..5b3f989 100644 --- a/src/ssl.mli +++ b/src/ssl.mli @@ -311,8 +311,24 @@ type context_type = | Server_context (** Server connections. *) | Both_context (** Client and server connections. *) -val create_context : protocol -> context_type -> context -(** Create a context. *) +val create_context : ?ktls:bool -> protocol -> context_type -> context +(** Create a context. + + The ktls optional parameter ([false] by default) allows for using kernel + TLS function. You must use [ktls_send_available] and [ktls_recv_available] + to check that ktls is enabled before you can use [Unix.read] and + [Unix.single_write] directly to write to the SSL buffer *) + +val ktls_send_available : socket -> bool +(** Checks if ktls is available to write. This allows for using + [Unix.single_write] on the file descriptor underlying the ssl socket. *) + +val ktls_recv_available : socket -> bool +(** Checks if ktls is available to read. This allows for using [Unix.read] on + the file descriptor underlying the ssl socket. + + NOTE: recv is supported for TLS1_3 only with recent version of openssl, requiring the recent commit: https://github.com/openssl/openssl/commit/7c78932b9a4330fb7c8db72b3fb37cbff1401f8b + *) val set_min_protocol_version : context -> protocol -> unit (** [set_min_protocol_version ctx proto] sets the minimum supported protocol @@ -547,7 +563,7 @@ val connect : socket -> unit val accept : socket -> unit (** Accept an SSL connection. *) -val open_connection : protocol -> Unix.sockaddr -> socket +val open_connection : ?ktls:bool -> protocol -> Unix.sockaddr -> socket (** Open an SSL connection. *) val open_connection_with_context : context -> Unix.sockaddr -> socket @@ -620,7 +636,7 @@ module Runtime_lock : sig val accept : socket -> unit (** Accept an SSL connection. *) - val open_connection : protocol -> Unix.sockaddr -> socket + val open_connection : ?ktls:bool -> protocol -> Unix.sockaddr -> socket (** Open an SSL connection. *) val open_connection_with_context : context -> Unix.sockaddr -> socket diff --git a/src/ssl_stubs.c b/src/ssl_stubs.c index 180cdaf..36ccb61 100644 --- a/src/ssl_stubs.c +++ b/src/ssl_stubs.c @@ -540,8 +540,20 @@ static void set_protocol(SSL_CTX *ssl_context, int protocol) { } } -CAMLprim value ocaml_ssl_create_context(value protocol, value type) { - CAMLparam2(protocol, type); +CAMLprim value caml_ssl_ktls_send_available(value out_fd) { + CAMLparam1(out_fd); + int r = BIO_get_ktls_send(SSL_get_wbio(SSL_val(out_fd))); + CAMLreturn(Val_int(r)); +} + +CAMLprim value caml_ssl_ktls_recv_available(value out_fd) { + CAMLparam1(out_fd); + int r = BIO_get_ktls_recv(SSL_get_wbio(SSL_val(out_fd))); + CAMLreturn(Val_int(r)); +} + +CAMLprim value ocaml_ssl_create_context(value protocol, value type, value ktls) { + CAMLparam3(protocol, type, ktls); CAMLlocal1(block); SSL_CTX *ctx; const SSL_METHOD *method = get_method(Int_val(type)); @@ -559,6 +571,7 @@ CAMLprim value ocaml_ssl_create_context(value protocol, value type) { mode, hide SSL_ERROR_WANT_(READ|WRITE) from us. */ SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY); + if Int_val(ktls) SSL_CTX_set_options(ctx, SSL_OP_ENABLE_KTLS); caml_acquire_runtime_system(); block = caml_alloc_custom(&ctx_ops, sizeof(SSL_CTX *), 0, 1); diff --git a/tests/dune b/tests/dune index 26d43f5..6f9b384 100644 --- a/tests/dune +++ b/tests/dune @@ -3,6 +3,11 @@ (modules util) (libraries ssl threads str alcotest)) +(library + (name util_ktls) + (modules util_ktls) + (libraries ssl threads str alcotest)) + (test (name ssl_test) (modules ssl_test) @@ -43,3 +48,9 @@ (modules ssl_io) (libraries ssl alcotest util) (deps ca.pem ca.key server.key server.pem)) + +(test + (name ssl_io_ktls) + (modules ssl_io_ktls) + (libraries ssl alcotest util_ktls) + (deps ca.pem ca.key server.key server.pem)) diff --git a/tests/ssl_io_ktls.ml b/tests/ssl_io_ktls.ml new file mode 100644 index 0000000..d95abc8 --- /dev/null +++ b/tests/ssl_io_ktls.ml @@ -0,0 +1,86 @@ +open Alcotest + +module Util = Util_ktls + +let test_verify () = + let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 11342) in + Util.server_thread addr None |> ignore; + + let context = Ssl.create_context ~ktls:true TLSv1_2 Client_context in + let ssl = Ssl.open_connection_with_context context addr in + assert(Ssl.ktls_send_available ssl); + assert(Ssl.ktls_recv_available ssl); + let verify_result = + try + Ssl.verify ssl; + "" + with + | e -> Printexc.to_string e + in + Ssl.shutdown_connection ssl; + check + bool + "no verify errors" + true + (Str.search_forward + (Str.regexp_string "error:00:000000:lib(0)") + verify_result + 0 + > 0) + +let test_set_host () = + let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 11343) in + Util.server_thread addr None |> ignore; + + let context = Ssl.create_context ~ktls:true TLSv1_2 Client_context in + let domain = Unix.domain_of_sockaddr addr in + let sock = Unix.socket domain Unix.SOCK_STREAM 0 in + let ssl = Ssl.embed_socket sock context in + Ssl.set_host ssl "localhost"; + Unix.connect sock addr; + Ssl.connect ssl; + let verify_result = + try + Ssl.verify ssl; + assert(Ssl.ktls_send_available ssl); + assert(Ssl.ktls_recv_available ssl); + "" + with + | e -> Printexc.to_string e + in + Ssl.shutdown_connection ssl; + check + bool + "no verify errors" + true + (Str.search_forward + (Str.regexp_string "error:00:000000:lib(0)") + verify_result + 0 + > 0) + +let test_read_write () = + let addr = Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1", 11344) in + Util.server_thread addr (Some (fun _ -> "received")) |> ignore; + + let context = Ssl.create_context ~ktls:true TLSv1_2 Client_context in + let ssl = Ssl.open_connection_with_context context addr in + assert(Ssl.ktls_send_available ssl); + assert(Ssl.ktls_recv_available ssl); + let send_msg = "send" in + let write_buf = Bytes.create (String.length send_msg) in + Unix.single_write (Ssl.file_descr_of_socket ssl) write_buf 0 4 |> ignore; + let read_buf = Bytes.create 8 in + Unix.read (Ssl.file_descr_of_socket ssl) read_buf 0 8 |> ignore; + Ssl.shutdown_connection ssl; + check string "received message" "received" (Bytes.to_string read_buf) + +let () = + run + "Ssl io functions" + [ ( "IO" + , [ test_case "Verify" `Quick test_verify + ; test_case "Set host" `Quick test_set_host + ; test_case "Read write" `Quick test_read_write + ] ) + ] diff --git a/tests/util_ktls.ml b/tests/util_ktls.ml new file mode 100644 index 0000000..cd47887 --- /dev/null +++ b/tests/util_ktls.ml @@ -0,0 +1,96 @@ +module Ssl = struct + include Ssl + + let[@ocaml.alert "-deprecated"] get_error_string = get_error_string +end + +open Ssl + +type server_args = + { address : Unix.sockaddr + ; condition : Condition.t + ; mutex : Mutex.t + ; parser : (string -> string) option + } + +let server_rw_loop ssl parser_func = + let fd = Ssl.file_descr_of_socket ssl in + let rw_loop = ref true in + while !rw_loop do + try + let read_buf = Bytes.create 256 in + let read_bytes = Unix.read fd read_buf 0 256 in + if read_bytes > 0 + then ( + let input = Bytes.to_string read_buf in + let response = parser_func input in + Unix.write_substring fd response 0 (String.length response) |> ignore; + Ssl.close_notify ssl |> ignore; + rw_loop := false) + with + | Read_error read_error -> + (match read_error with Error_ssl -> rw_loop := false | _ -> ()) + done + +let server_init args = + try + (* Server initialization *) + Mutex.lock args.mutex; + let socket = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in + Unix.setsockopt socket Unix.SO_REUSEADDR true; + Unix.bind socket args.address; + let context = create_context ~ktls:true TLSv1_2 Server_context in + use_certificate context "server.pem" "server.key"; + Ssl.set_context_alpn_select_callback context (fun client_protos -> + List.find_opt (fun opt -> opt = "http/1.1") client_protos); + (* Signal ready and listen for connection *) + Unix.listen socket 1; + Some (socket, context) + with + | exn -> + Printexc.to_string exn |> print_endline; + None + +let server_listen args = + match server_init args with + | None -> + Mutex.unlock args.mutex; + Condition.signal args.condition; + Thread.exit () [@warning "-3"] + | Some (socket, context) -> + Mutex.unlock args.mutex; + Condition.signal args.condition; + let listen = Unix.accept socket in + let ssl = embed_socket (fst listen) context in + accept ssl; + assert(ktls_send_available ssl); + assert(ktls_recv_available ssl); + (* Exit right away unless we need to rw *) + (match args.parser with + | Some parser_func -> server_rw_loop ssl parser_func + | None -> + (); + shutdown ssl; + Thread.exit () [@warning "-3"]) + +let server_thread addr parser = + let mutex = Mutex.create () in + Mutex.lock mutex; + let condition = Condition.create () in + let args = { address = addr; condition; mutex; parser } in + let thread = Thread.create server_listen args in + Condition.wait condition mutex; + thread + +let check_ssl_no_error err = + Str.string_partial_match (Str.regexp_string "error:00000000:lib(0)") err 0 + +let[@ocaml.alert "-deprecated"] pp_protocol ppf = function + | SSLv23 -> Format.fprintf ppf "SSLv23" + | SSLv3 -> Format.fprintf ppf "SSLv3" + | TLSv1 -> Format.fprintf ppf "TLSv1" + | TLSv1_1 -> Format.fprintf ppf "TLSv1_1" + | TLSv1_2 -> Format.fprintf ppf "TLSv1_2" + | TLSv1_3 -> Format.fprintf ppf "TLSv1_3" + +let protocol_testable = Alcotest.testable pp_protocol (fun r1 r2 -> r1 == r2)