| OLD | NEW |
| (Empty) |
| 1 # -*- test-case-name: twisted.conch.test.test_ssh -*- | |
| 2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
| 3 # See LICENSE for details. | |
| 4 | |
| 5 try: | |
| 6 import Crypto | |
| 7 except ImportError: | |
| 8 Crypto = None | |
| 9 | |
| 10 from twisted.conch.ssh import common, session, forwarding | |
| 11 from twisted.conch import avatar, error | |
| 12 from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh | |
| 13 from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh | |
| 14 from twisted.cred import portal | |
| 15 from twisted.internet import defer, protocol, reactor | |
| 16 from twisted.internet.error import ProcessTerminated | |
| 17 from twisted.python import failure, log | |
| 18 from twisted.trial import unittest | |
| 19 | |
| 20 from test_recvline import LoopbackRelay | |
| 21 | |
| 22 import struct | |
| 23 | |
| 24 | |
| 25 class ConchTestRealm: | |
| 26 | |
| 27 def requestAvatar(self, avatarID, mind, *interfaces): | |
| 28 unittest.assertEquals(avatarID, 'testuser') | |
| 29 a = ConchTestAvatar() | |
| 30 return interfaces[0], a, a.logout | |
| 31 | |
| 32 class ConchTestAvatar(avatar.ConchUser): | |
| 33 loggedOut = False | |
| 34 | |
| 35 def __init__(self): | |
| 36 avatar.ConchUser.__init__(self) | |
| 37 self.listeners = {} | |
| 38 self.channelLookup.update({'session': session.SSHSession, | |
| 39 'direct-tcpip':forwarding.openConnectForwardingClient}) | |
| 40 self.subsystemLookup.update({'crazy': CrazySubsystem}) | |
| 41 | |
| 42 def global_foo(self, data): | |
| 43 unittest.assertEquals(data, 'bar') | |
| 44 return 1 | |
| 45 | |
| 46 def global_foo_2(self, data): | |
| 47 unittest.assertEquals(data, 'bar2') | |
| 48 return 1, 'data' | |
| 49 | |
| 50 def global_tcpip_forward(self, data): | |
| 51 host, port = forwarding.unpackGlobal_tcpip_forward(data) | |
| 52 try: listener = reactor.listenTCP(port, | |
| 53 forwarding.SSHListenForwardingFactory(self.conn, | |
| 54 (host, port), | |
| 55 forwarding.SSHListenServerForwardingChannel), | |
| 56 interface = host) | |
| 57 except: | |
| 58 log.err() | |
| 59 unittest.fail("something went wrong with remote->local forwarding") | |
| 60 return 0 | |
| 61 else: | |
| 62 self.listeners[(host, port)] = listener | |
| 63 return 1 | |
| 64 | |
| 65 def global_cancel_tcpip_forward(self, data): | |
| 66 host, port = forwarding.unpackGlobal_tcpip_forward(data) | |
| 67 listener = self.listeners.get((host, port), None) | |
| 68 if not listener: | |
| 69 return 0 | |
| 70 del self.listeners[(host, port)] | |
| 71 listener.stopListening() | |
| 72 return 1 | |
| 73 | |
| 74 def logout(self): | |
| 75 loggedOut = True | |
| 76 for listener in self.listeners.values(): | |
| 77 log.msg('stopListening %s' % listener) | |
| 78 listener.stopListening() | |
| 79 | |
| 80 class ConchSessionForTestAvatar: | |
| 81 | |
| 82 def __init__(self, avatar): | |
| 83 unittest.assert_(isinstance(avatar, ConchTestAvatar)) | |
| 84 self.avatar = avatar | |
| 85 self.cmd = None | |
| 86 self.proto = None | |
| 87 self.ptyReq = False | |
| 88 self.eof = 0 | |
| 89 | |
| 90 def getPty(self, term, windowSize, attrs): | |
| 91 log.msg('pty req') | |
| 92 unittest.assertEquals(term, 'conch-test-term') | |
| 93 unittest.assertEquals(windowSize, (24, 80, 0, 0)) | |
| 94 self.ptyReq = True | |
| 95 | |
| 96 def openShell(self, proto): | |
| 97 log.msg('openning shell') | |
| 98 unittest.assertEquals(self.ptyReq, True) | |
| 99 self.proto = proto | |
| 100 EchoTransport(proto) | |
| 101 self.cmd = 'shell' | |
| 102 | |
| 103 def execCommand(self, proto, cmd): | |
| 104 self.cmd = cmd | |
| 105 unittest.assert_(cmd.split()[0] in ['false', 'echo', 'secho', 'eecho','j
umboliah'], | |
| 106 'invalid command: %s' % cmd.split()[0]) | |
| 107 if cmd == 'jumboliah': | |
| 108 raise error.ConchError('bad exec') | |
| 109 self.proto = proto | |
| 110 f = cmd.split()[0] | |
| 111 if f == 'false': | |
| 112 FalseTransport(proto) | |
| 113 elif f == 'echo': | |
| 114 t = EchoTransport(proto) | |
| 115 t.write(cmd[5:]) | |
| 116 t.loseConnection() | |
| 117 elif f == 'secho': | |
| 118 t = SuperEchoTransport(proto) | |
| 119 t.write(cmd[6:]) | |
| 120 t.loseConnection() | |
| 121 elif f == 'eecho': | |
| 122 t = ErrEchoTransport(proto) | |
| 123 t.write(cmd[6:]) | |
| 124 t.loseConnection() | |
| 125 self.avatar.conn.transport.expectedLoseConnection = 1 | |
| 126 | |
| 127 # def closeReceived(self): | |
| 128 # #if self.proto: | |
| 129 # # self.proto.transport.loseConnection() | |
| 130 # self.loseConnection() | |
| 131 | |
| 132 def eofReceived(self): | |
| 133 self.eof = 1 | |
| 134 | |
| 135 def closed(self): | |
| 136 log.msg('closed cmd "%s"' % self.cmd) | |
| 137 if self.cmd == 'echo hello': | |
| 138 rwl = self.proto.session.remoteWindowLeft | |
| 139 unittest.assertEquals(rwl, 4) | |
| 140 elif self.cmd == 'eecho hello': | |
| 141 rwl = self.proto.session.remoteWindowLeft | |
| 142 unittest.assertEquals(rwl, 4) | |
| 143 elif self.cmd == 'shell': | |
| 144 unittest.assert_(self.eof) | |
| 145 | |
| 146 from twisted.python import components | |
| 147 components.registerAdapter(ConchSessionForTestAvatar, ConchTestAvatar, session.I
Session) | |
| 148 | |
| 149 class CrazySubsystem(protocol.Protocol): | |
| 150 | |
| 151 def __init__(self, *args, **kw): | |
| 152 pass | |
| 153 | |
| 154 def connectionMade(self): | |
| 155 """ | |
| 156 good ... good | |
| 157 """ | |
| 158 | |
| 159 class FalseTransport: | |
| 160 | |
| 161 def __init__(self, p): | |
| 162 p.makeConnection(self) | |
| 163 p.processEnded(failure.Failure(ProcessTerminated(255, None, None))) | |
| 164 | |
| 165 def loseConnection(self): | |
| 166 pass | |
| 167 | |
| 168 class EchoTransport: | |
| 169 | |
| 170 def __init__(self, p): | |
| 171 self.proto = p | |
| 172 p.makeConnection(self) | |
| 173 self.closed = 0 | |
| 174 | |
| 175 def write(self, data): | |
| 176 log.msg(repr(data)) | |
| 177 self.proto.outReceived(data) | |
| 178 self.proto.outReceived('\r\n') | |
| 179 if '\x00' in data: # mimic 'exit' for the shell test | |
| 180 self.loseConnection() | |
| 181 | |
| 182 def loseConnection(self): | |
| 183 if self.closed: return | |
| 184 self.closed = 1 | |
| 185 self.proto.inConnectionLost() | |
| 186 self.proto.outConnectionLost() | |
| 187 self.proto.errConnectionLost() | |
| 188 self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)
)) | |
| 189 | |
| 190 class ErrEchoTransport: | |
| 191 | |
| 192 def __init__(self, p): | |
| 193 self.proto = p | |
| 194 p.makeConnection(self) | |
| 195 self.closed = 0 | |
| 196 | |
| 197 def write(self, data): | |
| 198 self.proto.errReceived(data) | |
| 199 self.proto.errReceived('\r\n') | |
| 200 | |
| 201 def loseConnection(self): | |
| 202 if self.closed: return | |
| 203 self.closed = 1 | |
| 204 self.proto.inConnectionLost() | |
| 205 self.proto.outConnectionLost() | |
| 206 self.proto.errConnectionLost() | |
| 207 self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)
)) | |
| 208 | |
| 209 class SuperEchoTransport: | |
| 210 | |
| 211 def __init__(self, p): | |
| 212 self.proto = p | |
| 213 p.makeConnection(self) | |
| 214 self.closed = 0 | |
| 215 | |
| 216 def write(self, data): | |
| 217 self.proto.outReceived(data) | |
| 218 self.proto.outReceived('\r\n') | |
| 219 self.proto.errReceived(data) | |
| 220 self.proto.errReceived('\r\n') | |
| 221 | |
| 222 def loseConnection(self): | |
| 223 if self.closed: return | |
| 224 self.closed = 1 | |
| 225 self.proto.inConnectionLost() | |
| 226 self.proto.outConnectionLost() | |
| 227 self.proto.errConnectionLost() | |
| 228 self.proto.processEnded(failure.Failure(ProcessTerminated(0, None, None)
)) | |
| 229 | |
| 230 | |
| 231 if Crypto: # stuff that needs PyCrypto to even import | |
| 232 from twisted.conch import checkers | |
| 233 from twisted.conch.ssh import channel, connection, factory, keys | |
| 234 from twisted.conch.ssh import transport, userauth | |
| 235 | |
| 236 class UtilityTestCase(unittest.TestCase): | |
| 237 def testCounter(self): | |
| 238 c = transport._Counter('\x00\x00', 2) | |
| 239 for i in xrange(256 * 256): | |
| 240 self.assertEquals(c(), struct.pack('!H', (i + 1) % (2 ** 16))) | |
| 241 # It should wrap around, too. | |
| 242 for i in xrange(256 * 256): | |
| 243 self.assertEquals(c(), struct.pack('!H', (i + 1) % (2 ** 16))) | |
| 244 | |
| 245 | |
| 246 class ConchTestPublicKeyChecker(checkers.SSHPublicKeyDatabase): | |
| 247 def checkKey(self, credentials): | |
| 248 unittest.assertEquals(credentials.username, 'testuser', 'bad usernam
e') | |
| 249 unittest.assertEquals(credentials.blob, keys.getPublicKeyString(data
=publicDSA_openssh)) | |
| 250 return 1 | |
| 251 | |
| 252 class ConchTestPasswordChecker: | |
| 253 credentialInterfaces = checkers.IUsernamePassword, | |
| 254 | |
| 255 def requestAvatarId(self, credentials): | |
| 256 unittest.assertEquals(credentials.username, 'testuser', 'bad usernam
e') | |
| 257 unittest.assertEquals(credentials.password, 'testpass', 'bad passwor
d') | |
| 258 return defer.succeed(credentials.username) | |
| 259 | |
| 260 class ConchTestSSHChecker(checkers.SSHProtocolChecker): | |
| 261 | |
| 262 def areDone(self, avatarId): | |
| 263 unittest.assertEquals(avatarId, 'testuser') | |
| 264 if len(self.successfulCredentials[avatarId]) < 2: | |
| 265 return 0 | |
| 266 else: | |
| 267 return 1 | |
| 268 | |
| 269 class ConchTestServerFactory(factory.SSHFactory): | |
| 270 noisy = 0 | |
| 271 | |
| 272 services = { | |
| 273 'ssh-userauth':userauth.SSHUserAuthServer, | |
| 274 'ssh-connection':connection.SSHConnection | |
| 275 } | |
| 276 | |
| 277 def buildProtocol(self, addr): | |
| 278 proto = ConchTestServer() | |
| 279 proto.supportedPublicKeys = self.privateKeys.keys() | |
| 280 proto.factory = self | |
| 281 | |
| 282 if hasattr(self, 'expectedLoseConnection'): | |
| 283 proto.expectedLoseConnection = self.expectedLoseConnection | |
| 284 | |
| 285 self.proto = proto | |
| 286 return proto | |
| 287 | |
| 288 def getPublicKeys(self): | |
| 289 return { | |
| 290 'ssh-rsa':keys.getPublicKeyString(data=publicRSA_openssh), | |
| 291 'ssh-dss':keys.getPublicKeyString(data=publicDSA_openssh) | |
| 292 } | |
| 293 | |
| 294 def getPrivateKeys(self): | |
| 295 return { | |
| 296 'ssh-rsa':keys.getPrivateKeyObject(data=privateRSA_openssh), | |
| 297 'ssh-dss':keys.getPrivateKeyObject(data=privateDSA_openssh) | |
| 298 } | |
| 299 | |
| 300 def getPrimes(self): | |
| 301 return { | |
| 302 2048:[(transport.DH_GENERATOR, transport.DH_PRIME)] | |
| 303 } | |
| 304 | |
| 305 def getService(self, trans, name): | |
| 306 return factory.SSHFactory.getService(self, trans, name) | |
| 307 | |
| 308 class ConchTestBase: | |
| 309 | |
| 310 done = 0 | |
| 311 allowedToError = 0 | |
| 312 | |
| 313 def connectionLost(self, reason): | |
| 314 if self.done: | |
| 315 return | |
| 316 if not hasattr(self,'expectedLoseConnection'): | |
| 317 unittest.fail('unexpectedly lost connection %s\n%s' % (self, rea
son)) | |
| 318 self.done = 1 | |
| 319 | |
| 320 def receiveError(self, reasonCode, desc): | |
| 321 self.expectedLoseConnection = 1 | |
| 322 if not self.allowedToError: | |
| 323 unittest.fail('got disconnect for %s: reason %s, desc: %s' % | |
| 324 (self, reasonCode, desc)) | |
| 325 self.loseConnection() | |
| 326 | |
| 327 def receiveUnimplemented(self, seqID): | |
| 328 unittest.fail('got unimplemented: seqid %s' % seqID) | |
| 329 self.expectedLoseConnection = 1 | |
| 330 self.loseConnection() | |
| 331 | |
| 332 class ConchTestServer(ConchTestBase, transport.SSHServerTransport): | |
| 333 | |
| 334 def connectionLost(self, reason): | |
| 335 ConchTestBase.connectionLost(self, reason) | |
| 336 transport.SSHServerTransport.connectionLost(self, reason) | |
| 337 | |
| 338 class ConchTestClient(ConchTestBase, transport.SSHClientTransport): | |
| 339 | |
| 340 def connectionLost(self, reason): | |
| 341 ConchTestBase.connectionLost(self, reason) | |
| 342 transport.SSHClientTransport.connectionLost(self, reason) | |
| 343 | |
| 344 def verifyHostKey(self, key, fp): | |
| 345 unittest.assertEquals(key, keys.getPublicKeyString(data = publicRSA_
openssh)) | |
| 346 unittest.assertEquals(fp,'3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:
8f:af') | |
| 347 return defer.succeed(1) | |
| 348 | |
| 349 def connectionSecure(self): | |
| 350 self.requestService(ConchTestClientAuth('testuser', | |
| 351 ConchTestClientConnection())) | |
| 352 | |
| 353 class ConchTestClientAuth(userauth.SSHUserAuthClient): | |
| 354 | |
| 355 hasTriedNone = 0 # have we tried the 'none' auth yet? | |
| 356 canSucceedPublicKey = 0 # can we succed with this yet? | |
| 357 canSucceedPassword = 0 | |
| 358 | |
| 359 def ssh_USERAUTH_SUCCESS(self, packet): | |
| 360 if not self.canSucceedPassword and self.canSucceedPublicKey: | |
| 361 unittest.fail('got USERAUTH_SUCESS before password and publickey
') | |
| 362 userauth.SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet) | |
| 363 | |
| 364 def getPassword(self): | |
| 365 self.canSucceedPassword = 1 | |
| 366 return defer.succeed('testpass') | |
| 367 | |
| 368 def getPrivateKey(self): | |
| 369 self.canSucceedPublicKey = 1 | |
| 370 return defer.succeed(keys.getPrivateKeyObject(data=privateDSA_openss
h)) | |
| 371 | |
| 372 def getPublicKey(self): | |
| 373 return keys.getPublicKeyString(data=publicDSA_openssh) | |
| 374 | |
| 375 class ConchTestClientConnection(connection.SSHConnection): | |
| 376 | |
| 377 name = 'ssh-connection' | |
| 378 results = 0 | |
| 379 totalResults = 8 | |
| 380 | |
| 381 def serviceStarted(self): | |
| 382 self.openChannel(SSHTestFailExecChannel(conn = self)) | |
| 383 self.openChannel(SSHTestFalseChannel(conn = self)) | |
| 384 self.openChannel(SSHTestEchoChannel(localWindow=4, localMaxPacket=5,
conn = self)) | |
| 385 self.openChannel(SSHTestErrChannel(localWindow=4, localMaxPacket=5,
conn = self)) | |
| 386 self.openChannel(SSHTestMaxPacketChannel(localWindow=12, localMaxPac
ket=1, conn = self)) | |
| 387 self.openChannel(SSHTestShellChannel(conn = self)) | |
| 388 self.openChannel(SSHTestSubsystemChannel(conn = self)) | |
| 389 self.openChannel(SSHUnknownChannel(conn = self)) | |
| 390 | |
| 391 def addResult(self): | |
| 392 self.results += 1 | |
| 393 log.msg('got %s of %s results' % (self.results, self.totalResults)) | |
| 394 if self.results == self.totalResults: | |
| 395 self.transport.expectedLoseConnection = 1 | |
| 396 self.serviceStopped() | |
| 397 | |
| 398 class SSHUnknownChannel(channel.SSHChannel): | |
| 399 | |
| 400 name = 'crazy-unknown-channel' | |
| 401 | |
| 402 def openFailed(self, reason): | |
| 403 """ | |
| 404 good .... good | |
| 405 """ | |
| 406 log.msg('unknown open failed') | |
| 407 log.flushErrors() | |
| 408 self.conn.addResult() | |
| 409 | |
| 410 def channelOpen(self, ignored): | |
| 411 unittest.fail("opened unknown channel") | |
| 412 | |
| 413 class SSHTestFailExecChannel(channel.SSHChannel): | |
| 414 | |
| 415 name = 'session' | |
| 416 | |
| 417 def openFailed(self, reason): | |
| 418 unittest.fail('fail exec open failed: %s' % reason) | |
| 419 | |
| 420 def channelOpen(self, ignore): | |
| 421 d = self.conn.sendRequest(self, 'exec', common.NS('jumboliah'), 1) | |
| 422 d.addCallback(self._cbRequestWorked) | |
| 423 d.addErrback(self._ebRequestWorked) | |
| 424 log.msg('opened fail exec') | |
| 425 | |
| 426 def _cbRequestWorked(self, ignored): | |
| 427 unittest.fail('fail exec succeeded') | |
| 428 | |
| 429 def _ebRequestWorked(self, ignored): | |
| 430 log.msg('fail exec finished') | |
| 431 log.flushErrors() | |
| 432 self.conn.addResult() | |
| 433 self.loseConnection() | |
| 434 | |
| 435 class SSHTestFalseChannel(channel.SSHChannel): | |
| 436 | |
| 437 name = 'session' | |
| 438 | |
| 439 def openFailed(self, reason): | |
| 440 unittest.fail('false open failed: %s' % reason) | |
| 441 | |
| 442 def channelOpen(self, ignored): | |
| 443 d = self.conn.sendRequest(self, 'exec', common.NS('false'), 1) | |
| 444 d.addCallback(self._cbRequestWorked) | |
| 445 d.addErrback(self._ebRequestFailed) | |
| 446 log.msg('opened false') | |
| 447 | |
| 448 def _cbRequestWorked(self, ignored): | |
| 449 pass | |
| 450 | |
| 451 def _ebRequestFailed(self, reason): | |
| 452 unittest.fail('false exec failed: %s' % reason) | |
| 453 | |
| 454 def dataReceived(self, data): | |
| 455 unittest.fail('got data when using false') | |
| 456 | |
| 457 def request_exit_status(self, status): | |
| 458 status, = struct.unpack('>L', status) | |
| 459 if status == 0: | |
| 460 unittest.fail('false exit status was 0') | |
| 461 log.msg('finished false') | |
| 462 self.conn.addResult() | |
| 463 return 1 | |
| 464 | |
| 465 class SSHTestEchoChannel(channel.SSHChannel): | |
| 466 | |
| 467 name = 'session' | |
| 468 testBuf = '' | |
| 469 eofCalled = 0 | |
| 470 | |
| 471 def openFailed(self, reason): | |
| 472 unittest.fail('echo open failed: %s' % reason) | |
| 473 | |
| 474 def channelOpen(self, ignore): | |
| 475 d = self.conn.sendRequest(self, 'exec', common.NS('echo hello'), 1) | |
| 476 d.addErrback(self._ebRequestFailed) | |
| 477 log.msg('opened echo') | |
| 478 | |
| 479 def _ebRequestFailed(self, reason): | |
| 480 unittest.fail('echo exec failed: %s' % reason) | |
| 481 | |
| 482 def dataReceived(self, data): | |
| 483 self.testBuf += data | |
| 484 | |
| 485 def errReceived(self, dataType, data): | |
| 486 unittest.fail('echo channel got extended data') | |
| 487 | |
| 488 def request_exit_status(self, status): | |
| 489 self.status ,= struct.unpack('>L', status) | |
| 490 | |
| 491 def eofReceived(self): | |
| 492 log.msg('eof received') | |
| 493 self.eofCalled = 1 | |
| 494 | |
| 495 def closed(self): | |
| 496 if self.status != 0: | |
| 497 unittest.fail('echo exit status was not 0: %i' % self.status) | |
| 498 if self.testBuf != "hello\r\n": | |
| 499 unittest.fail('echo did not return hello: %s' % repr(self.testBu
f)) | |
| 500 unittest.assertEquals(self.localWindowLeft, 4) | |
| 501 unittest.assert_(self.eofCalled) | |
| 502 log.msg('finished echo') | |
| 503 self.conn.addResult() | |
| 504 return 1 | |
| 505 | |
| 506 class SSHTestErrChannel(channel.SSHChannel): | |
| 507 | |
| 508 name = 'session' | |
| 509 testBuf = '' | |
| 510 eofCalled = 0 | |
| 511 | |
| 512 def openFailed(self, reason): | |
| 513 unittest.fail('err open failed: %s' % reason) | |
| 514 | |
| 515 def channelOpen(self, ignore): | |
| 516 d = self.conn.sendRequest(self, 'exec', common.NS('eecho hello'), 1) | |
| 517 d.addErrback(self._ebRequestFailed) | |
| 518 log.msg('opened err') | |
| 519 | |
| 520 def _ebRequestFailed(self, reason): | |
| 521 unittest.fail('err exec failed: %s' % reason) | |
| 522 | |
| 523 def dataReceived(self, data): | |
| 524 unittest.fail('err channel got regular data: %s' % repr(data)) | |
| 525 | |
| 526 def extReceived(self, dataType, data): | |
| 527 unittest.assertEquals(dataType, connection.EXTENDED_DATA_STDERR) | |
| 528 self.testBuf += data | |
| 529 | |
| 530 def request_exit_status(self, status): | |
| 531 self.status ,= struct.unpack('>L', status) | |
| 532 | |
| 533 def eofReceived(self): | |
| 534 log.msg('eof received') | |
| 535 self.eofCalled = 1 | |
| 536 | |
| 537 def closed(self): | |
| 538 if self.status != 0: | |
| 539 unittest.fail('err exit status was not 0: %i' % self.status) | |
| 540 if self.testBuf != "hello\r\n": | |
| 541 unittest.fail('err did not return hello: %s' % repr(self.testBuf
)) | |
| 542 unittest.assertEquals(self.localWindowLeft, 4) | |
| 543 unittest.assert_(self.eofCalled) | |
| 544 log.msg('finished err') | |
| 545 self.conn.addResult() | |
| 546 return 1 | |
| 547 | |
| 548 class SSHTestMaxPacketChannel(channel.SSHChannel): | |
| 549 | |
| 550 name = 'session' | |
| 551 testBuf = '' | |
| 552 testExtBuf = '' | |
| 553 eofCalled = 0 | |
| 554 | |
| 555 def openFailed(self, reason): | |
| 556 unittest.fail('max packet open failed: %s' % reason) | |
| 557 | |
| 558 def channelOpen(self, ignore): | |
| 559 d = self.conn.sendRequest(self, 'exec', common.NS('secho hello'), 1) | |
| 560 d.addErrback(self._ebRequestFailed) | |
| 561 log.msg('opened max packet') | |
| 562 | |
| 563 def _ebRequestFailed(self, reason): | |
| 564 unittest.fail('max packet exec failed: %s' % reason) | |
| 565 | |
| 566 def dataReceived(self, data): | |
| 567 self.testBuf += data | |
| 568 | |
| 569 def extReceived(self, dataType, data): | |
| 570 unittest.assertEquals(dataType, connection.EXTENDED_DATA_STDERR) | |
| 571 self.testExtBuf += data | |
| 572 | |
| 573 def request_exit_status(self, status): | |
| 574 self.status ,= struct.unpack('>L', status) | |
| 575 | |
| 576 def eofReceived(self): | |
| 577 log.msg('eof received') | |
| 578 self.eofCalled = 1 | |
| 579 | |
| 580 def closed(self): | |
| 581 if self.status != 0: | |
| 582 unittest.fail('echo exit status was not 0: %i' % self.status) | |
| 583 unittest.assertEquals(self.testBuf, 'hello\r\n') | |
| 584 unittest.assertEquals(self.testExtBuf, 'hello\r\n') | |
| 585 unittest.assertEquals(self.localWindowLeft, 12) | |
| 586 unittest.assert_(self.eofCalled) | |
| 587 log.msg('finished max packet') | |
| 588 self.conn.addResult() | |
| 589 return 1 | |
| 590 | |
| 591 class SSHTestShellChannel(channel.SSHChannel): | |
| 592 | |
| 593 name = 'session' | |
| 594 testBuf = '' | |
| 595 eofCalled = 0 | |
| 596 closeCalled = 0 | |
| 597 | |
| 598 def openFailed(self, reason): | |
| 599 unittest.fail('shell open failed: %s' % reason) | |
| 600 | |
| 601 def channelOpen(self, ignored): | |
| 602 data = session.packRequest_pty_req('conch-test-term', (24, 80, 0, 0)
, '') | |
| 603 d = self.conn.sendRequest(self, 'pty-req', data, 1) | |
| 604 d.addCallback(self._cbPtyReq) | |
| 605 d.addErrback(self._ebPtyReq) | |
| 606 log.msg('opened shell') | |
| 607 | |
| 608 def _cbPtyReq(self, ignored): | |
| 609 d = self.conn.sendRequest(self, 'shell', '', 1) | |
| 610 d.addCallback(self._cbShellOpen) | |
| 611 d.addErrback(self._ebShellOpen) | |
| 612 | |
| 613 def _ebPtyReq(self, reason): | |
| 614 unittest.fail('pty request failed: %s' % reason) | |
| 615 | |
| 616 def _cbShellOpen(self, ignored): | |
| 617 self.write('testing the shell!\x00') | |
| 618 self.conn.sendEOF(self) | |
| 619 | |
| 620 def _ebShellOpen(self, reason): | |
| 621 unittest.fail('shell request failed: %s' % reason) | |
| 622 | |
| 623 def dataReceived(self, data): | |
| 624 self.testBuf += data | |
| 625 | |
| 626 def request_exit_status(self, status): | |
| 627 self.status ,= struct.unpack('>L', status) | |
| 628 | |
| 629 def eofReceived(self): | |
| 630 self.eofCalled = 1 | |
| 631 | |
| 632 def closed(self): | |
| 633 log.msg('calling shell closed') | |
| 634 if self.status != 0: | |
| 635 log.msg('shell exit status was not 0: %i' % self.status) | |
| 636 unittest.assertEquals(self.testBuf, 'testing the shell!\x00\r\n') | |
| 637 unittest.assert_(self.eofCalled) | |
| 638 log.msg('finished shell') | |
| 639 self.conn.addResult() | |
| 640 | |
| 641 class SSHTestSubsystemChannel(channel.SSHChannel): | |
| 642 | |
| 643 name = 'session' | |
| 644 | |
| 645 def openFailed(self, reason): | |
| 646 unittest.fail('subsystem open failed: %s' % reason) | |
| 647 | |
| 648 def channelOpen(self, ignore): | |
| 649 d = self.conn.sendRequest(self, 'subsystem', common.NS('not-crazy'),
1) | |
| 650 d.addCallback(self._cbRequestWorked) | |
| 651 d.addErrback(self._ebRequestFailed) | |
| 652 | |
| 653 | |
| 654 def _cbRequestWorked(self, ignored): | |
| 655 unittest.fail('opened non-crazy subsystem') | |
| 656 | |
| 657 def _ebRequestFailed(self, ignored): | |
| 658 d = self.conn.sendRequest(self, 'subsystem', common.NS('crazy'), 1) | |
| 659 d.addCallback(self._cbRealRequestWorked) | |
| 660 d.addErrback(self._ebRealRequestFailed) | |
| 661 | |
| 662 def _cbRealRequestWorked(self, ignored): | |
| 663 d1 = self.conn.sendGlobalRequest('foo', 'bar', 1) | |
| 664 d1.addErrback(self._ebFirstGlobal) | |
| 665 | |
| 666 d2 = self.conn.sendGlobalRequest('foo-2', 'bar2', 1) | |
| 667 d2.addCallback(lambda x: unittest.assertEquals(x, 'data')) | |
| 668 d2.addErrback(self._ebSecondGlobal) | |
| 669 | |
| 670 d3 = self.conn.sendGlobalRequest('bar', 'foo', 1) | |
| 671 d3.addCallback(self._cbThirdGlobal) | |
| 672 d3.addErrback(lambda x,s=self: log.msg('subsystem finished') or s.co
nn.addResult() or s.loseConnection()) | |
| 673 | |
| 674 def _ebRealRequestFailed(self, reason): | |
| 675 unittest.fail('opening crazy subsystem failed: %s' % reason) | |
| 676 | |
| 677 def _ebFirstGlobal(self, reason): | |
| 678 unittest.fail('first global request failed: %s' % reason) | |
| 679 | |
| 680 def _ebSecondGlobal(self, reason): | |
| 681 unittest.fail('second global request failed: %s' % reason) | |
| 682 | |
| 683 def _cbThirdGlobal(self, ignored): | |
| 684 unittest.fail('second global request succeeded') | |
| 685 | |
| 686 | |
| 687 | |
| 688 class SSHProtocolTestCase(unittest.TestCase): | |
| 689 | |
| 690 if not Crypto: | |
| 691 skip = "can't run w/o PyCrypto" | |
| 692 | |
| 693 def testOurServerOurClient(self): | |
| 694 """test the Conch server against the Conch client | |
| 695 """ | |
| 696 realm = ConchTestRealm() | |
| 697 p = portal.Portal(realm) | |
| 698 sshpc = ConchTestSSHChecker() | |
| 699 sshpc.registerChecker(ConchTestPasswordChecker()) | |
| 700 sshpc.registerChecker(ConchTestPublicKeyChecker()) | |
| 701 p.registerChecker(sshpc) | |
| 702 fac = ConchTestServerFactory() | |
| 703 fac.portal = p | |
| 704 fac.startFactory() | |
| 705 self.server = fac.buildProtocol(None) | |
| 706 self.clientTransport = LoopbackRelay(self.server) | |
| 707 self.client = ConchTestClient() | |
| 708 self.serverTransport = LoopbackRelay(self.client) | |
| 709 | |
| 710 self.server.makeConnection(self.serverTransport) | |
| 711 self.client.makeConnection(self.clientTransport) | |
| 712 | |
| 713 while self.serverTransport.buffer or self.clientTransport.buffer: | |
| 714 log.callWithContext({'system': 'serverTransport'}, | |
| 715 self.serverTransport.clearBuffer) | |
| 716 log.callWithContext({'system': 'clientTransport'}, | |
| 717 self.clientTransport.clearBuffer) | |
| 718 self.failIf(self.server.done and self.client.done) | |
| 719 | |
| 720 | |
| 721 class TestSSHFactory(unittest.TestCase): | |
| 722 | |
| 723 if not Crypto: | |
| 724 skip = "can't run w/o PyCrypto" | |
| 725 | |
| 726 def testMultipleFactories(self): | |
| 727 f1 = factory.SSHFactory() | |
| 728 f2 = factory.SSHFactory() | |
| 729 gpk = lambda: {'ssh-rsa' : keys.Key(None)} | |
| 730 f1.getPrimes = lambda: None | |
| 731 f2.getPrimes = lambda: {1:(2,3)} | |
| 732 f1.getPublicKeys = f2.getPublicKeys = gpk | |
| 733 f1.getPrivateKeys = f2.getPrivateKeys = gpk | |
| 734 f1.startFactory() | |
| 735 f2.startFactory() | |
| 736 p1 = f1.buildProtocol(None) | |
| 737 p2 = f2.buildProtocol(None) | |
| 738 self.failIf('diffie-hellman-group-exchange-sha1' in p1.supportedKeyExcha
nges, | |
| 739 p1.supportedKeyExchanges) | |
| 740 self.failUnless('diffie-hellman-group-exchange-sha1' in p2.supportedKeyE
xchanges, | |
| 741 p2.supportedKeyExchanges) | |
| 742 | |
| 743 | |
| 744 class EntropyTestCase(unittest.TestCase): | |
| 745 """ | |
| 746 Tests for L{common.entropy}. | |
| 747 """ | |
| 748 | |
| 749 def test_deprecation(self): | |
| 750 """ | |
| 751 Test the deprecation of L{common.entropy.get_bytes}. | |
| 752 """ | |
| 753 def wrapper(): | |
| 754 return common.entropy.get_bytes(10) | |
| 755 self.assertWarns(DeprecationWarning, | |
| 756 "entropy.get_bytes is deprecated, please use " | |
| 757 "twisted.python.randbytes.secureRandom instead.", | |
| 758 __file__, wrapper) | |
| 759 | |
| 760 | |
| 761 | |
| 762 class MPTestCase(unittest.TestCase): | |
| 763 """ | |
| 764 Tests for L{common.getMP}. | |
| 765 | |
| 766 @cvar getMP: a method providing a MP parser. | |
| 767 @type getMP: C{callable} | |
| 768 """ | |
| 769 getMP = staticmethod(common.getMP) | |
| 770 | |
| 771 if not Crypto: | |
| 772 skip = "can't run w/o PyCrypto" | |
| 773 | |
| 774 | |
| 775 def test_getMP(self): | |
| 776 """ | |
| 777 L{common.getMP} should parse the a multiple precision integer from a | |
| 778 string: a 4-byte length followed by length bytes of the integer. | |
| 779 """ | |
| 780 self.assertEquals( | |
| 781 self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01'), | |
| 782 (1, '')) | |
| 783 | |
| 784 | |
| 785 def test_getMPBigInteger(self): | |
| 786 """ | |
| 787 L{common.getMP} should be able to parse a big enough integer | |
| 788 (that doesn't fit on one byte). | |
| 789 """ | |
| 790 self.assertEquals( | |
| 791 self.getMP('\x00\x00\x00\x04\x01\x02\x03\x04'), | |
| 792 (16909060, '')) | |
| 793 | |
| 794 | |
| 795 def test_multipleGetMP(self): | |
| 796 """ | |
| 797 L{common.getMP} has the ability to parse multiple integer in the same | |
| 798 string. | |
| 799 """ | |
| 800 self.assertEquals( | |
| 801 self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01' | |
| 802 '\x00\x00\x00\x04\x00\x00\x00\x02', 2), | |
| 803 (1, 2, '')) | |
| 804 | |
| 805 | |
| 806 def test_getMPRemainingData(self): | |
| 807 """ | |
| 808 When more data than needed is sent to L{common.getMP}, it should return | |
| 809 the remaining data. | |
| 810 """ | |
| 811 self.assertEquals( | |
| 812 self.getMP('\x00\x00\x00\x04\x00\x00\x00\x01foo'), | |
| 813 (1, 'foo')) | |
| 814 | |
| 815 | |
| 816 def test_notEnoughData(self): | |
| 817 """ | |
| 818 When the string passed to L{common.getMP} doesn't even make 5 bytes, | |
| 819 it should raise a L{struct.error}. | |
| 820 """ | |
| 821 self.assertRaises(struct.error, self.getMP, '\x02\x00') | |
| 822 | |
| 823 | |
| 824 | |
| 825 class PyMPTestCase(MPTestCase): | |
| 826 """ | |
| 827 Tests for the python implementation of L{common.getMP}. | |
| 828 """ | |
| 829 getMP = staticmethod(common.getMP_py) | |
| 830 | |
| 831 | |
| 832 | |
| 833 class GMPYMPTestCase(MPTestCase): | |
| 834 """ | |
| 835 Tests for the gmpy implementation of L{common.getMP}. | |
| 836 """ | |
| 837 getMP = staticmethod(common._fastgetMP) | |
| 838 | |
| 839 | |
| 840 | |
| 841 try: | |
| 842 import gmpy | |
| 843 except ImportError: | |
| 844 GMPYMPTestCase.skip = "gmpy not available" | |
| OLD | NEW |