Skip to content

Commit

Permalink
Refactor: simpler interface for MessageFactory and better 1.3 message…
Browse files Browse the repository at this point in the history
… mock
  • Loading branch information
reneme committed Oct 11, 2021
1 parent e6bbf61 commit 775610b
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 145 deletions.
8 changes: 2 additions & 6 deletions src/lib/tls/msg_cert_req.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,15 @@ Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version,
Handshake_Hash& hash,
const Policy& policy,
const std::vector<X509_DN>& ca_certs) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V12>(io, hash, policy, ca_certs))
m_impl(MessageFactory::create<Certificate_Req_Impl>(protocol_version, io, hash, policy, ca_certs))
{
}

/**
* Deserialize a Certificate Request message
*/
Certificate_Req::Certificate_Req(const Protocol_Version& protocol_version, const std::vector<uint8_t>& buf) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Req_Impl, Protocol_Version::TLS_V12>(buf))
m_impl(MessageFactory::create<Certificate_Req_Impl>(protocol_version, buf))
{
}

Expand Down
8 changes: 2 additions & 6 deletions src/lib/tls/msg_cert_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,15 @@ Certificate_Verify::Certificate_Verify(Handshake_IO& io,
const Policy& policy,
RandomNumberGenerator& rng,
const Private_Key* priv_key) :
m_impl( state.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V13>(io, state, policy, rng, priv_key)
: TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V12>(io, state, policy, rng, priv_key))
m_impl(MessageFactory::create<Certificate_Verify_Impl>(state.version(), io, state, policy, rng, priv_key))
{
}

/*
* Deserialize a Certificate Verify message
*/
Certificate_Verify::Certificate_Verify(const Protocol_Version& protocol_version, const std::vector<uint8_t>& buf) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Certificate_Verify_Impl, Protocol_Version::TLS_V12>(buf))
m_impl(MessageFactory::create<Certificate_Verify_Impl>(protocol_version, buf))
{
}

Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_cert_verify_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ Certificate_Verify_Impl::Certificate_Verify_Impl(const std::vector<uint8_t>& buf
reader.assert_done();
}

Certificate_Verify_Impl::~Certificate_Verify_Impl() = default;

/*
* Serialize a Certificate Verify message
*/
Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_cert_verify_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class Certificate_Verify_Impl : public Handshake_Message

explicit Certificate_Verify_Impl(const std::vector<uint8_t>& buf);

virtual ~Certificate_Verify_Impl() = 0;

std::vector<uint8_t> serialize() const override;
private:
std::vector<uint8_t> m_signature;
Expand Down
8 changes: 2 additions & 6 deletions src/lib/tls/msg_certificate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ Certificate::Certificate(const Protocol_Version& protocol_version,
Handshake_IO& io,
Handshake_Hash& hash,
const std::vector<X509_Certificate>& cert_list) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V12>(io, hash, cert_list))
m_impl(MessageFactory::create<Certificate_Impl>(protocol_version, io, hash, cert_list))
{
}

Expand All @@ -56,9 +54,7 @@ Certificate::Certificate(const Protocol_Version& protocol_version,
*/
Certificate::Certificate(const Protocol_Version& protocol_version,
const std::vector<uint8_t>& buf, const Policy& policy) :
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V13>()
: TLS_Message_Factory::create<Certificate_Impl, Protocol_Version::TLS_V12>(buf, policy))
m_impl(MessageFactory::create<Certificate_Impl>(protocol_version, buf, policy))
{
}

Expand Down
17 changes: 8 additions & 9 deletions src/lib/tls/msg_client_hello.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ Client_Hello::Client_Hello(Handshake_IO& io,
const std::vector<uint8_t>& reneg_info,
const Client_Hello::Settings& client_settings,
const std::vector<std::string>& next_protocols) :
m_impl(client_settings.protocol_version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V13>(io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols)
: TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V12>(io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols))
m_impl(MessageFactory::create<Client_Hello_Impl>(client_settings.protocol_version(), io, hash, policy, cb, rng, reneg_info, client_settings, next_protocols))
{
}

Expand All @@ -82,9 +80,7 @@ Client_Hello::Client_Hello(Handshake_IO& io,
const std::vector<uint8_t>& reneg_info,
const Session& session,
const std::vector<std::string>& next_protocols) :
m_impl(session.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V13>(io, hash, policy, cb, rng, reneg_info, session, next_protocols)
: TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V12>(io, hash, policy, cb, rng, reneg_info, session, next_protocols))
m_impl(MessageFactory::create<Client_Hello_Impl>(session.version(), io, hash, policy, cb, rng, reneg_info, session, next_protocols))
{
}

Expand All @@ -95,9 +91,12 @@ Client_Hello::Client_Hello(const std::vector<uint8_t>& buf)
{
auto supported_versions = Client_Hello_Impl(buf).supported_versions();

m_impl = value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13))
? TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Client_Hello_Impl, Protocol_Version::TLS_V12>(buf);
const auto protocol_version =
value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13))
? Protocol_Version::TLS_V13
: Protocol_Version::TLS_V12;

m_impl = MessageFactory::create<Client_Hello_Impl>(protocol_version, buf);
}

Client_Hello::~Client_Hello() = default;
Expand Down
8 changes: 2 additions & 6 deletions src/lib/tls/msg_finished.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ namespace TLS {
Finished::Finished(Handshake_IO& io,
Handshake_State& state,
Connection_Side side) :
m_impl( state.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V13>(io, state, side)
: TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V12>(io, state, side))
m_impl(MessageFactory::create<Finished_Impl>(state.version(), io, state, side))
{
}

Expand All @@ -43,9 +41,7 @@ std::vector<uint8_t> Finished::serialize() const
* Deserialize a Finished message
*/
Finished::Finished(const Protocol_Version& protocol_version, const std::vector<uint8_t>& buf):
m_impl( protocol_version == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Finished_Impl, Protocol_Version::TLS_V12>(buf))
m_impl(MessageFactory::create<Finished_Impl>(protocol_version, buf))
{
}

Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_finished_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ std::vector<uint8_t> Finished_Impl::serialize() const
Finished_Impl::Finished_Impl(const std::vector<uint8_t>& buf) : m_verification_data(buf)
{}

Finished_Impl::~Finished_Impl() = default;

std::vector<uint8_t> Finished_Impl::verify_data() const
{
return m_verification_data;
Expand Down
2 changes: 0 additions & 2 deletions src/lib/tls/msg_finished_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class Finished_Impl : public Handshake_Message

explicit Finished_Impl(const std::vector<uint8_t>& buf);

virtual ~Finished_Impl() = 0;

std::vector<uint8_t> serialize() const override;
private:
std::vector<uint8_t> m_verification_data;
Expand Down
17 changes: 8 additions & 9 deletions src/lib/tls/msg_server_hello.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ Server_Hello::Server_Hello(Handshake_IO& io,
const Client_Hello& client_hello,
const Server_Hello::Settings& server_settings,
const std::string next_protocol) :
m_impl(server_settings.protocol_version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Server_Hello_Impl, Protocol_Version::TLS_V13>(io, hash, policy, cb, rng, reneg_info, client_hello, server_settings, next_protocol)
: TLS_Message_Factory::create<Server_Hello_Impl, Protocol_Version::TLS_V12>(io, hash, policy, cb, rng, reneg_info, client_hello, server_settings, next_protocol))
m_impl(MessageFactory::create<Server_Hello_Impl>(client_hello.version(), io, hash, policy, cb, rng, reneg_info, client_hello, server_settings, next_protocol))
{
}

Expand All @@ -51,9 +49,7 @@ Server_Hello::Server_Hello(Handshake_IO& io,
Session& resumed_session,
bool offer_session_ticket,
const std::string& next_protocol) :
m_impl(client_hello.version() == Protocol_Version::TLS_V13
? TLS_Message_Factory::create<Server_Hello_Impl, Protocol_Version::TLS_V13>(io, hash, policy, cb, rng, reneg_info, client_hello, resumed_session, offer_session_ticket, next_protocol)
: TLS_Message_Factory::create<Server_Hello_Impl, Protocol_Version::TLS_V12>(io, hash, policy, cb, rng, reneg_info, client_hello, resumed_session, offer_session_ticket, next_protocol))
m_impl(MessageFactory::create<Server_Hello_Impl>(client_hello.version(), io, hash, policy, cb, rng, reneg_info, client_hello, resumed_session, offer_session_ticket, next_protocol))
{
}

Expand All @@ -64,9 +60,12 @@ Server_Hello::Server_Hello(const std::vector<uint8_t>& buf)
{
auto supported_versions = Server_Hello_Impl(buf).supported_versions();

m_impl = value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13))
? TLS_Message_Factory::create<Server_Hello_Impl, Protocol_Version::TLS_V13>(buf)
: TLS_Message_Factory::create<Server_Hello_Impl, Protocol_Version::TLS_V12>(buf);
const auto protocol_version =
value_exists(supported_versions, Protocol_Version(Protocol_Version::TLS_V13))
? Protocol_Version::TLS_V13
: Protocol_Version::TLS_V12;

m_impl = MessageFactory::create<Server_Hello_Impl>(protocol_version, buf);
}

Server_Hello::~Server_Hello() = default;
Expand Down
93 changes: 37 additions & 56 deletions src/lib/tls/tls_message_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,93 +36,74 @@ class Certificate_Verify_Impl_12;
class Certificate_Impl_12;
class Finished_Impl_12;

class TLS_Message_Factory
{
public:
template<typename Message_Base_Type, Protocol_Version::Version_Code Version>
struct Impl_Version_Trait{};

template <typename Message_Base_Type, Protocol_Version::Version_Code Version, typename ... Args>
static std::unique_ptr<Message_Base_Type> create(Args&& ... args)
{
return std::make_unique<typename Impl_Version_Trait<Message_Base_Type, Version>::Ver_Impl>(std::forward<Args>(args) ... );
}
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Server_Hello_Impl, Protocol_Version::TLS_V12>
{
using Ver_Impl = Server_Hello_Impl_12;
};
namespace {

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Server_Hello_Impl, Protocol_Version::TLS_V13>
{
using Ver_Impl = Server_Hello_Impl_12; // TODO using Ver_Impl = Server_Hello_Impl_13
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Client_Hello_Impl, Protocol_Version::TLS_V12>
{
using Ver_Impl = Client_Hello_Impl_12;
};
template<typename Message_Base_Type>
struct implementation_trait{};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Client_Hello_Impl, Protocol_Version::TLS_V13>
struct implementation_trait<Server_Hello_Impl>
{
using Ver_Impl = Mock_Impl_13<Client_Hello_Impl>; // TODO using Ver_Impl = Client_Hello_Impl_13
using v12 = Server_Hello_Impl_12;
using v13 = Mock_Impl_13<Server_Hello_Impl>;
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Certificate_Req_Impl, Protocol_Version::TLS_V12>
struct implementation_trait<Client_Hello_Impl>
{
using Ver_Impl = Certificate_Req_Impl_12;
using v12 = Client_Hello_Impl_12;
using v13 = Mock_Impl_13<Client_Hello_Impl>;
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Certificate_Req_Impl, Protocol_Version::TLS_V13>
struct implementation_trait<Certificate_Req_Impl>
{
using Ver_Impl = Mock_Certificate_Req_Impl_13; // TODO using Ver_Impl = Certificate_Req_Impl_13
using v12 = Certificate_Req_Impl_12;
using v13 = Mock_Impl_13<Certificate_Req_Impl>;
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Certificate_Verify_Impl, Protocol_Version::TLS_V12>
struct implementation_trait<Certificate_Verify_Impl>
{
using Ver_Impl = Certificate_Verify_Impl_12;
using v12 = Certificate_Verify_Impl_12;
using v13 = Mock_Impl_13<Certificate_Verify_Impl>;
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Certificate_Verify_Impl, Protocol_Version::TLS_V13>
struct implementation_trait<Certificate_Impl>
{
using Ver_Impl = Mock_Impl_13<Certificate_Verify_Impl>; // TODO using Ver_Impl = Certificate_Verify_Impl_13
using v12 = Certificate_Impl_12;
using v13 = Mock_Impl_13<Certificate_Impl>;
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Certificate_Impl, Protocol_Version::TLS_V12>
struct implementation_trait<Finished_Impl>
{
using Ver_Impl = Certificate_Impl_12;
using v12 = Finished_Impl_12;
using v13 = Mock_Impl_13<Finished_Impl>;
};

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Certificate_Impl, Protocol_Version::TLS_V13>
{
using Ver_Impl = Mock_Certificate_Impl_13; // TODO using Ver_Impl = Certificate_Impl_13
};
}

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Finished_Impl, Protocol_Version::TLS_V12>
{
using Ver_Impl = Finished_Impl_12;
};
namespace MessageFactory {

template<>
struct TLS_Message_Factory::Impl_Version_Trait<Finished_Impl, Protocol_Version::TLS_V13>
template <typename MessageBaseT, typename... ParamTs>
std::unique_ptr<MessageBaseT> create(const Protocol_Version &protocol_version, ParamTs&&... parameters)
{
using Ver_Impl = Mock_Impl_13<Finished_Impl>; // TODO using Ver_Impl = Finished_Impl_13
};
using impl_t = implementation_trait<MessageBaseT>;

if (protocol_version == Protocol_Version::TLS_V13)
{
return std::make_unique<typename impl_t::v13>(std::forward<ParamTs>(parameters)...);
}
else
{
return std::make_unique<typename impl_t::v12>(std::forward<ParamTs>(parameters)...);
}
}

}
}

}

#endif
Loading

0 comments on commit 775610b

Please sign in to comment.