Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(370)

Side by Side Diff: third_party/twisted_8_1/twisted/test/test_amp.py

Issue 12261012: Remove third_party/twisted_8_1 (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/tools/build
Patch Set: Created 7 years, 10 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
OLDNEW
(Empty)
1 # Copyright (c) 2005 Divmod, Inc.
2 # Copyright (c) 2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 from twisted.python import filepath
6 from twisted.python.failure import Failure
7 from twisted.protocols import amp
8 from twisted.test import iosim
9 from twisted.trial import unittest
10 from twisted.internet import protocol, defer, error, reactor, interfaces
11
12
13 class TestProto(protocol.Protocol):
14 def __init__(self, onConnLost, dataToSend):
15 self.onConnLost = onConnLost
16 self.dataToSend = dataToSend
17
18 def connectionMade(self):
19 self.data = []
20 self.transport.write(self.dataToSend)
21
22 def dataReceived(self, bytes):
23 self.data.append(bytes)
24 # self.transport.loseConnection()
25
26 def connectionLost(self, reason):
27 self.onConnLost.callback(self.data)
28
29
30
31 class SimpleSymmetricProtocol(amp.AMP):
32
33 def sendHello(self, text):
34 return self.callRemoteString(
35 "hello",
36 hello=text)
37
38 def amp_HELLO(self, box):
39 return amp.Box(hello=box['hello'])
40
41 def amp_HOWDOYOUDO(self, box):
42 return amp.QuitBox(howdoyoudo='world')
43
44
45
46 class UnfriendlyGreeting(Exception):
47 """Greeting was insufficiently kind.
48 """
49
50 class DeathThreat(Exception):
51 """Greeting was insufficiently kind.
52 """
53
54 class UnknownProtocol(Exception):
55 """Asked to switch to the wrong protocol.
56 """
57
58
59 class TransportPeer(amp.Argument):
60 # this serves as some informal documentation for how to get variables from
61 # the protocol or your environment and pass them to methods as arguments.
62 def retrieve(self, d, name, proto):
63 return ''
64
65 def fromStringProto(self, notAString, proto):
66 return proto.transport.getPeer()
67
68 def toBox(self, name, strings, objects, proto):
69 return
70
71
72
73 class Hello(amp.Command):
74
75 commandName = 'hello'
76
77 arguments = [('hello', amp.String()),
78 ('optional', amp.Boolean(optional=True)),
79 ('print', amp.Unicode(optional=True)),
80 ('from', TransportPeer(optional=True)),
81 ('mixedCase', amp.String(optional=True)),
82 ('dash-arg', amp.String(optional=True)),
83 ('underscore_arg', amp.String(optional=True))]
84
85 response = [('hello', amp.String()),
86 ('print', amp.Unicode(optional=True))]
87
88 errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
89
90 fatalErrors = {DeathThreat: 'DEAD'}
91
92 class NoAnswerHello(Hello):
93 commandName = Hello.commandName
94 requiresAnswer = False
95
96 class FutureHello(amp.Command):
97 commandName = 'hello'
98
99 arguments = [('hello', amp.String()),
100 ('optional', amp.Boolean(optional=True)),
101 ('print', amp.Unicode(optional=True)),
102 ('from', TransportPeer(optional=True)),
103 ('bonus', amp.String(optional=True)), # addt'l arguments
104 # should generally be
105 # added at the end, and
106 # be optional...
107 ]
108
109 response = [('hello', amp.String()),
110 ('print', amp.Unicode(optional=True))]
111
112 errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
113
114 class WTF(amp.Command):
115 """
116 An example of an invalid command.
117 """
118
119
120 class BrokenReturn(amp.Command):
121 """ An example of a perfectly good command, but the handler is going to retu rn
122 None...
123 """
124
125 commandName = 'broken_return'
126
127 class Goodbye(amp.Command):
128 # commandName left blank on purpose: this tests implicit command names.
129 response = [('goodbye', amp.String())]
130 responseType = amp.QuitBox
131
132 class Howdoyoudo(amp.Command):
133 commandName = 'howdoyoudo'
134 # responseType = amp.QuitBox
135
136 class WaitForever(amp.Command):
137 commandName = 'wait_forever'
138
139 class GetList(amp.Command):
140 commandName = 'getlist'
141 arguments = [('length', amp.Integer())]
142 response = [('body', amp.AmpList([('x', amp.Integer())]))]
143
144 class SecuredPing(amp.Command):
145 # XXX TODO: actually make this refuse to send over an insecure connection
146 response = [('pinged', amp.Boolean())]
147
148 class TestSwitchProto(amp.ProtocolSwitchCommand):
149 commandName = 'Switch-Proto'
150
151 arguments = [
152 ('name', amp.String()),
153 ]
154 errors = {UnknownProtocol: 'UNKNOWN'}
155
156 class SingleUseFactory(protocol.ClientFactory):
157 def __init__(self, proto):
158 self.proto = proto
159 self.proto.factory = self
160
161 def buildProtocol(self, addr):
162 p, self.proto = self.proto, None
163 return p
164
165 reasonFailed = None
166
167 def clientConnectionFailed(self, connector, reason):
168 self.reasonFailed = reason
169 return
170
171 THING_I_DONT_UNDERSTAND = 'gwebol nargo'
172 class ThingIDontUnderstandError(Exception):
173 pass
174
175 class FactoryNotifier(amp.AMP):
176 factory = None
177 def connectionMade(self):
178 if self.factory is not None:
179 self.factory.theProto = self
180 if hasattr(self.factory, 'onMade'):
181 self.factory.onMade.callback(None)
182
183 def emitpong(self):
184 from twisted.internet.interfaces import ISSLTransport
185 if not ISSLTransport.providedBy(self.transport):
186 raise DeathThreat("only send secure pings over secure channels")
187 return {'pinged': True}
188 SecuredPing.responder(emitpong)
189
190
191 class SimpleSymmetricCommandProtocol(FactoryNotifier):
192 maybeLater = None
193 def __init__(self, onConnLost=None):
194 amp.AMP.__init__(self)
195 self.onConnLost = onConnLost
196
197 def sendHello(self, text):
198 return self.callRemote(Hello, hello=text)
199
200 def sendUnicodeHello(self, text, translation):
201 return self.callRemote(Hello, hello=text, Print=translation)
202
203 greeted = False
204
205 def cmdHello(self, hello, From, optional=None, Print=None,
206 mixedCase=None, dash_arg=None, underscore_arg=None):
207 assert From == self.transport.getPeer()
208 if hello == THING_I_DONT_UNDERSTAND:
209 raise ThingIDontUnderstandError()
210 if hello.startswith('fuck'):
211 raise UnfriendlyGreeting("Don't be a dick.")
212 if hello == 'die':
213 raise DeathThreat("aieeeeeeeee")
214 result = dict(hello=hello)
215 if Print is not None:
216 result.update(dict(Print=Print))
217 self.greeted = True
218 return result
219 Hello.responder(cmdHello)
220
221 def cmdGetlist(self, length):
222 return {'body': [dict(x=1)] * length}
223 GetList.responder(cmdGetlist)
224
225 def waitforit(self):
226 self.waiting = defer.Deferred()
227 return self.waiting
228 WaitForever.responder(waitforit)
229
230 def howdo(self):
231 return dict(howdoyoudo='world')
232 Howdoyoudo.responder(howdo)
233
234 def saybye(self):
235 return dict(goodbye="everyone")
236 Goodbye.responder(saybye)
237
238 def switchToTestProtocol(self, fail=False):
239 if fail:
240 name = 'no-proto'
241 else:
242 name = 'test-proto'
243 p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
244 return self.callRemote(
245 TestSwitchProto,
246 SingleUseFactory(p), name=name).addCallback(lambda ign: p)
247
248 def switchit(self, name):
249 if name == 'test-proto':
250 return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
251 raise UnknownProtocol(name)
252 TestSwitchProto.responder(switchit)
253
254 def donothing(self):
255 return None
256 BrokenReturn.responder(donothing)
257
258
259 class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
260 def switchit(self, name):
261 if name == 'test-proto':
262 self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA )
263 self.maybeLater = defer.Deferred()
264 return self.maybeLater
265 raise UnknownProtocol(name)
266 TestSwitchProto.responder(switchit)
267
268 class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
269 def badResponder(self, hello, From, optional=None, Print=None,
270 mixedCase=None, dash_arg=None, underscore_arg=None):
271 """
272 This responder does nothing and forgets to return a dictionary.
273 """
274 NoAnswerHello.responder(badResponder)
275
276 class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
277 def goodNoAnswerResponder(self, hello, From, optional=None, Print=None,
278 mixedCase=None, dash_arg=None, underscore_arg=None ):
279 return dict(hello=hello+"-noanswer")
280 NoAnswerHello.responder(goodNoAnswerResponder)
281
282 def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol,
283 ClientClass=SimpleSymmetricProtocol,
284 *a, **kw):
285 """Returns a 3-tuple: (client, server, pump)
286 """
287 return iosim.connectedServerAndClient(
288 ServerClass, ClientClass,
289 *a, **kw)
290
291 class TotallyDumbProtocol(protocol.Protocol):
292 buf = ''
293 def dataReceived(self, data):
294 self.buf += data
295
296 class LiteralAmp(amp.AMP):
297 def __init__(self):
298 self.boxes = []
299
300 def ampBoxReceived(self, box):
301 self.boxes.append(box)
302 return
303
304 class ParsingTest(unittest.TestCase):
305
306 def test_booleanValues(self):
307 """
308 Verify that the Boolean parser parses 'True' and 'False', but nothing
309 else.
310 """
311 b = amp.Boolean()
312 self.assertEquals(b.fromString("True"), True)
313 self.assertEquals(b.fromString("False"), False)
314 self.assertRaises(TypeError, b.fromString, "ninja")
315 self.assertRaises(TypeError, b.fromString, "true")
316 self.assertRaises(TypeError, b.fromString, "TRUE")
317 self.assertEquals(b.toString(True), 'True')
318 self.assertEquals(b.toString(False), 'False')
319
320 def test_pathValueRoundTrip(self):
321 """
322 Verify the 'Path' argument can parse and emit a file path.
323 """
324 fp = filepath.FilePath(self.mktemp())
325 p = amp.Path()
326 s = p.toString(fp)
327 v = p.fromString(s)
328 self.assertNotIdentical(fp, v) # sanity check
329 self.assertEquals(fp, v)
330
331
332 def test_sillyEmptyThing(self):
333 """
334 Test that empty boxes raise an error; they aren't supposed to be sent
335 on purpose.
336 """
337 a = amp.AMP()
338 return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
339
340
341 def test_ParsingRoundTrip(self):
342 """
343 Verify that various kinds of data make it through the encode/parse
344 round-trip unharmed.
345 """
346 c, s, p = connectedServerAndClient(ClientClass=LiteralAmp,
347 ServerClass=LiteralAmp)
348
349 SIMPLE = ('simple', 'test')
350 CE = ('ceq', ': ')
351 CR = ('crtest', 'test\r')
352 LF = ('lftest', 'hello\n')
353 NEWLINE = ('newline', 'test\r\none\r\ntwo')
354 NEWLINE2 = ('newline2', 'test\r\none\r\n two')
355 BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline')
356 BODYTEST = ('body', 'blah\r\n\r\ntesttest')
357
358 testData = [
359 [SIMPLE],
360 [SIMPLE, BODYTEST],
361 [SIMPLE, CE],
362 [SIMPLE, CR],
363 [SIMPLE, CE, CR, LF],
364 [CE, CR, LF],
365 [SIMPLE, NEWLINE, CE, NEWLINE2],
366 [BODYTEST, SIMPLE, NEWLINE]
367 ]
368
369 for test in testData:
370 jb = amp.Box()
371 jb.update(dict(test))
372 jb._sendTo(c)
373 p.flush()
374 self.assertEquals(s.boxes[-1], jb)
375
376
377
378 class FakeLocator(object):
379 """
380 This is a fake implementation of the interface implied by
381 L{CommandLocator}.
382 """
383 def __init__(self):
384 """
385 Remember the given keyword arguments as a set of responders.
386 """
387 self.commands = {}
388
389
390 def locateResponder(self, commandName):
391 """
392 Look up and return a function passed as a keyword argument of the given
393 name to the constructor.
394 """
395 return self.commands[commandName]
396
397
398 class FakeSender:
399 """
400 This is a fake implementation of the 'box sender' interface implied by
401 L{AMP}.
402 """
403 def __init__(self):
404 """
405 Create a fake sender and initialize the list of received boxes and
406 unhandled errors.
407 """
408 self.sentBoxes = []
409 self.unhandledErrors = []
410 self.expectedErrors = 0
411
412
413 def expectError(self):
414 """
415 Expect one error, so that the test doesn't fail.
416 """
417 self.expectedErrors += 1
418
419
420 def sendBox(self, box):
421 """
422 Accept a box, but don't do anything.
423 """
424 self.sentBoxes.append(box)
425
426
427 def unhandledError(self, failure):
428 """
429 Deal with failures by instantly re-raising them for easier debugging.
430 """
431 self.expectedErrors -= 1
432 if self.expectedErrors < 0:
433 failure.raiseException()
434 else:
435 self.unhandledErrors.append(failure)
436
437
438
439 class CommandDispatchTests(unittest.TestCase):
440 """
441 The AMP CommandDispatcher class dispatches converts AMP boxes into commands
442 and responses using Command.responder decorator.
443
444 Note: Originally, AMP's factoring was such that many tests for this
445 functionality are now implemented as full round-trip tests in L{AMPTest}.
446 Future tests should be written at this level instead, to ensure API
447 compatibility and to provide more granular, readable units of test
448 coverage.
449 """
450
451 def setUp(self):
452 """
453 Create a dispatcher to use.
454 """
455 self.locator = FakeLocator()
456 self.sender = FakeSender()
457 self.dispatcher = amp.BoxDispatcher(self.locator)
458 self.dispatcher.startReceivingBoxes(self.sender)
459
460
461 def test_receivedAsk(self):
462 """
463 L{CommandDispatcher.ampBoxReceived} should locate the appropriate
464 command in its responder lookup, based on the '_ask' key.
465 """
466 received = []
467 def thunk(box):
468 received.append(box)
469 return amp.Box({"hello": "goodbye"})
470 input = amp.Box(_command="hello",
471 _ask="test-command-id",
472 hello="world")
473 self.locator.commands['hello'] = thunk
474 self.dispatcher.ampBoxReceived(input)
475 self.assertEquals(received, [input])
476
477
478 def test_sendUnhandledError(self):
479 """
480 L{CommandDispatcher} should relay its unhandled errors in responding to
481 boxes to its boxSender.
482 """
483 err = RuntimeError("something went wrong, oh no")
484 self.sender.expectError()
485 self.dispatcher.unhandledError(Failure(err))
486 self.assertEqual(len(self.sender.unhandledErrors), 1)
487 self.assertEqual(self.sender.unhandledErrors[0].value, err)
488
489
490 def test_unhandledSerializationError(self):
491 """
492 Errors during serialization ought to be relayed to the sender's
493 unhandledError method.
494 """
495 err = RuntimeError("something undefined went wrong")
496 def thunk(result):
497 class BrokenBox(amp.Box):
498 def _sendTo(self, proto):
499 raise err
500 return BrokenBox()
501 self.locator.commands['hello'] = thunk
502 input = amp.Box(_command="hello",
503 _ask="test-command-id",
504 hello="world")
505 self.sender.expectError()
506 self.dispatcher.ampBoxReceived(input)
507 self.assertEquals(len(self.sender.unhandledErrors), 1)
508 self.assertEquals(self.sender.unhandledErrors[0].value, err)
509
510
511 def test_callRemote(self):
512 """
513 L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
514 box to its boxSender and record an outstanding L{Deferred}. When a
515 corresponding '_answer' packet is received, the L{Deferred} should be
516 fired, and the results translated via the given L{Command}'s response
517 de-serialization.
518 """
519 D = self.dispatcher.callRemote(Hello, hello='world')
520 self.assertEquals(self.sender.sentBoxes,
521 [amp.AmpBox(_command="hello",
522 _ask="1",
523 hello="world")])
524 answers = []
525 D.addCallback(answers.append)
526 self.assertEquals(answers, [])
527 self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay",
528 'print': "ignored",
529 '_answer': "1"}))
530 self.assertEquals(answers, [dict(hello="yay",
531 Print=u"ignored")])
532
533
534 class SimpleGreeting(amp.Command):
535 """
536 A very simple greeting command that uses a few basic argument types.
537 """
538 commandName = 'simple'
539 arguments = [('greeting', amp.Unicode()),
540 ('cookie', amp.Integer())]
541 response = [('cookieplus', amp.Integer())]
542
543
544 class TestLocator(amp.CommandLocator):
545 """
546 A locator which implements a responder to a 'hello' command.
547 """
548 def __init__(self):
549 self.greetings = []
550
551
552 def greetingResponder(self, greeting, cookie):
553 self.greetings.append((greeting, cookie))
554 return dict(cookieplus=cookie + 3)
555 greetingResponder = SimpleGreeting.responder(greetingResponder)
556
557
558
559 class OverrideLocatorAMP(amp.AMP):
560 def __init__(self):
561 amp.AMP.__init__(self)
562 self.customResponder = object()
563 self.expectations = {"custom": self.customResponder}
564 self.greetings = []
565
566
567 def lookupFunction(self, name):
568 """
569 Override the deprecated lookupFunction function.
570 """
571 if name in self.expectations:
572 result = self.expectations[name]
573 return result
574 else:
575 return super(OverrideLocatorAMP, self).lookupFunction(name)
576
577
578 def greetingResponder(self, greeting, cookie):
579 self.greetings.append((greeting, cookie))
580 return dict(cookieplus=cookie + 3)
581 greetingResponder = SimpleGreeting.responder(greetingResponder)
582
583
584
585
586 class CommandLocatorTests(unittest.TestCase):
587 """
588 The CommandLocator should enable users to specify responders to commands as
589 functions that take structured objects, annotated with metadata.
590 """
591
592 def test_responderDecorator(self):
593 """
594 A method on a L{CommandLocator} subclass decorated with a L{Command}
595 subclass's L{responder} decorator should be returned from
596 locateResponder, wrapped in logic to serialize and deserialize its
597 arguments.
598 """
599 locator = TestLocator()
600 responderCallable = locator.locateResponder("simple")
601 result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
602 def done(values):
603 self.assertEquals(values, amp.AmpBox(cookieplus='8'))
604 return result.addCallback(done)
605
606
607 def test_lookupFunctionDeprecatedOverride(self):
608 """
609 Subclasses which override locateResponder under its old name,
610 lookupFunction, should have the override invoked instead. (This tests
611 an AMP subclass, because in the version of the code that could invoke
612 this deprecated code path, there was no L{CommandLocator}.)
613 """
614 locator = OverrideLocatorAMP()
615 customResponderObject = self.assertWarns(
616 PendingDeprecationWarning,
617 "Override locateResponder, not lookupFunction.",
618 __file__, lambda : locator.locateResponder("custom"))
619 self.assertEquals(locator.customResponder, customResponderObject)
620 # Make sure upcalling works too
621 normalResponderObject = self.assertWarns(
622 PendingDeprecationWarning,
623 "Override locateResponder, not lookupFunction.",
624 __file__, lambda : locator.locateResponder("simple"))
625 result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5"))
626 def done(values):
627 self.assertEquals(values, amp.AmpBox(cookieplus='8'))
628 return result.addCallback(done)
629
630
631 def test_lookupFunctionDeprecatedInvoke(self):
632 """
633 Invoking locateResponder under its old name, lookupFunction, should
634 emit a deprecation warning, but do the same thing.
635 """
636 locator = TestLocator()
637 responderCallable = self.assertWarns(
638 PendingDeprecationWarning,
639 "Call locateResponder, not lookupFunction.", __file__,
640 lambda : locator.lookupFunction("simple"))
641 result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
642 def done(values):
643 self.assertEquals(values, amp.AmpBox(cookieplus='8'))
644 return result.addCallback(done)
645
646
647
648 SWITCH_CLIENT_DATA = 'Success!'
649 SWITCH_SERVER_DATA = 'No, really. Success.'
650
651
652 class BinaryProtocolTests(unittest.TestCase):
653 """
654 Tests for L{amp.BinaryBoxProtocol}.
655 """
656
657 def setUp(self):
658 """
659 Keep track of all boxes received by this test in its capacity as an
660 L{IBoxReceiver} implementor.
661 """
662 self.boxes = []
663 self.data = []
664
665
666 def startReceivingBoxes(self, sender):
667 """
668 Implement L{IBoxReceiver.startReceivingBoxes} to do nothing.
669 """
670
671
672 def ampBoxReceived(self, box):
673 """
674 A box was received by the protocol.
675 """
676 self.boxes.append(box)
677
678 stopReason = None
679 def stopReceivingBoxes(self, reason):
680 """
681 Record the reason that we stopped receiving boxes.
682 """
683 self.stopReason = reason
684
685
686 # fake ITransport
687 def getPeer(self):
688 return 'no peer'
689
690
691 def getHost(self):
692 return 'no host'
693
694
695 def write(self, data):
696 self.data.append(data)
697
698
699 def test_receiveBoxStateMachine(self):
700 """
701 When a binary box protocol receives:
702 * a key
703 * a value
704 * an empty string
705 it should emit a box and send it to its boxReceiver.
706 """
707 a = amp.BinaryBoxProtocol(self)
708 a.stringReceived("hello")
709 a.stringReceived("world")
710 a.stringReceived("")
711 self.assertEquals(self.boxes, [amp.AmpBox(hello="world")])
712
713
714 def test_receiveBoxData(self):
715 """
716 When a binary box protocol receives the serialized form of an AMP box,
717 it should emit a similar box to its boxReceiver.
718 """
719 a = amp.BinaryBoxProtocol(self)
720 a.dataReceived(amp.Box({"testKey": "valueTest",
721 "anotherKey": "anotherValue"}).serialize())
722 self.assertEquals(self.boxes,
723 [amp.Box({"testKey": "valueTest",
724 "anotherKey": "anotherValue"})])
725
726
727 def test_sendBox(self):
728 """
729 When a binary box protocol sends a box, it should emit the serialized
730 bytes of that box to its transport.
731 """
732 a = amp.BinaryBoxProtocol(self)
733 a.makeConnection(self)
734 aBox = amp.Box({"testKey": "valueTest",
735 "someData": "hello"})
736 a.makeConnection(self)
737 a.sendBox(aBox)
738 self.assertEquals(''.join(self.data), aBox.serialize())
739
740
741 def test_connectionLostStopSendingBoxes(self):
742 """
743 When a binary box protocol loses its connection, it should notify its
744 box receiver that it has stopped receiving boxes.
745 """
746 a = amp.BinaryBoxProtocol(self)
747 a.makeConnection(self)
748 aBox = amp.Box({"sample": "data"})
749 a.makeConnection(self)
750 connectionFailure = Failure(RuntimeError())
751 a.connectionLost(connectionFailure)
752 self.assertIdentical(self.stopReason, connectionFailure)
753
754
755 def test_protocolSwitch(self):
756 """
757 L{BinaryBoxProtocol} has the capacity to switch to a different protocol
758 on a box boundary. When a protocol is in the process of switching, it
759 cannot receive traffic.
760 """
761 otherProto = TestProto(None, "outgoing data")
762 test = self
763 class SwitchyReceiver:
764 switched = False
765 def startReceivingBoxes(self, sender):
766 pass
767 def ampBoxReceived(self, box):
768 test.assertFalse(self.switched,
769 "Should only receive one box!")
770 self.switched = True
771 a._lockForSwitch()
772 a._switchTo(otherProto)
773 a = amp.BinaryBoxProtocol(SwitchyReceiver())
774 anyOldBox = amp.Box({"include": "lots",
775 "of": "data"})
776 a.makeConnection(self)
777 # Include a 0-length box at the beginning of the next protocol's data,
778 # to make sure that AMP doesn't eat the data or try to deliver extra
779 # boxes either...
780 moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!"
781 a.dataReceived(moreThanOneBox)
782 self.assertIdentical(otherProto.transport, self)
783 self.assertEquals("".join(otherProto.data), "\x00\x00Hello, world!")
784 self.assertEquals(self.data, ["outgoing data"])
785 a.dataReceived("more data")
786 self.assertEquals("".join(otherProto.data),
787 "\x00\x00Hello, world!more data")
788 self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)
789
790
791 def test_protocolSwitchInvalidStates(self):
792 """
793 In order to make sure the protocol never gets any invalid data sent
794 into the middle of a box, it must be locked for switching before it is
795 switched. It can only be unlocked if the switch failed, and attempting
796 to send a box while it is locked should raise an exception.
797 """
798 a = amp.BinaryBoxProtocol(self)
799 a.makeConnection(self)
800 sampleBox = amp.Box({"some": "data"})
801 a._lockForSwitch()
802 self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
803 a._unlockFromSwitch()
804 a.sendBox(sampleBox)
805 self.assertEquals(''.join(self.data), sampleBox.serialize())
806 a._lockForSwitch()
807 otherProto = TestProto(None, "outgoing data")
808 a._switchTo(otherProto)
809 self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)
810
811
812 def test_protocolSwitchLoseConnection(self):
813 """
814 When the protocol is switched, it should notify its nested protocol of
815 disconnection.
816 """
817 class Loser(protocol.Protocol):
818 reason = None
819 def connectionLost(self, reason):
820 self.reason = reason
821 connectionLoser = Loser()
822 a = amp.BinaryBoxProtocol(self)
823 a.makeConnection(self)
824 a._lockForSwitch()
825 a._switchTo(connectionLoser)
826 connectionFailure = Failure(RuntimeError())
827 a.connectionLost(connectionFailure)
828 self.assertEquals(connectionLoser.reason, connectionFailure)
829
830
831 def test_protocolSwitchLoseClientConnection(self):
832 """
833 When the protocol is switched, it should notify its nested client
834 protocol factory of disconnection.
835 """
836 class ClientLoser:
837 reason = None
838 def clientConnectionLost(self, connector, reason):
839 self.reason = reason
840 a = amp.BinaryBoxProtocol(self)
841 connectionLoser = protocol.Protocol()
842 clientLoser = ClientLoser()
843 a.makeConnection(self)
844 a._lockForSwitch()
845 a._switchTo(connectionLoser, clientLoser)
846 connectionFailure = Failure(RuntimeError())
847 a.connectionLost(connectionFailure)
848 self.assertEquals(clientLoser.reason, connectionFailure)
849
850
851
852 class AMPTest(unittest.TestCase):
853
854 def test_interfaceDeclarations(self):
855 """
856 The classes in the amp module ought to implement the interfaces that
857 are declared for their benefit.
858 """
859 for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol ),
860 (amp.IBoxReceiver, amp.BoxDispatcher),
861 (amp.IResponderLocator, amp.CommandLoc ator),
862 (amp.IResponderLocator, amp.SimpleStri ngLocator),
863 (amp.IBoxSender, amp.AMP),
864 (amp.IBoxReceiver, amp.AMP),
865 (amp.IResponderLocator, amp.AMP)]:
866 self.failUnless(interface.implementedBy(implementation),
867 "%s does not implements(%s)" % (implementation, inte rface))
868
869
870 def test_helloWorld(self):
871 """
872 Verify that a simple command can be sent and its response received with
873 the simple low-level string-based API.
874 """
875 c, s, p = connectedServerAndClient()
876 L = []
877 HELLO = 'world'
878 c.sendHello(HELLO).addCallback(L.append)
879 p.flush()
880 self.assertEquals(L[0]['hello'], HELLO)
881
882
883 def test_wireFormatRoundTrip(self):
884 """
885 Verify that mixed-case, underscored and dashed arguments are mapped to
886 their python names properly.
887 """
888 c, s, p = connectedServerAndClient()
889 L = []
890 HELLO = 'world'
891 c.sendHello(HELLO).addCallback(L.append)
892 p.flush()
893 self.assertEquals(L[0]['hello'], HELLO)
894
895
896 def test_helloWorldUnicode(self):
897 """
898 Verify that unicode arguments can be encoded and decoded.
899 """
900 c, s, p = connectedServerAndClient(
901 ServerClass=SimpleSymmetricCommandProtocol,
902 ClientClass=SimpleSymmetricCommandProtocol)
903 L = []
904 HELLO = 'world'
905 HELLO_UNICODE = 'wor\u1234ld'
906 c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
907 p.flush()
908 self.assertEquals(L[0]['hello'], HELLO)
909 self.assertEquals(L[0]['Print'], HELLO_UNICODE)
910
911
912 def test_unknownCommandLow(self):
913 """
914 Verify that unknown commands using low-level APIs will be rejected with an
915 error, but will NOT terminate the connection.
916 """
917 c, s, p = connectedServerAndClient()
918 L = []
919 def clearAndAdd(e):
920 """
921 You can't propagate the error...
922 """
923 e.trap(amp.UnhandledCommand)
924 return "OK"
925 c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append)
926 p.flush()
927 self.assertEquals(L.pop(), "OK")
928 HELLO = 'world'
929 c.sendHello(HELLO).addCallback(L.append)
930 p.flush()
931 self.assertEquals(L[0]['hello'], HELLO)
932
933
934 def test_unknownCommandHigh(self):
935 """
936 Verify that unknown commands using high-level APIs will be rejected with an
937 error, but will NOT terminate the connection.
938 """
939 c, s, p = connectedServerAndClient()
940 L = []
941 def clearAndAdd(e):
942 """
943 You can't propagate the error...
944 """
945 e.trap(amp.UnhandledCommand)
946 return "OK"
947 c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
948 p.flush()
949 self.assertEquals(L.pop(), "OK")
950 HELLO = 'world'
951 c.sendHello(HELLO).addCallback(L.append)
952 p.flush()
953 self.assertEquals(L[0]['hello'], HELLO)
954
955
956 def test_brokenReturnValue(self):
957 """
958 It can be very confusing if you write some code which responds to a
959 command, but gets the return value wrong. Most commonly you end up
960 returning None instead of a dictionary.
961
962 Verify that if that happens, the framework logs a useful error.
963 """
964 L = []
965 SimpleSymmetricCommandProtocol().dispatchCommand(
966 amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append)
967 blr = L[0].trap(amp.BadLocalReturn)
968 self.failUnlessIn('None', repr(L[0].value))
969
970
971
972 def test_unknownArgument(self):
973 """
974 Verify that unknown arguments are ignored, and not passed to a Python
975 function which can't accept them.
976 """
977 c, s, p = connectedServerAndClient(
978 ServerClass=SimpleSymmetricCommandProtocol,
979 ClientClass=SimpleSymmetricCommandProtocol)
980 L = []
981 HELLO = 'world'
982 # c.sendHello(HELLO).addCallback(L.append)
983 c.callRemote(FutureHello,
984 hello=HELLO,
985 bonus="I'm not in the book!").addCallback(
986 L.append)
987 p.flush()
988 self.assertEquals(L[0]['hello'], HELLO)
989
990
991 def test_simpleReprs(self):
992 """
993 Verify that the various Box objects repr properly, for debugging.
994 """
995 self.assertEquals(type(repr(amp._TLSBox())), str)
996 self.assertEquals(type(repr(amp._SwitchBox('a'))), str)
997 self.assertEquals(type(repr(amp.QuitBox())), str)
998 self.assertEquals(type(repr(amp.AmpBox())), str)
999 self.failUnless("AmpBox" in repr(amp.AmpBox()))
1000
1001 def test_keyTooLong(self):
1002 """
1003 Verify that a key that is too long will immediately raise a synchronous
1004 exception.
1005 """
1006 c, s, p = connectedServerAndClient()
1007 L = []
1008 x = "H" * (0xff+1)
1009 tl = self.assertRaises(amp.TooLong,
1010 c.callRemoteString, "Hello",
1011 **{x: "hi"})
1012 self.failUnless(tl.isKey)
1013 self.failUnless(tl.isLocal)
1014 self.failUnlessIdentical(tl.keyName, None)
1015 self.failUnlessIdentical(tl.value, x)
1016 self.failUnless(str(len(x)) in repr(tl))
1017 self.failUnless("key" in repr(tl))
1018
1019
1020 def test_valueTooLong(self):
1021 """
1022 Verify that attempting to send value longer than 64k will immediately
1023 raise an exception.
1024 """
1025 c, s, p = connectedServerAndClient()
1026 L = []
1027 x = "H" * (0xffff+1)
1028 tl = self.assertRaises(amp.TooLong, c.sendHello, x)
1029 p.flush()
1030 self.failIf(tl.isKey)
1031 self.failUnless(tl.isLocal)
1032 self.failUnlessIdentical(tl.keyName, 'hello')
1033 self.failUnlessIdentical(tl.value, x)
1034 self.failUnless(str(len(x)) in repr(tl))
1035 self.failUnless("value" in repr(tl))
1036 self.failUnless('hello' in repr(tl))
1037
1038
1039 def test_helloWorldCommand(self):
1040 """
1041 Verify that a simple command can be sent and its response received with
1042 the high-level value parsing API.
1043 """
1044 c, s, p = connectedServerAndClient(
1045 ServerClass=SimpleSymmetricCommandProtocol,
1046 ClientClass=SimpleSymmetricCommandProtocol)
1047 L = []
1048 HELLO = 'world'
1049 c.sendHello(HELLO).addCallback(L.append)
1050 p.flush()
1051 self.assertEquals(L[0]['hello'], HELLO)
1052
1053
1054 def test_helloErrorHandling(self):
1055 """
1056 Verify that if a known error type is raised and handled, it will be
1057 properly relayed to the other end of the connection and translated into
1058 an exception, and no error will be logged.
1059 """
1060 L=[]
1061 c, s, p = connectedServerAndClient(
1062 ServerClass=SimpleSymmetricCommandProtocol,
1063 ClientClass=SimpleSymmetricCommandProtocol)
1064 HELLO = 'fuck you'
1065 c.sendHello(HELLO).addErrback(L.append)
1066 p.flush()
1067 L[0].trap(UnfriendlyGreeting)
1068 self.assertEquals(str(L[0].value), "Don't be a dick.")
1069
1070
1071 def test_helloFatalErrorHandling(self):
1072 """
1073 Verify that if a known, fatal error type is raised and handled, it will
1074 be properly relayed to the other end of the connection and translated
1075 into an exception, no error will be logged, and the connection will be
1076 terminated.
1077 """
1078 L=[]
1079 c, s, p = connectedServerAndClient(
1080 ServerClass=SimpleSymmetricCommandProtocol,
1081 ClientClass=SimpleSymmetricCommandProtocol)
1082 HELLO = 'die'
1083 c.sendHello(HELLO).addErrback(L.append)
1084 p.flush()
1085 L.pop().trap(DeathThreat)
1086 c.sendHello(HELLO).addErrback(L.append)
1087 p.flush()
1088 L.pop().trap(error.ConnectionDone)
1089
1090
1091
1092 def test_helloNoErrorHandling(self):
1093 """
1094 Verify that if an unknown error type is raised, it will be relayed to
1095 the other end of the connection and translated into an exception, it
1096 will be logged, and then the connection will be dropped.
1097 """
1098 L=[]
1099 c, s, p = connectedServerAndClient(
1100 ServerClass=SimpleSymmetricCommandProtocol,
1101 ClientClass=SimpleSymmetricCommandProtocol)
1102 HELLO = THING_I_DONT_UNDERSTAND
1103 c.sendHello(HELLO).addErrback(L.append)
1104 p.flush()
1105 ure = L.pop()
1106 ure.trap(amp.UnknownRemoteError)
1107 c.sendHello(HELLO).addErrback(L.append)
1108 cl = L.pop()
1109 cl.trap(error.ConnectionDone)
1110 # The exception should have been logged.
1111 self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError))
1112
1113
1114
1115 def test_lateAnswer(self):
1116 """
1117 Verify that a command that does not get answered until after the
1118 connection terminates will not cause any errors.
1119 """
1120 c, s, p = connectedServerAndClient(
1121 ServerClass=SimpleSymmetricCommandProtocol,
1122 ClientClass=SimpleSymmetricCommandProtocol)
1123 L = []
1124 HELLO = 'world'
1125 c.callRemote(WaitForever).addErrback(L.append)
1126 p.flush()
1127 self.assertEquals(L, [])
1128 s.transport.loseConnection()
1129 p.flush()
1130 L.pop().trap(error.ConnectionDone)
1131 # Just make sure that it doesn't error...
1132 s.waiting.callback({})
1133 return s.waiting
1134
1135
1136 def test_requiresNoAnswer(self):
1137 """
1138 Verify that a command that requires no answer is run.
1139 """
1140 L=[]
1141 c, s, p = connectedServerAndClient(
1142 ServerClass=SimpleSymmetricCommandProtocol,
1143 ClientClass=SimpleSymmetricCommandProtocol)
1144 HELLO = 'world'
1145 c.callRemote(NoAnswerHello, hello=HELLO)
1146 p.flush()
1147 self.failUnless(s.greeted)
1148
1149
1150 def test_requiresNoAnswerFail(self):
1151 """
1152 Verify that commands sent after a failed no-answer request do not comple te.
1153 """
1154 L=[]
1155 c, s, p = connectedServerAndClient(
1156 ServerClass=SimpleSymmetricCommandProtocol,
1157 ClientClass=SimpleSymmetricCommandProtocol)
1158 HELLO = 'fuck you'
1159 c.callRemote(NoAnswerHello, hello=HELLO)
1160 p.flush()
1161 # This should be logged locally.
1162 self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError))
1163 HELLO = 'world'
1164 c.callRemote(Hello, hello=HELLO).addErrback(L.append)
1165 p.flush()
1166 L.pop().trap(error.ConnectionDone)
1167 self.failIf(s.greeted)
1168
1169
1170 def test_noAnswerResponderBadAnswer(self):
1171 """
1172 Verify that responders of requiresAnswer=False commands have to return
1173 a dictionary anyway.
1174
1175 (requiresAnswer is a hint from the _client_ - the server may be called
1176 upon to answer commands in any case, if the client wants to know when
1177 they complete.)
1178 """
1179 c, s, p = connectedServerAndClient(
1180 ServerClass=BadNoAnswerCommandProtocol,
1181 ClientClass=SimpleSymmetricCommandProtocol)
1182 c.callRemote(NoAnswerHello, hello="hello")
1183 p.flush()
1184 le = self.flushLoggedErrors(amp.BadLocalReturn)
1185 self.assertEquals(len(le), 1)
1186
1187
1188 def test_noAnswerResponderAskedForAnswer(self):
1189 """
1190 Verify that responders with requiresAnswer=False will actually respond
1191 if the client sets requiresAnswer=True. In other words, verify that
1192 requiresAnswer is a hint honored only by the client.
1193 """
1194 c, s, p = connectedServerAndClient(
1195 ServerClass=NoAnswerCommandProtocol,
1196 ClientClass=SimpleSymmetricCommandProtocol)
1197 L = []
1198 c.callRemote(Hello, hello="Hello!").addCallback(L.append)
1199 p.flush()
1200 self.assertEquals(len(L), 1)
1201 self.assertEquals(L, [dict(hello="Hello!-noanswer",
1202 Print=None)]) # Optional response argument
1203
1204
1205 def test_ampListCommand(self):
1206 """
1207 Test encoding of an argument that uses the AmpList encoding.
1208 """
1209 c, s, p = connectedServerAndClient(
1210 ServerClass=SimpleSymmetricCommandProtocol,
1211 ClientClass=SimpleSymmetricCommandProtocol)
1212 L = []
1213 c.callRemote(GetList, length=10).addCallback(L.append)
1214 p.flush()
1215 values = L.pop().get('body')
1216 self.assertEquals(values, [{'x': 1}] * 10)
1217
1218
1219 def test_failEarlyOnArgSending(self):
1220 """
1221 Verify that if we pass an invalid argument list (omitting an argument), an
1222 exception will be raised.
1223 """
1224 okayCommand = Hello(hello="What?")
1225 self.assertRaises(amp.InvalidSignature, Hello)
1226
1227
1228 def test_doubleProtocolSwitch(self):
1229 """
1230 As a debugging aid, a protocol system should raise a
1231 L{ProtocolSwitched} exception when asked to switch a protocol that is
1232 already switched.
1233 """
1234 serverDeferred = defer.Deferred()
1235 serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
1236 clientDeferred = defer.Deferred()
1237 clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
1238 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1239 ClientClass=lambda: clientProto)
1240 def switched(result):
1241 self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
1242 self.testSucceeded = True
1243 c.switchToTestProtocol().addCallback(switched)
1244 p.flush()
1245 self.failUnless(self.testSucceeded)
1246
1247
1248 def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol,
1249 spuriousTraffic=False,
1250 spuriousError=False):
1251 """
1252 Verify that it is possible to switch to another protocol mid-connection and
1253 send data to it successfully.
1254 """
1255 self.testSucceeded = False
1256
1257 serverDeferred = defer.Deferred()
1258 serverProto = switcher(serverDeferred)
1259 clientDeferred = defer.Deferred()
1260 clientProto = switcher(clientDeferred)
1261 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1262 ClientClass=lambda: clientProto)
1263
1264 if spuriousTraffic:
1265 wfdr = [] # remote
1266 wfd = c.callRemote(WaitForever).addErrback(wfdr.append)
1267 switchDeferred = c.switchToTestProtocol()
1268 if spuriousTraffic:
1269 self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world')
1270
1271 def cbConnsLost(((serverSuccess, serverData),
1272 (clientSuccess, clientData))):
1273 self.failUnless(serverSuccess)
1274 self.failUnless(clientSuccess)
1275 self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA)
1276 self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA)
1277 self.testSucceeded = True
1278
1279 def cbSwitch(proto):
1280 return defer.DeferredList(
1281 [serverDeferred, clientDeferred]).addCallback(cbConnsLost)
1282
1283 switchDeferred.addCallback(cbSwitch)
1284 p.flush()
1285 if serverProto.maybeLater is not None:
1286 serverProto.maybeLater.callback(serverProto.maybeLaterProto)
1287 p.flush()
1288 if spuriousTraffic:
1289 # switch is done here; do this here to make sure that if we're
1290 # going to corrupt the connection, we do it before it's closed.
1291 if spuriousError:
1292 s.waiting.errback(amp.RemoteAmpError(
1293 "SPURIOUS",
1294 "Here's some traffic in the form of an error."))
1295 else:
1296 s.waiting.callback({})
1297 p.flush()
1298 c.transport.loseConnection() # close it
1299 p.flush()
1300 self.failUnless(self.testSucceeded)
1301
1302
1303 def test_protocolSwitchDeferred(self):
1304 """
1305 Verify that protocol-switching even works if the value returned from
1306 the command that does the switch is deferred.
1307 """
1308 return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtoco l)
1309
1310
1311 def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
1312 """
1313 Verify that if we try to switch protocols and it fails, the connection
1314 stays up and we can go back to speaking AMP.
1315 """
1316 self.testSucceeded = False
1317
1318 serverDeferred = defer.Deferred()
1319 serverProto = switcher(serverDeferred)
1320 clientDeferred = defer.Deferred()
1321 clientProto = switcher(clientDeferred)
1322 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1323 ClientClass=lambda: clientProto)
1324 L = []
1325 switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append)
1326 p.flush()
1327 L.pop().trap(UnknownProtocol)
1328 self.failIf(self.testSucceeded)
1329 # It's a known error, so let's send a "hello" on the same connection;
1330 # it should work.
1331 c.sendHello('world').addCallback(L.append)
1332 p.flush()
1333 self.assertEqual(L.pop()['hello'], 'world')
1334
1335
1336 def test_trafficAfterSwitch(self):
1337 """
1338 Verify that attempts to send traffic after a switch will not corrupt
1339 the nested protocol.
1340 """
1341 return self.test_protocolSwitch(spuriousTraffic=True)
1342
1343
1344 def test_errorAfterSwitch(self):
1345 """
1346 Returning an error after a protocol switch should record the underlying
1347 error.
1348 """
1349 return self.test_protocolSwitch(spuriousTraffic=True,
1350 spuriousError=True)
1351
1352
1353 def test_quitBoxQuits(self):
1354 """
1355 Verify that commands with a responseType of QuitBox will in fact
1356 terminate the connection.
1357 """
1358 c, s, p = connectedServerAndClient(
1359 ServerClass=SimpleSymmetricCommandProtocol,
1360 ClientClass=SimpleSymmetricCommandProtocol)
1361
1362 L = []
1363 HELLO = 'world'
1364 GOODBYE = 'everyone'
1365 c.sendHello(HELLO).addCallback(L.append)
1366 p.flush()
1367 self.assertEquals(L.pop()['hello'], HELLO)
1368 c.callRemote(Goodbye).addCallback(L.append)
1369 p.flush()
1370 self.assertEquals(L.pop()['goodbye'], GOODBYE)
1371 c.sendHello(HELLO).addErrback(L.append)
1372 L.pop().trap(error.ConnectionDone)
1373
1374
1375 def test_basicLiteralEmit(self):
1376 """
1377 Verify that the command dictionaries for a callRemoteN look correct
1378 after being serialized and parsed.
1379 """
1380 c, s, p = connectedServerAndClient()
1381 L = []
1382 s.ampBoxReceived = L.append
1383 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
1384 dash_arg='x', underscore_arg='y')
1385 p.flush()
1386 self.assertEquals(len(L), 1)
1387 for k, v in [('_command', Hello.commandName),
1388 ('hello', 'hello test'),
1389 ('mixedCase', 'mixed case arg test'),
1390 ('dash-arg', 'x'),
1391 ('underscore_arg', 'y')]:
1392 self.assertEquals(L[-1].pop(k), v)
1393 L[-1].pop('_ask')
1394 self.assertEquals(L[-1], {})
1395
1396
1397 def test_basicStructuredEmit(self):
1398 """
1399 Verify that a call similar to basicLiteralEmit's is handled properly wit h
1400 high-level quoting and passing to Python methods, and that argument
1401 names are correctly handled.
1402 """
1403 L = []
1404 class StructuredHello(amp.AMP):
1405 def h(self, *a, **k):
1406 L.append((a, k))
1407 return dict(hello='aaa')
1408 Hello.responder(h)
1409 c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
1410 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
1411 dash_arg='x', underscore_arg='y').addCallback(L.append)
1412 p.flush()
1413 self.assertEquals(len(L), 2)
1414 self.assertEquals(L[0],
1415 ((), dict(
1416 hello='hello test',
1417 mixedCase='mixed case arg test',
1418 dash_arg='x',
1419 underscore_arg='y',
1420
1421 # XXX - should optional arguments just not be passed?
1422 # passing None seems a little odd, looking at the way it
1423 # turns out here... -glyph
1424 From=('file', 'file'),
1425 Print=None,
1426 optional=None,
1427 )))
1428 self.assertEquals(L[1], dict(Print=None, hello='aaa'))
1429
1430 class PretendRemoteCertificateAuthority:
1431 def checkIsPretendRemote(self):
1432 return True
1433
1434 class IOSimCert:
1435 verifyCount = 0
1436
1437 def options(self, *ign):
1438 return self
1439
1440 def iosimVerify(self, otherCert):
1441 """
1442 This isn't a real certificate, and wouldn't work on a real socket, but
1443 iosim specifies a different API so that we don't have to do any crypto
1444 math to demonstrate that the right functions get called in the right
1445 places.
1446 """
1447 assert otherCert is self
1448 self.verifyCount += 1
1449 return True
1450
1451 class OKCert(IOSimCert):
1452 def options(self, x):
1453 assert x.checkIsPretendRemote()
1454 return self
1455
1456 class GrumpyCert(IOSimCert):
1457 def iosimVerify(self, otherCert):
1458 self.verifyCount += 1
1459 return False
1460
1461 class DroppyCert(IOSimCert):
1462 def __init__(self, toDrop):
1463 self.toDrop = toDrop
1464
1465 def iosimVerify(self, otherCert):
1466 self.verifyCount += 1
1467 self.toDrop.loseConnection()
1468 return True
1469
1470 class SecurableProto(FactoryNotifier):
1471
1472 factory = None
1473
1474 def verifyFactory(self):
1475 return [PretendRemoteCertificateAuthority()]
1476
1477 def getTLSVars(self):
1478 cert = self.certFactory()
1479 verify = self.verifyFactory()
1480 return dict(
1481 tls_localCertificate=cert,
1482 tls_verifyAuthorities=verify)
1483 amp.StartTLS.responder(getTLSVars)
1484
1485
1486
1487 class TLSTest(unittest.TestCase):
1488 def test_startingTLS(self):
1489 """
1490 Verify that starting TLS and succeeding at handshaking sends all the
1491 notifications to all the right places.
1492 """
1493 cli, svr, p = connectedServerAndClient(
1494 ServerClass=SecurableProto,
1495 ClientClass=SecurableProto)
1496
1497 okc = OKCert()
1498 svr.certFactory = lambda : okc
1499
1500 cli.callRemote(
1501 amp.StartTLS, tls_localCertificate=okc,
1502 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1503
1504 # let's buffer something to be delivered securely
1505 L = []
1506 d = cli.callRemote(SecuredPing).addCallback(L.append)
1507 p.flush()
1508 # once for client once for server
1509 self.assertEquals(okc.verifyCount, 2)
1510 L = []
1511 d = cli.callRemote(SecuredPing).addCallback(L.append)
1512 p.flush()
1513 self.assertEqual(L[0], {'pinged': True})
1514
1515
1516 def test_startTooManyTimes(self):
1517 """
1518 Verify that the protocol will complain if we attempt to renegotiate TLS,
1519 which we don't support.
1520 """
1521 cli, svr, p = connectedServerAndClient(
1522 ServerClass=SecurableProto,
1523 ClientClass=SecurableProto)
1524
1525 okc = OKCert()
1526 svr.certFactory = lambda : okc
1527
1528 cli.callRemote(amp.StartTLS,
1529 tls_localCertificate=okc,
1530 tls_verifyAuthorities=[PretendRemoteCertificateAuthority( )])
1531 p.flush()
1532 cli.noPeerCertificate = True # this is totally fake
1533 self.assertRaises(
1534 amp.OnlyOneTLS,
1535 cli.callRemote,
1536 amp.StartTLS,
1537 tls_localCertificate=okc,
1538 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1539
1540
1541 def test_negotiationFailed(self):
1542 """
1543 Verify that starting TLS and failing on both sides at handshaking sends
1544 notifications to all the right places and terminates the connection.
1545 """
1546
1547 badCert = GrumpyCert()
1548
1549 cli, svr, p = connectedServerAndClient(
1550 ServerClass=SecurableProto,
1551 ClientClass=SecurableProto)
1552 svr.certFactory = lambda : badCert
1553
1554 cli.callRemote(amp.StartTLS,
1555 tls_localCertificate=badCert)
1556
1557 p.flush()
1558 # once for client once for server - but both fail
1559 self.assertEquals(badCert.verifyCount, 2)
1560 d = cli.callRemote(SecuredPing)
1561 p.flush()
1562 self.assertFailure(d, iosim.OpenSSLVerifyError)
1563
1564
1565 def test_negotiationFailedByClosing(self):
1566 """
1567 Verify that starting TLS and failing by way of a lost connection
1568 notices that it is probably an SSL problem.
1569 """
1570
1571 cli, svr, p = connectedServerAndClient(
1572 ServerClass=SecurableProto,
1573 ClientClass=SecurableProto)
1574 droppyCert = DroppyCert(svr.transport)
1575 svr.certFactory = lambda : droppyCert
1576
1577 secure = cli.callRemote(amp.StartTLS,
1578 tls_localCertificate=droppyCert)
1579
1580 p.flush()
1581
1582 self.assertEquals(droppyCert.verifyCount, 2)
1583
1584 d = cli.callRemote(SecuredPing)
1585 p.flush()
1586
1587 # it might be a good idea to move this exception somewhere more
1588 # reasonable.
1589 self.assertFailure(d, error.PeerVerifyError)
1590
1591
1592
1593 class InheritedError(Exception):
1594 """
1595 This error is used to check inheritance.
1596 """
1597
1598
1599
1600 class OtherInheritedError(Exception):
1601 """
1602 This is a distinct error for checking inheritance.
1603 """
1604
1605
1606
1607 class BaseCommand(amp.Command):
1608 """
1609 This provides a command that will be subclassed.
1610 """
1611 errors = {InheritedError: 'INHERITED_ERROR'}
1612
1613
1614
1615 class InheritedCommand(BaseCommand):
1616 """
1617 This is a command which subclasses another command but does not override
1618 anything.
1619 """
1620
1621
1622
1623 class AddErrorsCommand(BaseCommand):
1624 """
1625 This is a command which subclasses another command but adds errors to the
1626 list.
1627 """
1628 arguments = [('other', amp.Boolean())]
1629 errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'}
1630
1631
1632
1633 class NormalCommandProtocol(amp.AMP):
1634 """
1635 This is a protocol which responds to L{BaseCommand}, and is used to test
1636 that inheritance does not interfere with the normal handling of errors.
1637 """
1638 def resp(self):
1639 raise InheritedError()
1640 BaseCommand.responder(resp)
1641
1642
1643
1644 class InheritedCommandProtocol(amp.AMP):
1645 """
1646 This is a protocol which responds to L{InheritedCommand}, and is used to
1647 test that inherited commands inherit their bases' errors if they do not
1648 respond to any of their own.
1649 """
1650 def resp(self):
1651 raise InheritedError()
1652 InheritedCommand.responder(resp)
1653
1654
1655
1656 class AddedCommandProtocol(amp.AMP):
1657 """
1658 This is a protocol which responds to L{AddErrorsCommand}, and is used to
1659 test that inherited commands can add their own new types of errors, but
1660 still respond in the same way to their parents types of errors.
1661 """
1662 def resp(self, other):
1663 if other:
1664 raise OtherInheritedError()
1665 else:
1666 raise InheritedError()
1667 AddErrorsCommand.responder(resp)
1668
1669
1670
1671 class CommandInheritanceTests(unittest.TestCase):
1672 """
1673 These tests verify that commands inherit error conditions properly.
1674 """
1675
1676 def errorCheck(self, err, proto, cmd, **kw):
1677 """
1678 Check that the appropriate kind of error is raised when a given command
1679 is sent to a given protocol.
1680 """
1681 c, s, p = connectedServerAndClient(ServerClass=proto,
1682 ClientClass=proto)
1683 d = c.callRemote(cmd, **kw)
1684 d2 = self.failUnlessFailure(d, err)
1685 p.flush()
1686 return d2
1687
1688
1689 def test_basicErrorPropagation(self):
1690 """
1691 Verify that errors specified in a superclass are respected normally
1692 even if it has subclasses.
1693 """
1694 return self.errorCheck(
1695 InheritedError, NormalCommandProtocol, BaseCommand)
1696
1697
1698 def test_inheritedErrorPropagation(self):
1699 """
1700 Verify that errors specified in a superclass command are propagated to
1701 its subclasses.
1702 """
1703 return self.errorCheck(
1704 InheritedError, InheritedCommandProtocol, InheritedCommand)
1705
1706
1707 def test_inheritedErrorAddition(self):
1708 """
1709 Verify that new errors specified in a subclass of an existing command
1710 are honored even if the superclass defines some errors.
1711 """
1712 return self.errorCheck(
1713 OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=T rue)
1714
1715
1716 def test_additionWithOriginalError(self):
1717 """
1718 Verify that errors specified in a command's superclass are respected
1719 even if that command defines new errors itself.
1720 """
1721 return self.errorCheck(
1722 InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False)
1723
1724
1725 def _loseAndPass(err, proto):
1726 # be specific, pass on the error to the client.
1727 err.trap(error.ConnectionLost, error.ConnectionDone)
1728 del proto.connectionLost
1729 proto.connectionLost(err)
1730
1731
1732 class LiveFireBase:
1733 """
1734 Utility for connected reactor-using tests.
1735 """
1736
1737 def setUp(self):
1738 """
1739 Create an amp server and connect a client to it.
1740 """
1741 from twisted.internet import reactor
1742 self.serverFactory = protocol.ServerFactory()
1743 self.serverFactory.protocol = self.serverProto
1744 self.clientFactory = protocol.ClientFactory()
1745 self.clientFactory.protocol = self.clientProto
1746 self.clientFactory.onMade = defer.Deferred()
1747 self.serverFactory.onMade = defer.Deferred()
1748 self.serverPort = reactor.listenTCP(0, self.serverFactory)
1749 self.addCleanup(self.serverPort.stopListening)
1750 self.clientConn = reactor.connectTCP(
1751 '127.0.0.1', self.serverPort.getHost().port,
1752 self.clientFactory)
1753 self.addCleanup(self.clientConn.disconnect)
1754 def getProtos(rlst):
1755 self.cli = self.clientFactory.theProto
1756 self.svr = self.serverFactory.theProto
1757 dl = defer.DeferredList([self.clientFactory.onMade,
1758 self.serverFactory.onMade])
1759 return dl.addCallback(getProtos)
1760
1761 def tearDown(self):
1762 """
1763 Cleanup client and server connections, and check the error got at
1764 C{connectionLost}.
1765 """
1766 L = []
1767 for conn in self.cli, self.svr:
1768 if conn.transport is not None:
1769 # depend on amp's function connection-dropping behavior
1770 d = defer.Deferred().addErrback(_loseAndPass, conn)
1771 conn.connectionLost = d.errback
1772 conn.transport.loseConnection()
1773 L.append(d)
1774 return defer.gatherResults(L
1775 ).addErrback(lambda first: first.value.subFailure)
1776
1777
1778 def show(x):
1779 import sys
1780 sys.stdout.write(x+'\n')
1781 sys.stdout.flush()
1782
1783
1784 def tempSelfSigned():
1785 from twisted.internet import ssl
1786
1787 sharedDN = ssl.DN(CN='shared')
1788 key = ssl.KeyPair.generate()
1789 cr = key.certificateRequest(sharedDN)
1790 sscrd = key.signCertificateRequest(
1791 sharedDN, cr, lambda dn: True, 1234567)
1792 cert = key.newCertificate(sscrd)
1793 return cert
1794
1795 tempcert = tempSelfSigned()
1796
1797
1798 class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase):
1799 clientProto = SecurableProto
1800 serverProto = SecurableProto
1801 def test_liveFireCustomTLS(self):
1802 """
1803 Using real, live TLS, actually negotiate a connection.
1804
1805 This also looks at the 'peerCertificate' attribute's correctness, since
1806 that's actually loaded using OpenSSL calls, but the main purpose is to
1807 make sure that we didn't miss anything obvious in iosim about TLS
1808 negotiations.
1809 """
1810
1811 cert = tempcert
1812
1813 self.svr.verifyFactory = lambda : [cert]
1814 self.svr.certFactory = lambda : cert
1815 # only needed on the server, we specify the client below.
1816
1817 def secured(rslt):
1818 x = cert.digest()
1819 def pinged(rslt2):
1820 # Interesting. OpenSSL won't even _tell_ us about the peer
1821 # cert until we negotiate. we should be able to do this in
1822 # 'secured' instead, but it looks like we can't. I think this
1823 # is a bug somewhere far deeper than here.
1824 self.failUnlessEqual(x, self.cli.hostCertificate.digest())
1825 self.failUnlessEqual(x, self.cli.peerCertificate.digest())
1826 self.failUnlessEqual(x, self.svr.hostCertificate.digest())
1827 self.failUnlessEqual(x, self.svr.peerCertificate.digest())
1828 return self.cli.callRemote(SecuredPing).addCallback(pinged)
1829 return self.cli.callRemote(amp.StartTLS,
1830 tls_localCertificate=cert,
1831 tls_verifyAuthorities=[cert]).addCallback(sec ured)
1832
1833
1834 class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
1835 """
1836 Specific implementation of server side protocol with different
1837 management of TLS.
1838 """
1839 def getTLSVars(self):
1840 """
1841 @return: the global C{tempcert} certificate as local certificate.
1842 """
1843 return dict(tls_localCertificate=tempcert)
1844 amp.StartTLS.responder(getTLSVars)
1845
1846
1847 class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase):
1848
1849 clientProto = SimpleSymmetricCommandProtocol
1850 serverProto = SimpleSymmetricCommandProtocol
1851
1852 def test_liveFireDefaultTLS(self):
1853 """
1854 Verify that out of the box, we can start TLS to at least encrypt the
1855 connection, even if we don't have any certificates to use.
1856 """
1857 def secured(result):
1858 return self.cli.callRemote(SecuredPing)
1859 return self.cli.callRemote(amp.StartTLS).addCallback(secured)
1860
1861
1862 class WithServerTLSVerification(LiveFireBase, unittest.TestCase):
1863 clientProto = SimpleSymmetricCommandProtocol
1864 serverProto = SlightlySmartTLS
1865
1866 def test_anonymousVerifyingClient(self):
1867 """
1868 Verify that anonymous clients can verify server certificates.
1869 """
1870 def secured(result):
1871 return self.cli.callRemote(SecuredPing)
1872 return self.cli.callRemote(amp.StartTLS,
1873 tls_verifyAuthorities=[tempcert]
1874 ).addCallback(secured)
1875
1876
1877
1878 class ProtocolIncludingArgument(amp.Argument):
1879 """
1880 An L{amp.Argument} which encodes its parser and serializer
1881 arguments *including the protocol* into its parsed and serialized
1882 forms.
1883 """
1884
1885 def fromStringProto(self, string, protocol):
1886 """
1887 Don't decode anything; just return all possible information.
1888
1889 @return: A two-tuple of the input string and the protocol.
1890 """
1891 return (string, protocol)
1892
1893 def toStringProto(self, obj, protocol):
1894 """
1895 Encode identifying information about L{object} and protocol
1896 into a string for later verification.
1897
1898 @type obj: L{object}
1899 @type protocol: L{amp.AMP}
1900 """
1901 return "%s:%s" % (id(obj), id(protocol))
1902
1903
1904
1905 class ProtocolIncludingCommand(amp.Command):
1906 """
1907 A command that has argument and response schemas which use
1908 L{ProtocolIncludingArgument}.
1909 """
1910 arguments = [('weird', ProtocolIncludingArgument())]
1911 response = [('weird', ProtocolIncludingArgument())]
1912
1913
1914
1915 class MagicSchemaCommand(amp.Command):
1916 """
1917 A command which overrides L{parseResponse}, L{parseArguments}, and
1918 L{makeResponse}.
1919 """
1920 def parseResponse(self, strings, protocol):
1921 """
1922 Don't do any parsing, just jam the input strings and protocol
1923 onto the C{protocol.parseResponseArguments} attribute as a
1924 two-tuple. Return the original strings.
1925 """
1926 protocol.parseResponseArguments = (strings, protocol)
1927 return strings
1928 parseResponse = classmethod(parseResponse)
1929
1930
1931 def parseArguments(cls, strings, protocol):
1932 """
1933 Don't do any parsing, just jam the input strings and protocol
1934 onto the C{protocol.parseArgumentsArguments} attribute as a
1935 two-tuple. Return the original strings.
1936 """
1937 protocol.parseArgumentsArguments = (strings, protocol)
1938 return strings
1939 parseArguments = classmethod(parseArguments)
1940
1941
1942 def makeArguments(cls, objects, protocol):
1943 """
1944 Don't do any serializing, just jam the input strings and protocol
1945 onto the C{protocol.makeArgumentsArguments} attribute as a
1946 two-tuple. Return the original strings.
1947 """
1948 protocol.makeArgumentsArguments = (objects, protocol)
1949 return objects
1950 makeArguments = classmethod(makeArguments)
1951
1952
1953
1954 class NoNetworkProtocol(amp.AMP):
1955 """
1956 An L{amp.AMP} subclass which overrides private methods to avoid
1957 testing the network. It also provides a responder for
1958 L{MagicSchemaCommand} that does nothing, so that tests can test
1959 aspects of the interaction of L{amp.Command}s and L{amp.AMP}.
1960
1961 @ivar parseArgumentsArguments: Arguments that have been passed to any
1962 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
1963 this protocol.
1964
1965 @ivar parseResponseArguments: Responses that have been returned from a
1966 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
1967 this protocol.
1968
1969 @ivar makeArgumentsArguments: Arguments that have been serialized by any
1970 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
1971 this protocol.
1972 """
1973 def _sendBoxCommand(self, commandName, strings, requiresAnswer):
1974 """
1975 Return a Deferred which fires with the original strings.
1976 """
1977 return defer.succeed(strings)
1978
1979 MagicSchemaCommand.responder(lambda s, weird: {})
1980
1981
1982
1983 class MyBox(dict):
1984 """
1985 A unique dict subclass.
1986 """
1987
1988
1989
1990 class ProtocolIncludingCommandWithDifferentCommandType(
1991 ProtocolIncludingCommand):
1992 """
1993 A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
1994 """
1995 commandType = MyBox
1996
1997
1998
1999 class CommandTestCase(unittest.TestCase):
2000 """
2001 Tests for L{amp.Command}.
2002 """
2003
2004 def test_parseResponse(self):
2005 """
2006 There should be a class method of Command which accepts a
2007 mapping of argument names to serialized forms and returns a
2008 similar mapping whose values have been parsed via the
2009 Command's response schema.
2010 """
2011 protocol = object()
2012 result = 'whatever'
2013 strings = {'weird': result}
2014 self.assertEqual(
2015 ProtocolIncludingCommand.parseResponse(strings, protocol),
2016 {'weird': (result, protocol)})
2017
2018
2019 def test_callRemoteCallsParseResponse(self):
2020 """
2021 Making a remote call on a L{amp.Command} subclass which
2022 overrides the C{parseResponse} method should call that
2023 C{parseResponse} method to get the response.
2024 """
2025 client = NoNetworkProtocol()
2026 thingy = "weeoo"
2027 response = client.callRemote(MagicSchemaCommand, weird=thingy)
2028 def gotResponse(ign):
2029 self.assertEquals(client.parseResponseArguments,
2030 ({"weird": thingy}, client))
2031 response.addCallback(gotResponse)
2032 return response
2033
2034
2035 def test_parseArguments(self):
2036 """
2037 There should be a class method of L{amp.Command} which accepts
2038 a mapping of argument names to serialized forms and returns a
2039 similar mapping whose values have been parsed via the
2040 command's argument schema.
2041 """
2042 protocol = object()
2043 result = 'whatever'
2044 strings = {'weird': result}
2045 self.assertEqual(
2046 ProtocolIncludingCommand.parseArguments(strings, protocol),
2047 {'weird': (result, protocol)})
2048
2049
2050 def test_responderCallsParseArguments(self):
2051 """
2052 Making a remote call on a L{amp.Command} subclass which
2053 overrides the C{parseArguments} method should call that
2054 C{parseArguments} method to get the arguments.
2055 """
2056 protocol = NoNetworkProtocol()
2057 responder = protocol.locateResponder(MagicSchemaCommand.commandName)
2058 argument = object()
2059 response = responder(dict(weird=argument))
2060 response.addCallback(
2061 lambda ign: self.assertEqual(protocol.parseArgumentsArguments,
2062 ({"weird": argument}, protocol)))
2063 return response
2064
2065
2066 def test_makeArguments(self):
2067 """
2068 There should be a class method of L{amp.Command} which accepts
2069 a mapping of argument names to objects and returns a similar
2070 mapping whose values have been serialized via the command's
2071 argument schema.
2072 """
2073 protocol = object()
2074 argument = object()
2075 objects = {'weird': argument}
2076 self.assertEqual(
2077 ProtocolIncludingCommand.makeArguments(objects, protocol),
2078 {'weird': "%d:%d" % (id(argument), id(protocol))})
2079
2080
2081 def test_makeArgumentsUsesCommandType(self):
2082 """
2083 L{amp.Command.makeArguments}'s return type should be the type
2084 of the result of L{amp.Command.commandType}.
2085 """
2086 protocol = object()
2087 objects = {"weird": "whatever"}
2088
2089 result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
2090 objects, protocol)
2091 self.assertIdentical(type(result), MyBox)
2092
2093
2094 def test_callRemoteCallsMakeArguments(self):
2095 """
2096 Making a remote call on a L{amp.Command} subclass which
2097 overrides the C{makeArguments} method should call that
2098 C{makeArguments} method to get the response.
2099 """
2100 client = NoNetworkProtocol()
2101 argument = object()
2102 response = client.callRemote(MagicSchemaCommand, weird=argument)
2103 def gotResponse(ign):
2104 self.assertEqual(client.makeArgumentsArguments,
2105 ({"weird": argument}, client))
2106 response.addCallback(gotResponse)
2107 return response
2108
2109 if not interfaces.IReactorSSL.providedBy(reactor):
2110 skipMsg = 'This test case requires SSL support in the reactor'
2111 TLSTest.skip = skipMsg
2112 LiveFireTLSTestCase.skip = skipMsg
2113 PlainVanillaLiveFire.skip = skipMsg
2114 WithServerTLSVerification.skip = skipMsg
2115
OLDNEW
« no previous file with comments | « third_party/twisted_8_1/twisted/test/test_adbapi.py ('k') | third_party/twisted_8_1/twisted/test/test_application.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698