Skip to content

Commit

Permalink
Add external PSK + middlebox compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
rperez authored and gpotter2 committed Apr 19, 2020
1 parent 551ab1a commit 7839ca9
Show file tree
Hide file tree
Showing 10 changed files with 382 additions and 125 deletions.
9 changes: 7 additions & 2 deletions scapy/layers/tls/automaton.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,11 @@ def raise_on_packet(self, pkt_cls, state, get_next_msg=True):
self.buffer_in = self.buffer_in[1:]
raise state()

def add_record(self, is_sslv2=None, is_tls13=None):
def add_record(self, is_sslv2=None, is_tls13=None, is_tls12=None):
"""
Add a new TLS or SSLv2 or TLS 1.3 record to the packets buffered out.
"""
if is_sslv2 is None and is_tls13 is None:
if is_sslv2 is None and is_tls13 is None and is_tls12 is None:
v = (self.cur_session.tls_version or
self.cur_session.advertised_tls_version)
if v in [0x0200, 0x0002]:
Expand All @@ -215,6 +215,11 @@ def add_record(self, is_sslv2=None, is_tls13=None):
self.buffer_out.append(SSLv2(tls_session=self.cur_session))
elif is_tls13:
self.buffer_out.append(TLS13(tls_session=self.cur_session))
# For TLS 1.3 middlebox compatibility, TLS record version must
# be 0x0303
elif is_tls12:
self.buffer_out.append(TLS(version="TLS 1.2",
tls_session=self.cur_session))
else:
self.buffer_out.append(TLS(tls_session=self.cur_session))

Expand Down
123 changes: 96 additions & 27 deletions scapy/layers/tls/automaton_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import print_function
import socket
import binascii

from scapy.config import conf
from scapy.pton_ntop import inet_pton
Expand All @@ -29,7 +30,7 @@
from scapy.layers.tls.session import tlsSession
from scapy.layers.tls.extensions import TLS_Ext_SupportedGroups, \
TLS_Ext_SupportedVersion_CH, TLS_Ext_SignatureAlgorithms, \
TLS_Ext_SupportedVersion_SH
TLS_Ext_SupportedVersion_SH, TLS_Ext_PSKKeyExchangeModes
from scapy.layers.tls.handshake import TLSCertificate, TLSCertificateRequest, \
TLSCertificateVerify, TLSClientHello, TLSClientKeyExchange, \
TLSEncryptedExtensions, TLSFinished, TLSServerHello, TLSServerHelloDone, \
Expand All @@ -41,11 +42,13 @@
SSLv2ClientFinished, SSLv2ServerFinished, SSLv2ClientCertificate, \
SSLv2RequestCertificate
from scapy.layers.tls.keyexchange_tls13 import TLS_Ext_KeyShare_CH, \
KeyShareEntry, TLS_Ext_KeyShare_HRR
KeyShareEntry, TLS_Ext_KeyShare_HRR, PSKIdentity, PSKBinderEntry, \
TLS_Ext_PreSharedKey_CH
from scapy.layers.tls.record import TLSAlert, TLSChangeCipherSpec, \
TLSApplicationData
from scapy.layers.tls.crypto.suites import _tls_cipher_suites
from scapy.layers.tls.crypto.groups import _tls_named_groups
from scapy.layers.tls.crypto.hkdf import TLS13_HKDF
from scapy.modules import six
from scapy.packet import Raw
from scapy.compat import bytes_encode
Expand Down Expand Up @@ -73,6 +76,7 @@ class TLSClientAutomaton(_TLSAutomaton):
def parse_args(self, server="127.0.0.1", dport=4433, server_name=None,
mycert=None, mykey=None,
client_hello=None, version=None,
psk=None, psk_mode=None,
data=None,
ciphersuite=None,
curve=None,
Expand Down Expand Up @@ -138,6 +142,8 @@ def parse_args(self, server="127.0.0.1", dport=4433, server_name=None,
else:
# Or secp256r1 otherwise
self.curve = 23
self.tls13_psk_secret = psk
self.tls13_psk_mode = psk_mode
if curve is not None:
for (group_id, ng) in _tls_named_groups.items():
if ng == curve:
Expand Down Expand Up @@ -171,14 +177,19 @@ def INITIAL(self):
@ATMT.state()
def INIT_TLS_SESSION(self):
self.cur_session = tlsSession(connection_end="client")
self.cur_session.client_certs = self.mycert
self.cur_session.client_key = self.mykey
s = self.cur_session
s.client_certs = self.mycert
s.client_key = self.mykey
v = self.advertised_tls_version
if v:
self.cur_session.advertised_tls_version = v
s.advertised_tls_version = v
else:
default_version = self.cur_session.advertised_tls_version
default_version = s.advertised_tls_version
self.advertised_tls_version = default_version

if s.advertised_tls_version >= 0x0304:
if self.tls13_psk_secret:
s.tls13_psk_secret = binascii.unhexlify(self.tls13_psk_secret)
raise self.CONNECT()

@ATMT.state()
Expand Down Expand Up @@ -872,21 +883,47 @@ def tls13_should_add_ClientHello(self):
if conf.crypto_valid_advanced:
supported_groups.append("x25519")
self.add_record(is_tls13=False)
ext = [TLS_Ext_SupportedVersion_CH(versions=["TLS 1.3"]),
TLS_Ext_SupportedGroups(groups=supported_groups),
TLS_Ext_KeyShare_CH(client_shares=[KeyShareEntry(group=self.curve)]), # noqa: E501
TLS_Ext_SignatureAlgorithms(sig_algs=["sha256+rsaepss",
"sha256+rsa"])]
if self.client_hello:
if not self.client_hello.ext:
self.client_hello.ext = ext
p = self.client_hello
else:
if self.ciphersuite is None:
c = 0x1301
else:
c = self.ciphersuite
p = TLS13ClientHello(ciphers=c, ext=ext)
p = TLS13ClientHello(ciphers=c)

ext = []
ext += TLS_Ext_SupportedVersion_CH(versions=["TLS 1.3"])

if self.cur_session.tls13_psk_secret:
if self.tls13_psk_mode == "psk_dhe_ke":
ext += TLS_Ext_PSKKeyExchangeModes(kxmodes="psk_dhe_ke")
ext += TLS_Ext_SupportedGroups(groups=supported_groups)
ext += TLS_Ext_KeyShare_CH(
client_shares=[KeyShareEntry(group=self.curve)]
)
else:
ext += TLS_Ext_PSKKeyExchangeModes(kxmodes="psk_ke")
# RFC844, section 4.2.11.
# "The "pre_shared_key" extension MUST be the last extension
# in the ClientHello "
hkdf = TLS13_HKDF("sha256")
hash_len = hkdf.hash.digest_size
psk_id = PSKIdentity(identity='Client_identity')
# XXX see how to not pass binder as argument
psk_binder_entry = PSKBinderEntry(binder_len=hash_len,
binder=b"\x00" * hash_len)

ext += TLS_Ext_PreSharedKey_CH(identities=[psk_id],
binders=[psk_binder_entry])
else:
ext += TLS_Ext_SupportedGroups(groups=supported_groups)
ext += TLS_Ext_KeyShare_CH(
client_shares=[KeyShareEntry(group=self.curve)]
)
ext += TLS_Ext_SignatureAlgorithms(sig_algs=["sha256+rsaepss",
"sha256+rsa"])
p.ext = ext
self.add_msg(p)
raise self.TLS13_ADDED_CLIENTHELLO()

Expand Down Expand Up @@ -967,10 +1004,32 @@ def tls13_should_add_ClientHello_Retry(self):
selected_version = e.version
if not selected_group or not selected_version:
raise self.CLOSE_NOTIFY()
ext = [TLS_Ext_SupportedVersion_CH(versions=[_tls_version[selected_version]]), # noqa: E501
TLS_Ext_SupportedGroups(groups=[_tls_named_groups[selected_group]]), # noqa: E501
TLS_Ext_KeyShare_CH(client_shares=[KeyShareEntry(group=selected_group)]), # noqa: E501
TLS_Ext_SignatureAlgorithms(sig_algs=["sha256+rsaepss"])]

ext = []
ext += TLS_Ext_SupportedVersion_CH(versions=[_tls_version[selected_version]]) # noqa: E501

if s.tls13_psk_secret:
if self.tls13_psk_mode == "psk_dhe_ke":
ext += TLS_Ext_PSKKeyExchangeModes(kxmodes="psk_dhe_ke"),
ext += TLS_Ext_SupportedGroups(groups=[_tls_named_groups[selected_group]]) # noqa: E501
ext += TLS_Ext_KeyShare_CH(client_shares=[KeyShareEntry(group=selected_group)]) # noqa: E501
else:
ext += TLS_Ext_PSKKeyExchangeModes(kxmodes="psk_ke")

hkdf = TLS13_HKDF("sha256")
hash_len = hkdf.hash.digest_size
psk_id = PSKIdentity(identity='Client_identity')
psk_binder_entry = PSKBinderEntry(binder_len=hash_len,
binder=b"\x00" * hash_len)

ext += TLS_Ext_PreSharedKey_CH(identities=[psk_id],
binders=[psk_binder_entry])

else:
ext += TLS_Ext_SupportedGroups(groups=[_tls_named_groups[selected_group]]) # noqa: E501
ext += TLS_Ext_KeyShare_CH(client_shares=[KeyShareEntry(group=selected_group)]) # noqa: E501
ext += TLS_Ext_SignatureAlgorithms(sig_algs=["sha256+rsaepss"])

p = TLS13ClientHello(ciphers=ciphersuite, ext=ext)
self.add_msg(p)
raise self.TLS13_ADDED_CLIENTHELLO()
Expand All @@ -979,21 +1038,22 @@ def tls13_should_add_ClientHello_Retry(self):
def TLS13_HANDLED_SERVERHELLO(self):
pass

@ATMT.state()
def TLS13_WAITING_ENCRYPTEDEXTENSIONS(self):
self.get_next_msg()

@ATMT.condition(TLS13_WAITING_ENCRYPTEDEXTENSIONS)
def tls13_should_handle_EncryptedExtensions(self):
self.raise_on_packet(TLSEncryptedExtensions,
self.TLS13_WAITING_CERTIFICATE)

@ATMT.condition(TLS13_HANDLED_SERVERHELLO, prio=1)
def tls13_should_handle_encrytpedExtensions(self):
self.raise_on_packet(TLSEncryptedExtensions,
self.TLS13_HANDLED_ENCRYPTEDEXTENSIONS)

@ATMT.condition(TLS13_HANDLED_SERVERHELLO, prio=2)
def tls13_should_handle_ChangeCipherSpec(self):
self.raise_on_packet(TLSChangeCipherSpec,
self.TLS13_HANDLED_CHANGE_CIPHER_SPEC)

@ATMT.state()
def TLS13_HANDLED_CHANGE_CIPHER_SPEC(self):
self.cur_session.middlebox_compatibility = True
raise self.TLS13_HANDLED_SERVERHELLO()

@ATMT.condition(TLS13_HANDLED_SERVERHELLO, prio=3)
def tls13_missing_encryptedExtension(self):
self.vprint("Missing TLS 1.3 EncryptedExtensions message!")
raise self.CLOSE_NOTIFY()
Expand All @@ -1015,6 +1075,12 @@ def tls13_should_handle_certificateRequest_from_encryptedExtensions(self):
def tls13_should_handle_certificate_from_encryptedExtensions(self):
self.tls13_should_handle_Certificate()

@ATMT.condition(TLS13_HANDLED_ENCRYPTEDEXTENSIONS, prio=3)
def tls13_should_handle_finished_from_encryptedExtensions(self):
if self.cur_session.tls13_psk_secret:
self.raise_on_packet(TLSFinished,
self.TLS13_HANDLED_FINISHED)

@ATMT.state()
def TLS13_HANDLED_CERTIFICATEREQUEST(self):
pass
Expand Down Expand Up @@ -1056,6 +1122,9 @@ def TLS13_HANDLED_FINISHED(self):

@ATMT.state()
def TLS13_PREPARE_CLIENTFLIGHT2(self):
if self.cur_session.middlebox_compatibility:
self.add_record(is_tls12=True)
self.add_msg(TLSChangeCipherSpec())
self.add_record(is_tls13=True)

@ATMT.condition(TLS13_PREPARE_CLIENTFLIGHT2, prio=1)
Expand Down
Loading

0 comments on commit 7839ca9

Please sign in to comment.