Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(171)

Unified Diff: third_party/tlslite/tlslite/messages.py

Issue 2205433002: Implement ALPN in tlslite. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Rebase. Created 4 years, 4 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: third_party/tlslite/tlslite/messages.py
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
index 5762ac64ca414b51e69be1b9bd7fc102f3b63646..3d684d43a7f24f725cffcd56d94b785b625364bf 100644
--- a/third_party/tlslite/tlslite/messages.py
+++ b/third_party/tlslite/tlslite/messages.py
@@ -99,6 +99,27 @@ class HandshakeMsg(object):
headerWriter.add(len(w.bytes), 3)
return headerWriter.bytes + w.bytes
+ def parse_next_protos(self, b):
+ protos = []
+ while True:
+ if len(b) == 0:
+ break
+ l = b[0]
+ b = b[1:]
+ if len(b) < l:
+ raise BadNextProtos(len(b))
+ protos.append(b[:l])
+ b = b[l:]
+ return protos
+
+ def next_protos_encoded(self, protocol_list):
+ b = bytearray()
+ for e in protocol_list:
+ if len(e) > 255 or len(e) == 0:
+ raise BadNextProtos(len(e))
+ b += bytearray( [len(e)] ) + bytearray(e)
+ return b
+
class ClientHello(HandshakeMsg):
def __init__(self, ssl2=False):
HandshakeMsg.__init__(self, HandshakeType.client_hello)
@@ -111,6 +132,7 @@ class ClientHello(HandshakeMsg):
self.compression_methods = [] # a list of 8-bit values
self.srp_username = None # a string
self.tack = False
+ self.alpn_protos_advertised = None
self.supports_npn = False
self.server_name = bytearray(0)
self.channel_id = False
@@ -120,8 +142,8 @@ class ClientHello(HandshakeMsg):
self.status_request = False
def create(self, version, random, session_id, cipher_suites,
- certificate_types=None, srpUsername=None,
- tack=False, supports_npn=False, serverName=None):
+ certificate_types=None, srpUsername=None, tack=False,
+ alpn_protos_advertised=None, supports_npn=False, serverName=None):
self.client_version = version
self.random = random
self.session_id = session_id
@@ -131,6 +153,7 @@ class ClientHello(HandshakeMsg):
if srpUsername:
self.srp_username = bytearray(srpUsername, "utf-8")
self.tack = tack
+ self.alpn_protos_advertised = alpn_protos_advertised
self.supports_npn = supports_npn
if serverName:
self.server_name = bytearray(serverName, "utf-8")
@@ -171,6 +194,11 @@ class ClientHello(HandshakeMsg):
self.certificate_types = p.getVarList(1, 1)
elif extType == ExtensionType.tack:
self.tack = True
+ elif extType == ExtensionType.alpn:
+ structLength = p.get(2)
+ if (structLength + 2 != extLength):
+ raise SyntaxError()
+ self.alpn_protos_advertised = self.parse_next_protos(p.getFixBytes(structLength))
elif extType == ExtensionType.supports_npn:
self.supports_npn = True
elif extType == ExtensionType.server_name:
@@ -243,6 +271,12 @@ class ClientHello(HandshakeMsg):
w2.add(ExtensionType.srp, 2)
w2.add(len(self.srp_username)+1, 2)
w2.addVarSeq(self.srp_username, 1, 1)
+ if self.alpn_protos_advertised is not None:
+ encoded_alpn_protos_advertised = self.next_protos_encoded(self.alpn_protos_advertised)
+ w2.add(ExtensionType.alpn, 2)
+ w2.add(len(encoded_alpn_protos_advertised) + 2, 2)
+ w2.add(len(encoded_alpn_protos_advertised), 2)
+ w2.addFixSeq(encoded_alpn_protos_advertised, 1)
if self.supports_npn:
w2.add(ExtensionType.supports_npn, 2)
w2.add(0, 2)
@@ -267,6 +301,13 @@ class BadNextProtos(Exception):
def __str__(self):
return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
+class InvalidAlpnResponse(Exception):
+ def __init__(self, l):
+ self.length = l
+
+ def __str__(self):
+ return 'ALPN server response protocol list has invalid length %d. It must be of length one.' % self.length
+
class ServerHello(HandshakeMsg):
def __init__(self):
HandshakeMsg.__init__(self, HandshakeType.server_hello)
@@ -277,6 +318,7 @@ class ServerHello(HandshakeMsg):
self.certificate_type = CertificateType.x509
self.compression_method = 0
self.tackExt = None
+ self.alpn_proto_selected = None
self.next_protos_advertised = None
self.next_protos = None
self.channel_id = False
@@ -286,7 +328,8 @@ class ServerHello(HandshakeMsg):
self.status_request = False
def create(self, version, random, session_id, cipher_suite,
- certificate_type, tackExt, next_protos_advertised):
+ certificate_type, tackExt, alpn_proto_selected,
+ next_protos_advertised):
self.server_version = version
self.random = random
self.session_id = session_id
@@ -294,6 +337,7 @@ class ServerHello(HandshakeMsg):
self.certificate_type = certificate_type
self.compression_method = 0
self.tackExt = tackExt
+ self.alpn_proto_selected = alpn_proto_selected
self.next_protos_advertised = next_protos_advertised
return self
@@ -316,35 +360,22 @@ class ServerHello(HandshakeMsg):
self.certificate_type = p.get(1)
elif extType == ExtensionType.tack and tackpyLoaded:
self.tackExt = TackExtension(p.getFixBytes(extLength))
+ elif extType == ExtensionType.alpn:
+ structLength = p.get(2)
+ if (structLength + 2 != extLength):
davidben 2016/08/03 23:34:22 Nit: No parens in Python
Bence 2016/08/04 18:41:44 Done.
+ raise SyntaxError()
+ alpn_protos = self.parse_next_protos(p.getFixBytes(structLength))
+ if (alpn_protos.len() != 1):
davidben 2016/08/03 23:34:22 Ditto.
Bence 2016/08/04 18:41:44 Done.
+ raise InvalidAlpnResponse(alpn_protos.len());
+ self.alpn_proto_selected = alpn_protos[0]
elif extType == ExtensionType.supports_npn:
- self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
+ self.next_protos = self.parse_next_protos(p.getFixBytes(extLength))
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
p.stopLengthCheck()
return self
- def __parse_next_protos(self, b):
- protos = []
- while True:
- if len(b) == 0:
- break
- l = b[0]
- b = b[1:]
- if len(b) < l:
- raise BadNextProtos(len(b))
- protos.append(b[:l])
- b = b[l:]
- return protos
-
- def __next_protos_encoded(self):
- b = bytearray()
- for e in self.next_protos_advertised:
- if len(e) > 255 or len(e) == 0:
- raise BadNextProtos(len(e))
- b += bytearray( [len(e)] ) + bytearray(e)
- return b
-
def write(self):
w = Writer()
w.add(self.server_version[0], 1)
@@ -365,8 +396,16 @@ class ServerHello(HandshakeMsg):
w2.add(ExtensionType.tack, 2)
w2.add(len(b), 2)
w2.bytes += b
- if self.next_protos_advertised is not None:
- encoded_next_protos_advertised = self.__next_protos_encoded()
+ if self.alpn_proto_selected is not None:
+ alpn_protos_single_element_list = [self.alpn_proto_selected]
+ encoded_alpn_protos_advertised = self.next_protos_encoded(alpn_protos_single_element_list)
+ w2.add(ExtensionType.alpn, 2)
+ w2.add(len(encoded_alpn_protos_advertised) + 2, 2)
+ w2.add(len(encoded_alpn_protos_advertised), 2)
+ w2.addFixSeq(encoded_alpn_protos_advertised, 1)
+ # Do not use NPN if ALPN is used.
+ elif self.next_protos_advertised is not None:
+ encoded_next_protos_advertised = self.next_protos_encoded(self.next_protos_advertised)
w2.add(ExtensionType.supports_npn, 2)
w2.add(len(encoded_next_protos_advertised), 2)
w2.addFixSeq(encoded_next_protos_advertised, 1)

Powered by Google App Engine
This is Rietveld 408576698