Index: third_party/tlslite/tlslite/messages.py |
diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py |
index 5762ac64ca414b51e69be1b9bd7fc102f3b63646..1ce9320e13b211bf5178bc245ada4e6e1c967d0b 100644 |
--- a/third_party/tlslite/tlslite/messages.py |
+++ b/third_party/tlslite/tlslite/messages.py |
@@ -18,6 +18,27 @@ from .x509 import X509 |
from .x509certchain import X509CertChain |
from .utils.tackwrapper import * |
+def parse_next_protos(b): |
+ protos = [] |
+ while True: |
+ if len(b) == 0: |
+ break |
+ l = b[0] |
+ b = b[1:] |
+ if len(b) < l: |
+ raise BadNextProtos(len(b)) |
+ protos.append(b[:l]) |
+ b = b[l:] |
+ return protos |
+ |
+def next_protos_encoded(protocol_list): |
+ b = bytearray() |
+ for e in protocol_list: |
+ if len(e) > 255 or len(e) == 0: |
+ raise BadNextProtos(len(e)) |
+ b += bytearray( [len(e)] ) + bytearray(e) |
+ return b |
+ |
class RecordHeader3(object): |
def __init__(self): |
self.type = 0 |
@@ -111,6 +132,7 @@ class ClientHello(HandshakeMsg): |
self.compression_methods = [] # a list of 8-bit values |
self.srp_username = None # a string |
self.tack = False |
+ self.alpn_protos_advertised = None |
self.supports_npn = False |
self.server_name = bytearray(0) |
self.channel_id = False |
@@ -121,7 +143,8 @@ class ClientHello(HandshakeMsg): |
def create(self, version, random, session_id, cipher_suites, |
certificate_types=None, srpUsername=None, |
- tack=False, supports_npn=False, serverName=None): |
+ tack=False, alpn_protos_advertised=None, |
+ supports_npn=False, serverName=None): |
self.client_version = version |
self.random = random |
self.session_id = session_id |
@@ -131,6 +154,7 @@ class ClientHello(HandshakeMsg): |
if srpUsername: |
self.srp_username = bytearray(srpUsername, "utf-8") |
self.tack = tack |
+ self.alpn_protos_advertised = alpn_protos_advertised |
self.supports_npn = supports_npn |
if serverName: |
self.server_name = bytearray(serverName, "utf-8") |
@@ -171,6 +195,11 @@ class ClientHello(HandshakeMsg): |
self.certificate_types = p.getVarList(1, 1) |
elif extType == ExtensionType.tack: |
self.tack = True |
+ elif extType == ExtensionType.alpn: |
+ structLength = p.get(2) |
+ if structLength + 2 != extLength: |
+ raise SyntaxError() |
+ self.alpn_protos_advertised = parse_next_protos(p.getFixBytes(structLength)) |
elif extType == ExtensionType.supports_npn: |
self.supports_npn = True |
elif extType == ExtensionType.server_name: |
@@ -243,6 +272,12 @@ class ClientHello(HandshakeMsg): |
w2.add(ExtensionType.srp, 2) |
w2.add(len(self.srp_username)+1, 2) |
w2.addVarSeq(self.srp_username, 1, 1) |
+ if self.alpn_protos_advertised is not None: |
+ encoded_alpn_protos_advertised = next_protos_encoded(self.alpn_protos_advertised) |
+ w2.add(ExtensionType.alpn, 2) |
+ w2.add(len(encoded_alpn_protos_advertised) + 2, 2) |
+ w2.add(len(encoded_alpn_protos_advertised), 2) |
+ w2.addFixSeq(encoded_alpn_protos_advertised, 1) |
if self.supports_npn: |
w2.add(ExtensionType.supports_npn, 2) |
w2.add(0, 2) |
@@ -267,6 +302,13 @@ class BadNextProtos(Exception): |
def __str__(self): |
return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length |
+class InvalidALPNResponse(Exception): |
+ def __init__(self, l): |
+ self.length = l |
+ |
+ def __str__(self): |
+ return 'ALPN server response protocol list has invalid length %d. It must be of length one.' % self.length |
+ |
class ServerHello(HandshakeMsg): |
def __init__(self): |
HandshakeMsg.__init__(self, HandshakeType.server_hello) |
@@ -277,6 +319,7 @@ class ServerHello(HandshakeMsg): |
self.certificate_type = CertificateType.x509 |
self.compression_method = 0 |
self.tackExt = None |
+ self.alpn_proto_selected = None |
self.next_protos_advertised = None |
self.next_protos = None |
self.channel_id = False |
@@ -286,7 +329,8 @@ class ServerHello(HandshakeMsg): |
self.status_request = False |
def create(self, version, random, session_id, cipher_suite, |
- certificate_type, tackExt, next_protos_advertised): |
+ certificate_type, tackExt, alpn_proto_selected, |
+ next_protos_advertised): |
self.server_version = version |
self.random = random |
self.session_id = session_id |
@@ -294,6 +338,7 @@ class ServerHello(HandshakeMsg): |
self.certificate_type = certificate_type |
self.compression_method = 0 |
self.tackExt = tackExt |
+ self.alpn_proto_selected = alpn_proto_selected |
self.next_protos_advertised = next_protos_advertised |
return self |
@@ -316,35 +361,22 @@ class ServerHello(HandshakeMsg): |
self.certificate_type = p.get(1) |
elif extType == ExtensionType.tack and tackpyLoaded: |
self.tackExt = TackExtension(p.getFixBytes(extLength)) |
+ elif extType == ExtensionType.alpn: |
+ structLength = p.get(2) |
+ if structLength + 2 != extLength: |
+ raise SyntaxError() |
+ alpn_protos = parse_next_protos(p.getFixBytes(structLength)) |
+ if len(alpn_protos) != 1: |
+ raise InvalidALPNResponse(len(alpn_protos)); |
+ self.alpn_proto_selected = alpn_protos[0] |
elif extType == ExtensionType.supports_npn: |
- self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength)) |
+ self.next_protos = parse_next_protos(p.getFixBytes(extLength)) |
else: |
p.getFixBytes(extLength) |
soFar += 4 + extLength |
p.stopLengthCheck() |
return self |
- def __parse_next_protos(self, b): |
- protos = [] |
- while True: |
- if len(b) == 0: |
- break |
- l = b[0] |
- b = b[1:] |
- if len(b) < l: |
- raise BadNextProtos(len(b)) |
- protos.append(b[:l]) |
- b = b[l:] |
- return protos |
- |
- def __next_protos_encoded(self): |
- b = bytearray() |
- for e in self.next_protos_advertised: |
- if len(e) > 255 or len(e) == 0: |
- raise BadNextProtos(len(e)) |
- b += bytearray( [len(e)] ) + bytearray(e) |
- return b |
- |
def write(self): |
w = Writer() |
w.add(self.server_version[0], 1) |
@@ -365,8 +397,15 @@ class ServerHello(HandshakeMsg): |
w2.add(ExtensionType.tack, 2) |
w2.add(len(b), 2) |
w2.bytes += b |
+ if self.alpn_proto_selected is not None: |
+ alpn_protos_single_element_list = [self.alpn_proto_selected] |
+ encoded_alpn_protos_advertised = next_protos_encoded(alpn_protos_single_element_list) |
+ w2.add(ExtensionType.alpn, 2) |
+ w2.add(len(encoded_alpn_protos_advertised) + 2, 2) |
+ w2.add(len(encoded_alpn_protos_advertised), 2) |
+ w2.addFixSeq(encoded_alpn_protos_advertised, 1) |
if self.next_protos_advertised is not None: |
- encoded_next_protos_advertised = self.__next_protos_encoded() |
+ encoded_next_protos_advertised = next_protos_encoded(self.next_protos_advertised) |
w2.add(ExtensionType.supports_npn, 2) |
w2.add(len(encoded_next_protos_advertised), 2) |
w2.addFixSeq(encoded_next_protos_advertised, 1) |