diff --git a/src/lib/tls/msg_cert_req.cpp b/src/lib/tls/msg_cert_req.cpp index 864fe4791d9..dccbf3e0ec3 100644 --- a/src/lib/tls/msg_cert_req.cpp +++ b/src/lib/tls/msg_cert_req.cpp @@ -49,9 +49,7 @@ Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version, Handshake_Hash& hash, const Policy& policy, const std::vector& ca_certs) : - m_impl( protocol_version == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create() - : TLS_Message_Factory::create(io, hash, policy, ca_certs)) + m_impl(MessageFactory::create(protocol_version, io, hash, policy, ca_certs)) { } @@ -59,9 +57,7 @@ Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version, * Deserialize a Certificate Request message */ Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version, const std::vector& buf) : - m_impl( protocol_version == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create() - : TLS_Message_Factory::create(buf)) + m_impl(MessageFactory::create(protocol_version, buf)) { } diff --git a/src/lib/tls/msg_cert_verify.cpp b/src/lib/tls/msg_cert_verify.cpp index 43119a35cbf..7fd8f062845 100644 --- a/src/lib/tls/msg_cert_verify.cpp +++ b/src/lib/tls/msg_cert_verify.cpp @@ -28,9 +28,7 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io, const Policy& policy, RandomNumberGenerator& rng, const Private_Key* priv_key) : - m_impl( state.version() == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(io, state, policy, rng, priv_key) - : TLS_Message_Factory::create(io, state, policy, rng, priv_key)) + m_impl(MessageFactory::create(state.version(), io, state, policy, rng, priv_key)) { } @@ -38,9 +36,7 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io, * Deserialize a Certificate Verify message */ Certificate_Verify::Certificate_Verify(const Protocol_Version& protocol_version, const std::vector& buf) : - m_impl( protocol_version == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(buf) - : TLS_Message_Factory::create(buf)) + m_impl(MessageFactory::create(protocol_version, buf)) { } diff --git a/src/lib/tls/msg_cert_verify_impl.cpp b/src/lib/tls/msg_cert_verify_impl.cpp index 7994a65679d..2d4e1e470a0 100644 --- a/src/lib/tls/msg_cert_verify_impl.cpp +++ b/src/lib/tls/msg_cert_verify_impl.cpp @@ -58,8 +58,6 @@ Certificate_Verify_Impl::Certificate_Verify_Impl(const std::vector& buf reader.assert_done(); } -Certificate_Verify_Impl::~Certificate_Verify_Impl() = default; - /* * Serialize a Certificate Verify message */ diff --git a/src/lib/tls/msg_cert_verify_impl.h b/src/lib/tls/msg_cert_verify_impl.h index 7a3728a524d..785405051bb 100644 --- a/src/lib/tls/msg_cert_verify_impl.h +++ b/src/lib/tls/msg_cert_verify_impl.h @@ -52,8 +52,6 @@ class Certificate_Verify_Impl : public Handshake_Message explicit Certificate_Verify_Impl(const std::vector& buf); - virtual ~Certificate_Verify_Impl() = 0; - std::vector serialize() const override; private: std::vector m_signature; diff --git a/src/lib/tls/msg_certificate.cpp b/src/lib/tls/msg_certificate.cpp index 65c914f27c8..b3c1e564e97 100644 --- a/src/lib/tls/msg_certificate.cpp +++ b/src/lib/tls/msg_certificate.cpp @@ -45,9 +45,7 @@ Certificate::Certificate(const Protocol_Version& protocol_version, Handshake_IO& io, Handshake_Hash& hash, const std::vector& cert_list) : - m_impl( protocol_version == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create() - : TLS_Message_Factory::create(io, hash, cert_list)) + m_impl(MessageFactory::create(protocol_version, io, hash, cert_list)) { } @@ -56,9 +54,7 @@ Certificate::Certificate(const Protocol_Version& protocol_version, */ Certificate::Certificate(const Protocol_Version& protocol_version, const std::vector& buf, const Policy& policy) : - m_impl( protocol_version == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create() - : TLS_Message_Factory::create(buf, policy)) + m_impl(MessageFactory::create(protocol_version, buf, policy)) { } diff --git a/src/lib/tls/msg_client_hello.cpp b/src/lib/tls/msg_client_hello.cpp index 396b1d07129..09c56b596ce 100644 --- a/src/lib/tls/msg_client_hello.cpp +++ b/src/lib/tls/msg_client_hello.cpp @@ -65,9 +65,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, const std::vector& reneg_info, const Client_Hello::Settings& client_settings, const std::vector& next_protocols) : - m_impl(client_settings.protocol_version() == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols) - : TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols)) + m_impl(MessageFactory::create(client_settings.protocol_version(), io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols)) { } @@ -82,9 +80,7 @@ Client_Hello::Client_Hello(Handshake_IO& io, const std::vector& reneg_info, const Session& session, const std::vector& next_protocols) : - m_impl(session.version() == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, session, next_protocols) - : TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, session, next_protocols)) + m_impl(MessageFactory::create(session.version(), io, hash, policy, cb, rng, reneg_info, session, next_protocols)) { } @@ -95,9 +91,12 @@ Client_Hello::Client_Hello(const std::vector& buf) { auto supported_versions = Client_Hello_Impl(buf).supported_versions(); - m_impl = value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13)) - ? TLS_Message_Factory::create(buf) - : TLS_Message_Factory::create(buf); + const auto protocol_version = + value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13)) + ? Protocol_Version::TLS_V13 + : Protocol_Version::TLS_V12; + + m_impl = MessageFactory::create(protocol_version, buf); } Client_Hello::~Client_Hello() = default; diff --git a/src/lib/tls/msg_finished.cpp b/src/lib/tls/msg_finished.cpp index d4808a694be..2a75edc27d9 100644 --- a/src/lib/tls/msg_finished.cpp +++ b/src/lib/tls/msg_finished.cpp @@ -25,9 +25,7 @@ namespace TLS { Finished::Finished(Handshake_IO& io, Handshake_State& state, Connection_Side side) : - m_impl( state.version() == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(io, state, side) - : TLS_Message_Factory::create(io, state, side)) + m_impl(MessageFactory::create(state.version(), io, state, side)) { } @@ -43,9 +41,7 @@ std::vector Finished::serialize() const * Deserialize a Finished message */ Finished::Finished(const Protocol_Version& protocol_version, const std::vector& buf): - m_impl( protocol_version == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(buf) - : TLS_Message_Factory::create(buf)) + m_impl(MessageFactory::create(protocol_version, buf)) { } diff --git a/src/lib/tls/msg_finished_impl.cpp b/src/lib/tls/msg_finished_impl.cpp index 4b7836485ce..7ac4eee0789 100644 --- a/src/lib/tls/msg_finished_impl.cpp +++ b/src/lib/tls/msg_finished_impl.cpp @@ -74,8 +74,6 @@ std::vector Finished_Impl::serialize() const Finished_Impl::Finished_Impl(const std::vector& buf) : m_verification_data(buf) {} -Finished_Impl::~Finished_Impl() = default; - std::vector Finished_Impl::verify_data() const { return m_verification_data; diff --git a/src/lib/tls/msg_finished_impl.h b/src/lib/tls/msg_finished_impl.h index c43004fd9e3..1e41c514946 100644 --- a/src/lib/tls/msg_finished_impl.h +++ b/src/lib/tls/msg_finished_impl.h @@ -40,8 +40,6 @@ class Finished_Impl : public Handshake_Message explicit Finished_Impl(const std::vector& buf); - virtual ~Finished_Impl() = 0; - std::vector serialize() const override; private: std::vector m_verification_data; diff --git a/src/lib/tls/msg_server_hello.cpp b/src/lib/tls/msg_server_hello.cpp index af49fe882f3..0dbd55d1f33 100644 --- a/src/lib/tls/msg_server_hello.cpp +++ b/src/lib/tls/msg_server_hello.cpp @@ -34,9 +34,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, const Client_Hello& client_hello, const Server_Hello::Settings& server_settings, const std::string next_protocol) : - m_impl(server_settings.protocol_version() == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, client_hello, server_settings, next_protocol) - : TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, client_hello, server_settings, next_protocol)) + m_impl(MessageFactory::create(client_hello.version(), io, hash, policy, cb, rng, reneg_info, client_hello, server_settings, next_protocol)) { } @@ -51,9 +49,7 @@ Server_Hello::Server_Hello(Handshake_IO& io, Session& resumed_session, bool offer_session_ticket, const std::string& next_protocol) : - m_impl(client_hello.version() == Protocol_Version::TLS_V13 - ? TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, client_hello, resumed_session, offer_session_ticket, next_protocol) - : TLS_Message_Factory::create(io, hash, policy, cb, rng, reneg_info, client_hello, resumed_session, offer_session_ticket, next_protocol)) + m_impl(MessageFactory::create(client_hello.version(), io, hash, policy, cb, rng, reneg_info, client_hello, resumed_session, offer_session_ticket, next_protocol)) { } @@ -64,9 +60,12 @@ Server_Hello::Server_Hello(const std::vector& buf) { auto supported_versions = Server_Hello_Impl(buf).supported_versions(); - m_impl = value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13)) - ? TLS_Message_Factory::create(buf) - : TLS_Message_Factory::create(buf); + const auto protocol_version = + value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13)) + ? Protocol_Version::TLS_V13 + : Protocol_Version::TLS_V12; + + m_impl = MessageFactory::create(protocol_version, buf); } Server_Hello::~Server_Hello() = default; diff --git a/src/lib/tls/tls_message_factory.h b/src/lib/tls/tls_message_factory.h index 124e50078f9..779863b90e6 100644 --- a/src/lib/tls/tls_message_factory.h +++ b/src/lib/tls/tls_message_factory.h @@ -36,93 +36,74 @@ class Certificate_Verify_Impl_12; class Certificate_Impl_12; class Finished_Impl_12; -class TLS_Message_Factory - { - public: - template - struct Impl_Version_Trait{}; - - template - static std::unique_ptr create(Args&& ... args) - { - return std::make_unique::Ver_Impl>(std::forward(args) ... ); - } - }; - -template<> -struct TLS_Message_Factory::Impl_Version_Trait - { - using Ver_Impl = Server_Hello_Impl_12; - }; +namespace { -template<> -struct TLS_Message_Factory::Impl_Version_Trait - { - using Ver_Impl = Server_Hello_Impl_12; // TODO using Ver_Impl = Server_Hello_Impl_13 - }; - -template<> -struct TLS_Message_Factory::Impl_Version_Trait - { - using Ver_Impl = Client_Hello_Impl_12; - }; +template +struct implementation_trait{}; template<> -struct TLS_Message_Factory::Impl_Version_Trait +struct implementation_trait { - using Ver_Impl = Mock_Impl_13; // TODO using Ver_Impl = Client_Hello_Impl_13 + using v12 = Server_Hello_Impl_12; + using v13 = Mock_Impl_13; }; template<> -struct TLS_Message_Factory::Impl_Version_Trait +struct implementation_trait { - using Ver_Impl = Certificate_Req_Impl_12; + using v12 = Client_Hello_Impl_12; + using v13 = Mock_Impl_13; }; template<> -struct TLS_Message_Factory::Impl_Version_Trait +struct implementation_trait { - using Ver_Impl = Mock_Certificate_Req_Impl_13; // TODO using Ver_Impl = Certificate_Req_Impl_13 + using v12 = Certificate_Req_Impl_12; + using v13 = Mock_Impl_13; }; template<> -struct TLS_Message_Factory::Impl_Version_Trait +struct implementation_trait { - using Ver_Impl = Certificate_Verify_Impl_12; + using v12 = Certificate_Verify_Impl_12; + using v13 = Mock_Impl_13; }; template<> -struct TLS_Message_Factory::Impl_Version_Trait +struct implementation_trait { - using Ver_Impl = Mock_Impl_13; // TODO using Ver_Impl = Certificate_Verify_Impl_13 + using v12 = Certificate_Impl_12; + using v13 = Mock_Impl_13; }; template<> -struct TLS_Message_Factory::Impl_Version_Trait +struct implementation_trait { - using Ver_Impl = Certificate_Impl_12; + using v12 = Finished_Impl_12; + using v13 = Mock_Impl_13; }; -template<> -struct TLS_Message_Factory::Impl_Version_Trait - { - using Ver_Impl = Mock_Certificate_Impl_13; // TODO using Ver_Impl = Certificate_Impl_13 - }; +} -template<> -struct TLS_Message_Factory::Impl_Version_Trait - { - using Ver_Impl = Finished_Impl_12; - }; +namespace MessageFactory { -template<> -struct TLS_Message_Factory::Impl_Version_Trait +template +std::unique_ptr create(const Protocol_Version &protocol_version, ParamTs&&... parameters) { - using Ver_Impl = Mock_Impl_13; // TODO using Ver_Impl = Finished_Impl_13 - }; + using impl_t = implementation_trait; + + if (protocol_version == Protocol_Version::TLS_V13) + { + return std::make_unique(std::forward(parameters)...); + } + else + { + return std::make_unique(std::forward(parameters)...); + } + } } } - +} #endif diff --git a/src/lib/tls/tls_mock_msg_impl_13.h b/src/lib/tls/tls_mock_msg_impl_13.h index 692bee30f9a..e97293d9a41 100644 --- a/src/lib/tls/tls_mock_msg_impl_13.h +++ b/src/lib/tls/tls_mock_msg_impl_13.h @@ -8,6 +8,7 @@ #ifndef BOTAN_TLS_MOCK_MSG_IMPL_13_H_ #define BOTAN_TLS_MOCK_MSG_IMPL_13_H_ +#include #include #include #include @@ -15,64 +16,78 @@ #include #include +#include namespace Botan { namespace TLS { -#include -template< typename T > -class Mock_Impl_13: public T +namespace { + +template +[[noreturn]] RetT nyi() + { + throw Not_Implemented("Implementation for TLSv1.3 not ready yet. You are welcome to implement it."); + } + +template +inline constexpr bool must_be_upcalled = !std::is_abstract_v && !std::is_default_constructible_v; + +template +class Mock_Impl_13_Internal; + +template +class Mock_Impl_13_Internal>> : public T { - public: - template - explicit Mock_Impl_13(Args&& ... args) - : T(std::forward(args) ... ) +public: + template + Mock_Impl_13_Internal(Args&&...) { - // TODO throw std::runtime_error("Implemenation for TLSv1.3 not ready yet. You are welcome to implement it."); + nyi(); } + }; -class Mock_Certificate_Impl_13 : public Certificate_Impl +template +class Mock_Impl_13_Internal>> : public T { - public: - template - explicit Mock_Certificate_Impl_13(Args&& ... args) - : Certificate_Impl(std::forward(args) ... ) +public: + template + Mock_Impl_13_Internal(Args&&... args) + : T(std::forward(args)...) { - // TODO throw std::runtime_error("Implemenation for TLSv1.3 not ready yet. You are welcome to implement it."); + nyi(); } - // from Certificate_Impl - std::vector serialize() const override { return {}; } - const std::vector& cert_chain() const override { return m_mock_cert_chain; } - std::size_t count() const override { return {}; } - bool empty() const override { return {}; } +}; + +} - private: - std::vector m_mock_cert_chain; +template +class Mock_Impl_13 : public Mock_Impl_13_Internal { + using Mock_Impl_13_Internal::Mock_Impl_13_Internal; }; -class Mock_Certificate_Req_Impl_13 : public Certificate_Req_Impl -{ - public: - template - explicit Mock_Certificate_Req_Impl_13(Args&& ... args) - : Certificate_Req_Impl(std::forward(args) ... ) - { - // throw std::runtime_error("Implemenation for TLSv1.3 not ready yet. You are welcome to implement it."); - } +template<> +class Mock_Impl_13 : public Mock_Impl_13_Internal { +public: + using Mock_Impl_13_Internal::Mock_Impl_13_Internal; + + const std::vector& cert_chain() const override { return nyi&>(); } + size_t count() const override { return nyi(); } + bool empty() const override { return nyi(); } + std::vector serialize() const override { return nyi>(); } +}; - // from Certificate_Req_Impl - std::vector serialize() const override { return {}; } - const std::vector& acceptable_cert_types() const override { return m_acceptable_cert_types; } - const std::vector& acceptable_CAs() const override { return m_mock_acceptable_CAs; } - const std::vector& signature_schemes() const override { return m_mock_signature_schemes; } +template<> +class Mock_Impl_13 : public Mock_Impl_13_Internal { +public: + using Mock_Impl_13_Internal::Mock_Impl_13_Internal; - private: - std::vector m_acceptable_cert_types; - std::vector m_mock_acceptable_CAs; - std::vector m_mock_signature_schemes; + const std::vector& acceptable_cert_types() const override { return nyi&>(); } + const std::vector& acceptable_CAs() const override { return nyi&>(); } + const std::vector& signature_schemes() const override { return nyi&>(); } + std::vector serialize() const override { return nyi>(); } }; }