| Index: third_party/tlslite/tlslite/messages.py
|
| diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
|
| index 5762ac64ca414b51e69be1b9bd7fc102f3b63646..1ce9320e13b211bf5178bc245ada4e6e1c967d0b 100644
|
| --- a/third_party/tlslite/tlslite/messages.py
|
| +++ b/third_party/tlslite/tlslite/messages.py
|
| @@ -18,6 +18,27 @@ from .x509 import X509
|
| from .x509certchain import X509CertChain
|
| from .utils.tackwrapper import *
|
|
|
| +def parse_next_protos(b):
|
| + protos = []
|
| + while True:
|
| + if len(b) == 0:
|
| + break
|
| + l = b[0]
|
| + b = b[1:]
|
| + if len(b) < l:
|
| + raise BadNextProtos(len(b))
|
| + protos.append(b[:l])
|
| + b = b[l:]
|
| + return protos
|
| +
|
| +def next_protos_encoded(protocol_list):
|
| + b = bytearray()
|
| + for e in protocol_list:
|
| + if len(e) > 255 or len(e) == 0:
|
| + raise BadNextProtos(len(e))
|
| + b += bytearray( [len(e)] ) + bytearray(e)
|
| + return b
|
| +
|
| class RecordHeader3(object):
|
| def __init__(self):
|
| self.type = 0
|
| @@ -111,6 +132,7 @@ class ClientHello(HandshakeMsg):
|
| self.compression_methods = [] # a list of 8-bit values
|
| self.srp_username = None # a string
|
| self.tack = False
|
| + self.alpn_protos_advertised = None
|
| self.supports_npn = False
|
| self.server_name = bytearray(0)
|
| self.channel_id = False
|
| @@ -121,7 +143,8 @@ class ClientHello(HandshakeMsg):
|
|
|
| def create(self, version, random, session_id, cipher_suites,
|
| certificate_types=None, srpUsername=None,
|
| - tack=False, supports_npn=False, serverName=None):
|
| + tack=False, alpn_protos_advertised=None,
|
| + supports_npn=False, serverName=None):
|
| self.client_version = version
|
| self.random = random
|
| self.session_id = session_id
|
| @@ -131,6 +154,7 @@ class ClientHello(HandshakeMsg):
|
| if srpUsername:
|
| self.srp_username = bytearray(srpUsername, "utf-8")
|
| self.tack = tack
|
| + self.alpn_protos_advertised = alpn_protos_advertised
|
| self.supports_npn = supports_npn
|
| if serverName:
|
| self.server_name = bytearray(serverName, "utf-8")
|
| @@ -171,6 +195,11 @@ class ClientHello(HandshakeMsg):
|
| self.certificate_types = p.getVarList(1, 1)
|
| elif extType == ExtensionType.tack:
|
| self.tack = True
|
| + elif extType == ExtensionType.alpn:
|
| + structLength = p.get(2)
|
| + if structLength + 2 != extLength:
|
| + raise SyntaxError()
|
| + self.alpn_protos_advertised = parse_next_protos(p.getFixBytes(structLength))
|
| elif extType == ExtensionType.supports_npn:
|
| self.supports_npn = True
|
| elif extType == ExtensionType.server_name:
|
| @@ -243,6 +272,12 @@ class ClientHello(HandshakeMsg):
|
| w2.add(ExtensionType.srp, 2)
|
| w2.add(len(self.srp_username)+1, 2)
|
| w2.addVarSeq(self.srp_username, 1, 1)
|
| + if self.alpn_protos_advertised is not None:
|
| + encoded_alpn_protos_advertised = next_protos_encoded(self.alpn_protos_advertised)
|
| + w2.add(ExtensionType.alpn, 2)
|
| + w2.add(len(encoded_alpn_protos_advertised) + 2, 2)
|
| + w2.add(len(encoded_alpn_protos_advertised), 2)
|
| + w2.addFixSeq(encoded_alpn_protos_advertised, 1)
|
| if self.supports_npn:
|
| w2.add(ExtensionType.supports_npn, 2)
|
| w2.add(0, 2)
|
| @@ -267,6 +302,13 @@ class BadNextProtos(Exception):
|
| def __str__(self):
|
| return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
|
|
|
| +class InvalidALPNResponse(Exception):
|
| + def __init__(self, l):
|
| + self.length = l
|
| +
|
| + def __str__(self):
|
| + return 'ALPN server response protocol list has invalid length %d. It must be of length one.' % self.length
|
| +
|
| class ServerHello(HandshakeMsg):
|
| def __init__(self):
|
| HandshakeMsg.__init__(self, HandshakeType.server_hello)
|
| @@ -277,6 +319,7 @@ class ServerHello(HandshakeMsg):
|
| self.certificate_type = CertificateType.x509
|
| self.compression_method = 0
|
| self.tackExt = None
|
| + self.alpn_proto_selected = None
|
| self.next_protos_advertised = None
|
| self.next_protos = None
|
| self.channel_id = False
|
| @@ -286,7 +329,8 @@ class ServerHello(HandshakeMsg):
|
| self.status_request = False
|
|
|
| def create(self, version, random, session_id, cipher_suite,
|
| - certificate_type, tackExt, next_protos_advertised):
|
| + certificate_type, tackExt, alpn_proto_selected,
|
| + next_protos_advertised):
|
| self.server_version = version
|
| self.random = random
|
| self.session_id = session_id
|
| @@ -294,6 +338,7 @@ class ServerHello(HandshakeMsg):
|
| self.certificate_type = certificate_type
|
| self.compression_method = 0
|
| self.tackExt = tackExt
|
| + self.alpn_proto_selected = alpn_proto_selected
|
| self.next_protos_advertised = next_protos_advertised
|
| return self
|
|
|
| @@ -316,35 +361,22 @@ class ServerHello(HandshakeMsg):
|
| self.certificate_type = p.get(1)
|
| elif extType == ExtensionType.tack and tackpyLoaded:
|
| self.tackExt = TackExtension(p.getFixBytes(extLength))
|
| + elif extType == ExtensionType.alpn:
|
| + structLength = p.get(2)
|
| + if structLength + 2 != extLength:
|
| + raise SyntaxError()
|
| + alpn_protos = parse_next_protos(p.getFixBytes(structLength))
|
| + if len(alpn_protos) != 1:
|
| + raise InvalidALPNResponse(len(alpn_protos));
|
| + self.alpn_proto_selected = alpn_protos[0]
|
| elif extType == ExtensionType.supports_npn:
|
| - self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
|
| + self.next_protos = parse_next_protos(p.getFixBytes(extLength))
|
| else:
|
| p.getFixBytes(extLength)
|
| soFar += 4 + extLength
|
| p.stopLengthCheck()
|
| return self
|
|
|
| - def __parse_next_protos(self, b):
|
| - protos = []
|
| - while True:
|
| - if len(b) == 0:
|
| - break
|
| - l = b[0]
|
| - b = b[1:]
|
| - if len(b) < l:
|
| - raise BadNextProtos(len(b))
|
| - protos.append(b[:l])
|
| - b = b[l:]
|
| - return protos
|
| -
|
| - def __next_protos_encoded(self):
|
| - b = bytearray()
|
| - for e in self.next_protos_advertised:
|
| - if len(e) > 255 or len(e) == 0:
|
| - raise BadNextProtos(len(e))
|
| - b += bytearray( [len(e)] ) + bytearray(e)
|
| - return b
|
| -
|
| def write(self):
|
| w = Writer()
|
| w.add(self.server_version[0], 1)
|
| @@ -365,8 +397,15 @@ class ServerHello(HandshakeMsg):
|
| w2.add(ExtensionType.tack, 2)
|
| w2.add(len(b), 2)
|
| w2.bytes += b
|
| + if self.alpn_proto_selected is not None:
|
| + alpn_protos_single_element_list = [self.alpn_proto_selected]
|
| + encoded_alpn_protos_advertised = next_protos_encoded(alpn_protos_single_element_list)
|
| + w2.add(ExtensionType.alpn, 2)
|
| + w2.add(len(encoded_alpn_protos_advertised) + 2, 2)
|
| + w2.add(len(encoded_alpn_protos_advertised), 2)
|
| + w2.addFixSeq(encoded_alpn_protos_advertised, 1)
|
| if self.next_protos_advertised is not None:
|
| - encoded_next_protos_advertised = self.__next_protos_encoded()
|
| + encoded_next_protos_advertised = next_protos_encoded(self.next_protos_advertised)
|
| w2.add(ExtensionType.supports_npn, 2)
|
| w2.add(len(encoded_next_protos_advertised), 2)
|
| w2.addFixSeq(encoded_next_protos_advertised, 1)
|
|
|