OLD | NEW |
| 1 # Authors: |
| 2 # Trevor Perrin |
| 3 # Google (adapted by Sam Rushing) - NPN support |
| 4 # Martin von Loewis - python 3 port |
| 5 # |
| 6 # See the LICENSE file for legal information regarding use of this file. |
| 7 |
1 """Helper class for TLSConnection.""" | 8 """Helper class for TLSConnection.""" |
2 from __future__ import generators | 9 from __future__ import generators |
3 | 10 |
4 from utils.compat import * | 11 from .utils.compat import * |
5 from utils.cryptomath import * | 12 from .utils.cryptomath import * |
6 from utils.cipherfactory import createAES, createRC4, createTripleDES | 13 from .utils.cipherfactory import createAES, createRC4, createTripleDES |
7 from utils.codec import * | 14 from .utils.codec import * |
8 from errors import * | 15 from .errors import * |
9 from messages import * | 16 from .messages import * |
10 from mathtls import * | 17 from .mathtls import * |
11 from constants import * | 18 from .constants import * |
12 from utils.cryptomath import getRandomBytes | 19 from .utils.cryptomath import getRandomBytes |
13 from utils import hmac | |
14 from fileobject import FileObject | |
15 | |
16 # The sha module is deprecated in Python 2.6 | |
17 try: | |
18 import sha | |
19 except ImportError: | |
20 from hashlib import sha1 as sha | |
21 | |
22 # The md5 module is deprecated in Python 2.6 | |
23 try: | |
24 import md5 | |
25 except ImportError: | |
26 from hashlib import md5 | |
27 | 20 |
28 import socket | 21 import socket |
29 import errno | 22 import errno |
30 import traceback | 23 import traceback |
31 | 24 |
32 class _ConnectionState: | 25 class _ConnectionState(object): |
33 def __init__(self): | 26 def __init__(self): |
34 self.macContext = None | 27 self.macContext = None |
35 self.encContext = None | 28 self.encContext = None |
36 self.seqnum = 0 | 29 self.seqnum = 0 |
37 | 30 |
38 def getSeqNumStr(self): | 31 def getSeqNumBytes(self): |
39 w = Writer(8) | 32 w = Writer() |
40 w.add(self.seqnum, 8) | 33 w.add(self.seqnum, 8) |
41 seqnumStr = bytesToString(w.bytes) | |
42 self.seqnum += 1 | 34 self.seqnum += 1 |
43 return seqnumStr | 35 return w.bytes |
44 | 36 |
45 | 37 |
46 class TLSRecordLayer: | 38 class TLSRecordLayer(object): |
47 """ | 39 """ |
48 This class handles data transmission for a TLS connection. | 40 This class handles data transmission for a TLS connection. |
49 | 41 |
50 Its only subclass is L{tlslite.TLSConnection.TLSConnection}. We've | 42 Its only subclass is L{tlslite.TLSConnection.TLSConnection}. We've |
51 separated the code in this class from TLSConnection to make things | 43 separated the code in this class from TLSConnection to make things |
52 more readable. | 44 more readable. |
53 | 45 |
54 | 46 |
55 @type sock: socket.socket | 47 @type sock: socket.socket |
56 @ivar sock: The underlying socket object. | 48 @ivar sock: The underlying socket object. |
57 | 49 |
58 @type session: L{tlslite.Session.Session} | 50 @type session: L{tlslite.Session.Session} |
59 @ivar session: The session corresponding to this connection. | 51 @ivar session: The session corresponding to this connection. |
60 | 52 |
61 Due to TLS session resumption, multiple connections can correspond | 53 Due to TLS session resumption, multiple connections can correspond |
62 to the same underlying session. | 54 to the same underlying session. |
63 | 55 |
64 @type version: tuple | 56 @type version: tuple |
65 @ivar version: The TLS version being used for this connection. | 57 @ivar version: The TLS version being used for this connection. |
66 | 58 |
67 (3,0) means SSL 3.0, and (3,1) means TLS 1.0. | 59 (3,0) means SSL 3.0, and (3,1) means TLS 1.0. |
68 | 60 |
69 @type closed: bool | 61 @type closed: bool |
70 @ivar closed: If this connection is closed. | 62 @ivar closed: If this connection is closed. |
71 | 63 |
72 @type resumed: bool | 64 @type resumed: bool |
73 @ivar resumed: If this connection is based on a resumed session. | 65 @ivar resumed: If this connection is based on a resumed session. |
74 | 66 |
75 @type allegedSharedKeyUsername: str or None | |
76 @ivar allegedSharedKeyUsername: This is set to the shared-key | |
77 username asserted by the client, whether the handshake succeeded or | |
78 not. If the handshake fails, this can be inspected to | |
79 determine if a guessing attack is in progress against a particular | |
80 user account. | |
81 | |
82 @type allegedSrpUsername: str or None | 67 @type allegedSrpUsername: str or None |
83 @ivar allegedSrpUsername: This is set to the SRP username | 68 @ivar allegedSrpUsername: This is set to the SRP username |
84 asserted by the client, whether the handshake succeeded or not. | 69 asserted by the client, whether the handshake succeeded or not. |
85 If the handshake fails, this can be inspected to determine | 70 If the handshake fails, this can be inspected to determine |
86 if a guessing attack is in progress against a particular user | 71 if a guessing attack is in progress against a particular user |
87 account. | 72 account. |
88 | 73 |
89 @type closeSocket: bool | 74 @type closeSocket: bool |
90 @ivar closeSocket: If the socket should be closed when the | 75 @ivar closeSocket: If the socket should be closed when the |
91 connection is closed (writable). | 76 connection is closed, defaults to True (writable). |
92 | 77 |
93 If you set this to True, TLS Lite will assume the responsibility of | 78 If you set this to True, TLS Lite will assume the responsibility of |
94 closing the socket when the TLS Connection is shutdown (either | 79 closing the socket when the TLS Connection is shutdown (either |
95 through an error or through the user calling close()). The default | 80 through an error or through the user calling close()). The default |
96 is False. | 81 is False. |
97 | 82 |
98 @type ignoreAbruptClose: bool | 83 @type ignoreAbruptClose: bool |
99 @ivar ignoreAbruptClose: If an abrupt close of the socket should | 84 @ivar ignoreAbruptClose: If an abrupt close of the socket should |
100 raise an error (writable). | 85 raise an error (writable). |
101 | 86 |
(...skipping 15 matching lines...) Expand all Loading... |
117 self.sock = sock | 102 self.sock = sock |
118 | 103 |
119 #My session object (Session instance; read-only) | 104 #My session object (Session instance; read-only) |
120 self.session = None | 105 self.session = None |
121 | 106 |
122 #Am I a client or server? | 107 #Am I a client or server? |
123 self._client = None | 108 self._client = None |
124 | 109 |
125 #Buffers for processing messages | 110 #Buffers for processing messages |
126 self._handshakeBuffer = [] | 111 self._handshakeBuffer = [] |
127 self._readBuffer = "" | 112 self.clearReadBuffer() |
| 113 self.clearWriteBuffer() |
128 | 114 |
129 #Handshake digests | 115 #Handshake digests |
130 self._handshake_md5 = md5.md5() | 116 self._handshake_md5 = hashlib.md5() |
131 self._handshake_sha = sha.sha() | 117 self._handshake_sha = hashlib.sha1() |
132 | 118 |
133 #TLS Protocol Version | 119 #TLS Protocol Version |
134 self.version = (0,0) #read-only | 120 self.version = (0,0) #read-only |
135 self._versionCheck = False #Once we choose a version, this is True | 121 self._versionCheck = False #Once we choose a version, this is True |
136 | 122 |
137 #Current and Pending connection states | 123 #Current and Pending connection states |
138 self._writeState = _ConnectionState() | 124 self._writeState = _ConnectionState() |
139 self._readState = _ConnectionState() | 125 self._readState = _ConnectionState() |
140 self._pendingWriteState = _ConnectionState() | 126 self._pendingWriteState = _ConnectionState() |
141 self._pendingReadState = _ConnectionState() | 127 self._pendingReadState = _ConnectionState() |
142 | 128 |
143 #Is the connection open? | 129 #Is the connection open? |
144 self.closed = True #read-only | 130 self.closed = True #read-only |
145 self._refCount = 0 #Used to trigger closure | 131 self._refCount = 0 #Used to trigger closure |
146 | 132 |
147 #Is this a resumed (or shared-key) session? | 133 #Is this a resumed session? |
148 self.resumed = False #read-only | 134 self.resumed = False #read-only |
149 | 135 |
150 #What username did the client claim in his handshake? | 136 #What username did the client claim in his handshake? |
151 self.allegedSharedKeyUsername = None | |
152 self.allegedSrpUsername = None | 137 self.allegedSrpUsername = None |
153 | 138 |
154 #On a call to close(), do we close the socket? (writeable) | 139 #On a call to close(), do we close the socket? (writeable) |
155 self.closeSocket = False | 140 self.closeSocket = True |
156 | 141 |
157 #If the socket is abruptly closed, do we ignore it | 142 #If the socket is abruptly closed, do we ignore it |
158 #and pretend the connection was shut down properly? (writeable) | 143 #and pretend the connection was shut down properly? (writeable) |
159 self.ignoreAbruptClose = False | 144 self.ignoreAbruptClose = False |
160 | 145 |
161 #Fault we will induce, for testing purposes | 146 #Fault we will induce, for testing purposes |
162 self.fault = None | 147 self.fault = None |
163 | 148 |
| 149 def clearReadBuffer(self): |
| 150 self._readBuffer = b'' |
| 151 |
| 152 def clearWriteBuffer(self): |
| 153 self._send_writer = None |
| 154 |
| 155 |
164 #********************************************************* | 156 #********************************************************* |
165 # Public Functions START | 157 # Public Functions START |
166 #********************************************************* | 158 #********************************************************* |
167 | 159 |
168 def read(self, max=None, min=1): | 160 def read(self, max=None, min=1): |
169 """Read some data from the TLS connection. | 161 """Read some data from the TLS connection. |
170 | 162 |
171 This function will block until at least 'min' bytes are | 163 This function will block until at least 'min' bytes are |
172 available (or the connection is closed). | 164 available (or the connection is closed). |
173 | 165 |
(...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
206 @rtype: iterable | 198 @rtype: iterable |
207 @return: A generator; see above for details. | 199 @return: A generator; see above for details. |
208 """ | 200 """ |
209 try: | 201 try: |
210 while len(self._readBuffer)<min and not self.closed: | 202 while len(self._readBuffer)<min and not self.closed: |
211 try: | 203 try: |
212 for result in self._getMsg(ContentType.application_data): | 204 for result in self._getMsg(ContentType.application_data): |
213 if result in (0,1): | 205 if result in (0,1): |
214 yield result | 206 yield result |
215 applicationData = result | 207 applicationData = result |
216 self._readBuffer += bytesToString(applicationData.write()) | 208 self._readBuffer += applicationData.write() |
217 except TLSRemoteAlert, alert: | 209 except TLSRemoteAlert as alert: |
218 if alert.description != AlertDescription.close_notify: | 210 if alert.description != AlertDescription.close_notify: |
219 raise | 211 raise |
220 except TLSAbruptCloseError: | 212 except TLSAbruptCloseError: |
221 if not self.ignoreAbruptClose: | 213 if not self.ignoreAbruptClose: |
222 raise | 214 raise |
223 else: | 215 else: |
224 self._shutdown(True) | 216 self._shutdown(True) |
225 | 217 |
226 if max == None: | 218 if max == None: |
227 max = len(self._readBuffer) | 219 max = len(self._readBuffer) |
228 | 220 |
229 returnStr = self._readBuffer[:max] | 221 returnBytes = self._readBuffer[:max] |
230 self._readBuffer = self._readBuffer[max:] | 222 self._readBuffer = self._readBuffer[max:] |
231 yield returnStr | 223 yield bytes(returnBytes) |
| 224 except GeneratorExit: |
| 225 raise |
232 except: | 226 except: |
233 self._shutdown(False) | 227 self._shutdown(False) |
234 raise | 228 raise |
235 | 229 |
| 230 def unread(self, b): |
| 231 """Add bytes to the front of the socket read buffer for future |
| 232 reading. Be careful using this in the context of select(...): if you |
| 233 unread the last data from a socket, that won't wake up selected waiters, |
| 234 and those waiters may hang forever. |
| 235 """ |
| 236 self._readBuffer = b + self._readBuffer |
| 237 |
236 def write(self, s): | 238 def write(self, s): |
237 """Write some data to the TLS connection. | 239 """Write some data to the TLS connection. |
238 | 240 |
239 This function will block until all the data has been sent. | 241 This function will block until all the data has been sent. |
240 | 242 |
241 If an exception is raised, the connection will have been | 243 If an exception is raised, the connection will have been |
242 automatically closed. | 244 automatically closed. |
243 | 245 |
244 @type s: str | 246 @type s: str |
245 @param s: The data to transmit to the other party. | 247 @param s: The data to transmit to the other party. |
246 | 248 |
247 @raise socket.error: If a socket error occurs. | 249 @raise socket.error: If a socket error occurs. |
248 """ | 250 """ |
249 for result in self.writeAsync(s): | 251 for result in self.writeAsync(s): |
250 pass | 252 pass |
251 | 253 |
252 def writeAsync(self, s): | 254 def writeAsync(self, s): |
253 """Start a write operation on the TLS connection. | 255 """Start a write operation on the TLS connection. |
254 | 256 |
255 This function returns a generator which behaves similarly to | 257 This function returns a generator which behaves similarly to |
256 write(). Successive invocations of the generator will return | 258 write(). Successive invocations of the generator will return |
257 1 if it is waiting to write to the socket, or will raise | 259 1 if it is waiting to write to the socket, or will raise |
258 StopIteration if the write operation has completed. | 260 StopIteration if the write operation has completed. |
259 | 261 |
260 @rtype: iterable | 262 @rtype: iterable |
261 @return: A generator; see above for details. | 263 @return: A generator; see above for details. |
262 """ | 264 """ |
263 try: | 265 try: |
264 if self.closed: | 266 if self.closed: |
265 raise ValueError() | 267 raise TLSClosedConnectionError("attempt to write to closed conne
ction") |
266 | 268 |
267 index = 0 | 269 index = 0 |
268 blockSize = 16384 | 270 blockSize = 16384 |
269 skipEmptyFrag = False | 271 randomizeFirstBlock = True |
270 while 1: | 272 while 1: |
271 startIndex = index * blockSize | 273 startIndex = index * blockSize |
272 endIndex = startIndex + blockSize | 274 endIndex = startIndex + blockSize |
273 if startIndex >= len(s): | 275 if startIndex >= len(s): |
274 break | 276 break |
275 if endIndex > len(s): | 277 if endIndex > len(s): |
276 endIndex = len(s) | 278 endIndex = len(s) |
277 block = stringToBytes(s[startIndex : endIndex]) | 279 block = bytearray(s[startIndex : endIndex]) |
278 applicationData = ApplicationData().create(block) | 280 applicationData = ApplicationData().create(block) |
279 for result in self._sendMsg(applicationData, skipEmptyFrag): | 281 for result in self._sendMsg(applicationData, \ |
| 282 randomizeFirstBlock): |
280 yield result | 283 yield result |
281 skipEmptyFrag = True #only send an empy fragment on 1st message | 284 randomizeFirstBlock = False #only on 1st message |
282 index += 1 | 285 index += 1 |
283 except: | 286 except GeneratorExit: |
| 287 raise |
| 288 except Exception: |
284 self._shutdown(False) | 289 self._shutdown(False) |
285 raise | 290 raise |
286 | 291 |
287 def close(self): | 292 def close(self): |
288 """Close the TLS connection. | 293 """Close the TLS connection. |
289 | 294 |
290 This function will block until it has exchanged close_notify | 295 This function will block until it has exchanged close_notify |
291 alerts with the other party. After doing so, it will shut down the | 296 alerts with the other party. After doing so, it will shut down the |
292 TLS connection. Further attempts to read through this connection | 297 TLS connection. Further attempts to read through this connection |
293 will return "". Further attempts to write through this connection | 298 will return "". Further attempts to write through this connection |
294 will raise ValueError. | 299 will raise ValueError. |
295 | 300 |
296 If makefile() has been called on this connection, the connection | 301 If makefile() has been called on this connection, the connection |
297 will be not be closed until the connection object and all file | 302 will be not be closed until the connection object and all file |
298 objects have been closed. | 303 objects have been closed. |
299 | 304 |
300 Even if an exception is raised, the connection will have been | 305 Even if an exception is raised, the connection will have been |
301 closed. | 306 closed. |
302 | 307 |
303 @raise socket.error: If a socket error occurs. | 308 @raise socket.error: If a socket error occurs. |
304 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed | 309 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed |
305 without a preceding alert. | 310 without a preceding alert. |
306 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. | 311 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. |
307 """ | 312 """ |
308 if not self.closed: | 313 if not self.closed: |
309 for result in self._decrefAsync(): | 314 for result in self._decrefAsync(): |
310 pass | 315 pass |
311 | 316 |
| 317 # Python 3 callback |
| 318 _decref_socketios = close |
| 319 |
312 def closeAsync(self): | 320 def closeAsync(self): |
313 """Start a close operation on the TLS connection. | 321 """Start a close operation on the TLS connection. |
314 | 322 |
315 This function returns a generator which behaves similarly to | 323 This function returns a generator which behaves similarly to |
316 close(). Successive invocations of the generator will return 0 | 324 close(). Successive invocations of the generator will return 0 |
317 if it is waiting to read from the socket, 1 if it is waiting | 325 if it is waiting to read from the socket, 1 if it is waiting |
318 to write to the socket, or will raise StopIteration if the | 326 to write to the socket, or will raise StopIteration if the |
319 close operation has completed. | 327 close operation has completed. |
320 | 328 |
321 @rtype: iterable | 329 @rtype: iterable |
322 @return: A generator; see above for details. | 330 @return: A generator; see above for details. |
323 """ | 331 """ |
324 if not self.closed: | 332 if not self.closed: |
325 for result in self._decrefAsync(): | 333 for result in self._decrefAsync(): |
326 yield result | 334 yield result |
327 | 335 |
328 def _decrefAsync(self): | 336 def _decrefAsync(self): |
329 self._refCount -= 1 | 337 self._refCount -= 1 |
330 if self._refCount == 0 and not self.closed: | 338 if self._refCount == 0 and not self.closed: |
331 try: | 339 try: |
332 for result in self._sendMsg(Alert().create(\ | 340 for result in self._sendMsg(Alert().create(\ |
333 AlertDescription.close_notify, AlertLevel.warning)): | 341 AlertDescription.close_notify, AlertLevel.warning)): |
334 yield result | 342 yield result |
335 alert = None | 343 alert = None |
336 # Forcing a shutdown as WinHTTP does not seem to be | 344 # By default close the socket, since it's been observed |
337 # responsive to the close notify. | 345 # that some other libraries will not respond to the |
338 prevCloseSocket = self.closeSocket | 346 # close_notify alert, thus leaving us hanging if we're |
339 self.closeSocket = True | 347 # expecting it |
340 self._shutdown(True) | 348 if self.closeSocket: |
341 self.closeSocket = prevCloseSocket | |
342 while not alert: | |
343 for result in self._getMsg((ContentType.alert, \ | |
344 ContentType.application_data)): | |
345 if result in (0,1): | |
346 yield result | |
347 if result.contentType == ContentType.alert: | |
348 alert = result | |
349 if alert.description == AlertDescription.close_notify: | |
350 self._shutdown(True) | 349 self._shutdown(True) |
351 else: | 350 else: |
352 raise TLSRemoteAlert(alert) | 351 while not alert: |
| 352 for result in self._getMsg((ContentType.alert, \ |
| 353 ContentType.application_data))
: |
| 354 if result in (0,1): |
| 355 yield result |
| 356 if result.contentType == ContentType.alert: |
| 357 alert = result |
| 358 if alert.description == AlertDescription.close_notify: |
| 359 self._shutdown(True) |
| 360 else: |
| 361 raise TLSRemoteAlert(alert) |
353 except (socket.error, TLSAbruptCloseError): | 362 except (socket.error, TLSAbruptCloseError): |
354 #If the other side closes the socket, that's okay | 363 #If the other side closes the socket, that's okay |
355 self._shutdown(True) | 364 self._shutdown(True) |
| 365 except GeneratorExit: |
| 366 raise |
356 except: | 367 except: |
357 self._shutdown(False) | 368 self._shutdown(False) |
358 raise | 369 raise |
359 | 370 |
| 371 def getVersionName(self): |
| 372 """Get the name of this TLS version. |
| 373 |
| 374 @rtype: str |
| 375 @return: The name of the TLS version used with this connection. |
| 376 Either None, 'SSL 3.0', 'TLS 1.0', or 'TLS 1.1'. |
| 377 """ |
| 378 if self.version == (3,0): |
| 379 return "SSL 3.0" |
| 380 elif self.version == (3,1): |
| 381 return "TLS 1.0" |
| 382 elif self.version == (3,2): |
| 383 return "TLS 1.1" |
| 384 else: |
| 385 return None |
| 386 |
360 def getCipherName(self): | 387 def getCipherName(self): |
361 """Get the name of the cipher used with this connection. | 388 """Get the name of the cipher used with this connection. |
362 | 389 |
363 @rtype: str | 390 @rtype: str |
364 @return: The name of the cipher used with this connection. | 391 @return: The name of the cipher used with this connection. |
365 Either 'aes128', 'aes256', 'rc4', or '3des'. | 392 Either 'aes128', 'aes256', 'rc4', or '3des'. |
366 """ | 393 """ |
367 if not self._writeState.encContext: | 394 if not self._writeState.encContext: |
368 return None | 395 return None |
369 return self._writeState.encContext.name | 396 return self._writeState.encContext.name |
370 | 397 |
371 def getCipherImplementation(self): | 398 def getCipherImplementation(self): |
372 """Get the name of the cipher implementation used with | 399 """Get the name of the cipher implementation used with |
373 this connection. | 400 this connection. |
374 | 401 |
375 @rtype: str | 402 @rtype: str |
376 @return: The name of the cipher implementation used with | 403 @return: The name of the cipher implementation used with |
377 this connection. Either 'python', 'cryptlib', 'openssl', | 404 this connection. Either 'python', 'openssl', or 'pycrypto'. |
378 or 'pycrypto'. | |
379 """ | 405 """ |
380 if not self._writeState.encContext: | 406 if not self._writeState.encContext: |
381 return None | 407 return None |
382 return self._writeState.encContext.implementation | 408 return self._writeState.encContext.implementation |
383 | 409 |
384 | 410 |
385 | 411 |
386 #Emulate a socket, somewhat - | 412 #Emulate a socket, somewhat - |
387 def send(self, s): | 413 def send(self, s): |
388 """Send data to the TLS connection (socket emulation). | 414 """Send data to the TLS connection (socket emulation). |
(...skipping 13 matching lines...) Expand all Loading... |
402 def recv(self, bufsize): | 428 def recv(self, bufsize): |
403 """Get some data from the TLS connection (socket emulation). | 429 """Get some data from the TLS connection (socket emulation). |
404 | 430 |
405 @raise socket.error: If a socket error occurs. | 431 @raise socket.error: If a socket error occurs. |
406 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed | 432 @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed |
407 without a preceding alert. | 433 without a preceding alert. |
408 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. | 434 @raise tlslite.errors.TLSAlert: If a TLS alert is signalled. |
409 """ | 435 """ |
410 return self.read(bufsize) | 436 return self.read(bufsize) |
411 | 437 |
| 438 def recv_into(self, b): |
| 439 # XXX doc string |
| 440 data = self.read(len(b)) |
| 441 if not data: |
| 442 return None |
| 443 b[:len(data)] = data |
| 444 return len(data) |
| 445 |
412 def makefile(self, mode='r', bufsize=-1): | 446 def makefile(self, mode='r', bufsize=-1): |
413 """Create a file object for the TLS connection (socket emulation). | 447 """Create a file object for the TLS connection (socket emulation). |
414 | 448 |
415 @rtype: L{tlslite.FileObject.FileObject} | 449 @rtype: L{socket._fileobject} |
416 """ | 450 """ |
417 self._refCount += 1 | 451 self._refCount += 1 |
418 return FileObject(self, mode, bufsize) | 452 # So, it is pretty fragile to be using Python internal objects |
| 453 # like this, but it is probably the best/easiest way to provide |
| 454 # matching behavior for socket emulation purposes. The 'close' |
| 455 # argument is nice, its apparently a recent addition to this |
| 456 # class, so that when fileobject.close() gets called, it will |
| 457 # close() us, causing the refcount to be decremented (decrefAsync). |
| 458 # |
| 459 # If this is the last close() on the outstanding fileobjects / |
| 460 # TLSConnection, then the "actual" close alerts will be sent, |
| 461 # socket closed, etc. |
| 462 if sys.version_info < (3,): |
| 463 return socket._fileobject(self, mode, bufsize, close=True) |
| 464 else: |
| 465 # XXX need to wrap this further if buffering is requested |
| 466 return socket.SocketIO(self, mode) |
419 | 467 |
420 def getsockname(self): | 468 def getsockname(self): |
421 """Return the socket's own address (socket emulation).""" | 469 """Return the socket's own address (socket emulation).""" |
422 return self.sock.getsockname() | 470 return self.sock.getsockname() |
423 | 471 |
424 def getpeername(self): | 472 def getpeername(self): |
425 """Return the remote address to which the socket is connected | 473 """Return the remote address to which the socket is connected |
426 (socket emulation).""" | 474 (socket emulation).""" |
427 return self.sock.getpeername() | 475 return self.sock.getpeername() |
428 | 476 |
429 def settimeout(self, value): | 477 def settimeout(self, value): |
430 """Set a timeout on blocking socket operations (socket emulation).""" | 478 """Set a timeout on blocking socket operations (socket emulation).""" |
431 return self.sock.settimeout(value) | 479 return self.sock.settimeout(value) |
432 | 480 |
433 def gettimeout(self): | 481 def gettimeout(self): |
434 """Return the timeout associated with socket operations (socket | 482 """Return the timeout associated with socket operations (socket |
435 emulation).""" | 483 emulation).""" |
436 return self.sock.gettimeout() | 484 return self.sock.gettimeout() |
437 | 485 |
438 def setsockopt(self, level, optname, value): | 486 def setsockopt(self, level, optname, value): |
439 """Set the value of the given socket option (socket emulation).""" | 487 """Set the value of the given socket option (socket emulation).""" |
440 return self.sock.setsockopt(level, optname, value) | 488 return self.sock.setsockopt(level, optname, value) |
441 | 489 |
| 490 def shutdown(self, how): |
| 491 """Shutdown the underlying socket.""" |
| 492 return self.sock.shutdown(how) |
| 493 |
| 494 def fileno(self): |
| 495 """Not implement in TLS Lite.""" |
| 496 raise NotImplementedError() |
| 497 |
442 | 498 |
443 #********************************************************* | 499 #********************************************************* |
444 # Public Functions END | 500 # Public Functions END |
445 #********************************************************* | 501 #********************************************************* |
446 | 502 |
447 def _shutdown(self, resumable): | 503 def _shutdown(self, resumable): |
448 self._writeState = _ConnectionState() | 504 self._writeState = _ConnectionState() |
449 self._readState = _ConnectionState() | 505 self._readState = _ConnectionState() |
450 #Don't do this: self._readBuffer = "" | |
451 self.version = (0,0) | 506 self.version = (0,0) |
452 self._versionCheck = False | 507 self._versionCheck = False |
453 self.closed = True | 508 self.closed = True |
454 if self.closeSocket: | 509 if self.closeSocket: |
455 self.sock.close() | 510 self.sock.close() |
456 | 511 |
457 #Even if resumable is False, we'll never toggle this on | 512 #Even if resumable is False, we'll never toggle this on |
458 if not resumable and self.session: | 513 if not resumable and self.session: |
459 self.session.resumable = False | 514 self.session.resumable = False |
460 | 515 |
461 | 516 |
462 def _sendError(self, alertDescription, errorStr=None): | 517 def _sendError(self, alertDescription, errorStr=None): |
463 alert = Alert().create(alertDescription, AlertLevel.fatal) | 518 alert = Alert().create(alertDescription, AlertLevel.fatal) |
464 for result in self._sendMsg(alert): | 519 for result in self._sendMsg(alert): |
465 yield result | 520 yield result |
466 self._shutdown(False) | 521 self._shutdown(False) |
467 raise TLSLocalAlert(alert, errorStr) | 522 raise TLSLocalAlert(alert, errorStr) |
468 | 523 |
469 def _sendMsgs(self, msgs): | 524 def _sendMsgs(self, msgs): |
470 skipEmptyFrag = False | 525 randomizeFirstBlock = True |
471 for msg in msgs: | 526 for msg in msgs: |
472 for result in self._sendMsg(msg, skipEmptyFrag): | 527 for result in self._sendMsg(msg, randomizeFirstBlock): |
473 yield result | 528 yield result |
474 skipEmptyFrag = True | 529 randomizeFirstBlock = True |
475 | 530 |
476 def _sendMsg(self, msg, skipEmptyFrag=False): | 531 def _sendMsg(self, msg, randomizeFirstBlock = True): |
477 bytes = msg.write() | 532 #Whenever we're connected and asked to send an app data message, |
| 533 #we first send the first byte of the message. This prevents |
| 534 #an attacker from launching a chosen-plaintext attack based on |
| 535 #knowing the next IV (a la BEAST). |
| 536 if not self.closed and randomizeFirstBlock and self.version <= (3,1) \ |
| 537 and self._writeState.encContext \ |
| 538 and self._writeState.encContext.isBlockCipher \ |
| 539 and isinstance(msg, ApplicationData): |
| 540 msgFirstByte = msg.splitFirstByte() |
| 541 for result in self._sendMsg(msgFirstByte, |
| 542 randomizeFirstBlock = False): |
| 543 yield result |
| 544 |
| 545 b = msg.write() |
| 546 |
| 547 # If a 1-byte message was passed in, and we "split" the |
| 548 # first(only) byte off above, we may have a 0-length msg: |
| 549 if len(b) == 0: |
| 550 return |
| 551 |
478 contentType = msg.contentType | 552 contentType = msg.contentType |
479 | 553 |
480 #Whenever we're connected and asked to send a message, | |
481 #we first send an empty Application Data message. This prevents | |
482 #an attacker from launching a chosen-plaintext attack based on | |
483 #knowing the next IV. | |
484 if not self.closed and not skipEmptyFrag and self.version == (3,1): | |
485 if self._writeState.encContext: | |
486 if self._writeState.encContext.isBlockCipher: | |
487 for result in self._sendMsg(ApplicationData(), | |
488 skipEmptyFrag=True): | |
489 yield result | |
490 | |
491 #Update handshake hashes | 554 #Update handshake hashes |
492 if contentType == ContentType.handshake: | 555 if contentType == ContentType.handshake: |
493 bytesStr = bytesToString(bytes) | 556 self._handshake_md5.update(compat26Str(b)) |
494 self._handshake_md5.update(bytesStr) | 557 self._handshake_sha.update(compat26Str(b)) |
495 self._handshake_sha.update(bytesStr) | |
496 | 558 |
497 #Calculate MAC | 559 #Calculate MAC |
498 if self._writeState.macContext: | 560 if self._writeState.macContext: |
499 seqnumStr = self._writeState.getSeqNumStr() | 561 seqnumBytes = self._writeState.getSeqNumBytes() |
500 bytesStr = bytesToString(bytes) | |
501 mac = self._writeState.macContext.copy() | 562 mac = self._writeState.macContext.copy() |
502 mac.update(seqnumStr) | 563 mac.update(compatHMAC(seqnumBytes)) |
503 mac.update(chr(contentType)) | 564 mac.update(compatHMAC(bytearray([contentType]))) |
504 if self.version == (3,0): | 565 if self.version == (3,0): |
505 mac.update( chr( int(len(bytes)/256) ) ) | 566 mac.update( compatHMAC( bytearray([len(b)//256] ))) |
506 mac.update( chr( int(len(bytes)%256) ) ) | 567 mac.update( compatHMAC( bytearray([len(b)%256] ))) |
507 elif self.version in ((3,1), (3,2)): | 568 elif self.version in ((3,1), (3,2)): |
508 mac.update(chr(self.version[0])) | 569 mac.update(compatHMAC( bytearray([self.version[0]] ))) |
509 mac.update(chr(self.version[1])) | 570 mac.update(compatHMAC( bytearray([self.version[1]] ))) |
510 mac.update( chr( int(len(bytes)/256) ) ) | 571 mac.update( compatHMAC( bytearray([len(b)//256] ))) |
511 mac.update( chr( int(len(bytes)%256) ) ) | 572 mac.update( compatHMAC( bytearray([len(b)%256] ))) |
512 else: | 573 else: |
513 raise AssertionError() | 574 raise AssertionError() |
514 mac.update(bytesStr) | 575 mac.update(compatHMAC(b)) |
515 macString = mac.digest() | 576 macBytes = bytearray(mac.digest()) |
516 macBytes = stringToBytes(macString) | |
517 if self.fault == Fault.badMAC: | 577 if self.fault == Fault.badMAC: |
518 macBytes[0] = (macBytes[0]+1) % 256 | 578 macBytes[0] = (macBytes[0]+1) % 256 |
519 | 579 |
520 #Encrypt for Block or Stream Cipher | 580 #Encrypt for Block or Stream Cipher |
521 if self._writeState.encContext: | 581 if self._writeState.encContext: |
522 #Add padding and encrypt (for Block Cipher): | 582 #Add padding and encrypt (for Block Cipher): |
523 if self._writeState.encContext.isBlockCipher: | 583 if self._writeState.encContext.isBlockCipher: |
524 | 584 |
525 #Add TLS 1.1 fixed block | 585 #Add TLS 1.1 fixed block |
526 if self.version == (3,2): | 586 if self.version == (3,2): |
527 bytes = self.fixedIVBlock + bytes | 587 b = self.fixedIVBlock + b |
528 | 588 |
529 #Add padding: bytes = bytes + (macBytes + paddingBytes) | 589 #Add padding: b = b + (macBytes + paddingBytes) |
530 currentLength = len(bytes) + len(macBytes) + 1 | 590 currentLength = len(b) + len(macBytes) |
531 blockLength = self._writeState.encContext.block_size | 591 blockLength = self._writeState.encContext.block_size |
532 paddingLength = blockLength-(currentLength % blockLength) | 592 paddingLength = blockLength - 1 - (currentLength % blockLength) |
533 | 593 |
534 paddingBytes = createByteArraySequence([paddingLength] * \ | 594 paddingBytes = bytearray([paddingLength] * (paddingLength+1)) |
535 (paddingLength+1)) | |
536 if self.fault == Fault.badPadding: | 595 if self.fault == Fault.badPadding: |
537 paddingBytes[0] = (paddingBytes[0]+1) % 256 | 596 paddingBytes[0] = (paddingBytes[0]+1) % 256 |
538 endBytes = concatArrays(macBytes, paddingBytes) | 597 endBytes = macBytes + paddingBytes |
539 bytes = concatArrays(bytes, endBytes) | 598 b += endBytes |
540 #Encrypt | 599 #Encrypt |
541 plaintext = stringToBytes(bytes) | 600 b = self._writeState.encContext.encrypt(b) |
542 ciphertext = self._writeState.encContext.encrypt(plaintext) | |
543 bytes = stringToBytes(ciphertext) | |
544 | 601 |
545 #Encrypt (for Stream Cipher) | 602 #Encrypt (for Stream Cipher) |
546 else: | 603 else: |
547 bytes = concatArrays(bytes, macBytes) | 604 b += macBytes |
548 plaintext = bytesToString(bytes) | 605 b = self._writeState.encContext.encrypt(b) |
549 ciphertext = self._writeState.encContext.encrypt(plaintext) | |
550 bytes = stringToBytes(ciphertext) | |
551 | 606 |
552 #Add record header and send | 607 #Add record header and send |
553 r = RecordHeader3().create(self.version, contentType, len(bytes)) | 608 r = RecordHeader3().create(self.version, contentType, len(b)) |
554 s = bytesToString(concatArrays(r.write(), bytes)) | 609 s = r.write() + b |
555 while 1: | 610 while 1: |
556 try: | 611 try: |
557 bytesSent = self.sock.send(s) #Might raise socket.error | 612 bytesSent = self.sock.send(s) #Might raise socket.error |
558 except socket.error, why: | 613 except socket.error as why: |
559 if why[0] == errno.EWOULDBLOCK: | 614 if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): |
560 yield 1 | 615 yield 1 |
561 continue | 616 continue |
562 else: | 617 else: |
563 raise | 618 # The socket was unexpectedly closed. The tricky part |
| 619 # is that there may be an alert sent by the other party |
| 620 # sitting in the read buffer. So, if we get here after |
| 621 # handshaking, we will just raise the error and let the |
| 622 # caller read more data if it would like, thus stumbling |
| 623 # upon the error. |
| 624 # |
| 625 # However, if we get here DURING handshaking, we take |
| 626 # it upon ourselves to see if the next message is an |
| 627 # Alert. |
| 628 if contentType == ContentType.handshake: |
| 629 |
| 630 # See if there's an alert record |
| 631 # Could raise socket.error or TLSAbruptCloseError |
| 632 for result in self._getNextRecord(): |
| 633 if result in (0,1): |
| 634 yield result |
| 635 |
| 636 # Closes the socket |
| 637 self._shutdown(False) |
| 638 |
| 639 # If we got an alert, raise it |
| 640 recordHeader, p = result |
| 641 if recordHeader.type == ContentType.alert: |
| 642 alert = Alert().parse(p) |
| 643 raise TLSRemoteAlert(alert) |
| 644 else: |
| 645 # If we got some other message who know what |
| 646 # the remote side is doing, just go ahead and |
| 647 # raise the socket.error |
| 648 raise |
564 if bytesSent == len(s): | 649 if bytesSent == len(s): |
565 return | 650 return |
566 s = s[bytesSent:] | 651 s = s[bytesSent:] |
567 yield 1 | 652 yield 1 |
568 | 653 |
569 | 654 |
570 def _getMsg(self, expectedType, secondaryType=None, constructorType=None): | 655 def _getMsg(self, expectedType, secondaryType=None, constructorType=None): |
571 try: | 656 try: |
572 if not isinstance(expectedType, tuple): | 657 if not isinstance(expectedType, tuple): |
573 expectedType = (expectedType,) | 658 expectedType = (expectedType,) |
(...skipping 109 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
683 subType = HandshakeType.client_hello | 768 subType = HandshakeType.client_hello |
684 else: | 769 else: |
685 subType = p.get(1) | 770 subType = p.get(1) |
686 if subType not in secondaryType: | 771 if subType not in secondaryType: |
687 for result in self._sendError(\ | 772 for result in self._sendError(\ |
688 AlertDescription.unexpected_message, | 773 AlertDescription.unexpected_message, |
689 "Expecting %s, got %s" % (str(secondaryType), su
bType)): | 774 "Expecting %s, got %s" % (str(secondaryType), su
bType)): |
690 yield result | 775 yield result |
691 | 776 |
692 #Update handshake hashes | 777 #Update handshake hashes |
693 sToHash = bytesToString(p.bytes) | 778 self._handshake_md5.update(compat26Str(p.bytes)) |
694 self._handshake_md5.update(sToHash) | 779 self._handshake_sha.update(compat26Str(p.bytes)) |
695 self._handshake_sha.update(sToHash) | |
696 | 780 |
697 #Parse based on handshake type | 781 #Parse based on handshake type |
698 if subType == HandshakeType.client_hello: | 782 if subType == HandshakeType.client_hello: |
699 yield ClientHello(recordHeader.ssl2).parse(p) | 783 yield ClientHello(recordHeader.ssl2).parse(p) |
700 elif subType == HandshakeType.server_hello: | 784 elif subType == HandshakeType.server_hello: |
701 yield ServerHello().parse(p) | 785 yield ServerHello().parse(p) |
702 elif subType == HandshakeType.certificate: | 786 elif subType == HandshakeType.certificate: |
703 yield Certificate(constructorType).parse(p) | 787 yield Certificate(constructorType).parse(p) |
704 elif subType == HandshakeType.certificate_request: | 788 elif subType == HandshakeType.certificate_request: |
705 yield CertificateRequest().parse(p) | 789 yield CertificateRequest().parse(p) |
706 elif subType == HandshakeType.certificate_verify: | 790 elif subType == HandshakeType.certificate_verify: |
707 yield CertificateVerify().parse(p) | 791 yield CertificateVerify().parse(p) |
708 elif subType == HandshakeType.server_key_exchange: | 792 elif subType == HandshakeType.server_key_exchange: |
709 yield ServerKeyExchange(constructorType).parse(p) | 793 yield ServerKeyExchange(constructorType).parse(p) |
710 elif subType == HandshakeType.server_hello_done: | 794 elif subType == HandshakeType.server_hello_done: |
711 yield ServerHelloDone().parse(p) | 795 yield ServerHelloDone().parse(p) |
712 elif subType == HandshakeType.client_key_exchange: | 796 elif subType == HandshakeType.client_key_exchange: |
713 yield ClientKeyExchange(constructorType, \ | 797 yield ClientKeyExchange(constructorType, \ |
714 self.version).parse(p) | 798 self.version).parse(p) |
715 elif subType == HandshakeType.finished: | 799 elif subType == HandshakeType.finished: |
716 yield Finished(self.version).parse(p) | 800 yield Finished(self.version).parse(p) |
| 801 elif subType == HandshakeType.next_protocol: |
| 802 yield NextProtocol().parse(p) |
717 elif subType == HandshakeType.encrypted_extensions: | 803 elif subType == HandshakeType.encrypted_extensions: |
718 yield EncryptedExtensions().parse(p) | 804 yield EncryptedExtensions().parse(p) |
719 else: | 805 else: |
720 raise AssertionError() | 806 raise AssertionError() |
721 | 807 |
722 #If an exception was raised by a Parser or Message instance: | 808 #If an exception was raised by a Parser or Message instance: |
723 except SyntaxError, e: | 809 except SyntaxError as e: |
724 for result in self._sendError(AlertDescription.decode_error, | 810 for result in self._sendError(AlertDescription.decode_error, |
725 formatExceptionTrace(e)): | 811 formatExceptionTrace(e)): |
726 yield result | 812 yield result |
727 | 813 |
728 | 814 |
729 #Returns next record or next handshake message | 815 #Returns next record or next handshake message |
730 def _getNextRecord(self): | 816 def _getNextRecord(self): |
731 | 817 |
732 #If there's a handshake message waiting, return it | 818 #If there's a handshake message waiting, return it |
733 if self._handshakeBuffer: | 819 if self._handshakeBuffer: |
734 recordHeader, bytes = self._handshakeBuffer[0] | 820 recordHeader, b = self._handshakeBuffer[0] |
735 self._handshakeBuffer = self._handshakeBuffer[1:] | 821 self._handshakeBuffer = self._handshakeBuffer[1:] |
736 yield (recordHeader, Parser(bytes)) | 822 yield (recordHeader, Parser(b)) |
737 return | 823 return |
738 | 824 |
739 #Otherwise... | 825 #Otherwise... |
740 #Read the next record header | 826 #Read the next record header |
741 bytes = createByteArraySequence([]) | 827 b = bytearray(0) |
742 recordHeaderLength = 1 | 828 recordHeaderLength = 1 |
743 ssl2 = False | 829 ssl2 = False |
744 while 1: | 830 while 1: |
745 try: | 831 try: |
746 s = self.sock.recv(recordHeaderLength-len(bytes)) | 832 s = self.sock.recv(recordHeaderLength-len(b)) |
747 except socket.error, why: | 833 except socket.error as why: |
748 if why[0] == errno.EWOULDBLOCK: | 834 if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): |
749 yield 0 | 835 yield 0 |
750 continue | 836 continue |
751 else: | 837 else: |
752 raise | 838 raise |
753 | 839 |
754 #If the connection was abruptly closed, raise an error | 840 #If the connection was abruptly closed, raise an error |
755 if len(s)==0: | 841 if len(s)==0: |
756 raise TLSAbruptCloseError() | 842 raise TLSAbruptCloseError() |
757 | 843 |
758 bytes += stringToBytes(s) | 844 b += bytearray(s) |
759 if len(bytes)==1: | 845 if len(b)==1: |
760 if bytes[0] in ContentType.all: | 846 if b[0] in ContentType.all: |
761 ssl2 = False | 847 ssl2 = False |
762 recordHeaderLength = 5 | 848 recordHeaderLength = 5 |
763 elif bytes[0] == 128: | 849 elif b[0] == 128: |
764 ssl2 = True | 850 ssl2 = True |
765 recordHeaderLength = 2 | 851 recordHeaderLength = 2 |
766 else: | 852 else: |
767 raise SyntaxError() | 853 raise SyntaxError() |
768 if len(bytes) == recordHeaderLength: | 854 if len(b) == recordHeaderLength: |
769 break | 855 break |
770 | 856 |
771 #Parse the record header | 857 #Parse the record header |
772 if ssl2: | 858 if ssl2: |
773 r = RecordHeader2().parse(Parser(bytes)) | 859 r = RecordHeader2().parse(Parser(b)) |
774 else: | 860 else: |
775 r = RecordHeader3().parse(Parser(bytes)) | 861 r = RecordHeader3().parse(Parser(b)) |
776 | 862 |
777 #Check the record header fields | 863 #Check the record header fields |
778 if r.length > 18432: | 864 if r.length > 18432: |
779 for result in self._sendError(AlertDescription.record_overflow): | 865 for result in self._sendError(AlertDescription.record_overflow): |
780 yield result | 866 yield result |
781 | 867 |
782 #Read the record contents | 868 #Read the record contents |
783 bytes = createByteArraySequence([]) | 869 b = bytearray(0) |
784 while 1: | 870 while 1: |
785 try: | 871 try: |
786 s = self.sock.recv(r.length - len(bytes)) | 872 s = self.sock.recv(r.length - len(b)) |
787 except socket.error, why: | 873 except socket.error as why: |
788 if why[0] == errno.EWOULDBLOCK: | 874 if why.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN): |
789 yield 0 | 875 yield 0 |
790 continue | 876 continue |
791 else: | 877 else: |
792 raise | 878 raise |
793 | 879 |
794 #If the connection is closed, raise a socket error | 880 #If the connection is closed, raise a socket error |
795 if len(s)==0: | 881 if len(s)==0: |
796 raise TLSAbruptCloseError() | 882 raise TLSAbruptCloseError() |
797 | 883 |
798 bytes += stringToBytes(s) | 884 b += bytearray(s) |
799 if len(bytes) == r.length: | 885 if len(b) == r.length: |
800 break | 886 break |
801 | 887 |
802 #Check the record header fields (2) | 888 #Check the record header fields (2) |
803 #We do this after reading the contents from the socket, so that | 889 #We do this after reading the contents from the socket, so that |
804 #if there's an error, we at least don't leave extra bytes in the | 890 #if there's an error, we at least don't leave extra bytes in the |
805 #socket.. | 891 #socket.. |
806 # | 892 # |
807 # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP. | 893 # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP. |
808 # SO WE LEAVE IT OUT FOR NOW. | 894 # SO WE LEAVE IT OUT FOR NOW. |
809 # | 895 # |
810 #if self._versionCheck and r.version != self.version: | 896 #if self._versionCheck and r.version != self.version: |
811 # for result in self._sendError(AlertDescription.protocol_version, | 897 # for result in self._sendError(AlertDescription.protocol_version, |
812 # "Version in header field: %s, should be %s" % (str(r.versio
n), | 898 # "Version in header field: %s, should be %s" % (str(r.versio
n), |
813 # str(self.version
))): | 899 # str(self.version
))): |
814 # yield result | 900 # yield result |
815 | 901 |
816 #Decrypt the record | 902 #Decrypt the record |
817 for result in self._decryptRecord(r.type, bytes): | 903 for result in self._decryptRecord(r.type, b): |
818 if result in (0,1): | 904 if result in (0,1): yield result |
819 yield result | 905 else: break |
820 else: | 906 b = result |
821 break | 907 p = Parser(b) |
822 bytes = result | |
823 p = Parser(bytes) | |
824 | 908 |
825 #If it doesn't contain handshake messages, we can just return it | 909 #If it doesn't contain handshake messages, we can just return it |
826 if r.type != ContentType.handshake: | 910 if r.type != ContentType.handshake: |
827 yield (r, p) | 911 yield (r, p) |
828 #If it's an SSLv2 ClientHello, we can return it as well | 912 #If it's an SSLv2 ClientHello, we can return it as well |
829 elif r.ssl2: | 913 elif r.ssl2: |
830 yield (r, p) | 914 yield (r, p) |
831 else: | 915 else: |
832 #Otherwise, we loop through and add the handshake messages to the | 916 #Otherwise, we loop through and add the handshake messages to the |
833 #handshake buffer | 917 #handshake buffer |
834 while 1: | 918 while 1: |
835 if p.index == len(bytes): #If we're at the end | 919 if p.index == len(b): #If we're at the end |
836 if not self._handshakeBuffer: | 920 if not self._handshakeBuffer: |
837 for result in self._sendError(\ | 921 for result in self._sendError(\ |
838 AlertDescription.decode_error, \ | 922 AlertDescription.decode_error, \ |
839 "Received empty handshake record"): | 923 "Received empty handshake record"): |
840 yield result | 924 yield result |
841 break | 925 break |
842 #There needs to be at least 4 bytes to get a header | 926 #There needs to be at least 4 bytes to get a header |
843 if p.index+4 > len(bytes): | 927 if p.index+4 > len(b): |
844 for result in self._sendError(\ | 928 for result in self._sendError(\ |
845 AlertDescription.decode_error, | 929 AlertDescription.decode_error, |
846 "A record has a partial handshake message (1)"): | 930 "A record has a partial handshake message (1)"): |
847 yield result | 931 yield result |
848 p.get(1) # skip handshake type | 932 p.get(1) # skip handshake type |
849 msgLength = p.get(3) | 933 msgLength = p.get(3) |
850 if p.index+msgLength > len(bytes): | 934 if p.index+msgLength > len(b): |
851 for result in self._sendError(\ | 935 for result in self._sendError(\ |
852 AlertDescription.decode_error, | 936 AlertDescription.decode_error, |
853 "A record has a partial handshake message (2)"): | 937 "A record has a partial handshake message (2)"): |
854 yield result | 938 yield result |
855 | 939 |
856 handshakePair = (r, bytes[p.index-4 : p.index+msgLength]) | 940 handshakePair = (r, b[p.index-4 : p.index+msgLength]) |
857 self._handshakeBuffer.append(handshakePair) | 941 self._handshakeBuffer.append(handshakePair) |
858 p.index += msgLength | 942 p.index += msgLength |
859 | 943 |
860 #We've moved at least one handshake message into the | 944 #We've moved at least one handshake message into the |
861 #handshakeBuffer, return the first one | 945 #handshakeBuffer, return the first one |
862 recordHeader, bytes = self._handshakeBuffer[0] | 946 recordHeader, b = self._handshakeBuffer[0] |
863 self._handshakeBuffer = self._handshakeBuffer[1:] | 947 self._handshakeBuffer = self._handshakeBuffer[1:] |
864 yield (recordHeader, Parser(bytes)) | 948 yield (recordHeader, Parser(b)) |
865 | 949 |
866 | 950 |
867 def _decryptRecord(self, recordType, bytes): | 951 def _decryptRecord(self, recordType, b): |
868 if self._readState.encContext: | 952 if self._readState.encContext: |
869 | 953 |
870 #Decrypt if it's a block cipher | 954 #Decrypt if it's a block cipher |
871 if self._readState.encContext.isBlockCipher: | 955 if self._readState.encContext.isBlockCipher: |
872 blockLength = self._readState.encContext.block_size | 956 blockLength = self._readState.encContext.block_size |
873 if len(bytes) % blockLength != 0: | 957 if len(b) % blockLength != 0: |
874 for result in self._sendError(\ | 958 for result in self._sendError(\ |
875 AlertDescription.decryption_failed, | 959 AlertDescription.decryption_failed, |
876 "Encrypted data not a multiple of blocksize"): | 960 "Encrypted data not a multiple of blocksize"): |
877 yield result | 961 yield result |
878 ciphertext = bytesToString(bytes) | 962 b = self._readState.encContext.decrypt(b) |
879 plaintext = self._readState.encContext.decrypt(ciphertext) | |
880 if self.version == (3,2): #For TLS 1.1, remove explicit IV | 963 if self.version == (3,2): #For TLS 1.1, remove explicit IV |
881 plaintext = plaintext[self._readState.encContext.block_size
: ] | 964 b = b[self._readState.encContext.block_size : ] |
882 bytes = stringToBytes(plaintext) | |
883 | 965 |
884 #Check padding | 966 #Check padding |
885 paddingGood = True | 967 paddingGood = True |
886 paddingLength = bytes[-1] | 968 paddingLength = b[-1] |
887 if (paddingLength+1) > len(bytes): | 969 if (paddingLength+1) > len(b): |
888 paddingGood=False | 970 paddingGood=False |
889 totalPaddingLength = 0 | 971 totalPaddingLength = 0 |
890 else: | 972 else: |
891 if self.version == (3,0): | 973 if self.version == (3,0): |
892 totalPaddingLength = paddingLength+1 | 974 totalPaddingLength = paddingLength+1 |
893 elif self.version in ((3,1), (3,2)): | 975 elif self.version in ((3,1), (3,2)): |
894 totalPaddingLength = paddingLength+1 | 976 totalPaddingLength = paddingLength+1 |
895 paddingBytes = bytes[-totalPaddingLength:-1] | 977 paddingBytes = b[-totalPaddingLength:-1] |
896 for byte in paddingBytes: | 978 for byte in paddingBytes: |
897 if byte != paddingLength: | 979 if byte != paddingLength: |
898 paddingGood = False | 980 paddingGood = False |
899 totalPaddingLength = 0 | 981 totalPaddingLength = 0 |
900 else: | 982 else: |
901 raise AssertionError() | 983 raise AssertionError() |
902 | 984 |
903 #Decrypt if it's a stream cipher | 985 #Decrypt if it's a stream cipher |
904 else: | 986 else: |
905 paddingGood = True | 987 paddingGood = True |
906 ciphertext = bytesToString(bytes) | 988 b = self._readState.encContext.decrypt(b) |
907 plaintext = self._readState.encContext.decrypt(ciphertext) | |
908 bytes = stringToBytes(plaintext) | |
909 totalPaddingLength = 0 | 989 totalPaddingLength = 0 |
910 | 990 |
911 #Check MAC | 991 #Check MAC |
912 macGood = True | 992 macGood = True |
913 macLength = self._readState.macContext.digest_size | 993 macLength = self._readState.macContext.digest_size |
914 endLength = macLength + totalPaddingLength | 994 endLength = macLength + totalPaddingLength |
915 if endLength > len(bytes): | 995 if endLength > len(b): |
916 macGood = False | 996 macGood = False |
917 else: | 997 else: |
918 #Read MAC | 998 #Read MAC |
919 startIndex = len(bytes) - endLength | 999 startIndex = len(b) - endLength |
920 endIndex = startIndex + macLength | 1000 endIndex = startIndex + macLength |
921 checkBytes = bytes[startIndex : endIndex] | 1001 checkBytes = b[startIndex : endIndex] |
922 | 1002 |
923 #Calculate MAC | 1003 #Calculate MAC |
924 seqnumStr = self._readState.getSeqNumStr() | 1004 seqnumBytes = self._readState.getSeqNumBytes() |
925 bytes = bytes[:-endLength] | 1005 b = b[:-endLength] |
926 bytesStr = bytesToString(bytes) | |
927 mac = self._readState.macContext.copy() | 1006 mac = self._readState.macContext.copy() |
928 mac.update(seqnumStr) | 1007 mac.update(compatHMAC(seqnumBytes)) |
929 mac.update(chr(recordType)) | 1008 mac.update(compatHMAC(bytearray([recordType]))) |
930 if self.version == (3,0): | 1009 if self.version == (3,0): |
931 mac.update( chr( int(len(bytes)/256) ) ) | 1010 mac.update( compatHMAC(bytearray( [len(b)//256] ) )) |
932 mac.update( chr( int(len(bytes)%256) ) ) | 1011 mac.update( compatHMAC(bytearray( [len(b)%256] ) )) |
933 elif self.version in ((3,1), (3,2)): | 1012 elif self.version in ((3,1), (3,2)): |
934 mac.update(chr(self.version[0])) | 1013 mac.update(compatHMAC(bytearray( [self.version[0]] ) )) |
935 mac.update(chr(self.version[1])) | 1014 mac.update(compatHMAC(bytearray( [self.version[1]] ) )) |
936 mac.update( chr( int(len(bytes)/256) ) ) | 1015 mac.update(compatHMAC(bytearray( [len(b)//256] ) )) |
937 mac.update( chr( int(len(bytes)%256) ) ) | 1016 mac.update(compatHMAC(bytearray( [len(b)%256] ) )) |
938 else: | 1017 else: |
939 raise AssertionError() | 1018 raise AssertionError() |
940 mac.update(bytesStr) | 1019 mac.update(compatHMAC(b)) |
941 macString = mac.digest() | 1020 macBytes = bytearray(mac.digest()) |
942 macBytes = stringToBytes(macString) | |
943 | 1021 |
944 #Compare MACs | 1022 #Compare MACs |
945 if macBytes != checkBytes: | 1023 if macBytes != checkBytes: |
946 macGood = False | 1024 macGood = False |
947 | 1025 |
948 if not (paddingGood and macGood): | 1026 if not (paddingGood and macGood): |
949 for result in self._sendError(AlertDescription.bad_record_mac, | 1027 for result in self._sendError(AlertDescription.bad_record_mac, |
950 "MAC failure (or padding failure)"): | 1028 "MAC failure (or padding failure)"): |
951 yield result | 1029 yield result |
952 | 1030 |
953 yield bytes | 1031 yield b |
954 | 1032 |
955 def _handshakeStart(self, client): | 1033 def _handshakeStart(self, client): |
| 1034 if not self.closed: |
| 1035 raise ValueError("Renegotiation disallowed for security reasons") |
956 self._client = client | 1036 self._client = client |
957 self._handshake_md5 = md5.md5() | 1037 self._handshake_md5 = hashlib.md5() |
958 self._handshake_sha = sha.sha() | 1038 self._handshake_sha = hashlib.sha1() |
959 self._handshakeBuffer = [] | 1039 self._handshakeBuffer = [] |
960 self.allegedSharedKeyUsername = None | |
961 self.allegedSrpUsername = None | 1040 self.allegedSrpUsername = None |
962 self._refCount = 1 | 1041 self._refCount = 1 |
963 | 1042 |
964 def _handshakeDone(self, resumed): | 1043 def _handshakeDone(self, resumed): |
965 self.resumed = resumed | 1044 self.resumed = resumed |
966 self.closed = False | 1045 self.closed = False |
967 | 1046 |
968 def _calcPendingStates(self, clientRandom, serverRandom, implementations): | 1047 def _calcPendingStates(self, cipherSuite, masterSecret, |
969 if self.session.cipherSuite in CipherSuite.aes128Suites: | 1048 clientRandom, serverRandom, implementations): |
970 macLength = 20 | 1049 if cipherSuite in CipherSuite.aes128Suites: |
971 keyLength = 16 | 1050 keyLength = 16 |
972 ivLength = 16 | 1051 ivLength = 16 |
973 createCipherFunc = createAES | 1052 createCipherFunc = createAES |
974 elif self.session.cipherSuite in CipherSuite.aes256Suites: | 1053 elif cipherSuite in CipherSuite.aes256Suites: |
975 macLength = 20 | |
976 keyLength = 32 | 1054 keyLength = 32 |
977 ivLength = 16 | 1055 ivLength = 16 |
978 createCipherFunc = createAES | 1056 createCipherFunc = createAES |
979 elif self.session.cipherSuite in CipherSuite.rc4Suites: | 1057 elif cipherSuite in CipherSuite.rc4Suites: |
980 macLength = 20 | |
981 keyLength = 16 | 1058 keyLength = 16 |
982 ivLength = 0 | 1059 ivLength = 0 |
983 createCipherFunc = createRC4 | 1060 createCipherFunc = createRC4 |
984 elif self.session.cipherSuite in CipherSuite.tripleDESSuites: | 1061 elif cipherSuite in CipherSuite.tripleDESSuites: |
985 macLength = 20 | |
986 keyLength = 24 | 1062 keyLength = 24 |
987 ivLength = 8 | 1063 ivLength = 8 |
988 createCipherFunc = createTripleDES | 1064 createCipherFunc = createTripleDES |
989 else: | 1065 else: |
990 raise AssertionError() | 1066 raise AssertionError() |
| 1067 |
| 1068 if cipherSuite in CipherSuite.shaSuites: |
| 1069 macLength = 20 |
| 1070 digestmod = hashlib.sha1 |
| 1071 elif cipherSuite in CipherSuite.md5Suites: |
| 1072 macLength = 16 |
| 1073 digestmod = hashlib.md5 |
991 | 1074 |
992 if self.version == (3,0): | 1075 if self.version == (3,0): |
993 createMACFunc = MAC_SSL | 1076 createMACFunc = createMAC_SSL |
994 elif self.version in ((3,1), (3,2)): | 1077 elif self.version in ((3,1), (3,2)): |
995 createMACFunc = hmac.HMAC | 1078 createMACFunc = createHMAC |
996 | 1079 |
997 outputLength = (macLength*2) + (keyLength*2) + (ivLength*2) | 1080 outputLength = (macLength*2) + (keyLength*2) + (ivLength*2) |
998 | 1081 |
999 #Calculate Keying Material from Master Secret | 1082 #Calculate Keying Material from Master Secret |
1000 if self.version == (3,0): | 1083 if self.version == (3,0): |
1001 keyBlock = PRF_SSL(self.session.masterSecret, | 1084 keyBlock = PRF_SSL(masterSecret, |
1002 concatArrays(serverRandom, clientRandom), | 1085 serverRandom + clientRandom, |
1003 outputLength) | 1086 outputLength) |
1004 elif self.version in ((3,1), (3,2)): | 1087 elif self.version in ((3,1), (3,2)): |
1005 keyBlock = PRF(self.session.masterSecret, | 1088 keyBlock = PRF(masterSecret, |
1006 "key expansion", | 1089 b"key expansion", |
1007 concatArrays(serverRandom,clientRandom), | 1090 serverRandom + clientRandom, |
1008 outputLength) | 1091 outputLength) |
1009 else: | 1092 else: |
1010 raise AssertionError() | 1093 raise AssertionError() |
1011 | 1094 |
1012 #Slice up Keying Material | 1095 #Slice up Keying Material |
1013 clientPendingState = _ConnectionState() | 1096 clientPendingState = _ConnectionState() |
1014 serverPendingState = _ConnectionState() | 1097 serverPendingState = _ConnectionState() |
1015 p = Parser(keyBlock) | 1098 p = Parser(keyBlock) |
1016 clientMACBlock = bytesToString(p.getFixBytes(macLength)) | 1099 clientMACBlock = p.getFixBytes(macLength) |
1017 serverMACBlock = bytesToString(p.getFixBytes(macLength)) | 1100 serverMACBlock = p.getFixBytes(macLength) |
1018 clientKeyBlock = bytesToString(p.getFixBytes(keyLength)) | 1101 clientKeyBlock = p.getFixBytes(keyLength) |
1019 serverKeyBlock = bytesToString(p.getFixBytes(keyLength)) | 1102 serverKeyBlock = p.getFixBytes(keyLength) |
1020 clientIVBlock = bytesToString(p.getFixBytes(ivLength)) | 1103 clientIVBlock = p.getFixBytes(ivLength) |
1021 serverIVBlock = bytesToString(p.getFixBytes(ivLength)) | 1104 serverIVBlock = p.getFixBytes(ivLength) |
1022 clientPendingState.macContext = createMACFunc(clientMACBlock, | 1105 clientPendingState.macContext = createMACFunc( |
1023 digestmod=sha) | 1106 compatHMAC(clientMACBlock), digestmod=digestmod) |
1024 serverPendingState.macContext = createMACFunc(serverMACBlock, | 1107 serverPendingState.macContext = createMACFunc( |
1025 digestmod=sha) | 1108 compatHMAC(serverMACBlock), digestmod=digestmod) |
1026 clientPendingState.encContext = createCipherFunc(clientKeyBlock, | 1109 clientPendingState.encContext = createCipherFunc(clientKeyBlock, |
1027 clientIVBlock, | 1110 clientIVBlock, |
1028 implementations) | 1111 implementations) |
1029 serverPendingState.encContext = createCipherFunc(serverKeyBlock, | 1112 serverPendingState.encContext = createCipherFunc(serverKeyBlock, |
1030 serverIVBlock, | 1113 serverIVBlock, |
1031 implementations) | 1114 implementations) |
1032 | 1115 |
1033 #Assign new connection states to pending states | 1116 #Assign new connection states to pending states |
1034 if self._client: | 1117 if self._client: |
1035 self._pendingWriteState = clientPendingState | 1118 self._pendingWriteState = clientPendingState |
1036 self._pendingReadState = serverPendingState | 1119 self._pendingReadState = serverPendingState |
1037 else: | 1120 else: |
1038 self._pendingWriteState = serverPendingState | 1121 self._pendingWriteState = serverPendingState |
1039 self._pendingReadState = clientPendingState | 1122 self._pendingReadState = clientPendingState |
1040 | 1123 |
1041 if self.version == (3,2) and ivLength: | 1124 if self.version == (3,2) and ivLength: |
1042 #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC | 1125 #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC |
1043 #residue to create the IV for each sent block) | 1126 #residue to create the IV for each sent block) |
1044 self.fixedIVBlock = getRandomBytes(ivLength) | 1127 self.fixedIVBlock = getRandomBytes(ivLength) |
1045 | 1128 |
1046 def _changeWriteState(self): | 1129 def _changeWriteState(self): |
1047 self._writeState = self._pendingWriteState | 1130 self._writeState = self._pendingWriteState |
1048 self._pendingWriteState = _ConnectionState() | 1131 self._pendingWriteState = _ConnectionState() |
1049 | 1132 |
1050 def _changeReadState(self): | 1133 def _changeReadState(self): |
1051 self._readState = self._pendingReadState | 1134 self._readState = self._pendingReadState |
1052 self._pendingReadState = _ConnectionState() | 1135 self._pendingReadState = _ConnectionState() |
1053 | 1136 |
1054 def _sendFinished(self): | |
1055 #Send ChangeCipherSpec | |
1056 for result in self._sendMsg(ChangeCipherSpec()): | |
1057 yield result | |
1058 | |
1059 #Switch to pending write state | |
1060 self._changeWriteState() | |
1061 | |
1062 #Calculate verification data | |
1063 verifyData = self._calcFinished(True) | |
1064 if self.fault == Fault.badFinished: | |
1065 verifyData[0] = (verifyData[0]+1)%256 | |
1066 | |
1067 #Send Finished message under new state | |
1068 finished = Finished(self.version).create(verifyData) | |
1069 for result in self._sendMsg(finished): | |
1070 yield result | |
1071 | |
1072 def _getChangeCipherSpec(self): | |
1073 #Get and check ChangeCipherSpec | |
1074 for result in self._getMsg(ContentType.change_cipher_spec): | |
1075 if result in (0,1): | |
1076 yield result | |
1077 changeCipherSpec = result | |
1078 | |
1079 if changeCipherSpec.type != 1: | |
1080 for result in self._sendError(AlertDescription.illegal_parameter, | |
1081 "ChangeCipherSpec type incorrect"): | |
1082 yield result | |
1083 | |
1084 #Switch to pending read state | |
1085 self._changeReadState() | |
1086 | |
1087 def _getEncryptedExtensions(self): | |
1088 for result in self._getMsg(ContentType.handshake, | |
1089 HandshakeType.encrypted_extensions): | |
1090 if result in (0,1): | |
1091 yield result | |
1092 encrypted_extensions = result | |
1093 self.channel_id = encrypted_extensions.channel_id_key | |
1094 | |
1095 def _getFinished(self): | |
1096 #Calculate verification data | |
1097 verifyData = self._calcFinished(False) | |
1098 | |
1099 #Get and check Finished message under new state | |
1100 for result in self._getMsg(ContentType.handshake, | |
1101 HandshakeType.finished): | |
1102 if result in (0,1): | |
1103 yield result | |
1104 finished = result | |
1105 if finished.verify_data != verifyData: | |
1106 for result in self._sendError(AlertDescription.decrypt_error, | |
1107 "Finished message is incorrect"): | |
1108 yield result | |
1109 | |
1110 def _calcFinished(self, send=True): | |
1111 if self.version == (3,0): | |
1112 if (self._client and send) or (not self._client and not send): | |
1113 senderStr = "\x43\x4C\x4E\x54" | |
1114 else: | |
1115 senderStr = "\x53\x52\x56\x52" | |
1116 | |
1117 verifyData = self._calcSSLHandshakeHash(self.session.masterSecret, | |
1118 senderStr) | |
1119 return verifyData | |
1120 | |
1121 elif self.version in ((3,1), (3,2)): | |
1122 if (self._client and send) or (not self._client and not send): | |
1123 label = "client finished" | |
1124 else: | |
1125 label = "server finished" | |
1126 | |
1127 handshakeHashes = stringToBytes(self._handshake_md5.digest() + \ | |
1128 self._handshake_sha.digest()) | |
1129 verifyData = PRF(self.session.masterSecret, label, handshakeHashes, | |
1130 12) | |
1131 return verifyData | |
1132 else: | |
1133 raise AssertionError() | |
1134 | |
1135 #Used for Finished messages and CertificateVerify messages in SSL v3 | 1137 #Used for Finished messages and CertificateVerify messages in SSL v3 |
1136 def _calcSSLHandshakeHash(self, masterSecret, label): | 1138 def _calcSSLHandshakeHash(self, masterSecret, label): |
1137 masterSecretStr = bytesToString(masterSecret) | |
1138 | |
1139 imac_md5 = self._handshake_md5.copy() | 1139 imac_md5 = self._handshake_md5.copy() |
1140 imac_sha = self._handshake_sha.copy() | 1140 imac_sha = self._handshake_sha.copy() |
1141 | 1141 |
1142 imac_md5.update(label + masterSecretStr + '\x36'*48) | 1142 imac_md5.update(compatHMAC(label + masterSecret + bytearray([0x36]*48))) |
1143 imac_sha.update(label + masterSecretStr + '\x36'*40) | 1143 imac_sha.update(compatHMAC(label + masterSecret + bytearray([0x36]*40))) |
1144 | 1144 |
1145 md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \ | 1145 md5Bytes = MD5(masterSecret + bytearray([0x5c]*48) + \ |
1146 imac_md5.digest()).digest() | 1146 bytearray(imac_md5.digest())) |
1147 shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \ | 1147 shaBytes = SHA1(masterSecret + bytearray([0x5c]*40) + \ |
1148 imac_sha.digest()).digest() | 1148 bytearray(imac_sha.digest())) |
1149 | 1149 |
1150 return stringToBytes(md5Str + shaStr) | 1150 return md5Bytes + shaBytes |
| 1151 |
OLD | NEW |