OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 """ | |
5 Tests for ssh/transport.py and the classes therein. | |
6 """ | |
7 | |
8 import md5, sha | |
9 | |
10 try: | |
11 import Crypto | |
12 except ImportError: | |
13 Crypto = None | |
14 class transport: # fictional modules to make classes work | |
15 class SSHTransportBase: pass | |
16 class SSHServerTransport: pass | |
17 class SSHClientTransport: pass | |
18 class factory: | |
19 class SSHFactory: | |
20 pass | |
21 else: | |
22 from twisted.conch.ssh import transport, common, keys, factory | |
23 from twisted.conch.test import keydata | |
24 | |
25 from twisted.trial import unittest | |
26 from twisted.internet import defer | |
27 from twisted.protocols import loopback | |
28 from twisted.python import randbytes | |
29 from twisted.python.reflect import qual | |
30 from twisted.conch.ssh import service | |
31 from twisted.test import proto_helpers | |
32 | |
33 from twisted.conch.error import ConchError | |
34 | |
35 | |
36 | |
37 class MockTransportBase(transport.SSHTransportBase): | |
38 """ | |
39 A base class for the client and server protocols. Stores the messages | |
40 it receieves instead of ignoring them. | |
41 | |
42 @ivar errors: a list of tuples: (reasonCode, description) | |
43 @ivar unimplementeds: a list of integers: sequence number | |
44 @ivar debugs: a list of tuples: (alwaysDisplay, message, lang) | |
45 @ivar ignoreds: a list of strings: ignored data | |
46 """ | |
47 | |
48 | |
49 def connectionMade(self): | |
50 """ | |
51 Set up instance variables. | |
52 """ | |
53 transport.SSHTransportBase.connectionMade(self) | |
54 self.errors = [] | |
55 self.unimplementeds = [] | |
56 self.debugs = [] | |
57 self.ignoreds = [] | |
58 | |
59 | |
60 def receiveError(self, reasonCode, description): | |
61 """ | |
62 Store any errors received. | |
63 | |
64 @type reasonCode: C{int} | |
65 @type description: C{str} | |
66 """ | |
67 self.errors.append((reasonCode, description)) | |
68 | |
69 | |
70 def receiveUnimplemented(self, seqnum): | |
71 """ | |
72 Store any unimplemented packet messages. | |
73 | |
74 @type seqnum: C{int} | |
75 """ | |
76 self.unimplementeds.append(seqnum) | |
77 | |
78 | |
79 def receiveDebug(self, alwaysDisplay, message, lang): | |
80 """ | |
81 Store any debug messages. | |
82 | |
83 @type alwaysDisplay: C{bool} | |
84 @type message: C{str} | |
85 @type lang: C{str} | |
86 """ | |
87 self.debugs.append((alwaysDisplay, message, lang)) | |
88 | |
89 | |
90 def ssh_IGNORE(self, packet): | |
91 """ | |
92 Store any ignored data. | |
93 | |
94 @type packet: C{str} | |
95 """ | |
96 self.ignoreds.append(packet) | |
97 | |
98 | |
99 class MockCipher(object): | |
100 """ | |
101 A mocked-up version of twisted.conch.ssh.transport.SSHCiphers. | |
102 """ | |
103 outCipType = 'test' | |
104 encBlockSize = 6 | |
105 inCipType = 'test' | |
106 decBlockSize = 6 | |
107 inMACType = 'test' | |
108 outMACType = 'test' | |
109 verifyDigestSize = 1 | |
110 usedEncrypt = False | |
111 usedDecrypt = False | |
112 outMAC = (None, '', '', 1) | |
113 inMAC = (None, '', '', 1) | |
114 keys = () | |
115 | |
116 | |
117 def encrypt(self, x): | |
118 """ | |
119 Called to encrypt the packet. Simply record that encryption was used | |
120 and return the data unchanged. | |
121 """ | |
122 self.usedEncrypt = True | |
123 if (len(x) % self.encBlockSize) != 0: | |
124 raise RuntimeError("length %i modulo blocksize %i is not 0: %i" % | |
125 (len(x), self.encBlockSize, len(x) % self.encBlockSize)) | |
126 return x | |
127 | |
128 | |
129 def decrypt(self, x): | |
130 """ | |
131 Called to decrypt the packet. Simply record that decryption was used | |
132 and return the data unchanged. | |
133 """ | |
134 self.usedDecrypt = True | |
135 if (len(x) % self.encBlockSize) != 0: | |
136 raise RuntimeError("length %i modulo blocksize %i is not 0: %i" % | |
137 (len(x), self.decBlockSize, len(x) % self.decBlockSize)) | |
138 return x | |
139 | |
140 | |
141 def makeMAC(self, outgoingPacketSequence, payload): | |
142 """ | |
143 Make a Message Authentication Code by sending the character value of | |
144 the outgoing packet. | |
145 """ | |
146 return chr(outgoingPacketSequence) | |
147 | |
148 | |
149 def verify(self, incomingPacketSequence, packet, macData): | |
150 """ | |
151 Verify the Message Authentication Code by checking that the packet | |
152 sequence number is the same. | |
153 """ | |
154 return chr(incomingPacketSequence) == macData | |
155 | |
156 | |
157 def setKeys(self, ivOut, keyOut, ivIn, keyIn, macIn, macOut): | |
158 """ | |
159 Record the keys. | |
160 """ | |
161 self.keys = (ivOut, keyOut, ivIn, keyIn, macIn, macOut) | |
162 | |
163 | |
164 | |
165 class MockCompression: | |
166 """ | |
167 A mocked-up compression, based on the zlib interface. Instead of | |
168 compressing, it reverses the data and adds a 0x66 byte to the end. | |
169 """ | |
170 | |
171 | |
172 def compress(self, payload): | |
173 return payload[::-1] # reversed | |
174 | |
175 | |
176 def decompress(self, payload): | |
177 return payload[:-1][::-1] | |
178 | |
179 | |
180 def flush(self, kind): | |
181 return '\x66' | |
182 | |
183 | |
184 | |
185 class MockService(service.SSHService): | |
186 """ | |
187 A mocked-up service, based on twisted.conch.ssh.service.SSHService. | |
188 | |
189 @ivar started: True if this service has been started. | |
190 @ivar stopped: True if this service has been stopped. | |
191 """ | |
192 name = "MockService" | |
193 started = False | |
194 stopped = False | |
195 protocolMessages = {0xff: "MSG_TEST", 71: "MSG_fiction"} | |
196 | |
197 | |
198 def logPrefix(self): | |
199 return "MockService" | |
200 | |
201 | |
202 def serviceStarted(self): | |
203 """ | |
204 Record that the service was started. | |
205 """ | |
206 self.started = True | |
207 | |
208 | |
209 def serviceStopped(self): | |
210 """ | |
211 Record that the service was stopped. | |
212 """ | |
213 self.stopped = True | |
214 | |
215 | |
216 def ssh_TEST(self, packet): | |
217 """ | |
218 A message that this service responds to. | |
219 """ | |
220 self.transport.sendPacket(0xff, packet) | |
221 | |
222 | |
223 class MockFactory(factory.SSHFactory): | |
224 """ | |
225 A mocked-up factory based on twisted.conch.ssh.factory.SSHFactory. | |
226 """ | |
227 services = { | |
228 'ssh-userauth': MockService} | |
229 | |
230 | |
231 def getPublicKeys(self): | |
232 """ | |
233 Return the public keys that authenticate this server. | |
234 """ | |
235 return { | |
236 'ssh-rsa': keys.Key.fromString(keydata.publicRSA_openssh), | |
237 'ssh-dsa': keys.Key.fromString(keydata.publicDSA_openssh)} | |
238 | |
239 | |
240 def getPrivateKeys(self): | |
241 """ | |
242 Return the private keys that authenticate this server. | |
243 """ | |
244 return { | |
245 'ssh-rsa': keys.Key.fromString(keydata.privateRSA_openssh), | |
246 'ssh-dsa': keys.Key.fromString(keydata.privateDSA_openssh)} | |
247 | |
248 | |
249 def getPrimes(self): | |
250 """ | |
251 Return the Diffie-Hellman primes that can be used for the | |
252 diffie-hellman-group-exchange-sha1 key exchange. | |
253 """ | |
254 return { | |
255 1024: ((2, transport.DH_PRIME),), | |
256 2048: ((3, transport.DH_PRIME),), | |
257 4096: ((5, 7),)} | |
258 | |
259 | |
260 | |
261 class MockOldFactoryPublicKeys(MockFactory): | |
262 """ | |
263 The old SSHFactory returned mappings from key names to strings from | |
264 getPublicKeys(). We return those here for testing. | |
265 """ | |
266 | |
267 | |
268 def getPublicKeys(self): | |
269 """ | |
270 We used to map key types to public key blobs as strings. | |
271 """ | |
272 keys = MockFactory.getPublicKeys(self) | |
273 for name, key in keys.items()[:]: | |
274 keys[name] = key.blob() | |
275 return keys | |
276 | |
277 | |
278 | |
279 class MockOldFactoryPrivateKeys(MockFactory): | |
280 """ | |
281 The old SSHFactory returned mappings from key names to PyCrypto key | |
282 objects from getPrivateKeys(). We return those here for testing. | |
283 """ | |
284 | |
285 | |
286 def getPrivateKeys(self): | |
287 """ | |
288 We used to map key types to PyCrypto key objects. | |
289 """ | |
290 keys = MockFactory.getPrivateKeys(self) | |
291 for name, key in keys.items()[:]: | |
292 keys[name] = key.keyObject | |
293 return keys | |
294 | |
295 | |
296 | |
297 class TransportTestCase(unittest.TestCase): | |
298 """ | |
299 Base class for transport test cases. | |
300 """ | |
301 klass = None | |
302 | |
303 if Crypto is None: | |
304 skip = "cannot run w/o PyCrypto" | |
305 | |
306 | |
307 def setUp(self): | |
308 self.transport = proto_helpers.StringTransport() | |
309 self.proto = self.klass() | |
310 self.packets = [] | |
311 def secureRandom(len): | |
312 """ | |
313 Return a consistent entropy value | |
314 """ | |
315 return '\x99' * len | |
316 self.oldSecureRandom = randbytes.secureRandom | |
317 randbytes.secureRandom = secureRandom | |
318 def stubSendPacket(messageType, payload): | |
319 self.packets.append((messageType, payload)) | |
320 self.proto.makeConnection(self.transport) | |
321 # we just let the kex packet go into the transport | |
322 self.proto.sendPacket = stubSendPacket | |
323 | |
324 | |
325 def tearDown(self): | |
326 randbytes.secureRandom = self.oldSecureRandom | |
327 self.oldSecureRandom = None | |
328 | |
329 | |
330 | |
331 class BaseSSHTransportTestCase(TransportTestCase): | |
332 """ | |
333 Test TransportBase. It implements the non-server/client specific | |
334 parts of the SSH transport protocol. | |
335 """ | |
336 | |
337 klass = MockTransportBase | |
338 | |
339 | |
340 def test_sendVersion(self): | |
341 """ | |
342 Test that the first thing sent over the connection is the version | |
343 string. | |
344 """ | |
345 # the other setup was done in the setup method | |
346 self.assertEquals(self.transport.value().split('\r\n', 1)[0], | |
347 "SSH-2.0-Twisted") | |
348 | |
349 | |
350 def test_sendPacketPlain(self): | |
351 """ | |
352 Test that plain (unencrypted, uncompressed) packets are sent | |
353 correctly. The format is:: | |
354 uint32 length (including type and padding length) | |
355 byte padding length | |
356 byte type | |
357 bytes[length-padding length-2] data | |
358 bytes[padding length] padding | |
359 """ | |
360 proto = MockTransportBase() | |
361 proto.makeConnection(self.transport) | |
362 self.transport.clear() | |
363 message = ord('A') | |
364 payload = 'BCDEFG' | |
365 proto.sendPacket(message, payload) | |
366 value = self.transport.value() | |
367 self.assertEquals(value, '\x00\x00\x00\x0c\x04ABCDEFG\x99\x99\x99\x99') | |
368 | |
369 | |
370 def test_sendPacketEncrypted(self): | |
371 """ | |
372 Test that packets sent while encryption is enabled are sent | |
373 correctly. The whole packet should be encrypted. | |
374 """ | |
375 proto = MockTransportBase() | |
376 proto.makeConnection(self.transport) | |
377 proto.currentEncryptions = testCipher = MockCipher() | |
378 message = ord('A') | |
379 payload = 'BC' | |
380 self.transport.clear() | |
381 proto.sendPacket(message, payload) | |
382 self.assertTrue(testCipher.usedEncrypt) | |
383 value = self.transport.value() | |
384 self.assertEquals(value, '\x00\x00\x00\x08\x04ABC\x99\x99\x99\x99\x01') | |
385 | |
386 | |
387 def test_sendPacketCompressed(self): | |
388 """ | |
389 Test that packets sent while compression is enabled are sent | |
390 correctly. The packet type and data should be encrypted. | |
391 """ | |
392 proto = MockTransportBase() | |
393 proto.makeConnection(self.transport) | |
394 proto.outgoingCompression = MockCompression() | |
395 self.transport.clear() | |
396 proto.sendPacket(ord('A'), 'B') | |
397 value = self.transport.value() | |
398 self.assertEquals( | |
399 value, | |
400 '\x00\x00\x00\x0c\x08BA\x66\x99\x99\x99\x99\x99\x99\x99\x99') | |
401 | |
402 | |
403 def test_sendPacketBoth(self): | |
404 """ | |
405 Test that packets sent while compression and encryption are | |
406 enabled are sent correctly. The packet type and data should be | |
407 compressed and then the whole packet should be encrypted. | |
408 """ | |
409 proto = MockTransportBase() | |
410 proto.makeConnection(self.transport) | |
411 proto.currentEncryptions = testCipher = MockCipher() | |
412 proto.outgoingCompression = MockCompression() | |
413 message = ord('A') | |
414 payload = 'BC' | |
415 self.transport.clear() | |
416 proto.sendPacket(message, payload) | |
417 value = self.transport.value() | |
418 self.assertEquals( | |
419 value, | |
420 '\x00\x00\x00\x0e\x09CBA\x66\x99\x99\x99\x99\x99\x99\x99\x99\x99' | |
421 '\x01') | |
422 | |
423 | |
424 def test_getPacketPlain(self): | |
425 """ | |
426 Test that packets are retrieved correctly out of the buffer when | |
427 no encryption is enabled. | |
428 """ | |
429 proto = MockTransportBase() | |
430 proto.makeConnection(self.transport) | |
431 self.transport.clear() | |
432 proto.sendPacket(ord('A'), 'BC') | |
433 proto.buf = self.transport.value() + 'extra' | |
434 self.assertEquals(proto.getPacket(), 'ABC') | |
435 self.assertEquals(proto.buf, 'extra') | |
436 | |
437 | |
438 def test_getPacketEncrypted(self): | |
439 """ | |
440 Test that encrypted packets are retrieved correctly. | |
441 See test_sendPacketEncrypted. | |
442 """ | |
443 proto = MockTransportBase() | |
444 proto.sendKexInit = lambda: None # don't send packets | |
445 proto.makeConnection(self.transport) | |
446 self.transport.clear() | |
447 proto.currentEncryptions = testCipher = MockCipher() | |
448 proto.sendPacket(ord('A'), 'BCD') | |
449 value = self.transport.value() | |
450 proto.buf = value[:MockCipher.decBlockSize] | |
451 self.assertEquals(proto.getPacket(), None) | |
452 self.assertTrue(testCipher.usedDecrypt) | |
453 self.assertEquals(proto.first, '\x00\x00\x00\x0e\x09A') | |
454 proto.buf += value[MockCipher.decBlockSize:] | |
455 self.assertEquals(proto.getPacket(), 'ABCD') | |
456 self.assertEquals(proto.buf, '') | |
457 | |
458 | |
459 def test_getPacketCompressed(self): | |
460 """ | |
461 Test that compressed packets are retrieved correctly. See | |
462 test_sendPacketCompressed. | |
463 """ | |
464 proto = MockTransportBase() | |
465 proto.makeConnection(self.transport) | |
466 self.transport.clear() | |
467 proto.outgoingCompression = MockCompression() | |
468 proto.incomingCompression = proto.outgoingCompression | |
469 proto.sendPacket(ord('A'), 'BCD') | |
470 proto.buf = self.transport.value() | |
471 self.assertEquals(proto.getPacket(), 'ABCD') | |
472 | |
473 | |
474 def test_getPacketBoth(self): | |
475 """ | |
476 Test that compressed and encrypted packets are retrieved correctly. | |
477 See test_sendPacketBoth. | |
478 """ | |
479 proto = MockTransportBase() | |
480 proto.sendKexInit = lambda: None | |
481 proto.makeConnection(self.transport) | |
482 self.transport.clear() | |
483 proto.currentEncryptions = testCipher = MockCipher() | |
484 proto.outgoingCompression = MockCompression() | |
485 proto.incomingCompression = proto.outgoingCompression | |
486 proto.sendPacket(ord('A'), 'BCDEFG') | |
487 proto.buf = self.transport.value() | |
488 self.assertEquals(proto.getPacket(), 'ABCDEFG') | |
489 | |
490 | |
491 def test_ciphersAreValid(self): | |
492 """ | |
493 Test that all the supportedCiphers are valid. | |
494 """ | |
495 ciphers = transport.SSHCiphers('A', 'B', 'C', 'D') | |
496 iv = key = '\x00' * 16 | |
497 for cipName in self.proto.supportedCiphers: | |
498 self.assertTrue(ciphers._getCipher(cipName, iv, key)) | |
499 | |
500 | |
501 def test_sendKexInit(self): | |
502 """ | |
503 Test that the KEXINIT (key exchange initiation) message is sent | |
504 correctly. Payload:: | |
505 bytes[16] cookie | |
506 string key exchange algorithms | |
507 string public key algorithms | |
508 string outgoing ciphers | |
509 string incoming ciphers | |
510 string outgoing MACs | |
511 string incoming MACs | |
512 string outgoing compressions | |
513 string incoming compressions | |
514 bool first packet follows | |
515 uint32 0 | |
516 """ | |
517 value = self.transport.value().split('\r\n', 1)[1] | |
518 self.proto.buf = value | |
519 packet = self.proto.getPacket() | |
520 self.assertEquals(packet[0], chr(transport.MSG_KEXINIT)) | |
521 self.assertEquals(packet[1:17], '\x99' * 16) | |
522 (kex, pubkeys, ciphers1, ciphers2, macs1, macs2, compressions1, | |
523 compressions2, languages1, languages2, | |
524 buf) = common.getNS(packet[17:], 10) | |
525 | |
526 self.assertEquals(kex, ','.join(self.proto.supportedKeyExchanges)) | |
527 self.assertEquals(pubkeys, ','.join(self.proto.supportedPublicKeys)) | |
528 self.assertEquals(ciphers1, ','.join(self.proto.supportedCiphers)) | |
529 self.assertEquals(ciphers2, ','.join(self.proto.supportedCiphers)) | |
530 self.assertEquals(macs1, ','.join(self.proto.supportedMACs)) | |
531 self.assertEquals(macs2, ','.join(self.proto.supportedMACs)) | |
532 self.assertEquals(compressions1, | |
533 ','.join(self.proto.supportedCompressions)) | |
534 self.assertEquals(compressions2, | |
535 ','.join(self.proto.supportedCompressions)) | |
536 self.assertEquals(languages1, ','.join(self.proto.supportedLanguages)) | |
537 self.assertEquals(languages2, ','.join(self.proto.supportedLanguages)) | |
538 self.assertEquals(buf, '\x00' * 5) | |
539 | |
540 | |
541 def test_sendDebug(self): | |
542 """ | |
543 Test that debug messages are sent correctly. Payload:: | |
544 bool always display | |
545 string debug message | |
546 string language | |
547 """ | |
548 self.proto.sendDebug("test", True, 'en') | |
549 self.assertEquals( | |
550 self.packets, | |
551 [(transport.MSG_DEBUG, | |
552 "\x01\x00\x00\x00\x04test\x00\x00\x00\x02en")]) | |
553 | |
554 | |
555 def test_receiveDebug(self): | |
556 """ | |
557 Test that debug messages are received correctly. See test_sendDebug. | |
558 """ | |
559 self.proto.dispatchMessage( | |
560 transport.MSG_DEBUG, | |
561 '\x01\x00\x00\x00\x04test\x00\x00\x00\x02en') | |
562 self.assertEquals(self.proto.debugs, [(True, 'test', 'en')]) | |
563 | |
564 | |
565 def test_sendIgnore(self): | |
566 """ | |
567 Test that ignored messages are sent correctly. Payload:: | |
568 string ignored data | |
569 """ | |
570 self.proto.sendIgnore("test") | |
571 self.assertEquals( | |
572 self.packets, [(transport.MSG_IGNORE, | |
573 '\x00\x00\x00\x04test')]) | |
574 | |
575 | |
576 def test_receiveIgnore(self): | |
577 """ | |
578 Test that ignored messages are received correctly. See | |
579 test_sendIgnore. | |
580 """ | |
581 self.proto.dispatchMessage(transport.MSG_IGNORE, 'test') | |
582 self.assertEquals(self.proto.ignoreds, ['test']) | |
583 | |
584 | |
585 def test_sendUnimplemented(self): | |
586 """ | |
587 Test that unimplemented messages are sent correctly. Payload:: | |
588 uint32 sequence number | |
589 """ | |
590 self.proto.sendUnimplemented() | |
591 self.assertEquals( | |
592 self.packets, [(transport.MSG_UNIMPLEMENTED, | |
593 '\x00\x00\x00\x00')]) | |
594 | |
595 | |
596 def test_receiveUnimplemented(self): | |
597 """ | |
598 Test that unimplemented messages are received correctly. See | |
599 test_sendUnimplemented. | |
600 """ | |
601 self.proto.dispatchMessage(transport.MSG_UNIMPLEMENTED, | |
602 '\x00\x00\x00\xff') | |
603 self.assertEquals(self.proto.unimplementeds, [255]) | |
604 | |
605 | |
606 def test_sendDisconnect(self): | |
607 """ | |
608 Test that disconnection messages are sent correctly. Payload:: | |
609 uint32 reason code | |
610 string reason description | |
611 string language | |
612 """ | |
613 disconnected = [False] | |
614 def stubLoseConnection(): | |
615 disconnected[0] = True | |
616 self.transport.loseConnection = stubLoseConnection | |
617 self.proto.sendDisconnect(0xff, "test") | |
618 self.assertEquals( | |
619 self.packets, | |
620 [(transport.MSG_DISCONNECT, | |
621 "\x00\x00\x00\xff\x00\x00\x00\x04test\x00\x00\x00\x00")]) | |
622 self.assertTrue(disconnected[0]) | |
623 | |
624 | |
625 def test_receiveDisconnect(self): | |
626 """ | |
627 Test that disconnection messages are received correctly. See | |
628 test_sendDisconnect. | |
629 """ | |
630 disconnected = [False] | |
631 def stubLoseConnection(): | |
632 disconnected[0] = True | |
633 self.transport.loseConnection = stubLoseConnection | |
634 self.proto.dispatchMessage(transport.MSG_DISCONNECT, | |
635 '\x00\x00\x00\xff\x00\x00\x00\x04test') | |
636 self.assertEquals(self.proto.errors, [(255, 'test')]) | |
637 self.assertTrue(disconnected[0]) | |
638 | |
639 | |
640 def test_dataReceived(self): | |
641 """ | |
642 Test that dataReceived parses packets and dispatches them to | |
643 ssh_* methods. | |
644 """ | |
645 kexInit = [False] | |
646 def stubKEXINIT(packet): | |
647 kexInit[0] = True | |
648 self.proto.ssh_KEXINIT = stubKEXINIT | |
649 self.proto.dataReceived(self.transport.value()) | |
650 self.assertTrue(self.proto.gotVersion) | |
651 self.assertEquals(self.proto.ourVersionString, | |
652 self.proto.otherVersionString) | |
653 self.assertTrue(kexInit[0]) | |
654 | |
655 | |
656 def test_service(self): | |
657 """ | |
658 Test that the transport can set the running service and dispatches | |
659 packets to the service's packetReceived method. | |
660 """ | |
661 service = MockService() | |
662 self.proto.setService(service) | |
663 self.assertEquals(self.proto.service, service) | |
664 self.assertTrue(service.started) | |
665 self.proto.dispatchMessage(0xff, "test") | |
666 self.assertEquals(self.packets, [(0xff, "test")]) | |
667 | |
668 service2 = MockService() | |
669 self.proto.setService(service2) | |
670 self.assertTrue(service2.started) | |
671 self.assertTrue(service.stopped) | |
672 | |
673 self.proto.connectionLost(None) | |
674 self.assertTrue(service2.stopped) | |
675 | |
676 | |
677 def test_avatar(self): | |
678 """ | |
679 Test that the transport notifies the avatar of disconnections. | |
680 """ | |
681 disconnected = [False] | |
682 def logout(): | |
683 disconnected[0] = True | |
684 self.proto.logoutFunction = logout | |
685 self.proto.avatar = True | |
686 | |
687 self.proto.connectionLost(None) | |
688 self.assertTrue(disconnected[0]) | |
689 | |
690 | |
691 def test_isEncrypted(self): | |
692 """ | |
693 Test that the transport accurately reflects its encrypted status. | |
694 """ | |
695 self.assertFalse(self.proto.isEncrypted('in')) | |
696 self.assertFalse(self.proto.isEncrypted('out')) | |
697 self.assertFalse(self.proto.isEncrypted('both')) | |
698 self.proto.currentEncryptions = MockCipher() | |
699 self.assertTrue(self.proto.isEncrypted('in')) | |
700 self.assertTrue(self.proto.isEncrypted('out')) | |
701 self.assertTrue(self.proto.isEncrypted('both')) | |
702 self.proto.currentEncryptions = transport.SSHCiphers('none', 'none', | |
703 'none', 'none') | |
704 self.assertFalse(self.proto.isEncrypted('in')) | |
705 self.assertFalse(self.proto.isEncrypted('out')) | |
706 self.assertFalse(self.proto.isEncrypted('both')) | |
707 | |
708 self.assertRaises(TypeError, self.proto.isEncrypted, 'bad') | |
709 | |
710 | |
711 def test_isVerified(self): | |
712 """ | |
713 Test that the transport accurately reflects its verified status. | |
714 """ | |
715 self.assertFalse(self.proto.isVerified('in')) | |
716 self.assertFalse(self.proto.isVerified('out')) | |
717 self.assertFalse(self.proto.isVerified('both')) | |
718 self.proto.currentEncryptions = MockCipher() | |
719 self.assertTrue(self.proto.isVerified('in')) | |
720 self.assertTrue(self.proto.isVerified('out')) | |
721 self.assertTrue(self.proto.isVerified('both')) | |
722 self.proto.currentEncryptions = transport.SSHCiphers('none', 'none', | |
723 'none', 'none') | |
724 self.assertFalse(self.proto.isVerified('in')) | |
725 self.assertFalse(self.proto.isVerified('out')) | |
726 self.assertFalse(self.proto.isVerified('both')) | |
727 | |
728 self.assertRaises(TypeError, self.proto.isVerified, 'bad') | |
729 | |
730 | |
731 def test_loseConnection(self): | |
732 """ | |
733 Test that loseConnection sends a disconnect message and closes the | |
734 connection. | |
735 """ | |
736 disconnected = [False] | |
737 def stubLoseConnection(): | |
738 disconnected[0] = True | |
739 self.transport.loseConnection = stubLoseConnection | |
740 self.proto.loseConnection() | |
741 self.assertEquals(self.packets[0][0], transport.MSG_DISCONNECT) | |
742 self.assertEquals(self.packets[0][1][3], | |
743 chr(transport.DISCONNECT_CONNECTION_LOST)) | |
744 | |
745 | |
746 def test_badVersion(self): | |
747 """ | |
748 Test that the transport disconnects when it receives a bad version. | |
749 """ | |
750 def testBad(version): | |
751 self.packets = [] | |
752 self.proto.gotVersion = False | |
753 disconnected = [False] | |
754 def stubLoseConnection(): | |
755 disconnected[0] = True | |
756 self.transport.loseConnection = stubLoseConnection | |
757 for c in version + '\r\n': | |
758 self.proto.dataReceived(c) | |
759 self.assertTrue(disconnected[0]) | |
760 self.assertEquals(self.packets[0][0], transport.MSG_DISCONNECT) | |
761 self.assertEquals( | |
762 self.packets[0][1][3], | |
763 chr(transport.DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED)) | |
764 testBad('SSH-1.5-OpenSSH') | |
765 testBad('SSH-3.0-Twisted') | |
766 testBad('GET / HTTP/1.1') | |
767 | |
768 | |
769 def test_dataBeforeVersion(self): | |
770 """ | |
771 Test that the transport ignores data sent before the version string. | |
772 """ | |
773 proto = MockTransportBase() | |
774 proto.makeConnection(proto_helpers.StringTransport()) | |
775 data = ("""here's some stuff beforehand | |
776 here's some other stuff | |
777 """ + proto.ourVersionString + "\r\n") | |
778 [proto.dataReceived(c) for c in data] | |
779 self.assertTrue(proto.gotVersion) | |
780 self.assertEquals(proto.otherVersionString, proto.ourVersionString) | |
781 | |
782 | |
783 def test_compatabilityVersion(self): | |
784 """ | |
785 Test that the transport treats the compatbility version (1.99) | |
786 as equivalent to version 2.0. | |
787 """ | |
788 proto = MockTransportBase() | |
789 proto.makeConnection(proto_helpers.StringTransport()) | |
790 proto.dataReceived("SSH-1.99-OpenSSH\n") | |
791 self.assertTrue(proto.gotVersion) | |
792 self.assertEquals(proto.otherVersionString, "SSH-1.99-OpenSSH") | |
793 | |
794 | |
795 def test_badPackets(self): | |
796 """ | |
797 Test that the transport disconnects with an error when it receives | |
798 bad packets. | |
799 """ | |
800 def testBad(packet, error=transport.DISCONNECT_PROTOCOL_ERROR): | |
801 self.packets = [] | |
802 self.proto.buf = packet | |
803 self.assertEquals(self.proto.getPacket(), None) | |
804 self.assertEquals(len(self.packets), 1) | |
805 self.assertEquals(self.packets[0][0], transport.MSG_DISCONNECT) | |
806 self.assertEquals(self.packets[0][1][3], chr(error)) | |
807 | |
808 testBad('\xff' * 8) # big packet | |
809 testBad('\x00\x00\x00\x05\x00BCDE') # length not modulo blocksize | |
810 oldEncryptions = self.proto.currentEncryptions | |
811 self.proto.currentEncryptions = MockCipher() | |
812 testBad('\x00\x00\x00\x08\x06AB123456', # bad MAC | |
813 transport.DISCONNECT_MAC_ERROR) | |
814 self.proto.currentEncryptions.decrypt = lambda x: x[:-1] | |
815 testBad('\x00\x00\x00\x08\x06BCDEFGHIJK') # bad decryption | |
816 self.proto.currentEncryptions = oldEncryptions | |
817 self.proto.incomingCompression = MockCompression() | |
818 def stubDecompress(payload): | |
819 raise Exception('bad compression') | |
820 self.proto.incomingCompression.decompress = stubDecompress | |
821 testBad('\x00\x00\x00\x04\x00BCDE', # bad decompression | |
822 transport.DISCONNECT_COMPRESSION_ERROR) | |
823 self.flushLoggedErrors() | |
824 | |
825 | |
826 def test_unimplementedPackets(self): | |
827 """ | |
828 Test that unimplemented packet types cause MSG_UNIMPLEMENTED packets | |
829 to be sent. | |
830 """ | |
831 seqnum = self.proto.incomingPacketSequence | |
832 def checkUnimplemented(seqnum=seqnum): | |
833 self.assertEquals(self.packets[0][0], | |
834 transport.MSG_UNIMPLEMENTED) | |
835 self.assertEquals(self.packets[0][1][3], chr(seqnum)) | |
836 self.proto.packets = [] | |
837 seqnum += 1 | |
838 | |
839 self.proto.dispatchMessage(40, '') | |
840 checkUnimplemented() | |
841 transport.messages[41] = 'MSG_fiction' | |
842 self.proto.dispatchMessage(41, '') | |
843 checkUnimplemented() | |
844 self.proto.dispatchMessage(60, '') | |
845 checkUnimplemented() | |
846 self.proto.setService(MockService()) | |
847 self.proto.dispatchMessage(70, '') | |
848 checkUnimplemented() | |
849 self.proto.dispatchMessage(71, '') | |
850 checkUnimplemented() | |
851 | |
852 | |
853 def test_getKey(self): | |
854 """ | |
855 Test that _getKey generates the correct keys. | |
856 """ | |
857 self.proto.sessionID = 'EF' | |
858 | |
859 k1 = sha.new('AB' + 'CD' | |
860 + 'K' + self.proto.sessionID).digest() | |
861 k2 = sha.new('ABCD' + k1).digest() | |
862 self.assertEquals(self.proto._getKey('K', 'AB', 'CD'), k1 + k2) | |
863 | |
864 | |
865 def test_multipleClasses(self): | |
866 """ | |
867 Test that multiple instances have distinct states. | |
868 """ | |
869 proto = self.proto | |
870 proto.dataReceived(self.transport.value()) | |
871 proto.currentEncryptions = MockCipher() | |
872 proto.outgoingCompression = MockCompression() | |
873 proto.incomingCompression = MockCompression() | |
874 proto.setService(MockService()) | |
875 proto2 = MockTransportBase() | |
876 proto2.makeConnection(proto_helpers.StringTransport()) | |
877 proto2.sendIgnore('') | |
878 self.failIfEquals(proto.gotVersion, proto2.gotVersion) | |
879 self.failIfEquals(proto.transport, proto2.transport) | |
880 self.failIfEquals(proto.outgoingPacketSequence, | |
881 proto2.outgoingPacketSequence) | |
882 self.failIfEquals(proto.incomingPacketSequence, | |
883 proto2.incomingPacketSequence) | |
884 self.failIfEquals(proto.currentEncryptions, | |
885 proto2.currentEncryptions) | |
886 self.failIfEquals(proto.service, proto2.service) | |
887 | |
888 | |
889 | |
890 class ServerAndClientSSHTransportBaseCase: | |
891 """ | |
892 Tests that need to be run on both the server and the client. | |
893 """ | |
894 | |
895 | |
896 def checkDisconnected(self, kind=None): | |
897 """ | |
898 Helper function to check if the transport disconnected. | |
899 """ | |
900 if kind is None: | |
901 kind = transport.DISCONNECT_PROTOCOL_ERROR | |
902 self.assertEquals(self.packets[-1][0], transport.MSG_DISCONNECT) | |
903 self.assertEquals(self.packets[-1][1][3], chr(kind)) | |
904 | |
905 | |
906 def connectModifiedProtocol(self, protoModification, | |
907 kind=None): | |
908 """ | |
909 Helper function to connect a modified protocol to the test protocol | |
910 and test for disconnection. | |
911 """ | |
912 if kind is None: | |
913 kind = transport.DISCONNECT_KEY_EXCHANGE_FAILED | |
914 proto2 = self.klass() | |
915 protoModification(proto2) | |
916 proto2.makeConnection(proto_helpers.StringTransport()) | |
917 self.proto.dataReceived(proto2.transport.value()) | |
918 if kind: | |
919 self.checkDisconnected(kind) | |
920 return proto2 | |
921 | |
922 | |
923 def test_disconnectIfCantMatchKex(self): | |
924 """ | |
925 Test that the transport disconnects if it can't match the key | |
926 exchange | |
927 """ | |
928 def blankKeyExchanges(proto2): | |
929 proto2.supportedKeyExchanges = [] | |
930 self.connectModifiedProtocol(blankKeyExchanges) | |
931 | |
932 | |
933 def test_disconnectIfCantMatchKeyAlg(self): | |
934 """ | |
935 Like test_disconnectIfCantMatchKex, but for the key algorithm. | |
936 """ | |
937 def blankPublicKeys(proto2): | |
938 proto2.supportedPublicKeys = [] | |
939 self.connectModifiedProtocol(blankPublicKeys) | |
940 | |
941 | |
942 def test_disconnectIfCantMatchCompression(self): | |
943 """ | |
944 Like test_disconnectIfCantMatchKex, but for the compression. | |
945 """ | |
946 def blankCompressions(proto2): | |
947 proto2.supportedCompressions = [] | |
948 self.connectModifiedProtocol(blankCompressions) | |
949 | |
950 | |
951 def test_disconnectIfCantMatchCipher(self): | |
952 """ | |
953 Like test_disconnectIfCantMatchKex, but for the encryption. | |
954 """ | |
955 def blankCiphers(proto2): | |
956 proto2.supportedCiphers = [] | |
957 self.connectModifiedProtocol(blankCiphers) | |
958 | |
959 | |
960 def test_disconnectIfCantMatchMAC(self): | |
961 """ | |
962 Like test_disconnectIfCantMatchKex, but for the MAC. | |
963 """ | |
964 def blankMACs(proto2): | |
965 proto2.supportedMACs = [] | |
966 self.connectModifiedProtocol(blankMACs) | |
967 | |
968 | |
969 | |
970 class ServerSSHTransportTestCase(ServerAndClientSSHTransportBaseCase, | |
971 TransportTestCase): | |
972 """ | |
973 Tests for the SSHServerTransport. | |
974 """ | |
975 | |
976 klass = transport.SSHServerTransport | |
977 | |
978 | |
979 def setUp(self): | |
980 TransportTestCase.setUp(self) | |
981 self.proto.factory = MockFactory() | |
982 self.proto.factory.startFactory() | |
983 | |
984 | |
985 def tearDown(self): | |
986 TransportTestCase.tearDown(self) | |
987 self.proto.factory.stopFactory() | |
988 del self.proto.factory | |
989 | |
990 | |
991 def test_KEXINIT(self): | |
992 """ | |
993 Test that receiving a KEXINIT packet sets up the correct values on the | |
994 server. | |
995 """ | |
996 self.proto.dataReceived( 'SSH-2.0-Twisted\r\n\x00\x00\x01\xd4\t\x14' | |
997 '\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99' | |
998 '\x99\x00\x00\x00=diffie-hellman-group1-sha1,diffie-hellman-g' | |
999 'roup-exchange-sha1\x00\x00\x00\x0fssh-dss,ssh-rsa\x00\x00\x00' | |
1000 '\x85aes128-ctr,aes128-cbc,aes192-ctr,aes192-cbc,aes256-ctr,ae' | |
1001 's256-cbc,cast128-ctr,cast128-cbc,blowfish-ctr,blowfish-cbc,3d' | |
1002 'es-ctr,3des-cbc\x00\x00\x00\x85aes128-ctr,aes128-cbc,aes192-c' | |
1003 'tr,aes192-cbc,aes256-ctr,aes256-cbc,cast128-ctr,cast128-cbc,b' | |
1004 'lowfish-ctr,blowfish-cbc,3des-ctr,3des-cbc\x00\x00\x00\x12hma' | |
1005 'c-md5,hmac-sha1\x00\x00\x00\x12hmac-md5,hmac-sha1\x00\x00\x00' | |
1006 '\tnone,zlib\x00\x00\x00\tnone,zlib\x00\x00\x00\x00\x00\x00' | |
1007 '\x00\x00\x00\x00\x00\x00\x00\x99\x99\x99\x99\x99\x99\x99\x99' | |
1008 '\x99') | |
1009 self.assertEquals(self.proto.kexAlg, | |
1010 'diffie-hellman-group1-sha1') | |
1011 self.assertEquals(self.proto.keyAlg, | |
1012 'ssh-dss') | |
1013 self.assertEquals(self.proto.outgoingCompressionType, | |
1014 'none') | |
1015 self.assertEquals(self.proto.incomingCompressionType, | |
1016 'none') | |
1017 ne = self.proto.nextEncryptions | |
1018 self.assertEquals(ne.outCipType, 'aes128-ctr') | |
1019 self.assertEquals(ne.inCipType, 'aes128-ctr') | |
1020 self.assertEquals(ne.outMACType, 'hmac-md5') | |
1021 self.assertEquals(ne.inMACType, 'hmac-md5') | |
1022 | |
1023 | |
1024 def test_ignoreGuessPacketKex(self): | |
1025 """ | |
1026 The client is allowed to send a guessed key exchange packet | |
1027 after it sends the KEXINIT packet. However, if the key exchanges | |
1028 do not match, that guess packet must be ignored. This tests that | |
1029 the packet is ignored in the case of the key exchange method not | |
1030 matching. | |
1031 """ | |
1032 kexInitPacket = '\x00' * 16 + ( | |
1033 ''.join([common.NS(x) for x in | |
1034 [','.join(y) for y in | |
1035 [self.proto.supportedKeyExchanges[::-1], | |
1036 self.proto.supportedPublicKeys, | |
1037 self.proto.supportedCiphers, | |
1038 self.proto.supportedCiphers, | |
1039 self.proto.supportedMACs, | |
1040 self.proto.supportedMACs, | |
1041 self.proto.supportedCompressions, | |
1042 self.proto.supportedCompressions, | |
1043 self.proto.supportedLanguages, | |
1044 self.proto.supportedLanguages]]])) + ( | |
1045 '\xff\x00\x00\x00\x00') | |
1046 self.proto.ssh_KEXINIT(kexInitPacket) | |
1047 self.assertTrue(self.proto.ignoreNextPacket) | |
1048 self.proto.ssh_DEBUG("\x01\x00\x00\x00\x04test\x00\x00\x00\x00") | |
1049 self.assertTrue(self.proto.ignoreNextPacket) | |
1050 | |
1051 | |
1052 self.proto.ssh_KEX_DH_GEX_REQUEST_OLD('\x00\x00\x08\x00') | |
1053 self.assertFalse(self.proto.ignoreNextPacket) | |
1054 self.assertEquals(self.packets, []) | |
1055 self.proto.ignoreNextPacket = True | |
1056 | |
1057 self.proto.ssh_KEX_DH_GEX_REQUEST('\x00\x00\x08\x00' * 3) | |
1058 self.assertFalse(self.proto.ignoreNextPacket) | |
1059 self.assertEquals(self.packets, []) | |
1060 | |
1061 | |
1062 def test_ignoreGuessPacketKey(self): | |
1063 """ | |
1064 Like test_ignoreGuessPacketKex, but for an incorrectly guessed | |
1065 public key format. | |
1066 """ | |
1067 kexInitPacket = '\x00' * 16 + ( | |
1068 ''.join([common.NS(x) for x in | |
1069 [','.join(y) for y in | |
1070 [self.proto.supportedKeyExchanges, | |
1071 self.proto.supportedPublicKeys[::-1], | |
1072 self.proto.supportedCiphers, | |
1073 self.proto.supportedCiphers, | |
1074 self.proto.supportedMACs, | |
1075 self.proto.supportedMACs, | |
1076 self.proto.supportedCompressions, | |
1077 self.proto.supportedCompressions, | |
1078 self.proto.supportedLanguages, | |
1079 self.proto.supportedLanguages]]])) + ( | |
1080 '\xff\x00\x00\x00\x00') | |
1081 self.proto.ssh_KEXINIT(kexInitPacket) | |
1082 self.assertTrue(self.proto.ignoreNextPacket) | |
1083 self.proto.ssh_DEBUG("\x01\x00\x00\x00\x04test\x00\x00\x00\x00") | |
1084 self.assertTrue(self.proto.ignoreNextPacket) | |
1085 | |
1086 self.proto.ssh_KEX_DH_GEX_REQUEST_OLD('\x00\x00\x08\x00') | |
1087 self.assertFalse(self.proto.ignoreNextPacket) | |
1088 self.assertEquals(self.packets, []) | |
1089 self.proto.ignoreNextPacket = True | |
1090 | |
1091 self.proto.ssh_KEX_DH_GEX_REQUEST('\x00\x00\x08\x00' * 3) | |
1092 self.assertFalse(self.proto.ignoreNextPacket) | |
1093 self.assertEquals(self.packets, []) | |
1094 | |
1095 | |
1096 def test_KEXDH_INIT(self): | |
1097 """ | |
1098 Test that the KEXDH_INIT packet causes the server to send a | |
1099 KEXDH_REPLY with the server's public key and a signature. | |
1100 """ | |
1101 self.proto.supportedKeyExchanges = ['diffie-hellman-group1-sha1'] | |
1102 self.proto.supportedPublicKeys = ['ssh-rsa'] | |
1103 self.proto.dataReceived(self.transport.value()) | |
1104 e = pow(transport.DH_GENERATOR, 5000, | |
1105 transport.DH_PRIME) | |
1106 | |
1107 self.proto.ssh_KEX_DH_GEX_REQUEST_OLD(common.MP(e)) | |
1108 y = common.getMP('\x00\x00\x00\x40' + '\x99' * 64)[0] | |
1109 f = common._MPpow(transport.DH_GENERATOR, y, transport.DH_PRIME) | |
1110 sharedSecret = common._MPpow(e, y, transport.DH_PRIME) | |
1111 | |
1112 h = sha.new() | |
1113 h.update(common.NS(self.proto.ourVersionString) * 2) | |
1114 h.update(common.NS(self.proto.ourKexInitPayload) * 2) | |
1115 h.update(common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob())) | |
1116 h.update(common.MP(e)) | |
1117 h.update(f) | |
1118 h.update(sharedSecret) | |
1119 exchangeHash = h.digest() | |
1120 | |
1121 signature = self.proto.factory.privateKeys['ssh-rsa'].sign( | |
1122 exchangeHash) | |
1123 | |
1124 self.assertEquals( | |
1125 self.packets, | |
1126 [(transport.MSG_KEXDH_REPLY, | |
1127 common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()) | |
1128 + f + common.NS(signature)), | |
1129 (transport.MSG_NEWKEYS, '')]) | |
1130 | |
1131 | |
1132 def test_KEX_DH_GEX_REQUEST_OLD(self): | |
1133 """ | |
1134 Test that the KEX_DH_GEX_REQUEST_OLD message causes the server | |
1135 to reply with a KEX_DH_GEX_GROUP message with the correct | |
1136 Diffie-Hellman group. | |
1137 """ | |
1138 self.proto.supportedKeyExchanges = [ | |
1139 'diffie-hellman-group-exchange-sha1'] | |
1140 self.proto.supportedPublicKeys = ['ssh-rsa'] | |
1141 self.proto.dataReceived(self.transport.value()) | |
1142 self.proto.ssh_KEX_DH_GEX_REQUEST_OLD('\x00\x00\x04\x00') | |
1143 self.assertEquals( | |
1144 self.packets, | |
1145 [(transport.MSG_KEX_DH_GEX_GROUP, | |
1146 common.MP(transport.DH_PRIME) + '\x00\x00\x00\x01\x02')]) | |
1147 self.assertEquals(self.proto.g, 2) | |
1148 self.assertEquals(self.proto.p, transport.DH_PRIME) | |
1149 | |
1150 | |
1151 def test_KEX_DH_GEX_REQUEST_OLD_badKexAlg(self): | |
1152 """ | |
1153 Test that if the server recieves a KEX_DH_GEX_REQUEST_OLD message | |
1154 and the key exchange algorithm is not 'diffie-hellman-group1-sha1' or | |
1155 'diffie-hellman-group-exchange-sha1', we raise a ConchError. | |
1156 """ | |
1157 self.proto.kexAlg = None | |
1158 self.assertRaises(ConchError, self.proto.ssh_KEX_DH_GEX_REQUEST_OLD, | |
1159 None) | |
1160 | |
1161 | |
1162 def test_KEX_DH_GEX_REQUEST(self): | |
1163 """ | |
1164 Test that the KEX_DH_GEX_REQUEST message causes the server to reply | |
1165 with a KEX_DH_GEX_GROUP message with the correct Diffie-Hellman | |
1166 group. | |
1167 """ | |
1168 self.proto.supportedKeyExchanges = [ | |
1169 'diffie-hellman-group-exchange-sha1'] | |
1170 self.proto.supportedPublicKeys = ['ssh-rsa'] | |
1171 self.proto.dataReceived(self.transport.value()) | |
1172 self.proto.ssh_KEX_DH_GEX_REQUEST('\x00\x00\x04\x00\x00\x00\x08\x00' + | |
1173 '\x00\x00\x0c\x00') | |
1174 self.assertEquals( | |
1175 self.packets, | |
1176 [(transport.MSG_KEX_DH_GEX_GROUP, | |
1177 common.MP(transport.DH_PRIME) + '\x00\x00\x00\x01\x03')]) | |
1178 self.assertEquals(self.proto.g, 3) | |
1179 self.assertEquals(self.proto.p, transport.DH_PRIME) | |
1180 | |
1181 | |
1182 def test_KEX_DH_GEX_INIT_after_REQUEST(self): | |
1183 """ | |
1184 Test that the KEX_DH_GEX_INIT message after the client sends | |
1185 KEX_DH_GEX_REQUEST causes the server to send a KEX_DH_GEX_INIT message | |
1186 with a public key and signature. | |
1187 """ | |
1188 self.test_KEX_DH_GEX_REQUEST() | |
1189 e = pow(self.proto.g, 3, self.proto.p) | |
1190 y = common.getMP('\x00\x00\x00\x80' + '\x99' * 128)[0] | |
1191 f = common._MPpow(self.proto.g, y, self.proto.p) | |
1192 sharedSecret = common._MPpow(e, y, self.proto.p) | |
1193 h = sha.new() | |
1194 h.update(common.NS(self.proto.ourVersionString) * 2) | |
1195 h.update(common.NS(self.proto.ourKexInitPayload) * 2) | |
1196 h.update(common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob())) | |
1197 h.update('\x00\x00\x04\x00\x00\x00\x08\x00\x00\x00\x0c\x00') | |
1198 h.update(common.MP(self.proto.p)) | |
1199 h.update(common.MP(self.proto.g)) | |
1200 h.update(common.MP(e)) | |
1201 h.update(f) | |
1202 h.update(sharedSecret) | |
1203 exchangeHash = h.digest() | |
1204 self.proto.ssh_KEX_DH_GEX_INIT(common.MP(e)) | |
1205 self.assertEquals( | |
1206 self.packets[1], | |
1207 (transport.MSG_KEX_DH_GEX_REPLY, | |
1208 common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()) + | |
1209 f + common.NS(self.proto.factory.privateKeys['ssh-rsa'].sign( | |
1210 exchangeHash)))) | |
1211 | |
1212 | |
1213 def test_KEX_DH_GEX_INIT_after_REQUEST_OLD(self): | |
1214 """ | |
1215 Test that the KEX_DH_GEX_INIT message after the client sends | |
1216 KEX_DH_GEX_REQUEST_OLD causes the server to sent a KEX_DH_GEX_INIT | |
1217 message with a public key and signature. | |
1218 """ | |
1219 self.test_KEX_DH_GEX_REQUEST_OLD() | |
1220 e = pow(self.proto.g, 3, self.proto.p) | |
1221 y = common.getMP('\x00\x00\x00\x80' + '\x99' * 128)[0] | |
1222 f = common._MPpow(self.proto.g, y, self.proto.p) | |
1223 sharedSecret = common._MPpow(e, y, self.proto.p) | |
1224 h = sha.new() | |
1225 h.update(common.NS(self.proto.ourVersionString) * 2) | |
1226 h.update(common.NS(self.proto.ourKexInitPayload) * 2) | |
1227 h.update(common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob())) | |
1228 h.update('\x00\x00\x04\x00') | |
1229 h.update(common.MP(self.proto.p)) | |
1230 h.update(common.MP(self.proto.g)) | |
1231 h.update(common.MP(e)) | |
1232 h.update(f) | |
1233 h.update(sharedSecret) | |
1234 exchangeHash = h.digest() | |
1235 self.proto.ssh_KEX_DH_GEX_INIT(common.MP(e)) | |
1236 self.assertEquals( | |
1237 self.packets[1:], | |
1238 [(transport.MSG_KEX_DH_GEX_REPLY, | |
1239 common.NS(self.proto.factory.publicKeys['ssh-rsa'].blob()) + | |
1240 f + common.NS(self.proto.factory.privateKeys['ssh-rsa'].sign( | |
1241 exchangeHash))), | |
1242 (transport.MSG_NEWKEYS, '')]) | |
1243 | |
1244 | |
1245 def test_keySetup(self): | |
1246 """ | |
1247 Test that _keySetup sets up the next encryption keys. | |
1248 """ | |
1249 self.proto.nextEncryptions = MockCipher() | |
1250 self.proto._keySetup('AB', 'CD') | |
1251 self.assertEquals(self.proto.sessionID, 'CD') | |
1252 self.proto._keySetup('AB', 'EF') | |
1253 self.assertEquals(self.proto.sessionID, 'CD') | |
1254 self.assertEquals(self.packets[-1], (transport.MSG_NEWKEYS, '')) | |
1255 newKeys = [self.proto._getKey(c, 'AB', 'EF') for c in 'ABCDEF'] | |
1256 self.assertEquals( | |
1257 self.proto.nextEncryptions.keys, | |
1258 (newKeys[1], newKeys[3], newKeys[0], newKeys[2], newKeys[5], | |
1259 newKeys[4])) | |
1260 | |
1261 | |
1262 def test_NEWKEYS(self): | |
1263 """ | |
1264 Test that NEWKEYS transitions the keys in nextEncryptions to | |
1265 currentEncryptions. | |
1266 """ | |
1267 self.test_KEXINIT() | |
1268 | |
1269 self.proto.nextEncryptions = transport.SSHCiphers('none', 'none', | |
1270 'none', 'none') | |
1271 self.proto.ssh_NEWKEYS('') | |
1272 self.assertIdentical(self.proto.currentEncryptions, | |
1273 self.proto.nextEncryptions) | |
1274 self.assertIdentical(self.proto.outgoingCompression, None) | |
1275 self.assertIdentical(self.proto.incomingCompression, None) | |
1276 self.proto.outgoingCompressionType = 'zlib' | |
1277 self.proto.ssh_NEWKEYS('') | |
1278 self.failIfIdentical(self.proto.outgoingCompression, None) | |
1279 self.proto.incomingCompressionType = 'zlib' | |
1280 self.proto.ssh_NEWKEYS('') | |
1281 self.failIfIdentical(self.proto.incomingCompression, None) | |
1282 | |
1283 | |
1284 def test_SERVICE_REQUEST(self): | |
1285 """ | |
1286 Test that the SERVICE_REQUEST message requests and starts a | |
1287 service. | |
1288 """ | |
1289 self.proto.ssh_SERVICE_REQUEST(common.NS('ssh-userauth')) | |
1290 self.assertEquals(self.packets, [(transport.MSG_SERVICE_ACCEPT, | |
1291 common.NS('ssh-userauth'))]) | |
1292 self.assertEquals(self.proto.service.name, 'MockService') | |
1293 | |
1294 | |
1295 def test_disconnectNEWKEYSData(self): | |
1296 """ | |
1297 Test that NEWKEYS disconnects if it receives data. | |
1298 """ | |
1299 self.proto.ssh_NEWKEYS("bad packet") | |
1300 self.checkDisconnected() | |
1301 | |
1302 | |
1303 def test_disconnectSERVICE_REQUESTBadService(self): | |
1304 """ | |
1305 Test that SERVICE_REQUESTS disconnects if an unknown service is | |
1306 requested. | |
1307 """ | |
1308 self.proto.ssh_SERVICE_REQUEST(common.NS('no service')) | |
1309 self.checkDisconnected(transport.DISCONNECT_SERVICE_NOT_AVAILABLE) | |
1310 | |
1311 | |
1312 | |
1313 class ClientSSHTransportTestCase(ServerAndClientSSHTransportBaseCase, | |
1314 TransportTestCase): | |
1315 """ | |
1316 Tests for SSHClientTransport. | |
1317 """ | |
1318 | |
1319 klass = transport.SSHClientTransport | |
1320 | |
1321 | |
1322 def test_KEXINIT(self): | |
1323 """ | |
1324 Test that receiving a KEXINIT packet sets up the correct values on the | |
1325 client. The way algorithms are picks is that the first item in the | |
1326 client's list that is also in the server's list is chosen. | |
1327 """ | |
1328 self.proto.dataReceived( 'SSH-2.0-Twisted\r\n\x00\x00\x01\xd4\t\x14' | |
1329 '\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99\x99' | |
1330 '\x99\x00\x00\x00=diffie-hellman-group1-sha1,diffie-hellman-g' | |
1331 'roup-exchange-sha1\x00\x00\x00\x0fssh-dss,ssh-rsa\x00\x00\x00' | |
1332 '\x85aes128-ctr,aes128-cbc,aes192-ctr,aes192-cbc,aes256-ctr,ae' | |
1333 's256-cbc,cast128-ctr,cast128-cbc,blowfish-ctr,blowfish-cbc,3d' | |
1334 'es-ctr,3des-cbc\x00\x00\x00\x85aes128-ctr,aes128-cbc,aes192-c' | |
1335 'tr,aes192-cbc,aes256-ctr,aes256-cbc,cast128-ctr,cast128-cbc,b' | |
1336 'lowfish-ctr,blowfish-cbc,3des-ctr,3des-cbc\x00\x00\x00\x12hma' | |
1337 'c-md5,hmac-sha1\x00\x00\x00\x12hmac-md5,hmac-sha1\x00\x00\x00' | |
1338 '\tzlib,none\x00\x00\x00\tzlib,none\x00\x00\x00\x00\x00\x00' | |
1339 '\x00\x00\x00\x00\x00\x00\x00\x99\x99\x99\x99\x99\x99\x99\x99' | |
1340 '\x99') | |
1341 self.assertEquals(self.proto.kexAlg, | |
1342 'diffie-hellman-group-exchange-sha1') | |
1343 self.assertEquals(self.proto.keyAlg, | |
1344 'ssh-rsa') | |
1345 self.assertEquals(self.proto.outgoingCompressionType, | |
1346 'none') | |
1347 self.assertEquals(self.proto.incomingCompressionType, | |
1348 'none') | |
1349 ne = self.proto.nextEncryptions | |
1350 self.assertEquals(ne.outCipType, 'aes256-ctr') | |
1351 self.assertEquals(ne.inCipType, 'aes256-ctr') | |
1352 self.assertEquals(ne.outMACType, 'hmac-sha1') | |
1353 self.assertEquals(ne.inMACType, 'hmac-sha1') | |
1354 | |
1355 | |
1356 def verifyHostKey(self, pubKey, fingerprint): | |
1357 """ | |
1358 Mock version of SSHClientTransport.verifyHostKey. | |
1359 """ | |
1360 self.calledVerifyHostKey = True | |
1361 self.assertEquals(pubKey, self.blob) | |
1362 self.assertEquals(fingerprint.replace(':', ''), | |
1363 md5.new(pubKey).hexdigest()) | |
1364 return defer.succeed(True) | |
1365 | |
1366 | |
1367 def setUp(self): | |
1368 TransportTestCase.setUp(self) | |
1369 self.blob = keys.Key.fromString(keydata.publicRSA_openssh).blob() | |
1370 self.privObj = keys.Key.fromString(keydata.privateRSA_openssh) | |
1371 self.calledVerifyHostKey = False | |
1372 self.proto.verifyHostKey = self.verifyHostKey | |
1373 | |
1374 | |
1375 def test_notImplementedClientMethods(self): | |
1376 """ | |
1377 verifyHostKey() should return a Deferred which fails with a | |
1378 NotImplementedError exception. connectionSecure() should raise | |
1379 NotImplementedError(). | |
1380 """ | |
1381 self.assertRaises(NotImplementedError, self.klass().connectionSecure) | |
1382 def _checkRaises(f): | |
1383 f.trap(NotImplementedError) | |
1384 d = self.klass().verifyHostKey(None, None) | |
1385 return d.addCallback(self.fail).addErrback(_checkRaises) | |
1386 | |
1387 | |
1388 def test_KEXINIT_groupexchange(self): | |
1389 """ | |
1390 Test that a KEXINIT packet with a group-exchange key exchange results | |
1391 in a KEX_DH_GEX_REQUEST_OLD message.. | |
1392 """ | |
1393 self.proto.supportedKeyExchanges = [ | |
1394 'diffie-hellman-group-exchange-sha1'] | |
1395 self.proto.dataReceived(self.transport.value()) | |
1396 self.assertEquals(self.packets, [(transport.MSG_KEX_DH_GEX_REQUEST_OLD, | |
1397 '\x00\x00\x08\x00')]) | |
1398 | |
1399 | |
1400 def test_KEXINIT_group1(self): | |
1401 """ | |
1402 Like test_KEXINIT_groupexchange, but for the group-1 key exchange. | |
1403 """ | |
1404 self.proto.supportedKeyExchanges = ['diffie-hellman-group1-sha1'] | |
1405 self.proto.dataReceived(self.transport.value()) | |
1406 self.assertEquals(common.MP(self.proto.x)[5:], '\x99' * 64) | |
1407 self.assertEquals(self.packets, | |
1408 [(transport.MSG_KEXDH_INIT, self.proto.e)]) | |
1409 | |
1410 | |
1411 def test_KEXINIT_badKexAlg(self): | |
1412 """ | |
1413 Test that the client raises a ConchError if it receives a | |
1414 KEXINIT message bug doesn't have a key exchange algorithm that we | |
1415 understand. | |
1416 """ | |
1417 self.proto.supportedKeyExchanges = ['diffie-hellman-group2-sha1'] | |
1418 data = self.transport.value().replace('group1', 'group2') | |
1419 self.assertRaises(ConchError, self.proto.dataReceived, data) | |
1420 | |
1421 | |
1422 def test_KEXDH_REPLY(self): | |
1423 """ | |
1424 Test that the KEXDH_REPLY message verifies the server. | |
1425 """ | |
1426 self.test_KEXINIT_group1() | |
1427 | |
1428 sharedSecret = common._MPpow(transport.DH_GENERATOR, | |
1429 self.proto.x, transport.DH_PRIME) | |
1430 h = sha.new() | |
1431 h.update(common.NS(self.proto.ourVersionString) * 2) | |
1432 h.update(common.NS(self.proto.ourKexInitPayload) * 2) | |
1433 h.update(common.NS(self.blob)) | |
1434 h.update(self.proto.e) | |
1435 h.update('\x00\x00\x00\x01\x02') # f | |
1436 h.update(sharedSecret) | |
1437 exchangeHash = h.digest() | |
1438 | |
1439 def _cbTestKEXDH_REPLY(value): | |
1440 self.assertIdentical(value, None) | |
1441 self.assertEquals(self.calledVerifyHostKey, True) | |
1442 self.assertEquals(self.proto.sessionID, exchangeHash) | |
1443 | |
1444 signature = self.privObj.sign(exchangeHash) | |
1445 | |
1446 d = self.proto.ssh_KEX_DH_GEX_GROUP( | |
1447 (common.NS(self.blob) + '\x00\x00\x00\x01\x02' + | |
1448 common.NS(signature))) | |
1449 d.addCallback(_cbTestKEXDH_REPLY) | |
1450 | |
1451 return d | |
1452 | |
1453 | |
1454 def test_KEX_DH_GEX_GROUP(self): | |
1455 """ | |
1456 Test that the KEX_DH_GEX_GROUP message results in a | |
1457 KEX_DH_GEX_INIT message with the client's Diffie-Hellman public key. | |
1458 """ | |
1459 self.test_KEXINIT_groupexchange() | |
1460 self.proto.ssh_KEX_DH_GEX_GROUP( | |
1461 '\x00\x00\x00\x01\x0f\x00\x00\x00\x01\x02') | |
1462 self.assertEquals(self.proto.p, 15) | |
1463 self.assertEquals(self.proto.g, 2) | |
1464 self.assertEquals(common.MP(self.proto.x)[5:], '\x99' * 40) | |
1465 self.assertEquals(self.proto.e, | |
1466 common.MP(pow(2, self.proto.x, 15))) | |
1467 self.assertEquals(self.packets[1:], [(transport.MSG_KEX_DH_GEX_INIT, | |
1468 self.proto.e)]) | |
1469 | |
1470 | |
1471 def test_KEX_DH_GEX_REPLY(self): | |
1472 """ | |
1473 Test that the KEX_DH_GEX_REPLY message results in a verified | |
1474 server. | |
1475 """ | |
1476 | |
1477 self.test_KEX_DH_GEX_GROUP() | |
1478 sharedSecret = common._MPpow(3, self.proto.x, self.proto.p) | |
1479 h = sha.new() | |
1480 h.update(common.NS(self.proto.ourVersionString) * 2) | |
1481 h.update(common.NS(self.proto.ourKexInitPayload) * 2) | |
1482 h.update(common.NS(self.blob)) | |
1483 h.update('\x00\x00\x08\x00\x00\x00\x00\x01\x0f\x00\x00\x00\x01\x02') | |
1484 h.update(self.proto.e) | |
1485 h.update('\x00\x00\x00\x01\x03') # f | |
1486 h.update(sharedSecret) | |
1487 exchangeHash = h.digest() | |
1488 | |
1489 def _cbTestKEX_DH_GEX_REPLY(value): | |
1490 self.assertIdentical(value, None) | |
1491 self.assertEquals(self.calledVerifyHostKey, True) | |
1492 self.assertEquals(self.proto.sessionID, exchangeHash) | |
1493 | |
1494 signature = self.privObj.sign(exchangeHash) | |
1495 | |
1496 d = self.proto.ssh_KEX_DH_GEX_REPLY( | |
1497 common.NS(self.blob) + | |
1498 '\x00\x00\x00\x01\x03' + | |
1499 common.NS(signature)) | |
1500 d.addCallback(_cbTestKEX_DH_GEX_REPLY) | |
1501 return d | |
1502 | |
1503 | |
1504 def test_keySetup(self): | |
1505 """ | |
1506 Test that _keySetup sets up the next encryption keys. | |
1507 """ | |
1508 self.proto.nextEncryptions = MockCipher() | |
1509 self.proto._keySetup('AB', 'CD') | |
1510 self.assertEquals(self.proto.sessionID, 'CD') | |
1511 self.proto._keySetup('AB', 'EF') | |
1512 self.assertEquals(self.proto.sessionID, 'CD') | |
1513 self.assertEquals(self.packets[-1], (transport.MSG_NEWKEYS, '')) | |
1514 newKeys = [self.proto._getKey(c, 'AB', 'EF') for c in 'ABCDEF'] | |
1515 self.assertEquals(self.proto.nextEncryptions.keys, | |
1516 (newKeys[0], newKeys[2], newKeys[1], newKeys[3], | |
1517 newKeys[4], newKeys[5])) | |
1518 | |
1519 | |
1520 def test_NEWKEYS(self): | |
1521 """ | |
1522 Test that NEWKEYS transitions the keys from nextEncryptions to | |
1523 currentEncryptions. | |
1524 """ | |
1525 self.test_KEXINIT() | |
1526 secure = [False] | |
1527 def stubConnectionSecure(): | |
1528 secure[0] = True | |
1529 self.proto.connectionSecure = stubConnectionSecure | |
1530 | |
1531 self.proto.nextEncryptions = transport.SSHCiphers('none', 'none', | |
1532 'none', 'none') | |
1533 self.proto.ssh_NEWKEYS('') | |
1534 | |
1535 self.failIfIdentical(self.proto.currentEncryptions, | |
1536 self.proto.nextEncryptions) | |
1537 | |
1538 self.proto.nextEncryptions = MockCipher() | |
1539 self.proto._keySetup('AB', 'EF') | |
1540 self.assertIdentical(self.proto.outgoingCompression, None) | |
1541 self.assertIdentical(self.proto.incomingCompression, None) | |
1542 self.assertIdentical(self.proto.currentEncryptions, | |
1543 self.proto.nextEncryptions) | |
1544 self.assertTrue(secure[0]) | |
1545 self.proto.outgoingCompressionType = 'zlib' | |
1546 self.proto.ssh_NEWKEYS('') | |
1547 self.failIfIdentical(self.proto.outgoingCompression, None) | |
1548 self.proto.incomingCompressionType = 'zlib' | |
1549 self.proto.ssh_NEWKEYS('') | |
1550 self.failIfIdentical(self.proto.incomingCompression, None) | |
1551 | |
1552 | |
1553 def test_SERVICE_ACCEPT(self): | |
1554 """ | |
1555 Test that the SERVICE_ACCEPT packet starts the requested service. | |
1556 """ | |
1557 self.proto.instance = MockService() | |
1558 self.proto.ssh_SERVICE_ACCEPT('\x00\x00\x00\x0bMockService') | |
1559 self.assertTrue(self.proto.instance.started) | |
1560 | |
1561 | |
1562 def test_requestService(self): | |
1563 """ | |
1564 Test that requesting a service sends a SERVICE_REQUEST packet. | |
1565 """ | |
1566 self.proto.requestService(MockService()) | |
1567 self.assertEquals(self.packets, [(transport.MSG_SERVICE_REQUEST, | |
1568 '\x00\x00\x00\x0bMockService')]) | |
1569 | |
1570 | |
1571 def test_disconnectKEXDH_REPLYBadSignature(self): | |
1572 """ | |
1573 Test that KEXDH_REPLY disconnects if the signature is bad. | |
1574 """ | |
1575 self.test_KEXDH_REPLY() | |
1576 self.proto._continueKEXDH_REPLY(None, self.blob, 3, "bad signature") | |
1577 self.checkDisconnected(transport.DISCONNECT_KEY_EXCHANGE_FAILED) | |
1578 | |
1579 | |
1580 def test_disconnectGEX_REPLYBadSignature(self): | |
1581 """ | |
1582 Like test_disconnectKEXDH_REPLYBadSignature, but for DH_GEX_REPLY. | |
1583 """ | |
1584 self.test_KEX_DH_GEX_REPLY() | |
1585 self.proto._continueGEX_REPLY(None, self.blob, 3, "bad signature") | |
1586 self.checkDisconnected(transport.DISCONNECT_KEY_EXCHANGE_FAILED) | |
1587 | |
1588 | |
1589 def test_disconnectNEWKEYSData(self): | |
1590 """ | |
1591 Test that NEWKEYS disconnects if it receives data. | |
1592 """ | |
1593 self.proto.ssh_NEWKEYS("bad packet") | |
1594 self.checkDisconnected() | |
1595 | |
1596 | |
1597 def test_disconnectSERVICE_ACCEPT(self): | |
1598 """ | |
1599 Test that SERVICE_ACCEPT disconnects if the accepted protocol is | |
1600 differet from the asked-for protocol. | |
1601 """ | |
1602 self.proto.instance = MockService() | |
1603 self.proto.ssh_SERVICE_ACCEPT('\x00\x00\x00\x03bad') | |
1604 self.checkDisconnected() | |
1605 | |
1606 | |
1607 | |
1608 class SSHCiphersTestCase(unittest.TestCase): | |
1609 """ | |
1610 Tests for the SSHCiphers helper class. | |
1611 """ | |
1612 if Crypto is None: | |
1613 skip = "cannot run w/o PyCrypto" | |
1614 | |
1615 | |
1616 def test_init(self): | |
1617 """ | |
1618 Test that the initializer sets up the SSHCiphers object. | |
1619 """ | |
1620 ciphers = transport.SSHCiphers('A', 'B', 'C', 'D') | |
1621 self.assertEquals(ciphers.outCipType, 'A') | |
1622 self.assertEquals(ciphers.inCipType, 'B') | |
1623 self.assertEquals(ciphers.outMACType, 'C') | |
1624 self.assertEquals(ciphers.inMACType, 'D') | |
1625 | |
1626 | |
1627 def test_getCipher(self): | |
1628 """ | |
1629 Test that the _getCipher method returns the correct cipher. | |
1630 """ | |
1631 ciphers = transport.SSHCiphers('A', 'B', 'C', 'D') | |
1632 iv = key = '\x00' * 16 | |
1633 for cipName, (modName, keySize, counter) in ciphers.cipherMap.items(): | |
1634 cip = ciphers._getCipher(cipName, iv, key) | |
1635 if cipName == 'none': | |
1636 self.assertIsInstance(cip, transport._DummyCipher) | |
1637 else: | |
1638 self.assertTrue(str(cip).startswith('<' + modName)) | |
1639 | |
1640 | |
1641 def test_getMAC(self): | |
1642 """ | |
1643 Test that the _getMAC method returns the correct MAC. | |
1644 """ | |
1645 ciphers = transport.SSHCiphers('A', 'B', 'C', 'D') | |
1646 key = '\x00' * 64 | |
1647 for macName, mac in ciphers.macMap.items(): | |
1648 mod = ciphers._getMAC(macName, key) | |
1649 if macName == 'none': | |
1650 self.assertIdentical(mac, None) | |
1651 else: | |
1652 self.assertEquals(mod[0], mac) | |
1653 self.assertEquals(mod[1], | |
1654 Crypto.Cipher.XOR.new('\x36').encrypt(key)) | |
1655 self.assertEquals(mod[2], | |
1656 Crypto.Cipher.XOR.new('\x5c').encrypt(key)) | |
1657 self.assertEquals(mod[3], len(mod[0].new().digest())) | |
1658 | |
1659 | |
1660 def test_setKeysCiphers(self): | |
1661 """ | |
1662 Test that setKeys sets up the ciphers. | |
1663 """ | |
1664 key = '\x00' * 64 | |
1665 cipherItems = transport.SSHCiphers.cipherMap.items() | |
1666 for cipName, (modName, keySize, counter) in cipherItems: | |
1667 encCipher = transport.SSHCiphers(cipName, 'none', 'none', 'none') | |
1668 decCipher = transport.SSHCiphers('none', cipName, 'none', 'none') | |
1669 cip = encCipher._getCipher(cipName, key, key) | |
1670 bs = cip.block_size | |
1671 encCipher.setKeys(key, key, '', '', '', '') | |
1672 decCipher.setKeys('', '', key, key, '', '') | |
1673 self.assertEquals(encCipher.encBlockSize, bs) | |
1674 self.assertEquals(decCipher.decBlockSize, bs) | |
1675 enc = cip.encrypt(key[:bs]) | |
1676 enc2 = cip.encrypt(key[:bs]) | |
1677 if counter: | |
1678 self.failIfEquals(enc, enc2) | |
1679 self.assertEquals(encCipher.encrypt(key[:bs]), enc) | |
1680 self.assertEquals(encCipher.encrypt(key[:bs]), enc2) | |
1681 self.assertEquals(decCipher.decrypt(enc), key[:bs]) | |
1682 self.assertEquals(decCipher.decrypt(enc2), key[:bs]) | |
1683 | |
1684 | |
1685 def test_setKeysMACs(self): | |
1686 """ | |
1687 Test that setKeys sets up the MACs. | |
1688 """ | |
1689 key = '\x00' * 64 | |
1690 for macName, mod in transport.SSHCiphers.macMap.items(): | |
1691 outMac = transport.SSHCiphers('none', 'none', macName, 'none') | |
1692 inMac = transport.SSHCiphers('none', 'none', 'none', macName) | |
1693 outMac.setKeys('', '', '', '', key, '') | |
1694 inMac.setKeys('', '', '', '', '', key) | |
1695 if mod: | |
1696 ds = mod.digest_size | |
1697 else: | |
1698 ds = 0 | |
1699 self.assertEquals(inMac.verifyDigestSize, ds) | |
1700 if mod: | |
1701 mod, i, o, ds = outMac._getMAC(macName, key) | |
1702 seqid = 0 | |
1703 data = key | |
1704 packet = '\x00' * 4 + key | |
1705 if mod: | |
1706 mac = mod.new(o + mod.new(i + packet).digest()).digest() | |
1707 else: | |
1708 mac = '' | |
1709 self.assertEquals(outMac.makeMAC(seqid, data), mac) | |
1710 self.assertTrue(inMac.verify(seqid, data, mac)) | |
1711 | |
1712 | |
1713 | |
1714 class CounterTestCase(unittest.TestCase): | |
1715 """ | |
1716 Tests for the _Counter helper class. | |
1717 """ | |
1718 if Crypto is None: | |
1719 skip = "cannot run w/o PyCrypto" | |
1720 | |
1721 | |
1722 def test_init(self): | |
1723 """ | |
1724 Test that the counter is initialized correctly. | |
1725 """ | |
1726 counter = transport._Counter('\x00' * 8 + '\xff' * 8, 8) | |
1727 self.assertEquals(counter.blockSize, 8) | |
1728 self.assertEquals(counter.count.tostring(), '\x00' * 8) | |
1729 | |
1730 | |
1731 def test_count(self): | |
1732 """ | |
1733 Test that the counter counts incrementally and wraps at the top. | |
1734 """ | |
1735 counter = transport._Counter('\x00', 1) | |
1736 self.assertEquals(counter(), '\x01') | |
1737 self.assertEquals(counter(), '\x02') | |
1738 [counter() for i in range(252)] | |
1739 self.assertEquals(counter(), '\xff') | |
1740 self.assertEquals(counter(), '\x00') | |
1741 | |
1742 | |
1743 | |
1744 class TransportLoopbackTestCase(unittest.TestCase): | |
1745 """ | |
1746 Test the server transport and client transport against each other, | |
1747 """ | |
1748 if Crypto is None: | |
1749 skip = "cannot run w/o PyCrypto" | |
1750 | |
1751 | |
1752 def _runClientServer(self, mod): | |
1753 """ | |
1754 Run an async client and server, modifying each using the mod function | |
1755 provided. Returns a Deferred called back when both Protocols have | |
1756 disconnected. | |
1757 | |
1758 @type mod: C{func} | |
1759 @rtype: C{defer.Deferred} | |
1760 """ | |
1761 factory = MockFactory() | |
1762 server = transport.SSHServerTransport() | |
1763 server.factory = factory | |
1764 factory.startFactory() | |
1765 server.errors = [] | |
1766 server.receiveError = lambda code, desc: server.errors.append(( | |
1767 code, desc)) | |
1768 client = transport.SSHClientTransport() | |
1769 client.verifyHostKey = lambda x, y: defer.succeed(None) | |
1770 client.errors = [] | |
1771 client.receiveError = lambda code, desc: client.errors.append(( | |
1772 code, desc)) | |
1773 client.connectionSecure = lambda: client.loseConnection() | |
1774 server = mod(server) | |
1775 client = mod(client) | |
1776 def check(ignored, server, client): | |
1777 name = repr([server.supportedCiphers[0], | |
1778 server.supportedMACs[0], | |
1779 server.supportedKeyExchanges[0], | |
1780 server.supportedCompressions[0]]) | |
1781 self.assertEquals(client.errors, []) | |
1782 self.assertEquals(server.errors, [( | |
1783 transport.DISCONNECT_CONNECTION_LOST, | |
1784 "user closed connection")]) | |
1785 if server.supportedCiphers[0] == 'none': | |
1786 self.assertFalse(server.isEncrypted(), name) | |
1787 self.assertFalse(client.isEncrypted(), name) | |
1788 else: | |
1789 self.assertTrue(server.isEncrypted(), name) | |
1790 self.assertTrue(client.isEncrypted(), name) | |
1791 if server.supportedMACs[0] == 'none': | |
1792 self.assertFalse(server.isVerified(), name) | |
1793 self.assertFalse(client.isVerified(), name) | |
1794 else: | |
1795 self.assertTrue(server.isVerified(), name) | |
1796 self.assertTrue(client.isVerified(), name) | |
1797 | |
1798 d = loopback.loopbackAsync(server, client) | |
1799 d.addCallback(check, server, client) | |
1800 return d | |
1801 | |
1802 | |
1803 def test_ciphers(self): | |
1804 """ | |
1805 Test that the client and server play nicely together, in all | |
1806 the various combinations of ciphers. | |
1807 """ | |
1808 deferreds = [] | |
1809 for cipher in transport.SSHTransportBase.supportedCiphers + ['none']: | |
1810 def setCipher(proto): | |
1811 proto.supportedCiphers = [cipher] | |
1812 return proto | |
1813 deferreds.append(self._runClientServer(setCipher)) | |
1814 return defer.DeferredList(deferreds, fireOnOneErrback=True) | |
1815 | |
1816 | |
1817 def test_macs(self): | |
1818 """ | |
1819 Like test_ciphers, but for the various MACs. | |
1820 """ | |
1821 deferreds = [] | |
1822 for mac in transport.SSHTransportBase.supportedMACs + ['none']: | |
1823 def setMAC(proto): | |
1824 proto.supportedMACs = [mac] | |
1825 return proto | |
1826 deferreds.append(self._runClientServer(setMAC)) | |
1827 return defer.DeferredList(deferreds, fireOnOneErrback=True) | |
1828 | |
1829 | |
1830 def test_keyexchanges(self): | |
1831 """ | |
1832 Like test_ciphers, but for the various key exchanges. | |
1833 """ | |
1834 deferreds = [] | |
1835 for kex in transport.SSHTransportBase.supportedKeyExchanges: | |
1836 def setKeyExchange(proto): | |
1837 proto.supportedKeyExchanges = [kex] | |
1838 return proto | |
1839 deferreds.append(self._runClientServer(setKeyExchange)) | |
1840 return defer.DeferredList(deferreds, fireOnOneErrback=True) | |
1841 | |
1842 | |
1843 def test_compressions(self): | |
1844 """ | |
1845 Like test_ciphers, but for the various compressions. | |
1846 """ | |
1847 deferreds = [] | |
1848 for compression in transport.SSHTransportBase.supportedCompressions: | |
1849 def setCompression(proto): | |
1850 proto.supportedCompressions = [compression] | |
1851 return proto | |
1852 deferreds.append(self._runClientServer(setCompression)) | |
1853 return defer.DeferredList(deferreds, fireOnOneErrback=True) | |
1854 | |
1855 | |
1856 | |
1857 class OldFactoryTestCase(unittest.TestCase): | |
1858 """ | |
1859 The old C{SSHFactory.getPublicKeys}() returned mappings of key names to | |
1860 strings of key blobs and mappings of key names to PyCrypto key objects from | |
1861 C{SSHFactory.getPrivateKeys}() (they could also be specified with the | |
1862 C{publicKeys} and C{privateKeys} attributes). This is no longer supported | |
1863 by the C{SSHServerTransport}, so we warn the user if they create an old | |
1864 factory. | |
1865 """ | |
1866 | |
1867 | |
1868 def test_getPublicKeysWarning(self): | |
1869 """ | |
1870 If the return value of C{getPublicKeys}() isn't a mapping from key | |
1871 names to C{Key} objects, then warn the user and convert the mapping. | |
1872 """ | |
1873 sshFactory = MockOldFactoryPublicKeys() | |
1874 self.assertWarns(DeprecationWarning, | |
1875 "Returning a mapping from strings to strings from" | |
1876 " getPublicKeys()/publicKeys (in %s) is deprecated. Return " | |
1877 "a mapping from strings to Key objects instead." % | |
1878 (qual(MockOldFactoryPublicKeys),), | |
1879 factory.__file__, sshFactory.startFactory) | |
1880 self.assertEquals(sshFactory.publicKeys, MockFactory().getPublicKeys()) | |
1881 | |
1882 | |
1883 def test_getPrivateKeysWarning(self): | |
1884 """ | |
1885 If the return value of C{getPrivateKeys}() isn't a mapping from key | |
1886 names to C{Key} objects, then warn the user and convert the mapping. | |
1887 """ | |
1888 sshFactory = MockOldFactoryPrivateKeys() | |
1889 self.assertWarns(DeprecationWarning, | |
1890 "Returning a mapping from strings to PyCrypto key objects from" | |
1891 " getPrivateKeys()/privateKeys (in %s) is deprecated. Return" | |
1892 " a mapping from strings to Key objects instead." % | |
1893 (qual(MockOldFactoryPrivateKeys),), | |
1894 factory.__file__, sshFactory.startFactory) | |
1895 self.assertEquals(sshFactory.privateKeys, | |
1896 MockFactory().getPrivateKeys()) | |
1897 | |
1898 | |
1899 def test_publicKeysWarning(self): | |
1900 """ | |
1901 If the value of the C{publicKeys} attribute isn't a mapping from key | |
1902 names to C{Key} objects, then warn the user and convert the mapping. | |
1903 """ | |
1904 sshFactory = MockOldFactoryPublicKeys() | |
1905 sshFactory.publicKeys = sshFactory.getPublicKeys() | |
1906 self.assertWarns(DeprecationWarning, | |
1907 "Returning a mapping from strings to strings from" | |
1908 " getPublicKeys()/publicKeys (in %s) is deprecated. Return " | |
1909 "a mapping from strings to Key objects instead." % | |
1910 (qual(MockOldFactoryPublicKeys),), | |
1911 factory.__file__, sshFactory.startFactory) | |
1912 self.assertEquals(sshFactory.publicKeys, MockFactory().getPublicKeys()) | |
1913 | |
1914 | |
1915 def test_privateKeysWarning(self): | |
1916 """ | |
1917 If the return value of C{privateKeys} attribute isn't a mapping from | |
1918 key names to C{Key} objects, then warn the user and convert the | |
1919 mapping. | |
1920 """ | |
1921 sshFactory = MockOldFactoryPrivateKeys() | |
1922 sshFactory.privateKeys = sshFactory.getPrivateKeys() | |
1923 self.assertWarns(DeprecationWarning, | |
1924 "Returning a mapping from strings to PyCrypto key objects from" | |
1925 " getPrivateKeys()/privateKeys (in %s) is deprecated. Return" | |
1926 " a mapping from strings to Key objects instead." % | |
1927 (qual(MockOldFactoryPrivateKeys),), | |
1928 factory.__file__, sshFactory.startFactory) | |
1929 self.assertEquals(sshFactory.privateKeys, | |
1930 MockFactory().getPrivateKeys()) | |
OLD | NEW |