OLD | NEW |
| 1 # Authors: |
| 2 # Trevor Perrin |
| 3 # Google - handling CertificateRequest.certificate_types |
| 4 # Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support |
| 5 # Dimitris Moraitis - Anon ciphersuites |
| 6 # |
| 7 # See the LICENSE file for legal information regarding use of this file. |
| 8 |
1 """Classes representing TLS messages.""" | 9 """Classes representing TLS messages.""" |
2 | 10 |
3 from utils.compat import * | 11 from .utils.compat import * |
4 from utils.cryptomath import * | 12 from .utils.cryptomath import * |
5 from errors import * | 13 from .errors import * |
6 from utils.codec import * | 14 from .utils.codec import * |
7 from constants import * | 15 from .constants import * |
8 from x509 import X509 | 16 from .x509 import X509 |
9 from x509certchain import X509CertChain | 17 from .x509certchain import X509CertChain |
| 18 from .utils.tackwrapper import * |
10 | 19 |
11 # The sha module is deprecated in Python 2.6 | 20 class RecordHeader3(object): |
12 try: | |
13 import sha | |
14 except ImportError: | |
15 from hashlib import sha1 as sha | |
16 | |
17 # The md5 module is deprecated in Python 2.6 | |
18 try: | |
19 import md5 | |
20 except ImportError: | |
21 from hashlib import md5 | |
22 | |
23 class RecordHeader3: | |
24 def __init__(self): | 21 def __init__(self): |
25 self.type = 0 | 22 self.type = 0 |
26 self.version = (0,0) | 23 self.version = (0,0) |
27 self.length = 0 | 24 self.length = 0 |
28 self.ssl2 = False | 25 self.ssl2 = False |
29 | 26 |
30 def create(self, version, type, length): | 27 def create(self, version, type, length): |
31 self.type = type | 28 self.type = type |
32 self.version = version | 29 self.version = version |
33 self.length = length | 30 self.length = length |
34 return self | 31 return self |
35 | 32 |
36 def write(self): | 33 def write(self): |
37 w = Writer(5) | 34 w = Writer() |
38 w.add(self.type, 1) | 35 w.add(self.type, 1) |
39 w.add(self.version[0], 1) | 36 w.add(self.version[0], 1) |
40 w.add(self.version[1], 1) | 37 w.add(self.version[1], 1) |
41 w.add(self.length, 2) | 38 w.add(self.length, 2) |
42 return w.bytes | 39 return w.bytes |
43 | 40 |
44 def parse(self, p): | 41 def parse(self, p): |
45 self.type = p.get(1) | 42 self.type = p.get(1) |
46 self.version = (p.get(1), p.get(1)) | 43 self.version = (p.get(1), p.get(1)) |
47 self.length = p.get(2) | 44 self.length = p.get(2) |
48 self.ssl2 = False | 45 self.ssl2 = False |
49 return self | 46 return self |
50 | 47 |
51 class RecordHeader2: | 48 class RecordHeader2(object): |
52 def __init__(self): | 49 def __init__(self): |
53 self.type = 0 | 50 self.type = 0 |
54 self.version = (0,0) | 51 self.version = (0,0) |
55 self.length = 0 | 52 self.length = 0 |
56 self.ssl2 = True | 53 self.ssl2 = True |
57 | 54 |
58 def parse(self, p): | 55 def parse(self, p): |
59 if p.get(1)!=128: | 56 if p.get(1)!=128: |
60 raise SyntaxError() | 57 raise SyntaxError() |
61 self.type = ContentType.handshake | 58 self.type = ContentType.handshake |
62 self.version = (2,0) | 59 self.version = (2,0) |
63 #We don't support 2-byte-length-headers; could be a problem | 60 #We don't support 2-byte-length-headers; could be a problem |
64 self.length = p.get(1) | 61 self.length = p.get(1) |
65 return self | 62 return self |
66 | 63 |
67 | 64 |
68 class Msg: | 65 class Alert(object): |
69 def preWrite(self, trial): | |
70 if trial: | |
71 w = Writer() | |
72 else: | |
73 length = self.write(True) | |
74 w = Writer(length) | |
75 return w | |
76 | |
77 def postWrite(self, w, trial): | |
78 if trial: | |
79 return w.index | |
80 else: | |
81 return w.bytes | |
82 | |
83 class Alert(Msg): | |
84 def __init__(self): | 66 def __init__(self): |
85 self.contentType = ContentType.alert | 67 self.contentType = ContentType.alert |
86 self.level = 0 | 68 self.level = 0 |
87 self.description = 0 | 69 self.description = 0 |
88 | 70 |
89 def create(self, description, level=AlertLevel.fatal): | 71 def create(self, description, level=AlertLevel.fatal): |
90 self.level = level | 72 self.level = level |
91 self.description = description | 73 self.description = description |
92 return self | 74 return self |
93 | 75 |
94 def parse(self, p): | 76 def parse(self, p): |
95 p.setLengthCheck(2) | 77 p.setLengthCheck(2) |
96 self.level = p.get(1) | 78 self.level = p.get(1) |
97 self.description = p.get(1) | 79 self.description = p.get(1) |
98 p.stopLengthCheck() | 80 p.stopLengthCheck() |
99 return self | 81 return self |
100 | 82 |
101 def write(self): | 83 def write(self): |
102 w = Writer(2) | 84 w = Writer() |
103 w.add(self.level, 1) | 85 w.add(self.level, 1) |
104 w.add(self.description, 1) | 86 w.add(self.description, 1) |
105 return w.bytes | 87 return w.bytes |
106 | 88 |
107 | 89 |
108 class HandshakeMsg(Msg): | 90 class HandshakeMsg(object): |
109 def preWrite(self, handshakeType, trial): | 91 def __init__(self, handshakeType): |
110 if trial: | 92 self.contentType = ContentType.handshake |
111 w = Writer() | 93 self.handshakeType = handshakeType |
112 w.add(handshakeType, 1) | 94 |
113 w.add(0, 3) | 95 def postWrite(self, w): |
114 else: | 96 headerWriter = Writer() |
115 length = self.write(True) | 97 headerWriter.add(self.handshakeType, 1) |
116 w = Writer(length) | 98 headerWriter.add(len(w.bytes), 3) |
117 w.add(handshakeType, 1) | 99 return headerWriter.bytes + w.bytes |
118 w.add(length-4, 3) | |
119 return w | |
120 | |
121 | 100 |
122 class ClientHello(HandshakeMsg): | 101 class ClientHello(HandshakeMsg): |
123 def __init__(self, ssl2=False): | 102 def __init__(self, ssl2=False): |
124 self.contentType = ContentType.handshake | 103 HandshakeMsg.__init__(self, HandshakeType.client_hello) |
125 self.ssl2 = ssl2 | 104 self.ssl2 = ssl2 |
126 self.client_version = (0,0) | 105 self.client_version = (0,0) |
127 self.random = createByteArrayZeros(32) | 106 self.random = bytearray(32) |
128 self.session_id = createByteArraySequence([]) | 107 self.session_id = bytearray(0) |
129 self.cipher_suites = [] # a list of 16-bit values | 108 self.cipher_suites = [] # a list of 16-bit values |
130 self.certificate_types = [CertificateType.x509] | 109 self.certificate_types = [CertificateType.x509] |
131 self.compression_methods = [] # a list of 8-bit values | 110 self.compression_methods = [] # a list of 8-bit values |
132 self.srp_username = None # a string | 111 self.srp_username = None # a string |
| 112 self.tack = False |
| 113 self.supports_npn = False |
| 114 self.server_name = bytearray(0) |
133 self.channel_id = False | 115 self.channel_id = False |
134 self.support_signed_cert_timestamps = False | 116 self.support_signed_cert_timestamps = False |
135 self.status_request = False | 117 self.status_request = False |
136 | 118 |
137 def create(self, version, random, session_id, cipher_suites, | 119 def create(self, version, random, session_id, cipher_suites, |
138 certificate_types=None, srp_username=None): | 120 certificate_types=None, srpUsername=None, |
| 121 tack=False, supports_npn=False, serverName=None): |
139 self.client_version = version | 122 self.client_version = version |
140 self.random = random | 123 self.random = random |
141 self.session_id = session_id | 124 self.session_id = session_id |
142 self.cipher_suites = cipher_suites | 125 self.cipher_suites = cipher_suites |
143 self.certificate_types = certificate_types | 126 self.certificate_types = certificate_types |
144 self.compression_methods = [0] | 127 self.compression_methods = [0] |
145 self.srp_username = srp_username | 128 if srpUsername: |
| 129 self.srp_username = bytearray(srpUsername, "utf-8") |
| 130 self.tack = tack |
| 131 self.supports_npn = supports_npn |
| 132 if serverName: |
| 133 self.server_name = bytearray(serverName, "utf-8") |
146 return self | 134 return self |
147 | 135 |
148 def parse(self, p): | 136 def parse(self, p): |
149 if self.ssl2: | 137 if self.ssl2: |
150 self.client_version = (p.get(1), p.get(1)) | 138 self.client_version = (p.get(1), p.get(1)) |
151 cipherSpecsLength = p.get(2) | 139 cipherSpecsLength = p.get(2) |
152 sessionIDLength = p.get(2) | 140 sessionIDLength = p.get(2) |
153 randomLength = p.get(2) | 141 randomLength = p.get(2) |
154 self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3)) | 142 self.cipher_suites = p.getFixList(3, cipherSpecsLength//3) |
155 self.session_id = p.getFixBytes(sessionIDLength) | 143 self.session_id = p.getFixBytes(sessionIDLength) |
156 self.random = p.getFixBytes(randomLength) | 144 self.random = p.getFixBytes(randomLength) |
157 if len(self.random) < 32: | 145 if len(self.random) < 32: |
158 zeroBytes = 32-len(self.random) | 146 zeroBytes = 32-len(self.random) |
159 self.random = createByteArrayZeros(zeroBytes) + self.random | 147 self.random = bytearray(zeroBytes) + self.random |
160 self.compression_methods = [0]#Fake this value | 148 self.compression_methods = [0]#Fake this value |
161 | 149 |
162 #We're not doing a stopLengthCheck() for SSLv2, oh well.. | 150 #We're not doing a stopLengthCheck() for SSLv2, oh well.. |
163 else: | 151 else: |
164 p.startLengthCheck(3) | 152 p.startLengthCheck(3) |
165 self.client_version = (p.get(1), p.get(1)) | 153 self.client_version = (p.get(1), p.get(1)) |
166 self.random = p.getFixBytes(32) | 154 self.random = p.getFixBytes(32) |
167 self.session_id = p.getVarBytes(1) | 155 self.session_id = p.getVarBytes(1) |
168 self.cipher_suites = p.getVarList(2, 2) | 156 self.cipher_suites = p.getVarList(2, 2) |
169 self.compression_methods = p.getVarList(1, 1) | 157 self.compression_methods = p.getVarList(1, 1) |
170 if not p.atLengthCheck(): | 158 if not p.atLengthCheck(): |
171 totalExtLength = p.get(2) | 159 totalExtLength = p.get(2) |
172 soFar = 0 | 160 soFar = 0 |
173 while soFar != totalExtLength: | 161 while soFar != totalExtLength: |
174 extType = p.get(2) | 162 extType = p.get(2) |
175 extLength = p.get(2) | 163 extLength = p.get(2) |
176 if extType == 6: | 164 index1 = p.index |
177 self.srp_username = bytesToString(p.getVarBytes(1)) | 165 if extType == ExtensionType.srp: |
178 elif extType == 7: | 166 self.srp_username = p.getVarBytes(1) |
| 167 elif extType == ExtensionType.cert_type: |
179 self.certificate_types = p.getVarList(1, 1) | 168 self.certificate_types = p.getVarList(1, 1) |
| 169 elif extType == ExtensionType.tack: |
| 170 self.tack = True |
| 171 elif extType == ExtensionType.supports_npn: |
| 172 self.supports_npn = True |
| 173 elif extType == ExtensionType.server_name: |
| 174 serverNameListBytes = p.getFixBytes(extLength) |
| 175 p2 = Parser(serverNameListBytes) |
| 176 p2.startLengthCheck(2) |
| 177 while 1: |
| 178 if p2.atLengthCheck(): |
| 179 break # no host_name, oh well |
| 180 name_type = p2.get(1) |
| 181 hostNameBytes = p2.getVarBytes(2) |
| 182 if name_type == NameType.host_name: |
| 183 self.server_name = hostNameBytes |
| 184 break |
180 elif extType == ExtensionType.channel_id: | 185 elif extType == ExtensionType.channel_id: |
181 self.channel_id = True | 186 self.channel_id = True |
182 elif extType == ExtensionType.signed_cert_timestamps: | 187 elif extType == ExtensionType.signed_cert_timestamps: |
183 if extLength: | 188 if extLength: |
184 raise SyntaxError() | 189 raise SyntaxError() |
185 self.support_signed_cert_timestamps = True | 190 self.support_signed_cert_timestamps = True |
186 elif extType == ExtensionType.status_request: | 191 elif extType == ExtensionType.status_request: |
187 # Extension contents are currently ignored. | 192 # Extension contents are currently ignored. |
188 # According to RFC 6066, this is not strictly forbidden | 193 # According to RFC 6066, this is not strictly forbidden |
189 # (although it is suboptimal): | 194 # (although it is suboptimal): |
190 # Servers that receive a client hello containing the | 195 # Servers that receive a client hello containing the |
191 # "status_request" extension MAY return a suitable | 196 # "status_request" extension MAY return a suitable |
192 # certificate status response to the client along with | 197 # certificate status response to the client along with |
193 # their certificate. If OCSP is requested, they | 198 # their certificate. If OCSP is requested, they |
194 # SHOULD use the information contained in the extension | 199 # SHOULD use the information contained in the extension |
195 # when selecting an OCSP responder and SHOULD include | 200 # when selecting an OCSP responder and SHOULD include |
196 # request_extensions in the OCSP request. | 201 # request_extensions in the OCSP request. |
197 p.getFixBytes(extLength) | 202 p.getFixBytes(extLength) |
198 self.status_request = True | 203 self.status_request = True |
199 else: | 204 else: |
200 p.getFixBytes(extLength) | 205 _ = p.getFixBytes(extLength) |
| 206 index2 = p.index |
| 207 if index2 - index1 != extLength: |
| 208 raise SyntaxError("Bad length for extension_data") |
201 soFar += 4 + extLength | 209 soFar += 4 + extLength |
202 p.stopLengthCheck() | 210 p.stopLengthCheck() |
203 return self | 211 return self |
204 | 212 |
205 def write(self, trial=False): | 213 def write(self): |
206 w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial) | 214 w = Writer() |
207 w.add(self.client_version[0], 1) | 215 w.add(self.client_version[0], 1) |
208 w.add(self.client_version[1], 1) | 216 w.add(self.client_version[1], 1) |
209 w.addFixSeq(self.random, 1) | 217 w.addFixSeq(self.random, 1) |
210 w.addVarSeq(self.session_id, 1, 1) | 218 w.addVarSeq(self.session_id, 1, 1) |
211 w.addVarSeq(self.cipher_suites, 2, 2) | 219 w.addVarSeq(self.cipher_suites, 2, 2) |
212 w.addVarSeq(self.compression_methods, 1, 1) | 220 w.addVarSeq(self.compression_methods, 1, 1) |
213 | 221 |
214 extLength = 0 | 222 w2 = Writer() # For Extensions |
215 if self.certificate_types and self.certificate_types != \ | 223 if self.certificate_types and self.certificate_types != \ |
216 [CertificateType.x509]: | 224 [CertificateType.x509]: |
217 extLength += 5 + len(self.certificate_types) | 225 w2.add(ExtensionType.cert_type, 2) |
| 226 w2.add(len(self.certificate_types)+1, 2) |
| 227 w2.addVarSeq(self.certificate_types, 1, 1) |
218 if self.srp_username: | 228 if self.srp_username: |
219 extLength += 5 + len(self.srp_username) | 229 w2.add(ExtensionType.srp, 2) |
220 if extLength > 0: | 230 w2.add(len(self.srp_username)+1, 2) |
221 w.add(extLength, 2) | 231 w2.addVarSeq(self.srp_username, 1, 1) |
| 232 if self.supports_npn: |
| 233 w2.add(ExtensionType.supports_npn, 2) |
| 234 w2.add(0, 2) |
| 235 if self.server_name: |
| 236 w2.add(ExtensionType.server_name, 2) |
| 237 w2.add(len(self.server_name)+5, 2) |
| 238 w2.add(len(self.server_name)+3, 2) |
| 239 w2.add(NameType.host_name, 1) |
| 240 w2.addVarSeq(self.server_name, 1, 2) |
| 241 if self.tack: |
| 242 w2.add(ExtensionType.tack, 2) |
| 243 w2.add(0, 2) |
| 244 if len(w2.bytes): |
| 245 w.add(len(w2.bytes), 2) |
| 246 w.bytes += w2.bytes |
| 247 return self.postWrite(w) |
222 | 248 |
223 if self.certificate_types and self.certificate_types != \ | 249 class BadNextProtos(Exception): |
224 [CertificateType.x509]: | 250 def __init__(self, l): |
225 w.add(7, 2) | 251 self.length = l |
226 w.add(len(self.certificate_types)+1, 2) | |
227 w.addVarSeq(self.certificate_types, 1, 1) | |
228 if self.srp_username: | |
229 w.add(6, 2) | |
230 w.add(len(self.srp_username)+1, 2) | |
231 w.addVarSeq(stringToBytes(self.srp_username), 1, 1) | |
232 | 252 |
233 return HandshakeMsg.postWrite(self, w, trial) | 253 def __str__(self): |
234 | 254 return 'Cannot encode a list of next protocols because it contains an el
ement with invalid length %d. Element lengths must be 0 < x < 256' % self.length |
235 | 255 |
236 class ServerHello(HandshakeMsg): | 256 class ServerHello(HandshakeMsg): |
237 def __init__(self): | 257 def __init__(self): |
238 self.contentType = ContentType.handshake | 258 HandshakeMsg.__init__(self, HandshakeType.server_hello) |
239 self.server_version = (0,0) | 259 self.server_version = (0,0) |
240 self.random = createByteArrayZeros(32) | 260 self.random = bytearray(32) |
241 self.session_id = createByteArraySequence([]) | 261 self.session_id = bytearray(0) |
242 self.cipher_suite = 0 | 262 self.cipher_suite = 0 |
243 self.certificate_type = CertificateType.x509 | 263 self.certificate_type = CertificateType.x509 |
244 self.compression_method = 0 | 264 self.compression_method = 0 |
| 265 self.tackExt = None |
| 266 self.next_protos_advertised = None |
| 267 self.next_protos = None |
245 self.channel_id = False | 268 self.channel_id = False |
246 self.signed_cert_timestamps = None | 269 self.signed_cert_timestamps = None |
247 self.status_request = False | 270 self.status_request = False |
248 | 271 |
249 def create(self, version, random, session_id, cipher_suite, | 272 def create(self, version, random, session_id, cipher_suite, |
250 certificate_type): | 273 certificate_type, tackExt, next_protos_advertised): |
251 self.server_version = version | 274 self.server_version = version |
252 self.random = random | 275 self.random = random |
253 self.session_id = session_id | 276 self.session_id = session_id |
254 self.cipher_suite = cipher_suite | 277 self.cipher_suite = cipher_suite |
255 self.certificate_type = certificate_type | 278 self.certificate_type = certificate_type |
256 self.compression_method = 0 | 279 self.compression_method = 0 |
| 280 self.tackExt = tackExt |
| 281 self.next_protos_advertised = next_protos_advertised |
257 return self | 282 return self |
258 | 283 |
259 def parse(self, p): | 284 def parse(self, p): |
260 p.startLengthCheck(3) | 285 p.startLengthCheck(3) |
261 self.server_version = (p.get(1), p.get(1)) | 286 self.server_version = (p.get(1), p.get(1)) |
262 self.random = p.getFixBytes(32) | 287 self.random = p.getFixBytes(32) |
263 self.session_id = p.getVarBytes(1) | 288 self.session_id = p.getVarBytes(1) |
264 self.cipher_suite = p.get(2) | 289 self.cipher_suite = p.get(2) |
265 self.compression_method = p.get(1) | 290 self.compression_method = p.get(1) |
266 if not p.atLengthCheck(): | 291 if not p.atLengthCheck(): |
267 totalExtLength = p.get(2) | 292 totalExtLength = p.get(2) |
268 soFar = 0 | 293 soFar = 0 |
269 while soFar != totalExtLength: | 294 while soFar != totalExtLength: |
270 extType = p.get(2) | 295 extType = p.get(2) |
271 extLength = p.get(2) | 296 extLength = p.get(2) |
272 if extType == 7: | 297 if extType == ExtensionType.cert_type: |
| 298 if extLength != 1: |
| 299 raise SyntaxError() |
273 self.certificate_type = p.get(1) | 300 self.certificate_type = p.get(1) |
| 301 elif extType == ExtensionType.tack and tackpyLoaded: |
| 302 self.tackExt = TackExtension(p.getFixBytes(extLength)) |
| 303 elif extType == ExtensionType.supports_npn: |
| 304 self.next_protos = self.__parse_next_protos(p.getFixBytes(ex
tLength)) |
274 else: | 305 else: |
275 p.getFixBytes(extLength) | 306 p.getFixBytes(extLength) |
276 soFar += 4 + extLength | 307 soFar += 4 + extLength |
277 p.stopLengthCheck() | 308 p.stopLengthCheck() |
278 return self | 309 return self |
279 | 310 |
280 def write(self, trial=False): | 311 def __parse_next_protos(self, b): |
281 w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial) | 312 protos = [] |
| 313 while True: |
| 314 if len(b) == 0: |
| 315 break |
| 316 l = b[0] |
| 317 b = b[1:] |
| 318 if len(b) < l: |
| 319 raise BadNextProtos(len(b)) |
| 320 protos.append(b[:l]) |
| 321 b = b[l:] |
| 322 return protos |
| 323 |
| 324 def __next_protos_encoded(self): |
| 325 b = bytearray() |
| 326 for e in self.next_protos_advertised: |
| 327 if len(e) > 255 or len(e) == 0: |
| 328 raise BadNextProtos(len(e)) |
| 329 b += bytearray( [len(e)] ) + bytearray(e) |
| 330 return b |
| 331 |
| 332 def write(self): |
| 333 w = Writer() |
282 w.add(self.server_version[0], 1) | 334 w.add(self.server_version[0], 1) |
283 w.add(self.server_version[1], 1) | 335 w.add(self.server_version[1], 1) |
284 w.addFixSeq(self.random, 1) | 336 w.addFixSeq(self.random, 1) |
285 w.addVarSeq(self.session_id, 1, 1) | 337 w.addVarSeq(self.session_id, 1, 1) |
286 w.add(self.cipher_suite, 2) | 338 w.add(self.cipher_suite, 2) |
287 w.add(self.compression_method, 1) | 339 w.add(self.compression_method, 1) |
288 | 340 |
289 extLength = 0 | 341 w2 = Writer() # For Extensions |
290 if self.certificate_type and self.certificate_type != \ | 342 if self.certificate_type and self.certificate_type != \ |
291 CertificateType.x509: | 343 CertificateType.x509: |
292 extLength += 5 | 344 w2.add(ExtensionType.cert_type, 2) |
| 345 w2.add(1, 2) |
| 346 w2.add(self.certificate_type, 1) |
| 347 if self.tackExt: |
| 348 b = self.tackExt.serialize() |
| 349 w2.add(ExtensionType.tack, 2) |
| 350 w2.add(len(b), 2) |
| 351 w2.bytes += b |
| 352 if self.next_protos_advertised is not None: |
| 353 encoded_next_protos_advertised = self.__next_protos_encoded() |
| 354 w2.add(ExtensionType.supports_npn, 2) |
| 355 w2.add(len(encoded_next_protos_advertised), 2) |
| 356 w2.addFixSeq(encoded_next_protos_advertised, 1) |
| 357 if self.channel_id: |
| 358 w2.add(ExtensionType.channel_id, 2) |
| 359 w2.add(0, 2) |
| 360 if self.signed_cert_timestamps: |
| 361 w2.add(ExtensionType.signed_cert_timestamps, 2) |
| 362 w2.addVarSeq(bytearray(self.signed_cert_timestamps), 1, 2) |
| 363 if self.status_request: |
| 364 w2.add(ExtensionType.status_request, 2) |
| 365 w2.add(0, 2) |
| 366 if len(w2.bytes): |
| 367 w.add(len(w2.bytes), 2) |
| 368 w.bytes += w2.bytes |
| 369 return self.postWrite(w) |
293 | 370 |
294 if self.channel_id: | |
295 extLength += 4 | |
296 | |
297 if self.signed_cert_timestamps: | |
298 extLength += 4 + len(self.signed_cert_timestamps) | |
299 | |
300 if self.status_request: | |
301 extLength += 4 | |
302 | |
303 if extLength != 0: | |
304 w.add(extLength, 2) | |
305 | |
306 if self.certificate_type and self.certificate_type != \ | |
307 CertificateType.x509: | |
308 w.add(7, 2) | |
309 w.add(1, 2) | |
310 w.add(self.certificate_type, 1) | |
311 | |
312 if self.channel_id: | |
313 w.add(ExtensionType.channel_id, 2) | |
314 w.add(0, 2) | |
315 | |
316 if self.signed_cert_timestamps: | |
317 w.add(ExtensionType.signed_cert_timestamps, 2) | |
318 w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2) | |
319 | |
320 if self.status_request: | |
321 w.add(ExtensionType.status_request, 2) | |
322 w.add(0, 2) | |
323 | |
324 return HandshakeMsg.postWrite(self, w, trial) | |
325 | 371 |
326 class Certificate(HandshakeMsg): | 372 class Certificate(HandshakeMsg): |
327 def __init__(self, certificateType): | 373 def __init__(self, certificateType): |
| 374 HandshakeMsg.__init__(self, HandshakeType.certificate) |
328 self.certificateType = certificateType | 375 self.certificateType = certificateType |
329 self.contentType = ContentType.handshake | |
330 self.certChain = None | 376 self.certChain = None |
331 | 377 |
332 def create(self, certChain): | 378 def create(self, certChain): |
333 self.certChain = certChain | 379 self.certChain = certChain |
334 return self | 380 return self |
335 | 381 |
336 def parse(self, p): | 382 def parse(self, p): |
337 p.startLengthCheck(3) | 383 p.startLengthCheck(3) |
338 if self.certificateType == CertificateType.x509: | 384 if self.certificateType == CertificateType.x509: |
339 chainLength = p.get(3) | 385 chainLength = p.get(3) |
340 index = 0 | 386 index = 0 |
341 certificate_list = [] | 387 certificate_list = [] |
342 while index != chainLength: | 388 while index != chainLength: |
343 certBytes = p.getVarBytes(3) | 389 certBytes = p.getVarBytes(3) |
344 x509 = X509() | 390 x509 = X509() |
345 x509.parseBinary(certBytes) | 391 x509.parseBinary(certBytes) |
346 certificate_list.append(x509) | 392 certificate_list.append(x509) |
347 index += len(certBytes)+3 | 393 index += len(certBytes)+3 |
348 if certificate_list: | 394 if certificate_list: |
349 self.certChain = X509CertChain(certificate_list) | 395 self.certChain = X509CertChain(certificate_list) |
350 elif self.certificateType == CertificateType.cryptoID: | |
351 s = bytesToString(p.getVarBytes(2)) | |
352 if s: | |
353 try: | |
354 import cryptoIDlib.CertChain | |
355 except ImportError: | |
356 raise SyntaxError(\ | |
357 "cryptoID cert chain received, cryptoIDlib not present") | |
358 self.certChain = cryptoIDlib.CertChain.CertChain().parse(s) | |
359 else: | 396 else: |
360 raise AssertionError() | 397 raise AssertionError() |
361 | 398 |
362 p.stopLengthCheck() | 399 p.stopLengthCheck() |
363 return self | 400 return self |
364 | 401 |
365 def write(self, trial=False): | 402 def write(self): |
366 w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial) | 403 w = Writer() |
367 if self.certificateType == CertificateType.x509: | 404 if self.certificateType == CertificateType.x509: |
368 chainLength = 0 | 405 chainLength = 0 |
369 if self.certChain: | 406 if self.certChain: |
370 certificate_list = self.certChain.x509List | 407 certificate_list = self.certChain.x509List |
371 else: | 408 else: |
372 certificate_list = [] | 409 certificate_list = [] |
373 #determine length | 410 #determine length |
374 for cert in certificate_list: | 411 for cert in certificate_list: |
375 bytes = cert.writeBytes() | 412 bytes = cert.writeBytes() |
376 chainLength += len(bytes)+3 | 413 chainLength += len(bytes)+3 |
377 #add bytes | 414 #add bytes |
378 w.add(chainLength, 3) | 415 w.add(chainLength, 3) |
379 for cert in certificate_list: | 416 for cert in certificate_list: |
380 bytes = cert.writeBytes() | 417 bytes = cert.writeBytes() |
381 w.addVarSeq(bytes, 1, 3) | 418 w.addVarSeq(bytes, 1, 3) |
382 elif self.certificateType == CertificateType.cryptoID: | |
383 if self.certChain: | |
384 bytes = stringToBytes(self.certChain.write()) | |
385 else: | |
386 bytes = createByteArraySequence([]) | |
387 w.addVarSeq(bytes, 1, 2) | |
388 else: | 419 else: |
389 raise AssertionError() | 420 raise AssertionError() |
390 return HandshakeMsg.postWrite(self, w, trial) | 421 return self.postWrite(w) |
391 | 422 |
392 class CertificateStatus(HandshakeMsg): | 423 class CertificateStatus(HandshakeMsg): |
393 def __init__(self): | 424 def __init__(self): |
394 self.contentType = ContentType.handshake | 425 HandshakeMsg.__init__(self, HandshakeType.certificate_status) |
395 | 426 |
396 def create(self, ocsp_response): | 427 def create(self, ocsp_response): |
397 self.ocsp_response = ocsp_response | 428 self.ocsp_response = ocsp_response |
398 return self | 429 return self |
399 | 430 |
400 # Defined for the sake of completeness, even though we currently only | 431 # Defined for the sake of completeness, even though we currently only |
401 # support sending the status message (server-side), not requesting | 432 # support sending the status message (server-side), not requesting |
402 # or receiving it (client-side). | 433 # or receiving it (client-side). |
403 def parse(self, p): | 434 def parse(self, p): |
404 p.startLengthCheck(3) | 435 p.startLengthCheck(3) |
405 status_type = p.get(1) | 436 status_type = p.get(1) |
406 # Only one type is specified, so hardwire it. | 437 # Only one type is specified, so hardwire it. |
407 if status_type != CertificateStatusType.ocsp: | 438 if status_type != CertificateStatusType.ocsp: |
408 raise SyntaxError() | 439 raise SyntaxError() |
409 ocsp_response = p.getVarBytes(3) | 440 ocsp_response = p.getVarBytes(3) |
410 if not ocsp_response: | 441 if not ocsp_response: |
411 # Can't be empty | 442 # Can't be empty |
412 raise SyntaxError() | 443 raise SyntaxError() |
413 self.ocsp_response = ocsp_response | 444 self.ocsp_response = ocsp_response |
| 445 p.stopLengthCheck() |
414 return self | 446 return self |
415 | 447 |
416 def write(self, trial=False): | 448 def write(self): |
417 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_status, | 449 w = Writer() |
418 trial) | |
419 w.add(CertificateStatusType.ocsp, 1) | 450 w.add(CertificateStatusType.ocsp, 1) |
420 w.addVarSeq(stringToBytes(self.ocsp_response), 1, 3) | 451 w.addVarSeq(bytearray(self.ocsp_response), 1, 3) |
421 return HandshakeMsg.postWrite(self, w, trial) | 452 return self.postWrite(w) |
422 | 453 |
423 class CertificateRequest(HandshakeMsg): | 454 class CertificateRequest(HandshakeMsg): |
424 def __init__(self): | 455 def __init__(self): |
425 self.contentType = ContentType.handshake | 456 HandshakeMsg.__init__(self, HandshakeType.certificate_request) |
426 #Apple's Secure Transport library rejects empty certificate_types, so | 457 #Apple's Secure Transport library rejects empty certificate_types, so |
427 #default to rsa_sign. | 458 #default to rsa_sign. |
428 self.certificate_types = [ClientCertificateType.rsa_sign] | 459 self.certificate_types = [ClientCertificateType.rsa_sign] |
429 self.certificate_authorities = [] | 460 self.certificate_authorities = [] |
430 | 461 |
431 def create(self, certificate_types, certificate_authorities): | 462 def create(self, certificate_types, certificate_authorities): |
432 self.certificate_types = certificate_types | 463 self.certificate_types = certificate_types |
433 self.certificate_authorities = certificate_authorities | 464 self.certificate_authorities = certificate_authorities |
434 return self | 465 return self |
435 | 466 |
436 def parse(self, p): | 467 def parse(self, p): |
437 p.startLengthCheck(3) | 468 p.startLengthCheck(3) |
438 self.certificate_types = p.getVarList(1, 1) | 469 self.certificate_types = p.getVarList(1, 1) |
439 ca_list_length = p.get(2) | 470 ca_list_length = p.get(2) |
440 index = 0 | 471 index = 0 |
441 self.certificate_authorities = [] | 472 self.certificate_authorities = [] |
442 while index != ca_list_length: | 473 while index != ca_list_length: |
443 ca_bytes = p.getVarBytes(2) | 474 ca_bytes = p.getVarBytes(2) |
444 self.certificate_authorities.append(ca_bytes) | 475 self.certificate_authorities.append(ca_bytes) |
445 index += len(ca_bytes)+2 | 476 index += len(ca_bytes)+2 |
446 p.stopLengthCheck() | 477 p.stopLengthCheck() |
447 return self | 478 return self |
448 | 479 |
449 def write(self, trial=False): | 480 def write(self): |
450 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request, | 481 w = Writer() |
451 trial) | |
452 w.addVarSeq(self.certificate_types, 1, 1) | 482 w.addVarSeq(self.certificate_types, 1, 1) |
453 caLength = 0 | 483 caLength = 0 |
454 #determine length | 484 #determine length |
455 for ca_dn in self.certificate_authorities: | 485 for ca_dn in self.certificate_authorities: |
456 caLength += len(ca_dn)+2 | 486 caLength += len(ca_dn)+2 |
457 w.add(caLength, 2) | 487 w.add(caLength, 2) |
458 #add bytes | 488 #add bytes |
459 for ca_dn in self.certificate_authorities: | 489 for ca_dn in self.certificate_authorities: |
460 w.addVarSeq(ca_dn, 1, 2) | 490 w.addVarSeq(ca_dn, 1, 2) |
461 return HandshakeMsg.postWrite(self, w, trial) | 491 return self.postWrite(w) |
462 | 492 |
463 class ServerKeyExchange(HandshakeMsg): | 493 class ServerKeyExchange(HandshakeMsg): |
464 def __init__(self, cipherSuite): | 494 def __init__(self, cipherSuite): |
| 495 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange) |
465 self.cipherSuite = cipherSuite | 496 self.cipherSuite = cipherSuite |
466 self.contentType = ContentType.handshake | 497 self.srp_N = 0 |
467 self.srp_N = 0L | 498 self.srp_g = 0 |
468 self.srp_g = 0L | 499 self.srp_s = bytearray(0) |
469 self.srp_s = createByteArraySequence([]) | 500 self.srp_B = 0 |
470 self.srp_B = 0L | 501 # Anon DH params: |
471 self.signature = createByteArraySequence([]) | 502 self.dh_p = 0 |
| 503 self.dh_g = 0 |
| 504 self.dh_Ys = 0 |
| 505 self.signature = bytearray(0) |
472 | 506 |
473 def createSRP(self, srp_N, srp_g, srp_s, srp_B): | 507 def createSRP(self, srp_N, srp_g, srp_s, srp_B): |
474 self.srp_N = srp_N | 508 self.srp_N = srp_N |
475 self.srp_g = srp_g | 509 self.srp_g = srp_g |
476 self.srp_s = srp_s | 510 self.srp_s = srp_s |
477 self.srp_B = srp_B | 511 self.srp_B = srp_B |
478 return self | 512 return self |
| 513 |
| 514 def createDH(self, dh_p, dh_g, dh_Ys): |
| 515 self.dh_p = dh_p |
| 516 self.dh_g = dh_g |
| 517 self.dh_Ys = dh_Ys |
| 518 return self |
479 | 519 |
480 def parse(self, p): | 520 def parse(self, p): |
481 p.startLengthCheck(3) | 521 p.startLengthCheck(3) |
482 self.srp_N = bytesToNumber(p.getVarBytes(2)) | 522 if self.cipherSuite in CipherSuite.srpAllSuites: |
483 self.srp_g = bytesToNumber(p.getVarBytes(2)) | 523 self.srp_N = bytesToNumber(p.getVarBytes(2)) |
484 self.srp_s = p.getVarBytes(1) | 524 self.srp_g = bytesToNumber(p.getVarBytes(2)) |
485 self.srp_B = bytesToNumber(p.getVarBytes(2)) | 525 self.srp_s = p.getVarBytes(1) |
486 if self.cipherSuite in CipherSuite.srpRsaSuites: | 526 self.srp_B = bytesToNumber(p.getVarBytes(2)) |
487 self.signature = p.getVarBytes(2) | 527 if self.cipherSuite in CipherSuite.srpCertSuites: |
| 528 self.signature = p.getVarBytes(2) |
| 529 elif self.cipherSuite in CipherSuite.anonSuites: |
| 530 self.dh_p = bytesToNumber(p.getVarBytes(2)) |
| 531 self.dh_g = bytesToNumber(p.getVarBytes(2)) |
| 532 self.dh_Ys = bytesToNumber(p.getVarBytes(2)) |
488 p.stopLengthCheck() | 533 p.stopLengthCheck() |
489 return self | 534 return self |
490 | 535 |
491 def write(self, trial=False): | 536 def write(self): |
492 w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange, | 537 w = Writer() |
493 trial) | 538 if self.cipherSuite in CipherSuite.srpAllSuites: |
494 w.addVarSeq(numberToBytes(self.srp_N), 1, 2) | 539 w.addVarSeq(numberToByteArray(self.srp_N), 1, 2) |
495 w.addVarSeq(numberToBytes(self.srp_g), 1, 2) | 540 w.addVarSeq(numberToByteArray(self.srp_g), 1, 2) |
496 w.addVarSeq(self.srp_s, 1, 1) | 541 w.addVarSeq(self.srp_s, 1, 1) |
497 w.addVarSeq(numberToBytes(self.srp_B), 1, 2) | 542 w.addVarSeq(numberToByteArray(self.srp_B), 1, 2) |
498 if self.cipherSuite in CipherSuite.srpRsaSuites: | 543 if self.cipherSuite in CipherSuite.srpCertSuites: |
499 w.addVarSeq(self.signature, 1, 2) | 544 w.addVarSeq(self.signature, 1, 2) |
500 return HandshakeMsg.postWrite(self, w, trial) | 545 elif self.cipherSuite in CipherSuite.anonSuites: |
| 546 w.addVarSeq(numberToByteArray(self.dh_p), 1, 2) |
| 547 w.addVarSeq(numberToByteArray(self.dh_g), 1, 2) |
| 548 w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2) |
| 549 if self.cipherSuite in []: # TODO support for signed_params |
| 550 w.addVarSeq(self.signature, 1, 2) |
| 551 return self.postWrite(w) |
501 | 552 |
502 def hash(self, clientRandom, serverRandom): | 553 def hash(self, clientRandom, serverRandom): |
503 oldCipherSuite = self.cipherSuite | 554 oldCipherSuite = self.cipherSuite |
504 self.cipherSuite = None | 555 self.cipherSuite = None |
505 try: | 556 try: |
506 bytes = clientRandom + serverRandom + self.write()[4:] | 557 bytes = clientRandom + serverRandom + self.write()[4:] |
507 s = bytesToString(bytes) | 558 return MD5(bytes) + SHA1(bytes) |
508 return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest()) | |
509 finally: | 559 finally: |
510 self.cipherSuite = oldCipherSuite | 560 self.cipherSuite = oldCipherSuite |
511 | 561 |
512 class ServerHelloDone(HandshakeMsg): | 562 class ServerHelloDone(HandshakeMsg): |
513 def __init__(self): | 563 def __init__(self): |
514 self.contentType = ContentType.handshake | 564 HandshakeMsg.__init__(self, HandshakeType.server_hello_done) |
515 | 565 |
516 def create(self): | 566 def create(self): |
517 return self | 567 return self |
518 | 568 |
519 def parse(self, p): | 569 def parse(self, p): |
520 p.startLengthCheck(3) | 570 p.startLengthCheck(3) |
521 p.stopLengthCheck() | 571 p.stopLengthCheck() |
522 return self | 572 return self |
523 | 573 |
524 def write(self, trial=False): | 574 def write(self): |
525 w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial) | 575 w = Writer() |
526 return HandshakeMsg.postWrite(self, w, trial) | 576 return self.postWrite(w) |
527 | 577 |
528 class ClientKeyExchange(HandshakeMsg): | 578 class ClientKeyExchange(HandshakeMsg): |
529 def __init__(self, cipherSuite, version=None): | 579 def __init__(self, cipherSuite, version=None): |
| 580 HandshakeMsg.__init__(self, HandshakeType.client_key_exchange) |
530 self.cipherSuite = cipherSuite | 581 self.cipherSuite = cipherSuite |
531 self.version = version | 582 self.version = version |
532 self.contentType = ContentType.handshake | |
533 self.srp_A = 0 | 583 self.srp_A = 0 |
534 self.encryptedPreMasterSecret = createByteArraySequence([]) | 584 self.encryptedPreMasterSecret = bytearray(0) |
535 | 585 |
536 def createSRP(self, srp_A): | 586 def createSRP(self, srp_A): |
537 self.srp_A = srp_A | 587 self.srp_A = srp_A |
538 return self | 588 return self |
539 | 589 |
540 def createRSA(self, encryptedPreMasterSecret): | 590 def createRSA(self, encryptedPreMasterSecret): |
541 self.encryptedPreMasterSecret = encryptedPreMasterSecret | 591 self.encryptedPreMasterSecret = encryptedPreMasterSecret |
542 return self | 592 return self |
543 | 593 |
| 594 def createDH(self, dh_Yc): |
| 595 self.dh_Yc = dh_Yc |
| 596 return self |
| 597 |
544 def parse(self, p): | 598 def parse(self, p): |
545 p.startLengthCheck(3) | 599 p.startLengthCheck(3) |
546 if self.cipherSuite in CipherSuite.srpSuites + \ | 600 if self.cipherSuite in CipherSuite.srpAllSuites: |
547 CipherSuite.srpRsaSuites: | |
548 self.srp_A = bytesToNumber(p.getVarBytes(2)) | 601 self.srp_A = bytesToNumber(p.getVarBytes(2)) |
549 elif self.cipherSuite in CipherSuite.rsaSuites: | 602 elif self.cipherSuite in CipherSuite.certSuites: |
550 if self.version in ((3,1), (3,2)): | 603 if self.version in ((3,1), (3,2)): |
551 self.encryptedPreMasterSecret = p.getVarBytes(2) | 604 self.encryptedPreMasterSecret = p.getVarBytes(2) |
552 elif self.version == (3,0): | 605 elif self.version == (3,0): |
553 self.encryptedPreMasterSecret = \ | 606 self.encryptedPreMasterSecret = \ |
554 p.getFixBytes(len(p.bytes)-p.index) | 607 p.getFixBytes(len(p.bytes)-p.index) |
555 else: | 608 else: |
556 raise AssertionError() | 609 raise AssertionError() |
| 610 elif self.cipherSuite in CipherSuite.anonSuites: |
| 611 self.dh_Yc = bytesToNumber(p.getVarBytes(2)) |
557 else: | 612 else: |
558 raise AssertionError() | 613 raise AssertionError() |
559 p.stopLengthCheck() | 614 p.stopLengthCheck() |
560 return self | 615 return self |
561 | 616 |
562 def write(self, trial=False): | 617 def write(self): |
563 w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange, | 618 w = Writer() |
564 trial) | 619 if self.cipherSuite in CipherSuite.srpAllSuites: |
565 if self.cipherSuite in CipherSuite.srpSuites + \ | 620 w.addVarSeq(numberToByteArray(self.srp_A), 1, 2) |
566 CipherSuite.srpRsaSuites: | 621 elif self.cipherSuite in CipherSuite.certSuites: |
567 w.addVarSeq(numberToBytes(self.srp_A), 1, 2) | |
568 elif self.cipherSuite in CipherSuite.rsaSuites: | |
569 if self.version in ((3,1), (3,2)): | 622 if self.version in ((3,1), (3,2)): |
570 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) | 623 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2) |
571 elif self.version == (3,0): | 624 elif self.version == (3,0): |
572 w.addFixSeq(self.encryptedPreMasterSecret, 1) | 625 w.addFixSeq(self.encryptedPreMasterSecret, 1) |
573 else: | 626 else: |
574 raise AssertionError() | 627 raise AssertionError() |
| 628 elif self.cipherSuite in CipherSuite.anonSuites: |
| 629 w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2) |
575 else: | 630 else: |
576 raise AssertionError() | 631 raise AssertionError() |
577 return HandshakeMsg.postWrite(self, w, trial) | 632 return self.postWrite(w) |
578 | 633 |
579 class CertificateVerify(HandshakeMsg): | 634 class CertificateVerify(HandshakeMsg): |
580 def __init__(self): | 635 def __init__(self): |
581 self.contentType = ContentType.handshake | 636 HandshakeMsg.__init__(self, HandshakeType.certificate_verify) |
582 self.signature = createByteArraySequence([]) | 637 self.signature = bytearray(0) |
583 | 638 |
584 def create(self, signature): | 639 def create(self, signature): |
585 self.signature = signature | 640 self.signature = signature |
586 return self | 641 return self |
587 | 642 |
588 def parse(self, p): | 643 def parse(self, p): |
589 p.startLengthCheck(3) | 644 p.startLengthCheck(3) |
590 self.signature = p.getVarBytes(2) | 645 self.signature = p.getVarBytes(2) |
591 p.stopLengthCheck() | 646 p.stopLengthCheck() |
592 return self | 647 return self |
593 | 648 |
594 def write(self, trial=False): | 649 def write(self): |
595 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify, | 650 w = Writer() |
596 trial) | |
597 w.addVarSeq(self.signature, 1, 2) | 651 w.addVarSeq(self.signature, 1, 2) |
598 return HandshakeMsg.postWrite(self, w, trial) | 652 return self.postWrite(w) |
599 | 653 |
600 class ChangeCipherSpec(Msg): | 654 class ChangeCipherSpec(object): |
601 def __init__(self): | 655 def __init__(self): |
602 self.contentType = ContentType.change_cipher_spec | 656 self.contentType = ContentType.change_cipher_spec |
603 self.type = 1 | 657 self.type = 1 |
604 | 658 |
605 def create(self): | 659 def create(self): |
606 self.type = 1 | 660 self.type = 1 |
607 return self | 661 return self |
608 | 662 |
609 def parse(self, p): | 663 def parse(self, p): |
610 p.setLengthCheck(1) | 664 p.setLengthCheck(1) |
611 self.type = p.get(1) | 665 self.type = p.get(1) |
612 p.stopLengthCheck() | 666 p.stopLengthCheck() |
613 return self | 667 return self |
614 | 668 |
| 669 def write(self): |
| 670 w = Writer() |
| 671 w.add(self.type,1) |
| 672 return w.bytes |
| 673 |
| 674 |
| 675 class NextProtocol(HandshakeMsg): |
| 676 def __init__(self): |
| 677 HandshakeMsg.__init__(self, HandshakeType.next_protocol) |
| 678 self.next_proto = None |
| 679 |
| 680 def create(self, next_proto): |
| 681 self.next_proto = next_proto |
| 682 return self |
| 683 |
| 684 def parse(self, p): |
| 685 p.startLengthCheck(3) |
| 686 self.next_proto = p.getVarBytes(1) |
| 687 _ = p.getVarBytes(1) |
| 688 p.stopLengthCheck() |
| 689 return self |
| 690 |
615 def write(self, trial=False): | 691 def write(self, trial=False): |
616 w = Msg.preWrite(self, trial) | 692 w = Writer() |
617 w.add(self.type,1) | 693 w.addVarSeq(self.next_proto, 1, 1) |
618 return Msg.postWrite(self, w, trial) | 694 paddingLen = 32 - ((len(self.next_proto) + 2) % 32) |
619 | 695 w.addVarSeq(bytearray(paddingLen), 1, 1) |
| 696 return self.postWrite(w) |
620 | 697 |
621 class Finished(HandshakeMsg): | 698 class Finished(HandshakeMsg): |
622 def __init__(self, version): | 699 def __init__(self, version): |
623 self.contentType = ContentType.handshake | 700 HandshakeMsg.__init__(self, HandshakeType.finished) |
624 self.version = version | 701 self.version = version |
625 self.verify_data = createByteArraySequence([]) | 702 self.verify_data = bytearray(0) |
626 | 703 |
627 def create(self, verify_data): | 704 def create(self, verify_data): |
628 self.verify_data = verify_data | 705 self.verify_data = verify_data |
629 return self | 706 return self |
630 | 707 |
631 def parse(self, p): | 708 def parse(self, p): |
632 p.startLengthCheck(3) | 709 p.startLengthCheck(3) |
633 if self.version == (3,0): | 710 if self.version == (3,0): |
634 self.verify_data = p.getFixBytes(36) | 711 self.verify_data = p.getFixBytes(36) |
635 elif self.version in ((3,1), (3,2)): | 712 elif self.version in ((3,1), (3,2)): |
636 self.verify_data = p.getFixBytes(12) | 713 self.verify_data = p.getFixBytes(12) |
637 else: | 714 else: |
638 raise AssertionError() | 715 raise AssertionError() |
639 p.stopLengthCheck() | 716 p.stopLengthCheck() |
640 return self | 717 return self |
641 | 718 |
642 def write(self, trial=False): | 719 def write(self): |
643 w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial) | 720 w = Writer() |
644 w.addFixSeq(self.verify_data, 1) | 721 w.addFixSeq(self.verify_data, 1) |
645 return HandshakeMsg.postWrite(self, w, trial) | 722 return self.postWrite(w) |
646 | 723 |
647 class EncryptedExtensions(HandshakeMsg): | 724 class EncryptedExtensions(HandshakeMsg): |
648 def __init__(self): | 725 def __init__(self): |
649 self.channel_id_key = None | 726 self.channel_id_key = None |
650 self.channel_id_proof = None | 727 self.channel_id_proof = None |
651 | 728 |
652 def parse(self, p): | 729 def parse(self, p): |
653 p.startLengthCheck(3) | 730 p.startLengthCheck(3) |
654 soFar = 0 | 731 soFar = 0 |
655 while soFar != p.lengthCheck: | 732 while soFar != p.lengthCheck: |
656 extType = p.get(2) | 733 extType = p.get(2) |
657 extLength = p.get(2) | 734 extLength = p.get(2) |
658 if extType == ExtensionType.channel_id: | 735 if extType == ExtensionType.channel_id: |
659 if extLength != 32*4: | 736 if extLength != 32*4: |
660 raise SyntaxError() | 737 raise SyntaxError() |
661 self.channel_id_key = p.getFixBytes(64) | 738 self.channel_id_key = p.getFixBytes(64) |
662 self.channel_id_proof = p.getFixBytes(64) | 739 self.channel_id_proof = p.getFixBytes(64) |
663 else: | 740 else: |
664 p.getFixBytes(extLength) | 741 p.getFixBytes(extLength) |
665 soFar += 4 + extLength | 742 soFar += 4 + extLength |
666 p.stopLengthCheck() | 743 p.stopLengthCheck() |
667 return self | 744 return self |
668 | 745 |
669 class ApplicationData(Msg): | 746 class ApplicationData(object): |
670 def __init__(self): | 747 def __init__(self): |
671 self.contentType = ContentType.application_data | 748 self.contentType = ContentType.application_data |
672 self.bytes = createByteArraySequence([]) | 749 self.bytes = bytearray(0) |
673 | 750 |
674 def create(self, bytes): | 751 def create(self, bytes): |
675 self.bytes = bytes | 752 self.bytes = bytes |
676 return self | 753 return self |
| 754 |
| 755 def splitFirstByte(self): |
| 756 newMsg = ApplicationData().create(self.bytes[:1]) |
| 757 self.bytes = self.bytes[1:] |
| 758 return newMsg |
677 | 759 |
678 def parse(self, p): | 760 def parse(self, p): |
679 self.bytes = p.bytes | 761 self.bytes = p.bytes |
680 return self | 762 return self |
681 | 763 |
682 def write(self): | 764 def write(self): |
683 return self.bytes | 765 return self.bytes |
OLD | NEW |