| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2005 Divmod, Inc. | |
| 2 # Copyright (c) 2007 Twisted Matrix Laboratories. | |
| 3 # See LICENSE for details. | |
| 4 | |
| 5 from twisted.python import filepath | |
| 6 from twisted.python.failure import Failure | |
| 7 from twisted.protocols import amp | |
| 8 from twisted.test import iosim | |
| 9 from twisted.trial import unittest | |
| 10 from twisted.internet import protocol, defer, error, reactor, interfaces | |
| 11 | |
| 12 | |
| 13 class TestProto(protocol.Protocol): | |
| 14 def __init__(self, onConnLost, dataToSend): | |
| 15 self.onConnLost = onConnLost | |
| 16 self.dataToSend = dataToSend | |
| 17 | |
| 18 def connectionMade(self): | |
| 19 self.data = [] | |
| 20 self.transport.write(self.dataToSend) | |
| 21 | |
| 22 def dataReceived(self, bytes): | |
| 23 self.data.append(bytes) | |
| 24 # self.transport.loseConnection() | |
| 25 | |
| 26 def connectionLost(self, reason): | |
| 27 self.onConnLost.callback(self.data) | |
| 28 | |
| 29 | |
| 30 | |
| 31 class SimpleSymmetricProtocol(amp.AMP): | |
| 32 | |
| 33 def sendHello(self, text): | |
| 34 return self.callRemoteString( | |
| 35 "hello", | |
| 36 hello=text) | |
| 37 | |
| 38 def amp_HELLO(self, box): | |
| 39 return amp.Box(hello=box['hello']) | |
| 40 | |
| 41 def amp_HOWDOYOUDO(self, box): | |
| 42 return amp.QuitBox(howdoyoudo='world') | |
| 43 | |
| 44 | |
| 45 | |
| 46 class UnfriendlyGreeting(Exception): | |
| 47 """Greeting was insufficiently kind. | |
| 48 """ | |
| 49 | |
| 50 class DeathThreat(Exception): | |
| 51 """Greeting was insufficiently kind. | |
| 52 """ | |
| 53 | |
| 54 class UnknownProtocol(Exception): | |
| 55 """Asked to switch to the wrong protocol. | |
| 56 """ | |
| 57 | |
| 58 | |
| 59 class TransportPeer(amp.Argument): | |
| 60 # this serves as some informal documentation for how to get variables from | |
| 61 # the protocol or your environment and pass them to methods as arguments. | |
| 62 def retrieve(self, d, name, proto): | |
| 63 return '' | |
| 64 | |
| 65 def fromStringProto(self, notAString, proto): | |
| 66 return proto.transport.getPeer() | |
| 67 | |
| 68 def toBox(self, name, strings, objects, proto): | |
| 69 return | |
| 70 | |
| 71 | |
| 72 | |
| 73 class Hello(amp.Command): | |
| 74 | |
| 75 commandName = 'hello' | |
| 76 | |
| 77 arguments = [('hello', amp.String()), | |
| 78 ('optional', amp.Boolean(optional=True)), | |
| 79 ('print', amp.Unicode(optional=True)), | |
| 80 ('from', TransportPeer(optional=True)), | |
| 81 ('mixedCase', amp.String(optional=True)), | |
| 82 ('dash-arg', amp.String(optional=True)), | |
| 83 ('underscore_arg', amp.String(optional=True))] | |
| 84 | |
| 85 response = [('hello', amp.String()), | |
| 86 ('print', amp.Unicode(optional=True))] | |
| 87 | |
| 88 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} | |
| 89 | |
| 90 fatalErrors = {DeathThreat: 'DEAD'} | |
| 91 | |
| 92 class NoAnswerHello(Hello): | |
| 93 commandName = Hello.commandName | |
| 94 requiresAnswer = False | |
| 95 | |
| 96 class FutureHello(amp.Command): | |
| 97 commandName = 'hello' | |
| 98 | |
| 99 arguments = [('hello', amp.String()), | |
| 100 ('optional', amp.Boolean(optional=True)), | |
| 101 ('print', amp.Unicode(optional=True)), | |
| 102 ('from', TransportPeer(optional=True)), | |
| 103 ('bonus', amp.String(optional=True)), # addt'l arguments | |
| 104 # should generally be | |
| 105 # added at the end, and | |
| 106 # be optional... | |
| 107 ] | |
| 108 | |
| 109 response = [('hello', amp.String()), | |
| 110 ('print', amp.Unicode(optional=True))] | |
| 111 | |
| 112 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} | |
| 113 | |
| 114 class WTF(amp.Command): | |
| 115 """ | |
| 116 An example of an invalid command. | |
| 117 """ | |
| 118 | |
| 119 | |
| 120 class BrokenReturn(amp.Command): | |
| 121 """ An example of a perfectly good command, but the handler is going to retu
rn | |
| 122 None... | |
| 123 """ | |
| 124 | |
| 125 commandName = 'broken_return' | |
| 126 | |
| 127 class Goodbye(amp.Command): | |
| 128 # commandName left blank on purpose: this tests implicit command names. | |
| 129 response = [('goodbye', amp.String())] | |
| 130 responseType = amp.QuitBox | |
| 131 | |
| 132 class Howdoyoudo(amp.Command): | |
| 133 commandName = 'howdoyoudo' | |
| 134 # responseType = amp.QuitBox | |
| 135 | |
| 136 class WaitForever(amp.Command): | |
| 137 commandName = 'wait_forever' | |
| 138 | |
| 139 class GetList(amp.Command): | |
| 140 commandName = 'getlist' | |
| 141 arguments = [('length', amp.Integer())] | |
| 142 response = [('body', amp.AmpList([('x', amp.Integer())]))] | |
| 143 | |
| 144 class SecuredPing(amp.Command): | |
| 145 # XXX TODO: actually make this refuse to send over an insecure connection | |
| 146 response = [('pinged', amp.Boolean())] | |
| 147 | |
| 148 class TestSwitchProto(amp.ProtocolSwitchCommand): | |
| 149 commandName = 'Switch-Proto' | |
| 150 | |
| 151 arguments = [ | |
| 152 ('name', amp.String()), | |
| 153 ] | |
| 154 errors = {UnknownProtocol: 'UNKNOWN'} | |
| 155 | |
| 156 class SingleUseFactory(protocol.ClientFactory): | |
| 157 def __init__(self, proto): | |
| 158 self.proto = proto | |
| 159 self.proto.factory = self | |
| 160 | |
| 161 def buildProtocol(self, addr): | |
| 162 p, self.proto = self.proto, None | |
| 163 return p | |
| 164 | |
| 165 reasonFailed = None | |
| 166 | |
| 167 def clientConnectionFailed(self, connector, reason): | |
| 168 self.reasonFailed = reason | |
| 169 return | |
| 170 | |
| 171 THING_I_DONT_UNDERSTAND = 'gwebol nargo' | |
| 172 class ThingIDontUnderstandError(Exception): | |
| 173 pass | |
| 174 | |
| 175 class FactoryNotifier(amp.AMP): | |
| 176 factory = None | |
| 177 def connectionMade(self): | |
| 178 if self.factory is not None: | |
| 179 self.factory.theProto = self | |
| 180 if hasattr(self.factory, 'onMade'): | |
| 181 self.factory.onMade.callback(None) | |
| 182 | |
| 183 def emitpong(self): | |
| 184 from twisted.internet.interfaces import ISSLTransport | |
| 185 if not ISSLTransport.providedBy(self.transport): | |
| 186 raise DeathThreat("only send secure pings over secure channels") | |
| 187 return {'pinged': True} | |
| 188 SecuredPing.responder(emitpong) | |
| 189 | |
| 190 | |
| 191 class SimpleSymmetricCommandProtocol(FactoryNotifier): | |
| 192 maybeLater = None | |
| 193 def __init__(self, onConnLost=None): | |
| 194 amp.AMP.__init__(self) | |
| 195 self.onConnLost = onConnLost | |
| 196 | |
| 197 def sendHello(self, text): | |
| 198 return self.callRemote(Hello, hello=text) | |
| 199 | |
| 200 def sendUnicodeHello(self, text, translation): | |
| 201 return self.callRemote(Hello, hello=text, Print=translation) | |
| 202 | |
| 203 greeted = False | |
| 204 | |
| 205 def cmdHello(self, hello, From, optional=None, Print=None, | |
| 206 mixedCase=None, dash_arg=None, underscore_arg=None): | |
| 207 assert From == self.transport.getPeer() | |
| 208 if hello == THING_I_DONT_UNDERSTAND: | |
| 209 raise ThingIDontUnderstandError() | |
| 210 if hello.startswith('fuck'): | |
| 211 raise UnfriendlyGreeting("Don't be a dick.") | |
| 212 if hello == 'die': | |
| 213 raise DeathThreat("aieeeeeeeee") | |
| 214 result = dict(hello=hello) | |
| 215 if Print is not None: | |
| 216 result.update(dict(Print=Print)) | |
| 217 self.greeted = True | |
| 218 return result | |
| 219 Hello.responder(cmdHello) | |
| 220 | |
| 221 def cmdGetlist(self, length): | |
| 222 return {'body': [dict(x=1)] * length} | |
| 223 GetList.responder(cmdGetlist) | |
| 224 | |
| 225 def waitforit(self): | |
| 226 self.waiting = defer.Deferred() | |
| 227 return self.waiting | |
| 228 WaitForever.responder(waitforit) | |
| 229 | |
| 230 def howdo(self): | |
| 231 return dict(howdoyoudo='world') | |
| 232 Howdoyoudo.responder(howdo) | |
| 233 | |
| 234 def saybye(self): | |
| 235 return dict(goodbye="everyone") | |
| 236 Goodbye.responder(saybye) | |
| 237 | |
| 238 def switchToTestProtocol(self, fail=False): | |
| 239 if fail: | |
| 240 name = 'no-proto' | |
| 241 else: | |
| 242 name = 'test-proto' | |
| 243 p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA) | |
| 244 return self.callRemote( | |
| 245 TestSwitchProto, | |
| 246 SingleUseFactory(p), name=name).addCallback(lambda ign: p) | |
| 247 | |
| 248 def switchit(self, name): | |
| 249 if name == 'test-proto': | |
| 250 return TestProto(self.onConnLost, SWITCH_SERVER_DATA) | |
| 251 raise UnknownProtocol(name) | |
| 252 TestSwitchProto.responder(switchit) | |
| 253 | |
| 254 def donothing(self): | |
| 255 return None | |
| 256 BrokenReturn.responder(donothing) | |
| 257 | |
| 258 | |
| 259 class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol): | |
| 260 def switchit(self, name): | |
| 261 if name == 'test-proto': | |
| 262 self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA
) | |
| 263 self.maybeLater = defer.Deferred() | |
| 264 return self.maybeLater | |
| 265 raise UnknownProtocol(name) | |
| 266 TestSwitchProto.responder(switchit) | |
| 267 | |
| 268 class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): | |
| 269 def badResponder(self, hello, From, optional=None, Print=None, | |
| 270 mixedCase=None, dash_arg=None, underscore_arg=None): | |
| 271 """ | |
| 272 This responder does nothing and forgets to return a dictionary. | |
| 273 """ | |
| 274 NoAnswerHello.responder(badResponder) | |
| 275 | |
| 276 class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): | |
| 277 def goodNoAnswerResponder(self, hello, From, optional=None, Print=None, | |
| 278 mixedCase=None, dash_arg=None, underscore_arg=None
): | |
| 279 return dict(hello=hello+"-noanswer") | |
| 280 NoAnswerHello.responder(goodNoAnswerResponder) | |
| 281 | |
| 282 def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol, | |
| 283 ClientClass=SimpleSymmetricProtocol, | |
| 284 *a, **kw): | |
| 285 """Returns a 3-tuple: (client, server, pump) | |
| 286 """ | |
| 287 return iosim.connectedServerAndClient( | |
| 288 ServerClass, ClientClass, | |
| 289 *a, **kw) | |
| 290 | |
| 291 class TotallyDumbProtocol(protocol.Protocol): | |
| 292 buf = '' | |
| 293 def dataReceived(self, data): | |
| 294 self.buf += data | |
| 295 | |
| 296 class LiteralAmp(amp.AMP): | |
| 297 def __init__(self): | |
| 298 self.boxes = [] | |
| 299 | |
| 300 def ampBoxReceived(self, box): | |
| 301 self.boxes.append(box) | |
| 302 return | |
| 303 | |
| 304 class ParsingTest(unittest.TestCase): | |
| 305 | |
| 306 def test_booleanValues(self): | |
| 307 """ | |
| 308 Verify that the Boolean parser parses 'True' and 'False', but nothing | |
| 309 else. | |
| 310 """ | |
| 311 b = amp.Boolean() | |
| 312 self.assertEquals(b.fromString("True"), True) | |
| 313 self.assertEquals(b.fromString("False"), False) | |
| 314 self.assertRaises(TypeError, b.fromString, "ninja") | |
| 315 self.assertRaises(TypeError, b.fromString, "true") | |
| 316 self.assertRaises(TypeError, b.fromString, "TRUE") | |
| 317 self.assertEquals(b.toString(True), 'True') | |
| 318 self.assertEquals(b.toString(False), 'False') | |
| 319 | |
| 320 def test_pathValueRoundTrip(self): | |
| 321 """ | |
| 322 Verify the 'Path' argument can parse and emit a file path. | |
| 323 """ | |
| 324 fp = filepath.FilePath(self.mktemp()) | |
| 325 p = amp.Path() | |
| 326 s = p.toString(fp) | |
| 327 v = p.fromString(s) | |
| 328 self.assertNotIdentical(fp, v) # sanity check | |
| 329 self.assertEquals(fp, v) | |
| 330 | |
| 331 | |
| 332 def test_sillyEmptyThing(self): | |
| 333 """ | |
| 334 Test that empty boxes raise an error; they aren't supposed to be sent | |
| 335 on purpose. | |
| 336 """ | |
| 337 a = amp.AMP() | |
| 338 return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box()) | |
| 339 | |
| 340 | |
| 341 def test_ParsingRoundTrip(self): | |
| 342 """ | |
| 343 Verify that various kinds of data make it through the encode/parse | |
| 344 round-trip unharmed. | |
| 345 """ | |
| 346 c, s, p = connectedServerAndClient(ClientClass=LiteralAmp, | |
| 347 ServerClass=LiteralAmp) | |
| 348 | |
| 349 SIMPLE = ('simple', 'test') | |
| 350 CE = ('ceq', ': ') | |
| 351 CR = ('crtest', 'test\r') | |
| 352 LF = ('lftest', 'hello\n') | |
| 353 NEWLINE = ('newline', 'test\r\none\r\ntwo') | |
| 354 NEWLINE2 = ('newline2', 'test\r\none\r\n two') | |
| 355 BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline') | |
| 356 BODYTEST = ('body', 'blah\r\n\r\ntesttest') | |
| 357 | |
| 358 testData = [ | |
| 359 [SIMPLE], | |
| 360 [SIMPLE, BODYTEST], | |
| 361 [SIMPLE, CE], | |
| 362 [SIMPLE, CR], | |
| 363 [SIMPLE, CE, CR, LF], | |
| 364 [CE, CR, LF], | |
| 365 [SIMPLE, NEWLINE, CE, NEWLINE2], | |
| 366 [BODYTEST, SIMPLE, NEWLINE] | |
| 367 ] | |
| 368 | |
| 369 for test in testData: | |
| 370 jb = amp.Box() | |
| 371 jb.update(dict(test)) | |
| 372 jb._sendTo(c) | |
| 373 p.flush() | |
| 374 self.assertEquals(s.boxes[-1], jb) | |
| 375 | |
| 376 | |
| 377 | |
| 378 class FakeLocator(object): | |
| 379 """ | |
| 380 This is a fake implementation of the interface implied by | |
| 381 L{CommandLocator}. | |
| 382 """ | |
| 383 def __init__(self): | |
| 384 """ | |
| 385 Remember the given keyword arguments as a set of responders. | |
| 386 """ | |
| 387 self.commands = {} | |
| 388 | |
| 389 | |
| 390 def locateResponder(self, commandName): | |
| 391 """ | |
| 392 Look up and return a function passed as a keyword argument of the given | |
| 393 name to the constructor. | |
| 394 """ | |
| 395 return self.commands[commandName] | |
| 396 | |
| 397 | |
| 398 class FakeSender: | |
| 399 """ | |
| 400 This is a fake implementation of the 'box sender' interface implied by | |
| 401 L{AMP}. | |
| 402 """ | |
| 403 def __init__(self): | |
| 404 """ | |
| 405 Create a fake sender and initialize the list of received boxes and | |
| 406 unhandled errors. | |
| 407 """ | |
| 408 self.sentBoxes = [] | |
| 409 self.unhandledErrors = [] | |
| 410 self.expectedErrors = 0 | |
| 411 | |
| 412 | |
| 413 def expectError(self): | |
| 414 """ | |
| 415 Expect one error, so that the test doesn't fail. | |
| 416 """ | |
| 417 self.expectedErrors += 1 | |
| 418 | |
| 419 | |
| 420 def sendBox(self, box): | |
| 421 """ | |
| 422 Accept a box, but don't do anything. | |
| 423 """ | |
| 424 self.sentBoxes.append(box) | |
| 425 | |
| 426 | |
| 427 def unhandledError(self, failure): | |
| 428 """ | |
| 429 Deal with failures by instantly re-raising them for easier debugging. | |
| 430 """ | |
| 431 self.expectedErrors -= 1 | |
| 432 if self.expectedErrors < 0: | |
| 433 failure.raiseException() | |
| 434 else: | |
| 435 self.unhandledErrors.append(failure) | |
| 436 | |
| 437 | |
| 438 | |
| 439 class CommandDispatchTests(unittest.TestCase): | |
| 440 """ | |
| 441 The AMP CommandDispatcher class dispatches converts AMP boxes into commands | |
| 442 and responses using Command.responder decorator. | |
| 443 | |
| 444 Note: Originally, AMP's factoring was such that many tests for this | |
| 445 functionality are now implemented as full round-trip tests in L{AMPTest}. | |
| 446 Future tests should be written at this level instead, to ensure API | |
| 447 compatibility and to provide more granular, readable units of test | |
| 448 coverage. | |
| 449 """ | |
| 450 | |
| 451 def setUp(self): | |
| 452 """ | |
| 453 Create a dispatcher to use. | |
| 454 """ | |
| 455 self.locator = FakeLocator() | |
| 456 self.sender = FakeSender() | |
| 457 self.dispatcher = amp.BoxDispatcher(self.locator) | |
| 458 self.dispatcher.startReceivingBoxes(self.sender) | |
| 459 | |
| 460 | |
| 461 def test_receivedAsk(self): | |
| 462 """ | |
| 463 L{CommandDispatcher.ampBoxReceived} should locate the appropriate | |
| 464 command in its responder lookup, based on the '_ask' key. | |
| 465 """ | |
| 466 received = [] | |
| 467 def thunk(box): | |
| 468 received.append(box) | |
| 469 return amp.Box({"hello": "goodbye"}) | |
| 470 input = amp.Box(_command="hello", | |
| 471 _ask="test-command-id", | |
| 472 hello="world") | |
| 473 self.locator.commands['hello'] = thunk | |
| 474 self.dispatcher.ampBoxReceived(input) | |
| 475 self.assertEquals(received, [input]) | |
| 476 | |
| 477 | |
| 478 def test_sendUnhandledError(self): | |
| 479 """ | |
| 480 L{CommandDispatcher} should relay its unhandled errors in responding to | |
| 481 boxes to its boxSender. | |
| 482 """ | |
| 483 err = RuntimeError("something went wrong, oh no") | |
| 484 self.sender.expectError() | |
| 485 self.dispatcher.unhandledError(Failure(err)) | |
| 486 self.assertEqual(len(self.sender.unhandledErrors), 1) | |
| 487 self.assertEqual(self.sender.unhandledErrors[0].value, err) | |
| 488 | |
| 489 | |
| 490 def test_unhandledSerializationError(self): | |
| 491 """ | |
| 492 Errors during serialization ought to be relayed to the sender's | |
| 493 unhandledError method. | |
| 494 """ | |
| 495 err = RuntimeError("something undefined went wrong") | |
| 496 def thunk(result): | |
| 497 class BrokenBox(amp.Box): | |
| 498 def _sendTo(self, proto): | |
| 499 raise err | |
| 500 return BrokenBox() | |
| 501 self.locator.commands['hello'] = thunk | |
| 502 input = amp.Box(_command="hello", | |
| 503 _ask="test-command-id", | |
| 504 hello="world") | |
| 505 self.sender.expectError() | |
| 506 self.dispatcher.ampBoxReceived(input) | |
| 507 self.assertEquals(len(self.sender.unhandledErrors), 1) | |
| 508 self.assertEquals(self.sender.unhandledErrors[0].value, err) | |
| 509 | |
| 510 | |
| 511 def test_callRemote(self): | |
| 512 """ | |
| 513 L{CommandDispatcher.callRemote} should emit a properly formatted '_ask' | |
| 514 box to its boxSender and record an outstanding L{Deferred}. When a | |
| 515 corresponding '_answer' packet is received, the L{Deferred} should be | |
| 516 fired, and the results translated via the given L{Command}'s response | |
| 517 de-serialization. | |
| 518 """ | |
| 519 D = self.dispatcher.callRemote(Hello, hello='world') | |
| 520 self.assertEquals(self.sender.sentBoxes, | |
| 521 [amp.AmpBox(_command="hello", | |
| 522 _ask="1", | |
| 523 hello="world")]) | |
| 524 answers = [] | |
| 525 D.addCallback(answers.append) | |
| 526 self.assertEquals(answers, []) | |
| 527 self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay", | |
| 528 'print': "ignored", | |
| 529 '_answer': "1"})) | |
| 530 self.assertEquals(answers, [dict(hello="yay", | |
| 531 Print=u"ignored")]) | |
| 532 | |
| 533 | |
| 534 class SimpleGreeting(amp.Command): | |
| 535 """ | |
| 536 A very simple greeting command that uses a few basic argument types. | |
| 537 """ | |
| 538 commandName = 'simple' | |
| 539 arguments = [('greeting', amp.Unicode()), | |
| 540 ('cookie', amp.Integer())] | |
| 541 response = [('cookieplus', amp.Integer())] | |
| 542 | |
| 543 | |
| 544 class TestLocator(amp.CommandLocator): | |
| 545 """ | |
| 546 A locator which implements a responder to a 'hello' command. | |
| 547 """ | |
| 548 def __init__(self): | |
| 549 self.greetings = [] | |
| 550 | |
| 551 | |
| 552 def greetingResponder(self, greeting, cookie): | |
| 553 self.greetings.append((greeting, cookie)) | |
| 554 return dict(cookieplus=cookie + 3) | |
| 555 greetingResponder = SimpleGreeting.responder(greetingResponder) | |
| 556 | |
| 557 | |
| 558 | |
| 559 class OverrideLocatorAMP(amp.AMP): | |
| 560 def __init__(self): | |
| 561 amp.AMP.__init__(self) | |
| 562 self.customResponder = object() | |
| 563 self.expectations = {"custom": self.customResponder} | |
| 564 self.greetings = [] | |
| 565 | |
| 566 | |
| 567 def lookupFunction(self, name): | |
| 568 """ | |
| 569 Override the deprecated lookupFunction function. | |
| 570 """ | |
| 571 if name in self.expectations: | |
| 572 result = self.expectations[name] | |
| 573 return result | |
| 574 else: | |
| 575 return super(OverrideLocatorAMP, self).lookupFunction(name) | |
| 576 | |
| 577 | |
| 578 def greetingResponder(self, greeting, cookie): | |
| 579 self.greetings.append((greeting, cookie)) | |
| 580 return dict(cookieplus=cookie + 3) | |
| 581 greetingResponder = SimpleGreeting.responder(greetingResponder) | |
| 582 | |
| 583 | |
| 584 | |
| 585 | |
| 586 class CommandLocatorTests(unittest.TestCase): | |
| 587 """ | |
| 588 The CommandLocator should enable users to specify responders to commands as | |
| 589 functions that take structured objects, annotated with metadata. | |
| 590 """ | |
| 591 | |
| 592 def test_responderDecorator(self): | |
| 593 """ | |
| 594 A method on a L{CommandLocator} subclass decorated with a L{Command} | |
| 595 subclass's L{responder} decorator should be returned from | |
| 596 locateResponder, wrapped in logic to serialize and deserialize its | |
| 597 arguments. | |
| 598 """ | |
| 599 locator = TestLocator() | |
| 600 responderCallable = locator.locateResponder("simple") | |
| 601 result = responderCallable(amp.Box(greeting="ni hao", cookie="5")) | |
| 602 def done(values): | |
| 603 self.assertEquals(values, amp.AmpBox(cookieplus='8')) | |
| 604 return result.addCallback(done) | |
| 605 | |
| 606 | |
| 607 def test_lookupFunctionDeprecatedOverride(self): | |
| 608 """ | |
| 609 Subclasses which override locateResponder under its old name, | |
| 610 lookupFunction, should have the override invoked instead. (This tests | |
| 611 an AMP subclass, because in the version of the code that could invoke | |
| 612 this deprecated code path, there was no L{CommandLocator}.) | |
| 613 """ | |
| 614 locator = OverrideLocatorAMP() | |
| 615 customResponderObject = self.assertWarns( | |
| 616 PendingDeprecationWarning, | |
| 617 "Override locateResponder, not lookupFunction.", | |
| 618 __file__, lambda : locator.locateResponder("custom")) | |
| 619 self.assertEquals(locator.customResponder, customResponderObject) | |
| 620 # Make sure upcalling works too | |
| 621 normalResponderObject = self.assertWarns( | |
| 622 PendingDeprecationWarning, | |
| 623 "Override locateResponder, not lookupFunction.", | |
| 624 __file__, lambda : locator.locateResponder("simple")) | |
| 625 result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5")) | |
| 626 def done(values): | |
| 627 self.assertEquals(values, amp.AmpBox(cookieplus='8')) | |
| 628 return result.addCallback(done) | |
| 629 | |
| 630 | |
| 631 def test_lookupFunctionDeprecatedInvoke(self): | |
| 632 """ | |
| 633 Invoking locateResponder under its old name, lookupFunction, should | |
| 634 emit a deprecation warning, but do the same thing. | |
| 635 """ | |
| 636 locator = TestLocator() | |
| 637 responderCallable = self.assertWarns( | |
| 638 PendingDeprecationWarning, | |
| 639 "Call locateResponder, not lookupFunction.", __file__, | |
| 640 lambda : locator.lookupFunction("simple")) | |
| 641 result = responderCallable(amp.Box(greeting="ni hao", cookie="5")) | |
| 642 def done(values): | |
| 643 self.assertEquals(values, amp.AmpBox(cookieplus='8')) | |
| 644 return result.addCallback(done) | |
| 645 | |
| 646 | |
| 647 | |
| 648 SWITCH_CLIENT_DATA = 'Success!' | |
| 649 SWITCH_SERVER_DATA = 'No, really. Success.' | |
| 650 | |
| 651 | |
| 652 class BinaryProtocolTests(unittest.TestCase): | |
| 653 """ | |
| 654 Tests for L{amp.BinaryBoxProtocol}. | |
| 655 """ | |
| 656 | |
| 657 def setUp(self): | |
| 658 """ | |
| 659 Keep track of all boxes received by this test in its capacity as an | |
| 660 L{IBoxReceiver} implementor. | |
| 661 """ | |
| 662 self.boxes = [] | |
| 663 self.data = [] | |
| 664 | |
| 665 | |
| 666 def startReceivingBoxes(self, sender): | |
| 667 """ | |
| 668 Implement L{IBoxReceiver.startReceivingBoxes} to do nothing. | |
| 669 """ | |
| 670 | |
| 671 | |
| 672 def ampBoxReceived(self, box): | |
| 673 """ | |
| 674 A box was received by the protocol. | |
| 675 """ | |
| 676 self.boxes.append(box) | |
| 677 | |
| 678 stopReason = None | |
| 679 def stopReceivingBoxes(self, reason): | |
| 680 """ | |
| 681 Record the reason that we stopped receiving boxes. | |
| 682 """ | |
| 683 self.stopReason = reason | |
| 684 | |
| 685 | |
| 686 # fake ITransport | |
| 687 def getPeer(self): | |
| 688 return 'no peer' | |
| 689 | |
| 690 | |
| 691 def getHost(self): | |
| 692 return 'no host' | |
| 693 | |
| 694 | |
| 695 def write(self, data): | |
| 696 self.data.append(data) | |
| 697 | |
| 698 | |
| 699 def test_receiveBoxStateMachine(self): | |
| 700 """ | |
| 701 When a binary box protocol receives: | |
| 702 * a key | |
| 703 * a value | |
| 704 * an empty string | |
| 705 it should emit a box and send it to its boxReceiver. | |
| 706 """ | |
| 707 a = amp.BinaryBoxProtocol(self) | |
| 708 a.stringReceived("hello") | |
| 709 a.stringReceived("world") | |
| 710 a.stringReceived("") | |
| 711 self.assertEquals(self.boxes, [amp.AmpBox(hello="world")]) | |
| 712 | |
| 713 | |
| 714 def test_receiveBoxData(self): | |
| 715 """ | |
| 716 When a binary box protocol receives the serialized form of an AMP box, | |
| 717 it should emit a similar box to its boxReceiver. | |
| 718 """ | |
| 719 a = amp.BinaryBoxProtocol(self) | |
| 720 a.dataReceived(amp.Box({"testKey": "valueTest", | |
| 721 "anotherKey": "anotherValue"}).serialize()) | |
| 722 self.assertEquals(self.boxes, | |
| 723 [amp.Box({"testKey": "valueTest", | |
| 724 "anotherKey": "anotherValue"})]) | |
| 725 | |
| 726 | |
| 727 def test_sendBox(self): | |
| 728 """ | |
| 729 When a binary box protocol sends a box, it should emit the serialized | |
| 730 bytes of that box to its transport. | |
| 731 """ | |
| 732 a = amp.BinaryBoxProtocol(self) | |
| 733 a.makeConnection(self) | |
| 734 aBox = amp.Box({"testKey": "valueTest", | |
| 735 "someData": "hello"}) | |
| 736 a.makeConnection(self) | |
| 737 a.sendBox(aBox) | |
| 738 self.assertEquals(''.join(self.data), aBox.serialize()) | |
| 739 | |
| 740 | |
| 741 def test_connectionLostStopSendingBoxes(self): | |
| 742 """ | |
| 743 When a binary box protocol loses its connection, it should notify its | |
| 744 box receiver that it has stopped receiving boxes. | |
| 745 """ | |
| 746 a = amp.BinaryBoxProtocol(self) | |
| 747 a.makeConnection(self) | |
| 748 aBox = amp.Box({"sample": "data"}) | |
| 749 a.makeConnection(self) | |
| 750 connectionFailure = Failure(RuntimeError()) | |
| 751 a.connectionLost(connectionFailure) | |
| 752 self.assertIdentical(self.stopReason, connectionFailure) | |
| 753 | |
| 754 | |
| 755 def test_protocolSwitch(self): | |
| 756 """ | |
| 757 L{BinaryBoxProtocol} has the capacity to switch to a different protocol | |
| 758 on a box boundary. When a protocol is in the process of switching, it | |
| 759 cannot receive traffic. | |
| 760 """ | |
| 761 otherProto = TestProto(None, "outgoing data") | |
| 762 test = self | |
| 763 class SwitchyReceiver: | |
| 764 switched = False | |
| 765 def startReceivingBoxes(self, sender): | |
| 766 pass | |
| 767 def ampBoxReceived(self, box): | |
| 768 test.assertFalse(self.switched, | |
| 769 "Should only receive one box!") | |
| 770 self.switched = True | |
| 771 a._lockForSwitch() | |
| 772 a._switchTo(otherProto) | |
| 773 a = amp.BinaryBoxProtocol(SwitchyReceiver()) | |
| 774 anyOldBox = amp.Box({"include": "lots", | |
| 775 "of": "data"}) | |
| 776 a.makeConnection(self) | |
| 777 # Include a 0-length box at the beginning of the next protocol's data, | |
| 778 # to make sure that AMP doesn't eat the data or try to deliver extra | |
| 779 # boxes either... | |
| 780 moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!" | |
| 781 a.dataReceived(moreThanOneBox) | |
| 782 self.assertIdentical(otherProto.transport, self) | |
| 783 self.assertEquals("".join(otherProto.data), "\x00\x00Hello, world!") | |
| 784 self.assertEquals(self.data, ["outgoing data"]) | |
| 785 a.dataReceived("more data") | |
| 786 self.assertEquals("".join(otherProto.data), | |
| 787 "\x00\x00Hello, world!more data") | |
| 788 self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox) | |
| 789 | |
| 790 | |
| 791 def test_protocolSwitchInvalidStates(self): | |
| 792 """ | |
| 793 In order to make sure the protocol never gets any invalid data sent | |
| 794 into the middle of a box, it must be locked for switching before it is | |
| 795 switched. It can only be unlocked if the switch failed, and attempting | |
| 796 to send a box while it is locked should raise an exception. | |
| 797 """ | |
| 798 a = amp.BinaryBoxProtocol(self) | |
| 799 a.makeConnection(self) | |
| 800 sampleBox = amp.Box({"some": "data"}) | |
| 801 a._lockForSwitch() | |
| 802 self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox) | |
| 803 a._unlockFromSwitch() | |
| 804 a.sendBox(sampleBox) | |
| 805 self.assertEquals(''.join(self.data), sampleBox.serialize()) | |
| 806 a._lockForSwitch() | |
| 807 otherProto = TestProto(None, "outgoing data") | |
| 808 a._switchTo(otherProto) | |
| 809 self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch) | |
| 810 | |
| 811 | |
| 812 def test_protocolSwitchLoseConnection(self): | |
| 813 """ | |
| 814 When the protocol is switched, it should notify its nested protocol of | |
| 815 disconnection. | |
| 816 """ | |
| 817 class Loser(protocol.Protocol): | |
| 818 reason = None | |
| 819 def connectionLost(self, reason): | |
| 820 self.reason = reason | |
| 821 connectionLoser = Loser() | |
| 822 a = amp.BinaryBoxProtocol(self) | |
| 823 a.makeConnection(self) | |
| 824 a._lockForSwitch() | |
| 825 a._switchTo(connectionLoser) | |
| 826 connectionFailure = Failure(RuntimeError()) | |
| 827 a.connectionLost(connectionFailure) | |
| 828 self.assertEquals(connectionLoser.reason, connectionFailure) | |
| 829 | |
| 830 | |
| 831 def test_protocolSwitchLoseClientConnection(self): | |
| 832 """ | |
| 833 When the protocol is switched, it should notify its nested client | |
| 834 protocol factory of disconnection. | |
| 835 """ | |
| 836 class ClientLoser: | |
| 837 reason = None | |
| 838 def clientConnectionLost(self, connector, reason): | |
| 839 self.reason = reason | |
| 840 a = amp.BinaryBoxProtocol(self) | |
| 841 connectionLoser = protocol.Protocol() | |
| 842 clientLoser = ClientLoser() | |
| 843 a.makeConnection(self) | |
| 844 a._lockForSwitch() | |
| 845 a._switchTo(connectionLoser, clientLoser) | |
| 846 connectionFailure = Failure(RuntimeError()) | |
| 847 a.connectionLost(connectionFailure) | |
| 848 self.assertEquals(clientLoser.reason, connectionFailure) | |
| 849 | |
| 850 | |
| 851 | |
| 852 class AMPTest(unittest.TestCase): | |
| 853 | |
| 854 def test_interfaceDeclarations(self): | |
| 855 """ | |
| 856 The classes in the amp module ought to implement the interfaces that | |
| 857 are declared for their benefit. | |
| 858 """ | |
| 859 for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol
), | |
| 860 (amp.IBoxReceiver, amp.BoxDispatcher), | |
| 861 (amp.IResponderLocator, amp.CommandLoc
ator), | |
| 862 (amp.IResponderLocator, amp.SimpleStri
ngLocator), | |
| 863 (amp.IBoxSender, amp.AMP), | |
| 864 (amp.IBoxReceiver, amp.AMP), | |
| 865 (amp.IResponderLocator, amp.AMP)]: | |
| 866 self.failUnless(interface.implementedBy(implementation), | |
| 867 "%s does not implements(%s)" % (implementation, inte
rface)) | |
| 868 | |
| 869 | |
| 870 def test_helloWorld(self): | |
| 871 """ | |
| 872 Verify that a simple command can be sent and its response received with | |
| 873 the simple low-level string-based API. | |
| 874 """ | |
| 875 c, s, p = connectedServerAndClient() | |
| 876 L = [] | |
| 877 HELLO = 'world' | |
| 878 c.sendHello(HELLO).addCallback(L.append) | |
| 879 p.flush() | |
| 880 self.assertEquals(L[0]['hello'], HELLO) | |
| 881 | |
| 882 | |
| 883 def test_wireFormatRoundTrip(self): | |
| 884 """ | |
| 885 Verify that mixed-case, underscored and dashed arguments are mapped to | |
| 886 their python names properly. | |
| 887 """ | |
| 888 c, s, p = connectedServerAndClient() | |
| 889 L = [] | |
| 890 HELLO = 'world' | |
| 891 c.sendHello(HELLO).addCallback(L.append) | |
| 892 p.flush() | |
| 893 self.assertEquals(L[0]['hello'], HELLO) | |
| 894 | |
| 895 | |
| 896 def test_helloWorldUnicode(self): | |
| 897 """ | |
| 898 Verify that unicode arguments can be encoded and decoded. | |
| 899 """ | |
| 900 c, s, p = connectedServerAndClient( | |
| 901 ServerClass=SimpleSymmetricCommandProtocol, | |
| 902 ClientClass=SimpleSymmetricCommandProtocol) | |
| 903 L = [] | |
| 904 HELLO = 'world' | |
| 905 HELLO_UNICODE = 'wor\u1234ld' | |
| 906 c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append) | |
| 907 p.flush() | |
| 908 self.assertEquals(L[0]['hello'], HELLO) | |
| 909 self.assertEquals(L[0]['Print'], HELLO_UNICODE) | |
| 910 | |
| 911 | |
| 912 def test_unknownCommandLow(self): | |
| 913 """ | |
| 914 Verify that unknown commands using low-level APIs will be rejected with
an | |
| 915 error, but will NOT terminate the connection. | |
| 916 """ | |
| 917 c, s, p = connectedServerAndClient() | |
| 918 L = [] | |
| 919 def clearAndAdd(e): | |
| 920 """ | |
| 921 You can't propagate the error... | |
| 922 """ | |
| 923 e.trap(amp.UnhandledCommand) | |
| 924 return "OK" | |
| 925 c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append) | |
| 926 p.flush() | |
| 927 self.assertEquals(L.pop(), "OK") | |
| 928 HELLO = 'world' | |
| 929 c.sendHello(HELLO).addCallback(L.append) | |
| 930 p.flush() | |
| 931 self.assertEquals(L[0]['hello'], HELLO) | |
| 932 | |
| 933 | |
| 934 def test_unknownCommandHigh(self): | |
| 935 """ | |
| 936 Verify that unknown commands using high-level APIs will be rejected with
an | |
| 937 error, but will NOT terminate the connection. | |
| 938 """ | |
| 939 c, s, p = connectedServerAndClient() | |
| 940 L = [] | |
| 941 def clearAndAdd(e): | |
| 942 """ | |
| 943 You can't propagate the error... | |
| 944 """ | |
| 945 e.trap(amp.UnhandledCommand) | |
| 946 return "OK" | |
| 947 c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append) | |
| 948 p.flush() | |
| 949 self.assertEquals(L.pop(), "OK") | |
| 950 HELLO = 'world' | |
| 951 c.sendHello(HELLO).addCallback(L.append) | |
| 952 p.flush() | |
| 953 self.assertEquals(L[0]['hello'], HELLO) | |
| 954 | |
| 955 | |
| 956 def test_brokenReturnValue(self): | |
| 957 """ | |
| 958 It can be very confusing if you write some code which responds to a | |
| 959 command, but gets the return value wrong. Most commonly you end up | |
| 960 returning None instead of a dictionary. | |
| 961 | |
| 962 Verify that if that happens, the framework logs a useful error. | |
| 963 """ | |
| 964 L = [] | |
| 965 SimpleSymmetricCommandProtocol().dispatchCommand( | |
| 966 amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append) | |
| 967 blr = L[0].trap(amp.BadLocalReturn) | |
| 968 self.failUnlessIn('None', repr(L[0].value)) | |
| 969 | |
| 970 | |
| 971 | |
| 972 def test_unknownArgument(self): | |
| 973 """ | |
| 974 Verify that unknown arguments are ignored, and not passed to a Python | |
| 975 function which can't accept them. | |
| 976 """ | |
| 977 c, s, p = connectedServerAndClient( | |
| 978 ServerClass=SimpleSymmetricCommandProtocol, | |
| 979 ClientClass=SimpleSymmetricCommandProtocol) | |
| 980 L = [] | |
| 981 HELLO = 'world' | |
| 982 # c.sendHello(HELLO).addCallback(L.append) | |
| 983 c.callRemote(FutureHello, | |
| 984 hello=HELLO, | |
| 985 bonus="I'm not in the book!").addCallback( | |
| 986 L.append) | |
| 987 p.flush() | |
| 988 self.assertEquals(L[0]['hello'], HELLO) | |
| 989 | |
| 990 | |
| 991 def test_simpleReprs(self): | |
| 992 """ | |
| 993 Verify that the various Box objects repr properly, for debugging. | |
| 994 """ | |
| 995 self.assertEquals(type(repr(amp._TLSBox())), str) | |
| 996 self.assertEquals(type(repr(amp._SwitchBox('a'))), str) | |
| 997 self.assertEquals(type(repr(amp.QuitBox())), str) | |
| 998 self.assertEquals(type(repr(amp.AmpBox())), str) | |
| 999 self.failUnless("AmpBox" in repr(amp.AmpBox())) | |
| 1000 | |
| 1001 def test_keyTooLong(self): | |
| 1002 """ | |
| 1003 Verify that a key that is too long will immediately raise a synchronous | |
| 1004 exception. | |
| 1005 """ | |
| 1006 c, s, p = connectedServerAndClient() | |
| 1007 L = [] | |
| 1008 x = "H" * (0xff+1) | |
| 1009 tl = self.assertRaises(amp.TooLong, | |
| 1010 c.callRemoteString, "Hello", | |
| 1011 **{x: "hi"}) | |
| 1012 self.failUnless(tl.isKey) | |
| 1013 self.failUnless(tl.isLocal) | |
| 1014 self.failUnlessIdentical(tl.keyName, None) | |
| 1015 self.failUnlessIdentical(tl.value, x) | |
| 1016 self.failUnless(str(len(x)) in repr(tl)) | |
| 1017 self.failUnless("key" in repr(tl)) | |
| 1018 | |
| 1019 | |
| 1020 def test_valueTooLong(self): | |
| 1021 """ | |
| 1022 Verify that attempting to send value longer than 64k will immediately | |
| 1023 raise an exception. | |
| 1024 """ | |
| 1025 c, s, p = connectedServerAndClient() | |
| 1026 L = [] | |
| 1027 x = "H" * (0xffff+1) | |
| 1028 tl = self.assertRaises(amp.TooLong, c.sendHello, x) | |
| 1029 p.flush() | |
| 1030 self.failIf(tl.isKey) | |
| 1031 self.failUnless(tl.isLocal) | |
| 1032 self.failUnlessIdentical(tl.keyName, 'hello') | |
| 1033 self.failUnlessIdentical(tl.value, x) | |
| 1034 self.failUnless(str(len(x)) in repr(tl)) | |
| 1035 self.failUnless("value" in repr(tl)) | |
| 1036 self.failUnless('hello' in repr(tl)) | |
| 1037 | |
| 1038 | |
| 1039 def test_helloWorldCommand(self): | |
| 1040 """ | |
| 1041 Verify that a simple command can be sent and its response received with | |
| 1042 the high-level value parsing API. | |
| 1043 """ | |
| 1044 c, s, p = connectedServerAndClient( | |
| 1045 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1046 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1047 L = [] | |
| 1048 HELLO = 'world' | |
| 1049 c.sendHello(HELLO).addCallback(L.append) | |
| 1050 p.flush() | |
| 1051 self.assertEquals(L[0]['hello'], HELLO) | |
| 1052 | |
| 1053 | |
| 1054 def test_helloErrorHandling(self): | |
| 1055 """ | |
| 1056 Verify that if a known error type is raised and handled, it will be | |
| 1057 properly relayed to the other end of the connection and translated into | |
| 1058 an exception, and no error will be logged. | |
| 1059 """ | |
| 1060 L=[] | |
| 1061 c, s, p = connectedServerAndClient( | |
| 1062 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1063 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1064 HELLO = 'fuck you' | |
| 1065 c.sendHello(HELLO).addErrback(L.append) | |
| 1066 p.flush() | |
| 1067 L[0].trap(UnfriendlyGreeting) | |
| 1068 self.assertEquals(str(L[0].value), "Don't be a dick.") | |
| 1069 | |
| 1070 | |
| 1071 def test_helloFatalErrorHandling(self): | |
| 1072 """ | |
| 1073 Verify that if a known, fatal error type is raised and handled, it will | |
| 1074 be properly relayed to the other end of the connection and translated | |
| 1075 into an exception, no error will be logged, and the connection will be | |
| 1076 terminated. | |
| 1077 """ | |
| 1078 L=[] | |
| 1079 c, s, p = connectedServerAndClient( | |
| 1080 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1081 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1082 HELLO = 'die' | |
| 1083 c.sendHello(HELLO).addErrback(L.append) | |
| 1084 p.flush() | |
| 1085 L.pop().trap(DeathThreat) | |
| 1086 c.sendHello(HELLO).addErrback(L.append) | |
| 1087 p.flush() | |
| 1088 L.pop().trap(error.ConnectionDone) | |
| 1089 | |
| 1090 | |
| 1091 | |
| 1092 def test_helloNoErrorHandling(self): | |
| 1093 """ | |
| 1094 Verify that if an unknown error type is raised, it will be relayed to | |
| 1095 the other end of the connection and translated into an exception, it | |
| 1096 will be logged, and then the connection will be dropped. | |
| 1097 """ | |
| 1098 L=[] | |
| 1099 c, s, p = connectedServerAndClient( | |
| 1100 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1101 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1102 HELLO = THING_I_DONT_UNDERSTAND | |
| 1103 c.sendHello(HELLO).addErrback(L.append) | |
| 1104 p.flush() | |
| 1105 ure = L.pop() | |
| 1106 ure.trap(amp.UnknownRemoteError) | |
| 1107 c.sendHello(HELLO).addErrback(L.append) | |
| 1108 cl = L.pop() | |
| 1109 cl.trap(error.ConnectionDone) | |
| 1110 # The exception should have been logged. | |
| 1111 self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError)) | |
| 1112 | |
| 1113 | |
| 1114 | |
| 1115 def test_lateAnswer(self): | |
| 1116 """ | |
| 1117 Verify that a command that does not get answered until after the | |
| 1118 connection terminates will not cause any errors. | |
| 1119 """ | |
| 1120 c, s, p = connectedServerAndClient( | |
| 1121 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1122 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1123 L = [] | |
| 1124 HELLO = 'world' | |
| 1125 c.callRemote(WaitForever).addErrback(L.append) | |
| 1126 p.flush() | |
| 1127 self.assertEquals(L, []) | |
| 1128 s.transport.loseConnection() | |
| 1129 p.flush() | |
| 1130 L.pop().trap(error.ConnectionDone) | |
| 1131 # Just make sure that it doesn't error... | |
| 1132 s.waiting.callback({}) | |
| 1133 return s.waiting | |
| 1134 | |
| 1135 | |
| 1136 def test_requiresNoAnswer(self): | |
| 1137 """ | |
| 1138 Verify that a command that requires no answer is run. | |
| 1139 """ | |
| 1140 L=[] | |
| 1141 c, s, p = connectedServerAndClient( | |
| 1142 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1143 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1144 HELLO = 'world' | |
| 1145 c.callRemote(NoAnswerHello, hello=HELLO) | |
| 1146 p.flush() | |
| 1147 self.failUnless(s.greeted) | |
| 1148 | |
| 1149 | |
| 1150 def test_requiresNoAnswerFail(self): | |
| 1151 """ | |
| 1152 Verify that commands sent after a failed no-answer request do not comple
te. | |
| 1153 """ | |
| 1154 L=[] | |
| 1155 c, s, p = connectedServerAndClient( | |
| 1156 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1157 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1158 HELLO = 'fuck you' | |
| 1159 c.callRemote(NoAnswerHello, hello=HELLO) | |
| 1160 p.flush() | |
| 1161 # This should be logged locally. | |
| 1162 self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError)) | |
| 1163 HELLO = 'world' | |
| 1164 c.callRemote(Hello, hello=HELLO).addErrback(L.append) | |
| 1165 p.flush() | |
| 1166 L.pop().trap(error.ConnectionDone) | |
| 1167 self.failIf(s.greeted) | |
| 1168 | |
| 1169 | |
| 1170 def test_noAnswerResponderBadAnswer(self): | |
| 1171 """ | |
| 1172 Verify that responders of requiresAnswer=False commands have to return | |
| 1173 a dictionary anyway. | |
| 1174 | |
| 1175 (requiresAnswer is a hint from the _client_ - the server may be called | |
| 1176 upon to answer commands in any case, if the client wants to know when | |
| 1177 they complete.) | |
| 1178 """ | |
| 1179 c, s, p = connectedServerAndClient( | |
| 1180 ServerClass=BadNoAnswerCommandProtocol, | |
| 1181 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1182 c.callRemote(NoAnswerHello, hello="hello") | |
| 1183 p.flush() | |
| 1184 le = self.flushLoggedErrors(amp.BadLocalReturn) | |
| 1185 self.assertEquals(len(le), 1) | |
| 1186 | |
| 1187 | |
| 1188 def test_noAnswerResponderAskedForAnswer(self): | |
| 1189 """ | |
| 1190 Verify that responders with requiresAnswer=False will actually respond | |
| 1191 if the client sets requiresAnswer=True. In other words, verify that | |
| 1192 requiresAnswer is a hint honored only by the client. | |
| 1193 """ | |
| 1194 c, s, p = connectedServerAndClient( | |
| 1195 ServerClass=NoAnswerCommandProtocol, | |
| 1196 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1197 L = [] | |
| 1198 c.callRemote(Hello, hello="Hello!").addCallback(L.append) | |
| 1199 p.flush() | |
| 1200 self.assertEquals(len(L), 1) | |
| 1201 self.assertEquals(L, [dict(hello="Hello!-noanswer", | |
| 1202 Print=None)]) # Optional response argument | |
| 1203 | |
| 1204 | |
| 1205 def test_ampListCommand(self): | |
| 1206 """ | |
| 1207 Test encoding of an argument that uses the AmpList encoding. | |
| 1208 """ | |
| 1209 c, s, p = connectedServerAndClient( | |
| 1210 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1211 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1212 L = [] | |
| 1213 c.callRemote(GetList, length=10).addCallback(L.append) | |
| 1214 p.flush() | |
| 1215 values = L.pop().get('body') | |
| 1216 self.assertEquals(values, [{'x': 1}] * 10) | |
| 1217 | |
| 1218 | |
| 1219 def test_failEarlyOnArgSending(self): | |
| 1220 """ | |
| 1221 Verify that if we pass an invalid argument list (omitting an argument),
an | |
| 1222 exception will be raised. | |
| 1223 """ | |
| 1224 okayCommand = Hello(hello="What?") | |
| 1225 self.assertRaises(amp.InvalidSignature, Hello) | |
| 1226 | |
| 1227 | |
| 1228 def test_doubleProtocolSwitch(self): | |
| 1229 """ | |
| 1230 As a debugging aid, a protocol system should raise a | |
| 1231 L{ProtocolSwitched} exception when asked to switch a protocol that is | |
| 1232 already switched. | |
| 1233 """ | |
| 1234 serverDeferred = defer.Deferred() | |
| 1235 serverProto = SimpleSymmetricCommandProtocol(serverDeferred) | |
| 1236 clientDeferred = defer.Deferred() | |
| 1237 clientProto = SimpleSymmetricCommandProtocol(clientDeferred) | |
| 1238 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, | |
| 1239 ClientClass=lambda: clientProto) | |
| 1240 def switched(result): | |
| 1241 self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol) | |
| 1242 self.testSucceeded = True | |
| 1243 c.switchToTestProtocol().addCallback(switched) | |
| 1244 p.flush() | |
| 1245 self.failUnless(self.testSucceeded) | |
| 1246 | |
| 1247 | |
| 1248 def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol, | |
| 1249 spuriousTraffic=False, | |
| 1250 spuriousError=False): | |
| 1251 """ | |
| 1252 Verify that it is possible to switch to another protocol mid-connection
and | |
| 1253 send data to it successfully. | |
| 1254 """ | |
| 1255 self.testSucceeded = False | |
| 1256 | |
| 1257 serverDeferred = defer.Deferred() | |
| 1258 serverProto = switcher(serverDeferred) | |
| 1259 clientDeferred = defer.Deferred() | |
| 1260 clientProto = switcher(clientDeferred) | |
| 1261 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, | |
| 1262 ClientClass=lambda: clientProto) | |
| 1263 | |
| 1264 if spuriousTraffic: | |
| 1265 wfdr = [] # remote | |
| 1266 wfd = c.callRemote(WaitForever).addErrback(wfdr.append) | |
| 1267 switchDeferred = c.switchToTestProtocol() | |
| 1268 if spuriousTraffic: | |
| 1269 self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world') | |
| 1270 | |
| 1271 def cbConnsLost(((serverSuccess, serverData), | |
| 1272 (clientSuccess, clientData))): | |
| 1273 self.failUnless(serverSuccess) | |
| 1274 self.failUnless(clientSuccess) | |
| 1275 self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA) | |
| 1276 self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA) | |
| 1277 self.testSucceeded = True | |
| 1278 | |
| 1279 def cbSwitch(proto): | |
| 1280 return defer.DeferredList( | |
| 1281 [serverDeferred, clientDeferred]).addCallback(cbConnsLost) | |
| 1282 | |
| 1283 switchDeferred.addCallback(cbSwitch) | |
| 1284 p.flush() | |
| 1285 if serverProto.maybeLater is not None: | |
| 1286 serverProto.maybeLater.callback(serverProto.maybeLaterProto) | |
| 1287 p.flush() | |
| 1288 if spuriousTraffic: | |
| 1289 # switch is done here; do this here to make sure that if we're | |
| 1290 # going to corrupt the connection, we do it before it's closed. | |
| 1291 if spuriousError: | |
| 1292 s.waiting.errback(amp.RemoteAmpError( | |
| 1293 "SPURIOUS", | |
| 1294 "Here's some traffic in the form of an error.")) | |
| 1295 else: | |
| 1296 s.waiting.callback({}) | |
| 1297 p.flush() | |
| 1298 c.transport.loseConnection() # close it | |
| 1299 p.flush() | |
| 1300 self.failUnless(self.testSucceeded) | |
| 1301 | |
| 1302 | |
| 1303 def test_protocolSwitchDeferred(self): | |
| 1304 """ | |
| 1305 Verify that protocol-switching even works if the value returned from | |
| 1306 the command that does the switch is deferred. | |
| 1307 """ | |
| 1308 return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtoco
l) | |
| 1309 | |
| 1310 | |
| 1311 def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol): | |
| 1312 """ | |
| 1313 Verify that if we try to switch protocols and it fails, the connection | |
| 1314 stays up and we can go back to speaking AMP. | |
| 1315 """ | |
| 1316 self.testSucceeded = False | |
| 1317 | |
| 1318 serverDeferred = defer.Deferred() | |
| 1319 serverProto = switcher(serverDeferred) | |
| 1320 clientDeferred = defer.Deferred() | |
| 1321 clientProto = switcher(clientDeferred) | |
| 1322 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, | |
| 1323 ClientClass=lambda: clientProto) | |
| 1324 L = [] | |
| 1325 switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append) | |
| 1326 p.flush() | |
| 1327 L.pop().trap(UnknownProtocol) | |
| 1328 self.failIf(self.testSucceeded) | |
| 1329 # It's a known error, so let's send a "hello" on the same connection; | |
| 1330 # it should work. | |
| 1331 c.sendHello('world').addCallback(L.append) | |
| 1332 p.flush() | |
| 1333 self.assertEqual(L.pop()['hello'], 'world') | |
| 1334 | |
| 1335 | |
| 1336 def test_trafficAfterSwitch(self): | |
| 1337 """ | |
| 1338 Verify that attempts to send traffic after a switch will not corrupt | |
| 1339 the nested protocol. | |
| 1340 """ | |
| 1341 return self.test_protocolSwitch(spuriousTraffic=True) | |
| 1342 | |
| 1343 | |
| 1344 def test_errorAfterSwitch(self): | |
| 1345 """ | |
| 1346 Returning an error after a protocol switch should record the underlying | |
| 1347 error. | |
| 1348 """ | |
| 1349 return self.test_protocolSwitch(spuriousTraffic=True, | |
| 1350 spuriousError=True) | |
| 1351 | |
| 1352 | |
| 1353 def test_quitBoxQuits(self): | |
| 1354 """ | |
| 1355 Verify that commands with a responseType of QuitBox will in fact | |
| 1356 terminate the connection. | |
| 1357 """ | |
| 1358 c, s, p = connectedServerAndClient( | |
| 1359 ServerClass=SimpleSymmetricCommandProtocol, | |
| 1360 ClientClass=SimpleSymmetricCommandProtocol) | |
| 1361 | |
| 1362 L = [] | |
| 1363 HELLO = 'world' | |
| 1364 GOODBYE = 'everyone' | |
| 1365 c.sendHello(HELLO).addCallback(L.append) | |
| 1366 p.flush() | |
| 1367 self.assertEquals(L.pop()['hello'], HELLO) | |
| 1368 c.callRemote(Goodbye).addCallback(L.append) | |
| 1369 p.flush() | |
| 1370 self.assertEquals(L.pop()['goodbye'], GOODBYE) | |
| 1371 c.sendHello(HELLO).addErrback(L.append) | |
| 1372 L.pop().trap(error.ConnectionDone) | |
| 1373 | |
| 1374 | |
| 1375 def test_basicLiteralEmit(self): | |
| 1376 """ | |
| 1377 Verify that the command dictionaries for a callRemoteN look correct | |
| 1378 after being serialized and parsed. | |
| 1379 """ | |
| 1380 c, s, p = connectedServerAndClient() | |
| 1381 L = [] | |
| 1382 s.ampBoxReceived = L.append | |
| 1383 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test', | |
| 1384 dash_arg='x', underscore_arg='y') | |
| 1385 p.flush() | |
| 1386 self.assertEquals(len(L), 1) | |
| 1387 for k, v in [('_command', Hello.commandName), | |
| 1388 ('hello', 'hello test'), | |
| 1389 ('mixedCase', 'mixed case arg test'), | |
| 1390 ('dash-arg', 'x'), | |
| 1391 ('underscore_arg', 'y')]: | |
| 1392 self.assertEquals(L[-1].pop(k), v) | |
| 1393 L[-1].pop('_ask') | |
| 1394 self.assertEquals(L[-1], {}) | |
| 1395 | |
| 1396 | |
| 1397 def test_basicStructuredEmit(self): | |
| 1398 """ | |
| 1399 Verify that a call similar to basicLiteralEmit's is handled properly wit
h | |
| 1400 high-level quoting and passing to Python methods, and that argument | |
| 1401 names are correctly handled. | |
| 1402 """ | |
| 1403 L = [] | |
| 1404 class StructuredHello(amp.AMP): | |
| 1405 def h(self, *a, **k): | |
| 1406 L.append((a, k)) | |
| 1407 return dict(hello='aaa') | |
| 1408 Hello.responder(h) | |
| 1409 c, s, p = connectedServerAndClient(ServerClass=StructuredHello) | |
| 1410 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test', | |
| 1411 dash_arg='x', underscore_arg='y').addCallback(L.append) | |
| 1412 p.flush() | |
| 1413 self.assertEquals(len(L), 2) | |
| 1414 self.assertEquals(L[0], | |
| 1415 ((), dict( | |
| 1416 hello='hello test', | |
| 1417 mixedCase='mixed case arg test', | |
| 1418 dash_arg='x', | |
| 1419 underscore_arg='y', | |
| 1420 | |
| 1421 # XXX - should optional arguments just not be passed? | |
| 1422 # passing None seems a little odd, looking at the way it | |
| 1423 # turns out here... -glyph | |
| 1424 From=('file', 'file'), | |
| 1425 Print=None, | |
| 1426 optional=None, | |
| 1427 ))) | |
| 1428 self.assertEquals(L[1], dict(Print=None, hello='aaa')) | |
| 1429 | |
| 1430 class PretendRemoteCertificateAuthority: | |
| 1431 def checkIsPretendRemote(self): | |
| 1432 return True | |
| 1433 | |
| 1434 class IOSimCert: | |
| 1435 verifyCount = 0 | |
| 1436 | |
| 1437 def options(self, *ign): | |
| 1438 return self | |
| 1439 | |
| 1440 def iosimVerify(self, otherCert): | |
| 1441 """ | |
| 1442 This isn't a real certificate, and wouldn't work on a real socket, but | |
| 1443 iosim specifies a different API so that we don't have to do any crypto | |
| 1444 math to demonstrate that the right functions get called in the right | |
| 1445 places. | |
| 1446 """ | |
| 1447 assert otherCert is self | |
| 1448 self.verifyCount += 1 | |
| 1449 return True | |
| 1450 | |
| 1451 class OKCert(IOSimCert): | |
| 1452 def options(self, x): | |
| 1453 assert x.checkIsPretendRemote() | |
| 1454 return self | |
| 1455 | |
| 1456 class GrumpyCert(IOSimCert): | |
| 1457 def iosimVerify(self, otherCert): | |
| 1458 self.verifyCount += 1 | |
| 1459 return False | |
| 1460 | |
| 1461 class DroppyCert(IOSimCert): | |
| 1462 def __init__(self, toDrop): | |
| 1463 self.toDrop = toDrop | |
| 1464 | |
| 1465 def iosimVerify(self, otherCert): | |
| 1466 self.verifyCount += 1 | |
| 1467 self.toDrop.loseConnection() | |
| 1468 return True | |
| 1469 | |
| 1470 class SecurableProto(FactoryNotifier): | |
| 1471 | |
| 1472 factory = None | |
| 1473 | |
| 1474 def verifyFactory(self): | |
| 1475 return [PretendRemoteCertificateAuthority()] | |
| 1476 | |
| 1477 def getTLSVars(self): | |
| 1478 cert = self.certFactory() | |
| 1479 verify = self.verifyFactory() | |
| 1480 return dict( | |
| 1481 tls_localCertificate=cert, | |
| 1482 tls_verifyAuthorities=verify) | |
| 1483 amp.StartTLS.responder(getTLSVars) | |
| 1484 | |
| 1485 | |
| 1486 | |
| 1487 class TLSTest(unittest.TestCase): | |
| 1488 def test_startingTLS(self): | |
| 1489 """ | |
| 1490 Verify that starting TLS and succeeding at handshaking sends all the | |
| 1491 notifications to all the right places. | |
| 1492 """ | |
| 1493 cli, svr, p = connectedServerAndClient( | |
| 1494 ServerClass=SecurableProto, | |
| 1495 ClientClass=SecurableProto) | |
| 1496 | |
| 1497 okc = OKCert() | |
| 1498 svr.certFactory = lambda : okc | |
| 1499 | |
| 1500 cli.callRemote( | |
| 1501 amp.StartTLS, tls_localCertificate=okc, | |
| 1502 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) | |
| 1503 | |
| 1504 # let's buffer something to be delivered securely | |
| 1505 L = [] | |
| 1506 d = cli.callRemote(SecuredPing).addCallback(L.append) | |
| 1507 p.flush() | |
| 1508 # once for client once for server | |
| 1509 self.assertEquals(okc.verifyCount, 2) | |
| 1510 L = [] | |
| 1511 d = cli.callRemote(SecuredPing).addCallback(L.append) | |
| 1512 p.flush() | |
| 1513 self.assertEqual(L[0], {'pinged': True}) | |
| 1514 | |
| 1515 | |
| 1516 def test_startTooManyTimes(self): | |
| 1517 """ | |
| 1518 Verify that the protocol will complain if we attempt to renegotiate TLS, | |
| 1519 which we don't support. | |
| 1520 """ | |
| 1521 cli, svr, p = connectedServerAndClient( | |
| 1522 ServerClass=SecurableProto, | |
| 1523 ClientClass=SecurableProto) | |
| 1524 | |
| 1525 okc = OKCert() | |
| 1526 svr.certFactory = lambda : okc | |
| 1527 | |
| 1528 cli.callRemote(amp.StartTLS, | |
| 1529 tls_localCertificate=okc, | |
| 1530 tls_verifyAuthorities=[PretendRemoteCertificateAuthority(
)]) | |
| 1531 p.flush() | |
| 1532 cli.noPeerCertificate = True # this is totally fake | |
| 1533 self.assertRaises( | |
| 1534 amp.OnlyOneTLS, | |
| 1535 cli.callRemote, | |
| 1536 amp.StartTLS, | |
| 1537 tls_localCertificate=okc, | |
| 1538 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) | |
| 1539 | |
| 1540 | |
| 1541 def test_negotiationFailed(self): | |
| 1542 """ | |
| 1543 Verify that starting TLS and failing on both sides at handshaking sends | |
| 1544 notifications to all the right places and terminates the connection. | |
| 1545 """ | |
| 1546 | |
| 1547 badCert = GrumpyCert() | |
| 1548 | |
| 1549 cli, svr, p = connectedServerAndClient( | |
| 1550 ServerClass=SecurableProto, | |
| 1551 ClientClass=SecurableProto) | |
| 1552 svr.certFactory = lambda : badCert | |
| 1553 | |
| 1554 cli.callRemote(amp.StartTLS, | |
| 1555 tls_localCertificate=badCert) | |
| 1556 | |
| 1557 p.flush() | |
| 1558 # once for client once for server - but both fail | |
| 1559 self.assertEquals(badCert.verifyCount, 2) | |
| 1560 d = cli.callRemote(SecuredPing) | |
| 1561 p.flush() | |
| 1562 self.assertFailure(d, iosim.OpenSSLVerifyError) | |
| 1563 | |
| 1564 | |
| 1565 def test_negotiationFailedByClosing(self): | |
| 1566 """ | |
| 1567 Verify that starting TLS and failing by way of a lost connection | |
| 1568 notices that it is probably an SSL problem. | |
| 1569 """ | |
| 1570 | |
| 1571 cli, svr, p = connectedServerAndClient( | |
| 1572 ServerClass=SecurableProto, | |
| 1573 ClientClass=SecurableProto) | |
| 1574 droppyCert = DroppyCert(svr.transport) | |
| 1575 svr.certFactory = lambda : droppyCert | |
| 1576 | |
| 1577 secure = cli.callRemote(amp.StartTLS, | |
| 1578 tls_localCertificate=droppyCert) | |
| 1579 | |
| 1580 p.flush() | |
| 1581 | |
| 1582 self.assertEquals(droppyCert.verifyCount, 2) | |
| 1583 | |
| 1584 d = cli.callRemote(SecuredPing) | |
| 1585 p.flush() | |
| 1586 | |
| 1587 # it might be a good idea to move this exception somewhere more | |
| 1588 # reasonable. | |
| 1589 self.assertFailure(d, error.PeerVerifyError) | |
| 1590 | |
| 1591 | |
| 1592 | |
| 1593 class InheritedError(Exception): | |
| 1594 """ | |
| 1595 This error is used to check inheritance. | |
| 1596 """ | |
| 1597 | |
| 1598 | |
| 1599 | |
| 1600 class OtherInheritedError(Exception): | |
| 1601 """ | |
| 1602 This is a distinct error for checking inheritance. | |
| 1603 """ | |
| 1604 | |
| 1605 | |
| 1606 | |
| 1607 class BaseCommand(amp.Command): | |
| 1608 """ | |
| 1609 This provides a command that will be subclassed. | |
| 1610 """ | |
| 1611 errors = {InheritedError: 'INHERITED_ERROR'} | |
| 1612 | |
| 1613 | |
| 1614 | |
| 1615 class InheritedCommand(BaseCommand): | |
| 1616 """ | |
| 1617 This is a command which subclasses another command but does not override | |
| 1618 anything. | |
| 1619 """ | |
| 1620 | |
| 1621 | |
| 1622 | |
| 1623 class AddErrorsCommand(BaseCommand): | |
| 1624 """ | |
| 1625 This is a command which subclasses another command but adds errors to the | |
| 1626 list. | |
| 1627 """ | |
| 1628 arguments = [('other', amp.Boolean())] | |
| 1629 errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'} | |
| 1630 | |
| 1631 | |
| 1632 | |
| 1633 class NormalCommandProtocol(amp.AMP): | |
| 1634 """ | |
| 1635 This is a protocol which responds to L{BaseCommand}, and is used to test | |
| 1636 that inheritance does not interfere with the normal handling of errors. | |
| 1637 """ | |
| 1638 def resp(self): | |
| 1639 raise InheritedError() | |
| 1640 BaseCommand.responder(resp) | |
| 1641 | |
| 1642 | |
| 1643 | |
| 1644 class InheritedCommandProtocol(amp.AMP): | |
| 1645 """ | |
| 1646 This is a protocol which responds to L{InheritedCommand}, and is used to | |
| 1647 test that inherited commands inherit their bases' errors if they do not | |
| 1648 respond to any of their own. | |
| 1649 """ | |
| 1650 def resp(self): | |
| 1651 raise InheritedError() | |
| 1652 InheritedCommand.responder(resp) | |
| 1653 | |
| 1654 | |
| 1655 | |
| 1656 class AddedCommandProtocol(amp.AMP): | |
| 1657 """ | |
| 1658 This is a protocol which responds to L{AddErrorsCommand}, and is used to | |
| 1659 test that inherited commands can add their own new types of errors, but | |
| 1660 still respond in the same way to their parents types of errors. | |
| 1661 """ | |
| 1662 def resp(self, other): | |
| 1663 if other: | |
| 1664 raise OtherInheritedError() | |
| 1665 else: | |
| 1666 raise InheritedError() | |
| 1667 AddErrorsCommand.responder(resp) | |
| 1668 | |
| 1669 | |
| 1670 | |
| 1671 class CommandInheritanceTests(unittest.TestCase): | |
| 1672 """ | |
| 1673 These tests verify that commands inherit error conditions properly. | |
| 1674 """ | |
| 1675 | |
| 1676 def errorCheck(self, err, proto, cmd, **kw): | |
| 1677 """ | |
| 1678 Check that the appropriate kind of error is raised when a given command | |
| 1679 is sent to a given protocol. | |
| 1680 """ | |
| 1681 c, s, p = connectedServerAndClient(ServerClass=proto, | |
| 1682 ClientClass=proto) | |
| 1683 d = c.callRemote(cmd, **kw) | |
| 1684 d2 = self.failUnlessFailure(d, err) | |
| 1685 p.flush() | |
| 1686 return d2 | |
| 1687 | |
| 1688 | |
| 1689 def test_basicErrorPropagation(self): | |
| 1690 """ | |
| 1691 Verify that errors specified in a superclass are respected normally | |
| 1692 even if it has subclasses. | |
| 1693 """ | |
| 1694 return self.errorCheck( | |
| 1695 InheritedError, NormalCommandProtocol, BaseCommand) | |
| 1696 | |
| 1697 | |
| 1698 def test_inheritedErrorPropagation(self): | |
| 1699 """ | |
| 1700 Verify that errors specified in a superclass command are propagated to | |
| 1701 its subclasses. | |
| 1702 """ | |
| 1703 return self.errorCheck( | |
| 1704 InheritedError, InheritedCommandProtocol, InheritedCommand) | |
| 1705 | |
| 1706 | |
| 1707 def test_inheritedErrorAddition(self): | |
| 1708 """ | |
| 1709 Verify that new errors specified in a subclass of an existing command | |
| 1710 are honored even if the superclass defines some errors. | |
| 1711 """ | |
| 1712 return self.errorCheck( | |
| 1713 OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=T
rue) | |
| 1714 | |
| 1715 | |
| 1716 def test_additionWithOriginalError(self): | |
| 1717 """ | |
| 1718 Verify that errors specified in a command's superclass are respected | |
| 1719 even if that command defines new errors itself. | |
| 1720 """ | |
| 1721 return self.errorCheck( | |
| 1722 InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False) | |
| 1723 | |
| 1724 | |
| 1725 def _loseAndPass(err, proto): | |
| 1726 # be specific, pass on the error to the client. | |
| 1727 err.trap(error.ConnectionLost, error.ConnectionDone) | |
| 1728 del proto.connectionLost | |
| 1729 proto.connectionLost(err) | |
| 1730 | |
| 1731 | |
| 1732 class LiveFireBase: | |
| 1733 """ | |
| 1734 Utility for connected reactor-using tests. | |
| 1735 """ | |
| 1736 | |
| 1737 def setUp(self): | |
| 1738 """ | |
| 1739 Create an amp server and connect a client to it. | |
| 1740 """ | |
| 1741 from twisted.internet import reactor | |
| 1742 self.serverFactory = protocol.ServerFactory() | |
| 1743 self.serverFactory.protocol = self.serverProto | |
| 1744 self.clientFactory = protocol.ClientFactory() | |
| 1745 self.clientFactory.protocol = self.clientProto | |
| 1746 self.clientFactory.onMade = defer.Deferred() | |
| 1747 self.serverFactory.onMade = defer.Deferred() | |
| 1748 self.serverPort = reactor.listenTCP(0, self.serverFactory) | |
| 1749 self.addCleanup(self.serverPort.stopListening) | |
| 1750 self.clientConn = reactor.connectTCP( | |
| 1751 '127.0.0.1', self.serverPort.getHost().port, | |
| 1752 self.clientFactory) | |
| 1753 self.addCleanup(self.clientConn.disconnect) | |
| 1754 def getProtos(rlst): | |
| 1755 self.cli = self.clientFactory.theProto | |
| 1756 self.svr = self.serverFactory.theProto | |
| 1757 dl = defer.DeferredList([self.clientFactory.onMade, | |
| 1758 self.serverFactory.onMade]) | |
| 1759 return dl.addCallback(getProtos) | |
| 1760 | |
| 1761 def tearDown(self): | |
| 1762 """ | |
| 1763 Cleanup client and server connections, and check the error got at | |
| 1764 C{connectionLost}. | |
| 1765 """ | |
| 1766 L = [] | |
| 1767 for conn in self.cli, self.svr: | |
| 1768 if conn.transport is not None: | |
| 1769 # depend on amp's function connection-dropping behavior | |
| 1770 d = defer.Deferred().addErrback(_loseAndPass, conn) | |
| 1771 conn.connectionLost = d.errback | |
| 1772 conn.transport.loseConnection() | |
| 1773 L.append(d) | |
| 1774 return defer.gatherResults(L | |
| 1775 ).addErrback(lambda first: first.value.subFailure) | |
| 1776 | |
| 1777 | |
| 1778 def show(x): | |
| 1779 import sys | |
| 1780 sys.stdout.write(x+'\n') | |
| 1781 sys.stdout.flush() | |
| 1782 | |
| 1783 | |
| 1784 def tempSelfSigned(): | |
| 1785 from twisted.internet import ssl | |
| 1786 | |
| 1787 sharedDN = ssl.DN(CN='shared') | |
| 1788 key = ssl.KeyPair.generate() | |
| 1789 cr = key.certificateRequest(sharedDN) | |
| 1790 sscrd = key.signCertificateRequest( | |
| 1791 sharedDN, cr, lambda dn: True, 1234567) | |
| 1792 cert = key.newCertificate(sscrd) | |
| 1793 return cert | |
| 1794 | |
| 1795 tempcert = tempSelfSigned() | |
| 1796 | |
| 1797 | |
| 1798 class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase): | |
| 1799 clientProto = SecurableProto | |
| 1800 serverProto = SecurableProto | |
| 1801 def test_liveFireCustomTLS(self): | |
| 1802 """ | |
| 1803 Using real, live TLS, actually negotiate a connection. | |
| 1804 | |
| 1805 This also looks at the 'peerCertificate' attribute's correctness, since | |
| 1806 that's actually loaded using OpenSSL calls, but the main purpose is to | |
| 1807 make sure that we didn't miss anything obvious in iosim about TLS | |
| 1808 negotiations. | |
| 1809 """ | |
| 1810 | |
| 1811 cert = tempcert | |
| 1812 | |
| 1813 self.svr.verifyFactory = lambda : [cert] | |
| 1814 self.svr.certFactory = lambda : cert | |
| 1815 # only needed on the server, we specify the client below. | |
| 1816 | |
| 1817 def secured(rslt): | |
| 1818 x = cert.digest() | |
| 1819 def pinged(rslt2): | |
| 1820 # Interesting. OpenSSL won't even _tell_ us about the peer | |
| 1821 # cert until we negotiate. we should be able to do this in | |
| 1822 # 'secured' instead, but it looks like we can't. I think this | |
| 1823 # is a bug somewhere far deeper than here. | |
| 1824 self.failUnlessEqual(x, self.cli.hostCertificate.digest()) | |
| 1825 self.failUnlessEqual(x, self.cli.peerCertificate.digest()) | |
| 1826 self.failUnlessEqual(x, self.svr.hostCertificate.digest()) | |
| 1827 self.failUnlessEqual(x, self.svr.peerCertificate.digest()) | |
| 1828 return self.cli.callRemote(SecuredPing).addCallback(pinged) | |
| 1829 return self.cli.callRemote(amp.StartTLS, | |
| 1830 tls_localCertificate=cert, | |
| 1831 tls_verifyAuthorities=[cert]).addCallback(sec
ured) | |
| 1832 | |
| 1833 | |
| 1834 class SlightlySmartTLS(SimpleSymmetricCommandProtocol): | |
| 1835 """ | |
| 1836 Specific implementation of server side protocol with different | |
| 1837 management of TLS. | |
| 1838 """ | |
| 1839 def getTLSVars(self): | |
| 1840 """ | |
| 1841 @return: the global C{tempcert} certificate as local certificate. | |
| 1842 """ | |
| 1843 return dict(tls_localCertificate=tempcert) | |
| 1844 amp.StartTLS.responder(getTLSVars) | |
| 1845 | |
| 1846 | |
| 1847 class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase): | |
| 1848 | |
| 1849 clientProto = SimpleSymmetricCommandProtocol | |
| 1850 serverProto = SimpleSymmetricCommandProtocol | |
| 1851 | |
| 1852 def test_liveFireDefaultTLS(self): | |
| 1853 """ | |
| 1854 Verify that out of the box, we can start TLS to at least encrypt the | |
| 1855 connection, even if we don't have any certificates to use. | |
| 1856 """ | |
| 1857 def secured(result): | |
| 1858 return self.cli.callRemote(SecuredPing) | |
| 1859 return self.cli.callRemote(amp.StartTLS).addCallback(secured) | |
| 1860 | |
| 1861 | |
| 1862 class WithServerTLSVerification(LiveFireBase, unittest.TestCase): | |
| 1863 clientProto = SimpleSymmetricCommandProtocol | |
| 1864 serverProto = SlightlySmartTLS | |
| 1865 | |
| 1866 def test_anonymousVerifyingClient(self): | |
| 1867 """ | |
| 1868 Verify that anonymous clients can verify server certificates. | |
| 1869 """ | |
| 1870 def secured(result): | |
| 1871 return self.cli.callRemote(SecuredPing) | |
| 1872 return self.cli.callRemote(amp.StartTLS, | |
| 1873 tls_verifyAuthorities=[tempcert] | |
| 1874 ).addCallback(secured) | |
| 1875 | |
| 1876 | |
| 1877 | |
| 1878 class ProtocolIncludingArgument(amp.Argument): | |
| 1879 """ | |
| 1880 An L{amp.Argument} which encodes its parser and serializer | |
| 1881 arguments *including the protocol* into its parsed and serialized | |
| 1882 forms. | |
| 1883 """ | |
| 1884 | |
| 1885 def fromStringProto(self, string, protocol): | |
| 1886 """ | |
| 1887 Don't decode anything; just return all possible information. | |
| 1888 | |
| 1889 @return: A two-tuple of the input string and the protocol. | |
| 1890 """ | |
| 1891 return (string, protocol) | |
| 1892 | |
| 1893 def toStringProto(self, obj, protocol): | |
| 1894 """ | |
| 1895 Encode identifying information about L{object} and protocol | |
| 1896 into a string for later verification. | |
| 1897 | |
| 1898 @type obj: L{object} | |
| 1899 @type protocol: L{amp.AMP} | |
| 1900 """ | |
| 1901 return "%s:%s" % (id(obj), id(protocol)) | |
| 1902 | |
| 1903 | |
| 1904 | |
| 1905 class ProtocolIncludingCommand(amp.Command): | |
| 1906 """ | |
| 1907 A command that has argument and response schemas which use | |
| 1908 L{ProtocolIncludingArgument}. | |
| 1909 """ | |
| 1910 arguments = [('weird', ProtocolIncludingArgument())] | |
| 1911 response = [('weird', ProtocolIncludingArgument())] | |
| 1912 | |
| 1913 | |
| 1914 | |
| 1915 class MagicSchemaCommand(amp.Command): | |
| 1916 """ | |
| 1917 A command which overrides L{parseResponse}, L{parseArguments}, and | |
| 1918 L{makeResponse}. | |
| 1919 """ | |
| 1920 def parseResponse(self, strings, protocol): | |
| 1921 """ | |
| 1922 Don't do any parsing, just jam the input strings and protocol | |
| 1923 onto the C{protocol.parseResponseArguments} attribute as a | |
| 1924 two-tuple. Return the original strings. | |
| 1925 """ | |
| 1926 protocol.parseResponseArguments = (strings, protocol) | |
| 1927 return strings | |
| 1928 parseResponse = classmethod(parseResponse) | |
| 1929 | |
| 1930 | |
| 1931 def parseArguments(cls, strings, protocol): | |
| 1932 """ | |
| 1933 Don't do any parsing, just jam the input strings and protocol | |
| 1934 onto the C{protocol.parseArgumentsArguments} attribute as a | |
| 1935 two-tuple. Return the original strings. | |
| 1936 """ | |
| 1937 protocol.parseArgumentsArguments = (strings, protocol) | |
| 1938 return strings | |
| 1939 parseArguments = classmethod(parseArguments) | |
| 1940 | |
| 1941 | |
| 1942 def makeArguments(cls, objects, protocol): | |
| 1943 """ | |
| 1944 Don't do any serializing, just jam the input strings and protocol | |
| 1945 onto the C{protocol.makeArgumentsArguments} attribute as a | |
| 1946 two-tuple. Return the original strings. | |
| 1947 """ | |
| 1948 protocol.makeArgumentsArguments = (objects, protocol) | |
| 1949 return objects | |
| 1950 makeArguments = classmethod(makeArguments) | |
| 1951 | |
| 1952 | |
| 1953 | |
| 1954 class NoNetworkProtocol(amp.AMP): | |
| 1955 """ | |
| 1956 An L{amp.AMP} subclass which overrides private methods to avoid | |
| 1957 testing the network. It also provides a responder for | |
| 1958 L{MagicSchemaCommand} that does nothing, so that tests can test | |
| 1959 aspects of the interaction of L{amp.Command}s and L{amp.AMP}. | |
| 1960 | |
| 1961 @ivar parseArgumentsArguments: Arguments that have been passed to any | |
| 1962 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by | |
| 1963 this protocol. | |
| 1964 | |
| 1965 @ivar parseResponseArguments: Responses that have been returned from a | |
| 1966 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by | |
| 1967 this protocol. | |
| 1968 | |
| 1969 @ivar makeArgumentsArguments: Arguments that have been serialized by any | |
| 1970 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by | |
| 1971 this protocol. | |
| 1972 """ | |
| 1973 def _sendBoxCommand(self, commandName, strings, requiresAnswer): | |
| 1974 """ | |
| 1975 Return a Deferred which fires with the original strings. | |
| 1976 """ | |
| 1977 return defer.succeed(strings) | |
| 1978 | |
| 1979 MagicSchemaCommand.responder(lambda s, weird: {}) | |
| 1980 | |
| 1981 | |
| 1982 | |
| 1983 class MyBox(dict): | |
| 1984 """ | |
| 1985 A unique dict subclass. | |
| 1986 """ | |
| 1987 | |
| 1988 | |
| 1989 | |
| 1990 class ProtocolIncludingCommandWithDifferentCommandType( | |
| 1991 ProtocolIncludingCommand): | |
| 1992 """ | |
| 1993 A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox} | |
| 1994 """ | |
| 1995 commandType = MyBox | |
| 1996 | |
| 1997 | |
| 1998 | |
| 1999 class CommandTestCase(unittest.TestCase): | |
| 2000 """ | |
| 2001 Tests for L{amp.Command}. | |
| 2002 """ | |
| 2003 | |
| 2004 def test_parseResponse(self): | |
| 2005 """ | |
| 2006 There should be a class method of Command which accepts a | |
| 2007 mapping of argument names to serialized forms and returns a | |
| 2008 similar mapping whose values have been parsed via the | |
| 2009 Command's response schema. | |
| 2010 """ | |
| 2011 protocol = object() | |
| 2012 result = 'whatever' | |
| 2013 strings = {'weird': result} | |
| 2014 self.assertEqual( | |
| 2015 ProtocolIncludingCommand.parseResponse(strings, protocol), | |
| 2016 {'weird': (result, protocol)}) | |
| 2017 | |
| 2018 | |
| 2019 def test_callRemoteCallsParseResponse(self): | |
| 2020 """ | |
| 2021 Making a remote call on a L{amp.Command} subclass which | |
| 2022 overrides the C{parseResponse} method should call that | |
| 2023 C{parseResponse} method to get the response. | |
| 2024 """ | |
| 2025 client = NoNetworkProtocol() | |
| 2026 thingy = "weeoo" | |
| 2027 response = client.callRemote(MagicSchemaCommand, weird=thingy) | |
| 2028 def gotResponse(ign): | |
| 2029 self.assertEquals(client.parseResponseArguments, | |
| 2030 ({"weird": thingy}, client)) | |
| 2031 response.addCallback(gotResponse) | |
| 2032 return response | |
| 2033 | |
| 2034 | |
| 2035 def test_parseArguments(self): | |
| 2036 """ | |
| 2037 There should be a class method of L{amp.Command} which accepts | |
| 2038 a mapping of argument names to serialized forms and returns a | |
| 2039 similar mapping whose values have been parsed via the | |
| 2040 command's argument schema. | |
| 2041 """ | |
| 2042 protocol = object() | |
| 2043 result = 'whatever' | |
| 2044 strings = {'weird': result} | |
| 2045 self.assertEqual( | |
| 2046 ProtocolIncludingCommand.parseArguments(strings, protocol), | |
| 2047 {'weird': (result, protocol)}) | |
| 2048 | |
| 2049 | |
| 2050 def test_responderCallsParseArguments(self): | |
| 2051 """ | |
| 2052 Making a remote call on a L{amp.Command} subclass which | |
| 2053 overrides the C{parseArguments} method should call that | |
| 2054 C{parseArguments} method to get the arguments. | |
| 2055 """ | |
| 2056 protocol = NoNetworkProtocol() | |
| 2057 responder = protocol.locateResponder(MagicSchemaCommand.commandName) | |
| 2058 argument = object() | |
| 2059 response = responder(dict(weird=argument)) | |
| 2060 response.addCallback( | |
| 2061 lambda ign: self.assertEqual(protocol.parseArgumentsArguments, | |
| 2062 ({"weird": argument}, protocol))) | |
| 2063 return response | |
| 2064 | |
| 2065 | |
| 2066 def test_makeArguments(self): | |
| 2067 """ | |
| 2068 There should be a class method of L{amp.Command} which accepts | |
| 2069 a mapping of argument names to objects and returns a similar | |
| 2070 mapping whose values have been serialized via the command's | |
| 2071 argument schema. | |
| 2072 """ | |
| 2073 protocol = object() | |
| 2074 argument = object() | |
| 2075 objects = {'weird': argument} | |
| 2076 self.assertEqual( | |
| 2077 ProtocolIncludingCommand.makeArguments(objects, protocol), | |
| 2078 {'weird': "%d:%d" % (id(argument), id(protocol))}) | |
| 2079 | |
| 2080 | |
| 2081 def test_makeArgumentsUsesCommandType(self): | |
| 2082 """ | |
| 2083 L{amp.Command.makeArguments}'s return type should be the type | |
| 2084 of the result of L{amp.Command.commandType}. | |
| 2085 """ | |
| 2086 protocol = object() | |
| 2087 objects = {"weird": "whatever"} | |
| 2088 | |
| 2089 result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments( | |
| 2090 objects, protocol) | |
| 2091 self.assertIdentical(type(result), MyBox) | |
| 2092 | |
| 2093 | |
| 2094 def test_callRemoteCallsMakeArguments(self): | |
| 2095 """ | |
| 2096 Making a remote call on a L{amp.Command} subclass which | |
| 2097 overrides the C{makeArguments} method should call that | |
| 2098 C{makeArguments} method to get the response. | |
| 2099 """ | |
| 2100 client = NoNetworkProtocol() | |
| 2101 argument = object() | |
| 2102 response = client.callRemote(MagicSchemaCommand, weird=argument) | |
| 2103 def gotResponse(ign): | |
| 2104 self.assertEqual(client.makeArgumentsArguments, | |
| 2105 ({"weird": argument}, client)) | |
| 2106 response.addCallback(gotResponse) | |
| 2107 return response | |
| 2108 | |
| 2109 if not interfaces.IReactorSSL.providedBy(reactor): | |
| 2110 skipMsg = 'This test case requires SSL support in the reactor' | |
| 2111 TLSTest.skip = skipMsg | |
| 2112 LiveFireTLSTestCase.skip = skipMsg | |
| 2113 PlainVanillaLiveFire.skip = skipMsg | |
| 2114 WithServerTLSVerification.skip = skipMsg | |
| 2115 | |
| OLD | NEW |