OLD | NEW |
| (Empty) |
1 # Copyright (c) 2005 Divmod, Inc. | |
2 # Copyright (c) 2007 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 from twisted.python import filepath | |
6 from twisted.python.failure import Failure | |
7 from twisted.protocols import amp | |
8 from twisted.test import iosim | |
9 from twisted.trial import unittest | |
10 from twisted.internet import protocol, defer, error, reactor, interfaces | |
11 | |
12 | |
13 class TestProto(protocol.Protocol): | |
14 def __init__(self, onConnLost, dataToSend): | |
15 self.onConnLost = onConnLost | |
16 self.dataToSend = dataToSend | |
17 | |
18 def connectionMade(self): | |
19 self.data = [] | |
20 self.transport.write(self.dataToSend) | |
21 | |
22 def dataReceived(self, bytes): | |
23 self.data.append(bytes) | |
24 # self.transport.loseConnection() | |
25 | |
26 def connectionLost(self, reason): | |
27 self.onConnLost.callback(self.data) | |
28 | |
29 | |
30 | |
31 class SimpleSymmetricProtocol(amp.AMP): | |
32 | |
33 def sendHello(self, text): | |
34 return self.callRemoteString( | |
35 "hello", | |
36 hello=text) | |
37 | |
38 def amp_HELLO(self, box): | |
39 return amp.Box(hello=box['hello']) | |
40 | |
41 def amp_HOWDOYOUDO(self, box): | |
42 return amp.QuitBox(howdoyoudo='world') | |
43 | |
44 | |
45 | |
46 class UnfriendlyGreeting(Exception): | |
47 """Greeting was insufficiently kind. | |
48 """ | |
49 | |
50 class DeathThreat(Exception): | |
51 """Greeting was insufficiently kind. | |
52 """ | |
53 | |
54 class UnknownProtocol(Exception): | |
55 """Asked to switch to the wrong protocol. | |
56 """ | |
57 | |
58 | |
59 class TransportPeer(amp.Argument): | |
60 # this serves as some informal documentation for how to get variables from | |
61 # the protocol or your environment and pass them to methods as arguments. | |
62 def retrieve(self, d, name, proto): | |
63 return '' | |
64 | |
65 def fromStringProto(self, notAString, proto): | |
66 return proto.transport.getPeer() | |
67 | |
68 def toBox(self, name, strings, objects, proto): | |
69 return | |
70 | |
71 | |
72 | |
73 class Hello(amp.Command): | |
74 | |
75 commandName = 'hello' | |
76 | |
77 arguments = [('hello', amp.String()), | |
78 ('optional', amp.Boolean(optional=True)), | |
79 ('print', amp.Unicode(optional=True)), | |
80 ('from', TransportPeer(optional=True)), | |
81 ('mixedCase', amp.String(optional=True)), | |
82 ('dash-arg', amp.String(optional=True)), | |
83 ('underscore_arg', amp.String(optional=True))] | |
84 | |
85 response = [('hello', amp.String()), | |
86 ('print', amp.Unicode(optional=True))] | |
87 | |
88 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} | |
89 | |
90 fatalErrors = {DeathThreat: 'DEAD'} | |
91 | |
92 class NoAnswerHello(Hello): | |
93 commandName = Hello.commandName | |
94 requiresAnswer = False | |
95 | |
96 class FutureHello(amp.Command): | |
97 commandName = 'hello' | |
98 | |
99 arguments = [('hello', amp.String()), | |
100 ('optional', amp.Boolean(optional=True)), | |
101 ('print', amp.Unicode(optional=True)), | |
102 ('from', TransportPeer(optional=True)), | |
103 ('bonus', amp.String(optional=True)), # addt'l arguments | |
104 # should generally be | |
105 # added at the end, and | |
106 # be optional... | |
107 ] | |
108 | |
109 response = [('hello', amp.String()), | |
110 ('print', amp.Unicode(optional=True))] | |
111 | |
112 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} | |
113 | |
114 class WTF(amp.Command): | |
115 """ | |
116 An example of an invalid command. | |
117 """ | |
118 | |
119 | |
120 class BrokenReturn(amp.Command): | |
121 """ An example of a perfectly good command, but the handler is going to retu
rn | |
122 None... | |
123 """ | |
124 | |
125 commandName = 'broken_return' | |
126 | |
127 class Goodbye(amp.Command): | |
128 # commandName left blank on purpose: this tests implicit command names. | |
129 response = [('goodbye', amp.String())] | |
130 responseType = amp.QuitBox | |
131 | |
132 class Howdoyoudo(amp.Command): | |
133 commandName = 'howdoyoudo' | |
134 # responseType = amp.QuitBox | |
135 | |
136 class WaitForever(amp.Command): | |
137 commandName = 'wait_forever' | |
138 | |
139 class GetList(amp.Command): | |
140 commandName = 'getlist' | |
141 arguments = [('length', amp.Integer())] | |
142 response = [('body', amp.AmpList([('x', amp.Integer())]))] | |
143 | |
144 class SecuredPing(amp.Command): | |
145 # XXX TODO: actually make this refuse to send over an insecure connection | |
146 response = [('pinged', amp.Boolean())] | |
147 | |
148 class TestSwitchProto(amp.ProtocolSwitchCommand): | |
149 commandName = 'Switch-Proto' | |
150 | |
151 arguments = [ | |
152 ('name', amp.String()), | |
153 ] | |
154 errors = {UnknownProtocol: 'UNKNOWN'} | |
155 | |
156 class SingleUseFactory(protocol.ClientFactory): | |
157 def __init__(self, proto): | |
158 self.proto = proto | |
159 self.proto.factory = self | |
160 | |
161 def buildProtocol(self, addr): | |
162 p, self.proto = self.proto, None | |
163 return p | |
164 | |
165 reasonFailed = None | |
166 | |
167 def clientConnectionFailed(self, connector, reason): | |
168 self.reasonFailed = reason | |
169 return | |
170 | |
171 THING_I_DONT_UNDERSTAND = 'gwebol nargo' | |
172 class ThingIDontUnderstandError(Exception): | |
173 pass | |
174 | |
175 class FactoryNotifier(amp.AMP): | |
176 factory = None | |
177 def connectionMade(self): | |
178 if self.factory is not None: | |
179 self.factory.theProto = self | |
180 if hasattr(self.factory, 'onMade'): | |
181 self.factory.onMade.callback(None) | |
182 | |
183 def emitpong(self): | |
184 from twisted.internet.interfaces import ISSLTransport | |
185 if not ISSLTransport.providedBy(self.transport): | |
186 raise DeathThreat("only send secure pings over secure channels") | |
187 return {'pinged': True} | |
188 SecuredPing.responder(emitpong) | |
189 | |
190 | |
191 class SimpleSymmetricCommandProtocol(FactoryNotifier): | |
192 maybeLater = None | |
193 def __init__(self, onConnLost=None): | |
194 amp.AMP.__init__(self) | |
195 self.onConnLost = onConnLost | |
196 | |
197 def sendHello(self, text): | |
198 return self.callRemote(Hello, hello=text) | |
199 | |
200 def sendUnicodeHello(self, text, translation): | |
201 return self.callRemote(Hello, hello=text, Print=translation) | |
202 | |
203 greeted = False | |
204 | |
205 def cmdHello(self, hello, From, optional=None, Print=None, | |
206 mixedCase=None, dash_arg=None, underscore_arg=None): | |
207 assert From == self.transport.getPeer() | |
208 if hello == THING_I_DONT_UNDERSTAND: | |
209 raise ThingIDontUnderstandError() | |
210 if hello.startswith('fuck'): | |
211 raise UnfriendlyGreeting("Don't be a dick.") | |
212 if hello == 'die': | |
213 raise DeathThreat("aieeeeeeeee") | |
214 result = dict(hello=hello) | |
215 if Print is not None: | |
216 result.update(dict(Print=Print)) | |
217 self.greeted = True | |
218 return result | |
219 Hello.responder(cmdHello) | |
220 | |
221 def cmdGetlist(self, length): | |
222 return {'body': [dict(x=1)] * length} | |
223 GetList.responder(cmdGetlist) | |
224 | |
225 def waitforit(self): | |
226 self.waiting = defer.Deferred() | |
227 return self.waiting | |
228 WaitForever.responder(waitforit) | |
229 | |
230 def howdo(self): | |
231 return dict(howdoyoudo='world') | |
232 Howdoyoudo.responder(howdo) | |
233 | |
234 def saybye(self): | |
235 return dict(goodbye="everyone") | |
236 Goodbye.responder(saybye) | |
237 | |
238 def switchToTestProtocol(self, fail=False): | |
239 if fail: | |
240 name = 'no-proto' | |
241 else: | |
242 name = 'test-proto' | |
243 p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA) | |
244 return self.callRemote( | |
245 TestSwitchProto, | |
246 SingleUseFactory(p), name=name).addCallback(lambda ign: p) | |
247 | |
248 def switchit(self, name): | |
249 if name == 'test-proto': | |
250 return TestProto(self.onConnLost, SWITCH_SERVER_DATA) | |
251 raise UnknownProtocol(name) | |
252 TestSwitchProto.responder(switchit) | |
253 | |
254 def donothing(self): | |
255 return None | |
256 BrokenReturn.responder(donothing) | |
257 | |
258 | |
259 class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol): | |
260 def switchit(self, name): | |
261 if name == 'test-proto': | |
262 self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA
) | |
263 self.maybeLater = defer.Deferred() | |
264 return self.maybeLater | |
265 raise UnknownProtocol(name) | |
266 TestSwitchProto.responder(switchit) | |
267 | |
268 class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): | |
269 def badResponder(self, hello, From, optional=None, Print=None, | |
270 mixedCase=None, dash_arg=None, underscore_arg=None): | |
271 """ | |
272 This responder does nothing and forgets to return a dictionary. | |
273 """ | |
274 NoAnswerHello.responder(badResponder) | |
275 | |
276 class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): | |
277 def goodNoAnswerResponder(self, hello, From, optional=None, Print=None, | |
278 mixedCase=None, dash_arg=None, underscore_arg=None
): | |
279 return dict(hello=hello+"-noanswer") | |
280 NoAnswerHello.responder(goodNoAnswerResponder) | |
281 | |
282 def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol, | |
283 ClientClass=SimpleSymmetricProtocol, | |
284 *a, **kw): | |
285 """Returns a 3-tuple: (client, server, pump) | |
286 """ | |
287 return iosim.connectedServerAndClient( | |
288 ServerClass, ClientClass, | |
289 *a, **kw) | |
290 | |
291 class TotallyDumbProtocol(protocol.Protocol): | |
292 buf = '' | |
293 def dataReceived(self, data): | |
294 self.buf += data | |
295 | |
296 class LiteralAmp(amp.AMP): | |
297 def __init__(self): | |
298 self.boxes = [] | |
299 | |
300 def ampBoxReceived(self, box): | |
301 self.boxes.append(box) | |
302 return | |
303 | |
304 class ParsingTest(unittest.TestCase): | |
305 | |
306 def test_booleanValues(self): | |
307 """ | |
308 Verify that the Boolean parser parses 'True' and 'False', but nothing | |
309 else. | |
310 """ | |
311 b = amp.Boolean() | |
312 self.assertEquals(b.fromString("True"), True) | |
313 self.assertEquals(b.fromString("False"), False) | |
314 self.assertRaises(TypeError, b.fromString, "ninja") | |
315 self.assertRaises(TypeError, b.fromString, "true") | |
316 self.assertRaises(TypeError, b.fromString, "TRUE") | |
317 self.assertEquals(b.toString(True), 'True') | |
318 self.assertEquals(b.toString(False), 'False') | |
319 | |
320 def test_pathValueRoundTrip(self): | |
321 """ | |
322 Verify the 'Path' argument can parse and emit a file path. | |
323 """ | |
324 fp = filepath.FilePath(self.mktemp()) | |
325 p = amp.Path() | |
326 s = p.toString(fp) | |
327 v = p.fromString(s) | |
328 self.assertNotIdentical(fp, v) # sanity check | |
329 self.assertEquals(fp, v) | |
330 | |
331 | |
332 def test_sillyEmptyThing(self): | |
333 """ | |
334 Test that empty boxes raise an error; they aren't supposed to be sent | |
335 on purpose. | |
336 """ | |
337 a = amp.AMP() | |
338 return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box()) | |
339 | |
340 | |
341 def test_ParsingRoundTrip(self): | |
342 """ | |
343 Verify that various kinds of data make it through the encode/parse | |
344 round-trip unharmed. | |
345 """ | |
346 c, s, p = connectedServerAndClient(ClientClass=LiteralAmp, | |
347 ServerClass=LiteralAmp) | |
348 | |
349 SIMPLE = ('simple', 'test') | |
350 CE = ('ceq', ': ') | |
351 CR = ('crtest', 'test\r') | |
352 LF = ('lftest', 'hello\n') | |
353 NEWLINE = ('newline', 'test\r\none\r\ntwo') | |
354 NEWLINE2 = ('newline2', 'test\r\none\r\n two') | |
355 BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline') | |
356 BODYTEST = ('body', 'blah\r\n\r\ntesttest') | |
357 | |
358 testData = [ | |
359 [SIMPLE], | |
360 [SIMPLE, BODYTEST], | |
361 [SIMPLE, CE], | |
362 [SIMPLE, CR], | |
363 [SIMPLE, CE, CR, LF], | |
364 [CE, CR, LF], | |
365 [SIMPLE, NEWLINE, CE, NEWLINE2], | |
366 [BODYTEST, SIMPLE, NEWLINE] | |
367 ] | |
368 | |
369 for test in testData: | |
370 jb = amp.Box() | |
371 jb.update(dict(test)) | |
372 jb._sendTo(c) | |
373 p.flush() | |
374 self.assertEquals(s.boxes[-1], jb) | |
375 | |
376 | |
377 | |
378 class FakeLocator(object): | |
379 """ | |
380 This is a fake implementation of the interface implied by | |
381 L{CommandLocator}. | |
382 """ | |
383 def __init__(self): | |
384 """ | |
385 Remember the given keyword arguments as a set of responders. | |
386 """ | |
387 self.commands = {} | |
388 | |
389 | |
390 def locateResponder(self, commandName): | |
391 """ | |
392 Look up and return a function passed as a keyword argument of the given | |
393 name to the constructor. | |
394 """ | |
395 return self.commands[commandName] | |
396 | |
397 | |
398 class FakeSender: | |
399 """ | |
400 This is a fake implementation of the 'box sender' interface implied by | |
401 L{AMP}. | |
402 """ | |
403 def __init__(self): | |
404 """ | |
405 Create a fake sender and initialize the list of received boxes and | |
406 unhandled errors. | |
407 """ | |
408 self.sentBoxes = [] | |
409 self.unhandledErrors = [] | |
410 self.expectedErrors = 0 | |
411 | |
412 | |
413 def expectError(self): | |
414 """ | |
415 Expect one error, so that the test doesn't fail. | |
416 """ | |
417 self.expectedErrors += 1 | |
418 | |
419 | |
420 def sendBox(self, box): | |
421 """ | |
422 Accept a box, but don't do anything. | |
423 """ | |
424 self.sentBoxes.append(box) | |
425 | |
426 | |
427 def unhandledError(self, failure): | |
428 """ | |
429 Deal with failures by instantly re-raising them for easier debugging. | |
430 """ | |
431 self.expectedErrors -= 1 | |
432 if self.expectedErrors < 0: | |
433 failure.raiseException() | |
434 else: | |
435 self.unhandledErrors.append(failure) | |
436 | |
437 | |
438 | |
439 class CommandDispatchTests(unittest.TestCase): | |
440 """ | |
441 The AMP CommandDispatcher class dispatches converts AMP boxes into commands | |
442 and responses using Command.responder decorator. | |
443 | |
444 Note: Originally, AMP's factoring was such that many tests for this | |
445 functionality are now implemented as full round-trip tests in L{AMPTest}. | |
446 Future tests should be written at this level instead, to ensure API | |
447 compatibility and to provide more granular, readable units of test | |
448 coverage. | |
449 """ | |
450 | |
451 def setUp(self): | |
452 """ | |
453 Create a dispatcher to use. | |
454 """ | |
455 self.locator = FakeLocator() | |
456 self.sender = FakeSender() | |
457 self.dispatcher = amp.BoxDispatcher(self.locator) | |
458 self.dispatcher.startReceivingBoxes(self.sender) | |
459 | |
460 | |
461 def test_receivedAsk(self): | |
462 """ | |
463 L{CommandDispatcher.ampBoxReceived} should locate the appropriate | |
464 command in its responder lookup, based on the '_ask' key. | |
465 """ | |
466 received = [] | |
467 def thunk(box): | |
468 received.append(box) | |
469 return amp.Box({"hello": "goodbye"}) | |
470 input = amp.Box(_command="hello", | |
471 _ask="test-command-id", | |
472 hello="world") | |
473 self.locator.commands['hello'] = thunk | |
474 self.dispatcher.ampBoxReceived(input) | |
475 self.assertEquals(received, [input]) | |
476 | |
477 | |
478 def test_sendUnhandledError(self): | |
479 """ | |
480 L{CommandDispatcher} should relay its unhandled errors in responding to | |
481 boxes to its boxSender. | |
482 """ | |
483 err = RuntimeError("something went wrong, oh no") | |
484 self.sender.expectError() | |
485 self.dispatcher.unhandledError(Failure(err)) | |
486 self.assertEqual(len(self.sender.unhandledErrors), 1) | |
487 self.assertEqual(self.sender.unhandledErrors[0].value, err) | |
488 | |
489 | |
490 def test_unhandledSerializationError(self): | |
491 """ | |
492 Errors during serialization ought to be relayed to the sender's | |
493 unhandledError method. | |
494 """ | |
495 err = RuntimeError("something undefined went wrong") | |
496 def thunk(result): | |
497 class BrokenBox(amp.Box): | |
498 def _sendTo(self, proto): | |
499 raise err | |
500 return BrokenBox() | |
501 self.locator.commands['hello'] = thunk | |
502 input = amp.Box(_command="hello", | |
503 _ask="test-command-id", | |
504 hello="world") | |
505 self.sender.expectError() | |
506 self.dispatcher.ampBoxReceived(input) | |
507 self.assertEquals(len(self.sender.unhandledErrors), 1) | |
508 self.assertEquals(self.sender.unhandledErrors[0].value, err) | |
509 | |
510 | |
511 def test_callRemote(self): | |
512 """ | |
513 L{CommandDispatcher.callRemote} should emit a properly formatted '_ask' | |
514 box to its boxSender and record an outstanding L{Deferred}. When a | |
515 corresponding '_answer' packet is received, the L{Deferred} should be | |
516 fired, and the results translated via the given L{Command}'s response | |
517 de-serialization. | |
518 """ | |
519 D = self.dispatcher.callRemote(Hello, hello='world') | |
520 self.assertEquals(self.sender.sentBoxes, | |
521 [amp.AmpBox(_command="hello", | |
522 _ask="1", | |
523 hello="world")]) | |
524 answers = [] | |
525 D.addCallback(answers.append) | |
526 self.assertEquals(answers, []) | |
527 self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay", | |
528 'print': "ignored", | |
529 '_answer': "1"})) | |
530 self.assertEquals(answers, [dict(hello="yay", | |
531 Print=u"ignored")]) | |
532 | |
533 | |
534 class SimpleGreeting(amp.Command): | |
535 """ | |
536 A very simple greeting command that uses a few basic argument types. | |
537 """ | |
538 commandName = 'simple' | |
539 arguments = [('greeting', amp.Unicode()), | |
540 ('cookie', amp.Integer())] | |
541 response = [('cookieplus', amp.Integer())] | |
542 | |
543 | |
544 class TestLocator(amp.CommandLocator): | |
545 """ | |
546 A locator which implements a responder to a 'hello' command. | |
547 """ | |
548 def __init__(self): | |
549 self.greetings = [] | |
550 | |
551 | |
552 def greetingResponder(self, greeting, cookie): | |
553 self.greetings.append((greeting, cookie)) | |
554 return dict(cookieplus=cookie + 3) | |
555 greetingResponder = SimpleGreeting.responder(greetingResponder) | |
556 | |
557 | |
558 | |
559 class OverrideLocatorAMP(amp.AMP): | |
560 def __init__(self): | |
561 amp.AMP.__init__(self) | |
562 self.customResponder = object() | |
563 self.expectations = {"custom": self.customResponder} | |
564 self.greetings = [] | |
565 | |
566 | |
567 def lookupFunction(self, name): | |
568 """ | |
569 Override the deprecated lookupFunction function. | |
570 """ | |
571 if name in self.expectations: | |
572 result = self.expectations[name] | |
573 return result | |
574 else: | |
575 return super(OverrideLocatorAMP, self).lookupFunction(name) | |
576 | |
577 | |
578 def greetingResponder(self, greeting, cookie): | |
579 self.greetings.append((greeting, cookie)) | |
580 return dict(cookieplus=cookie + 3) | |
581 greetingResponder = SimpleGreeting.responder(greetingResponder) | |
582 | |
583 | |
584 | |
585 | |
586 class CommandLocatorTests(unittest.TestCase): | |
587 """ | |
588 The CommandLocator should enable users to specify responders to commands as | |
589 functions that take structured objects, annotated with metadata. | |
590 """ | |
591 | |
592 def test_responderDecorator(self): | |
593 """ | |
594 A method on a L{CommandLocator} subclass decorated with a L{Command} | |
595 subclass's L{responder} decorator should be returned from | |
596 locateResponder, wrapped in logic to serialize and deserialize its | |
597 arguments. | |
598 """ | |
599 locator = TestLocator() | |
600 responderCallable = locator.locateResponder("simple") | |
601 result = responderCallable(amp.Box(greeting="ni hao", cookie="5")) | |
602 def done(values): | |
603 self.assertEquals(values, amp.AmpBox(cookieplus='8')) | |
604 return result.addCallback(done) | |
605 | |
606 | |
607 def test_lookupFunctionDeprecatedOverride(self): | |
608 """ | |
609 Subclasses which override locateResponder under its old name, | |
610 lookupFunction, should have the override invoked instead. (This tests | |
611 an AMP subclass, because in the version of the code that could invoke | |
612 this deprecated code path, there was no L{CommandLocator}.) | |
613 """ | |
614 locator = OverrideLocatorAMP() | |
615 customResponderObject = self.assertWarns( | |
616 PendingDeprecationWarning, | |
617 "Override locateResponder, not lookupFunction.", | |
618 __file__, lambda : locator.locateResponder("custom")) | |
619 self.assertEquals(locator.customResponder, customResponderObject) | |
620 # Make sure upcalling works too | |
621 normalResponderObject = self.assertWarns( | |
622 PendingDeprecationWarning, | |
623 "Override locateResponder, not lookupFunction.", | |
624 __file__, lambda : locator.locateResponder("simple")) | |
625 result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5")) | |
626 def done(values): | |
627 self.assertEquals(values, amp.AmpBox(cookieplus='8')) | |
628 return result.addCallback(done) | |
629 | |
630 | |
631 def test_lookupFunctionDeprecatedInvoke(self): | |
632 """ | |
633 Invoking locateResponder under its old name, lookupFunction, should | |
634 emit a deprecation warning, but do the same thing. | |
635 """ | |
636 locator = TestLocator() | |
637 responderCallable = self.assertWarns( | |
638 PendingDeprecationWarning, | |
639 "Call locateResponder, not lookupFunction.", __file__, | |
640 lambda : locator.lookupFunction("simple")) | |
641 result = responderCallable(amp.Box(greeting="ni hao", cookie="5")) | |
642 def done(values): | |
643 self.assertEquals(values, amp.AmpBox(cookieplus='8')) | |
644 return result.addCallback(done) | |
645 | |
646 | |
647 | |
648 SWITCH_CLIENT_DATA = 'Success!' | |
649 SWITCH_SERVER_DATA = 'No, really. Success.' | |
650 | |
651 | |
652 class BinaryProtocolTests(unittest.TestCase): | |
653 """ | |
654 Tests for L{amp.BinaryBoxProtocol}. | |
655 """ | |
656 | |
657 def setUp(self): | |
658 """ | |
659 Keep track of all boxes received by this test in its capacity as an | |
660 L{IBoxReceiver} implementor. | |
661 """ | |
662 self.boxes = [] | |
663 self.data = [] | |
664 | |
665 | |
666 def startReceivingBoxes(self, sender): | |
667 """ | |
668 Implement L{IBoxReceiver.startReceivingBoxes} to do nothing. | |
669 """ | |
670 | |
671 | |
672 def ampBoxReceived(self, box): | |
673 """ | |
674 A box was received by the protocol. | |
675 """ | |
676 self.boxes.append(box) | |
677 | |
678 stopReason = None | |
679 def stopReceivingBoxes(self, reason): | |
680 """ | |
681 Record the reason that we stopped receiving boxes. | |
682 """ | |
683 self.stopReason = reason | |
684 | |
685 | |
686 # fake ITransport | |
687 def getPeer(self): | |
688 return 'no peer' | |
689 | |
690 | |
691 def getHost(self): | |
692 return 'no host' | |
693 | |
694 | |
695 def write(self, data): | |
696 self.data.append(data) | |
697 | |
698 | |
699 def test_receiveBoxStateMachine(self): | |
700 """ | |
701 When a binary box protocol receives: | |
702 * a key | |
703 * a value | |
704 * an empty string | |
705 it should emit a box and send it to its boxReceiver. | |
706 """ | |
707 a = amp.BinaryBoxProtocol(self) | |
708 a.stringReceived("hello") | |
709 a.stringReceived("world") | |
710 a.stringReceived("") | |
711 self.assertEquals(self.boxes, [amp.AmpBox(hello="world")]) | |
712 | |
713 | |
714 def test_receiveBoxData(self): | |
715 """ | |
716 When a binary box protocol receives the serialized form of an AMP box, | |
717 it should emit a similar box to its boxReceiver. | |
718 """ | |
719 a = amp.BinaryBoxProtocol(self) | |
720 a.dataReceived(amp.Box({"testKey": "valueTest", | |
721 "anotherKey": "anotherValue"}).serialize()) | |
722 self.assertEquals(self.boxes, | |
723 [amp.Box({"testKey": "valueTest", | |
724 "anotherKey": "anotherValue"})]) | |
725 | |
726 | |
727 def test_sendBox(self): | |
728 """ | |
729 When a binary box protocol sends a box, it should emit the serialized | |
730 bytes of that box to its transport. | |
731 """ | |
732 a = amp.BinaryBoxProtocol(self) | |
733 a.makeConnection(self) | |
734 aBox = amp.Box({"testKey": "valueTest", | |
735 "someData": "hello"}) | |
736 a.makeConnection(self) | |
737 a.sendBox(aBox) | |
738 self.assertEquals(''.join(self.data), aBox.serialize()) | |
739 | |
740 | |
741 def test_connectionLostStopSendingBoxes(self): | |
742 """ | |
743 When a binary box protocol loses its connection, it should notify its | |
744 box receiver that it has stopped receiving boxes. | |
745 """ | |
746 a = amp.BinaryBoxProtocol(self) | |
747 a.makeConnection(self) | |
748 aBox = amp.Box({"sample": "data"}) | |
749 a.makeConnection(self) | |
750 connectionFailure = Failure(RuntimeError()) | |
751 a.connectionLost(connectionFailure) | |
752 self.assertIdentical(self.stopReason, connectionFailure) | |
753 | |
754 | |
755 def test_protocolSwitch(self): | |
756 """ | |
757 L{BinaryBoxProtocol} has the capacity to switch to a different protocol | |
758 on a box boundary. When a protocol is in the process of switching, it | |
759 cannot receive traffic. | |
760 """ | |
761 otherProto = TestProto(None, "outgoing data") | |
762 test = self | |
763 class SwitchyReceiver: | |
764 switched = False | |
765 def startReceivingBoxes(self, sender): | |
766 pass | |
767 def ampBoxReceived(self, box): | |
768 test.assertFalse(self.switched, | |
769 "Should only receive one box!") | |
770 self.switched = True | |
771 a._lockForSwitch() | |
772 a._switchTo(otherProto) | |
773 a = amp.BinaryBoxProtocol(SwitchyReceiver()) | |
774 anyOldBox = amp.Box({"include": "lots", | |
775 "of": "data"}) | |
776 a.makeConnection(self) | |
777 # Include a 0-length box at the beginning of the next protocol's data, | |
778 # to make sure that AMP doesn't eat the data or try to deliver extra | |
779 # boxes either... | |
780 moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!" | |
781 a.dataReceived(moreThanOneBox) | |
782 self.assertIdentical(otherProto.transport, self) | |
783 self.assertEquals("".join(otherProto.data), "\x00\x00Hello, world!") | |
784 self.assertEquals(self.data, ["outgoing data"]) | |
785 a.dataReceived("more data") | |
786 self.assertEquals("".join(otherProto.data), | |
787 "\x00\x00Hello, world!more data") | |
788 self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox) | |
789 | |
790 | |
791 def test_protocolSwitchInvalidStates(self): | |
792 """ | |
793 In order to make sure the protocol never gets any invalid data sent | |
794 into the middle of a box, it must be locked for switching before it is | |
795 switched. It can only be unlocked if the switch failed, and attempting | |
796 to send a box while it is locked should raise an exception. | |
797 """ | |
798 a = amp.BinaryBoxProtocol(self) | |
799 a.makeConnection(self) | |
800 sampleBox = amp.Box({"some": "data"}) | |
801 a._lockForSwitch() | |
802 self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox) | |
803 a._unlockFromSwitch() | |
804 a.sendBox(sampleBox) | |
805 self.assertEquals(''.join(self.data), sampleBox.serialize()) | |
806 a._lockForSwitch() | |
807 otherProto = TestProto(None, "outgoing data") | |
808 a._switchTo(otherProto) | |
809 self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch) | |
810 | |
811 | |
812 def test_protocolSwitchLoseConnection(self): | |
813 """ | |
814 When the protocol is switched, it should notify its nested protocol of | |
815 disconnection. | |
816 """ | |
817 class Loser(protocol.Protocol): | |
818 reason = None | |
819 def connectionLost(self, reason): | |
820 self.reason = reason | |
821 connectionLoser = Loser() | |
822 a = amp.BinaryBoxProtocol(self) | |
823 a.makeConnection(self) | |
824 a._lockForSwitch() | |
825 a._switchTo(connectionLoser) | |
826 connectionFailure = Failure(RuntimeError()) | |
827 a.connectionLost(connectionFailure) | |
828 self.assertEquals(connectionLoser.reason, connectionFailure) | |
829 | |
830 | |
831 def test_protocolSwitchLoseClientConnection(self): | |
832 """ | |
833 When the protocol is switched, it should notify its nested client | |
834 protocol factory of disconnection. | |
835 """ | |
836 class ClientLoser: | |
837 reason = None | |
838 def clientConnectionLost(self, connector, reason): | |
839 self.reason = reason | |
840 a = amp.BinaryBoxProtocol(self) | |
841 connectionLoser = protocol.Protocol() | |
842 clientLoser = ClientLoser() | |
843 a.makeConnection(self) | |
844 a._lockForSwitch() | |
845 a._switchTo(connectionLoser, clientLoser) | |
846 connectionFailure = Failure(RuntimeError()) | |
847 a.connectionLost(connectionFailure) | |
848 self.assertEquals(clientLoser.reason, connectionFailure) | |
849 | |
850 | |
851 | |
852 class AMPTest(unittest.TestCase): | |
853 | |
854 def test_interfaceDeclarations(self): | |
855 """ | |
856 The classes in the amp module ought to implement the interfaces that | |
857 are declared for their benefit. | |
858 """ | |
859 for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol
), | |
860 (amp.IBoxReceiver, amp.BoxDispatcher), | |
861 (amp.IResponderLocator, amp.CommandLoc
ator), | |
862 (amp.IResponderLocator, amp.SimpleStri
ngLocator), | |
863 (amp.IBoxSender, amp.AMP), | |
864 (amp.IBoxReceiver, amp.AMP), | |
865 (amp.IResponderLocator, amp.AMP)]: | |
866 self.failUnless(interface.implementedBy(implementation), | |
867 "%s does not implements(%s)" % (implementation, inte
rface)) | |
868 | |
869 | |
870 def test_helloWorld(self): | |
871 """ | |
872 Verify that a simple command can be sent and its response received with | |
873 the simple low-level string-based API. | |
874 """ | |
875 c, s, p = connectedServerAndClient() | |
876 L = [] | |
877 HELLO = 'world' | |
878 c.sendHello(HELLO).addCallback(L.append) | |
879 p.flush() | |
880 self.assertEquals(L[0]['hello'], HELLO) | |
881 | |
882 | |
883 def test_wireFormatRoundTrip(self): | |
884 """ | |
885 Verify that mixed-case, underscored and dashed arguments are mapped to | |
886 their python names properly. | |
887 """ | |
888 c, s, p = connectedServerAndClient() | |
889 L = [] | |
890 HELLO = 'world' | |
891 c.sendHello(HELLO).addCallback(L.append) | |
892 p.flush() | |
893 self.assertEquals(L[0]['hello'], HELLO) | |
894 | |
895 | |
896 def test_helloWorldUnicode(self): | |
897 """ | |
898 Verify that unicode arguments can be encoded and decoded. | |
899 """ | |
900 c, s, p = connectedServerAndClient( | |
901 ServerClass=SimpleSymmetricCommandProtocol, | |
902 ClientClass=SimpleSymmetricCommandProtocol) | |
903 L = [] | |
904 HELLO = 'world' | |
905 HELLO_UNICODE = 'wor\u1234ld' | |
906 c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append) | |
907 p.flush() | |
908 self.assertEquals(L[0]['hello'], HELLO) | |
909 self.assertEquals(L[0]['Print'], HELLO_UNICODE) | |
910 | |
911 | |
912 def test_unknownCommandLow(self): | |
913 """ | |
914 Verify that unknown commands using low-level APIs will be rejected with
an | |
915 error, but will NOT terminate the connection. | |
916 """ | |
917 c, s, p = connectedServerAndClient() | |
918 L = [] | |
919 def clearAndAdd(e): | |
920 """ | |
921 You can't propagate the error... | |
922 """ | |
923 e.trap(amp.UnhandledCommand) | |
924 return "OK" | |
925 c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append) | |
926 p.flush() | |
927 self.assertEquals(L.pop(), "OK") | |
928 HELLO = 'world' | |
929 c.sendHello(HELLO).addCallback(L.append) | |
930 p.flush() | |
931 self.assertEquals(L[0]['hello'], HELLO) | |
932 | |
933 | |
934 def test_unknownCommandHigh(self): | |
935 """ | |
936 Verify that unknown commands using high-level APIs will be rejected with
an | |
937 error, but will NOT terminate the connection. | |
938 """ | |
939 c, s, p = connectedServerAndClient() | |
940 L = [] | |
941 def clearAndAdd(e): | |
942 """ | |
943 You can't propagate the error... | |
944 """ | |
945 e.trap(amp.UnhandledCommand) | |
946 return "OK" | |
947 c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append) | |
948 p.flush() | |
949 self.assertEquals(L.pop(), "OK") | |
950 HELLO = 'world' | |
951 c.sendHello(HELLO).addCallback(L.append) | |
952 p.flush() | |
953 self.assertEquals(L[0]['hello'], HELLO) | |
954 | |
955 | |
956 def test_brokenReturnValue(self): | |
957 """ | |
958 It can be very confusing if you write some code which responds to a | |
959 command, but gets the return value wrong. Most commonly you end up | |
960 returning None instead of a dictionary. | |
961 | |
962 Verify that if that happens, the framework logs a useful error. | |
963 """ | |
964 L = [] | |
965 SimpleSymmetricCommandProtocol().dispatchCommand( | |
966 amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append) | |
967 blr = L[0].trap(amp.BadLocalReturn) | |
968 self.failUnlessIn('None', repr(L[0].value)) | |
969 | |
970 | |
971 | |
972 def test_unknownArgument(self): | |
973 """ | |
974 Verify that unknown arguments are ignored, and not passed to a Python | |
975 function which can't accept them. | |
976 """ | |
977 c, s, p = connectedServerAndClient( | |
978 ServerClass=SimpleSymmetricCommandProtocol, | |
979 ClientClass=SimpleSymmetricCommandProtocol) | |
980 L = [] | |
981 HELLO = 'world' | |
982 # c.sendHello(HELLO).addCallback(L.append) | |
983 c.callRemote(FutureHello, | |
984 hello=HELLO, | |
985 bonus="I'm not in the book!").addCallback( | |
986 L.append) | |
987 p.flush() | |
988 self.assertEquals(L[0]['hello'], HELLO) | |
989 | |
990 | |
991 def test_simpleReprs(self): | |
992 """ | |
993 Verify that the various Box objects repr properly, for debugging. | |
994 """ | |
995 self.assertEquals(type(repr(amp._TLSBox())), str) | |
996 self.assertEquals(type(repr(amp._SwitchBox('a'))), str) | |
997 self.assertEquals(type(repr(amp.QuitBox())), str) | |
998 self.assertEquals(type(repr(amp.AmpBox())), str) | |
999 self.failUnless("AmpBox" in repr(amp.AmpBox())) | |
1000 | |
1001 def test_keyTooLong(self): | |
1002 """ | |
1003 Verify that a key that is too long will immediately raise a synchronous | |
1004 exception. | |
1005 """ | |
1006 c, s, p = connectedServerAndClient() | |
1007 L = [] | |
1008 x = "H" * (0xff+1) | |
1009 tl = self.assertRaises(amp.TooLong, | |
1010 c.callRemoteString, "Hello", | |
1011 **{x: "hi"}) | |
1012 self.failUnless(tl.isKey) | |
1013 self.failUnless(tl.isLocal) | |
1014 self.failUnlessIdentical(tl.keyName, None) | |
1015 self.failUnlessIdentical(tl.value, x) | |
1016 self.failUnless(str(len(x)) in repr(tl)) | |
1017 self.failUnless("key" in repr(tl)) | |
1018 | |
1019 | |
1020 def test_valueTooLong(self): | |
1021 """ | |
1022 Verify that attempting to send value longer than 64k will immediately | |
1023 raise an exception. | |
1024 """ | |
1025 c, s, p = connectedServerAndClient() | |
1026 L = [] | |
1027 x = "H" * (0xffff+1) | |
1028 tl = self.assertRaises(amp.TooLong, c.sendHello, x) | |
1029 p.flush() | |
1030 self.failIf(tl.isKey) | |
1031 self.failUnless(tl.isLocal) | |
1032 self.failUnlessIdentical(tl.keyName, 'hello') | |
1033 self.failUnlessIdentical(tl.value, x) | |
1034 self.failUnless(str(len(x)) in repr(tl)) | |
1035 self.failUnless("value" in repr(tl)) | |
1036 self.failUnless('hello' in repr(tl)) | |
1037 | |
1038 | |
1039 def test_helloWorldCommand(self): | |
1040 """ | |
1041 Verify that a simple command can be sent and its response received with | |
1042 the high-level value parsing API. | |
1043 """ | |
1044 c, s, p = connectedServerAndClient( | |
1045 ServerClass=SimpleSymmetricCommandProtocol, | |
1046 ClientClass=SimpleSymmetricCommandProtocol) | |
1047 L = [] | |
1048 HELLO = 'world' | |
1049 c.sendHello(HELLO).addCallback(L.append) | |
1050 p.flush() | |
1051 self.assertEquals(L[0]['hello'], HELLO) | |
1052 | |
1053 | |
1054 def test_helloErrorHandling(self): | |
1055 """ | |
1056 Verify that if a known error type is raised and handled, it will be | |
1057 properly relayed to the other end of the connection and translated into | |
1058 an exception, and no error will be logged. | |
1059 """ | |
1060 L=[] | |
1061 c, s, p = connectedServerAndClient( | |
1062 ServerClass=SimpleSymmetricCommandProtocol, | |
1063 ClientClass=SimpleSymmetricCommandProtocol) | |
1064 HELLO = 'fuck you' | |
1065 c.sendHello(HELLO).addErrback(L.append) | |
1066 p.flush() | |
1067 L[0].trap(UnfriendlyGreeting) | |
1068 self.assertEquals(str(L[0].value), "Don't be a dick.") | |
1069 | |
1070 | |
1071 def test_helloFatalErrorHandling(self): | |
1072 """ | |
1073 Verify that if a known, fatal error type is raised and handled, it will | |
1074 be properly relayed to the other end of the connection and translated | |
1075 into an exception, no error will be logged, and the connection will be | |
1076 terminated. | |
1077 """ | |
1078 L=[] | |
1079 c, s, p = connectedServerAndClient( | |
1080 ServerClass=SimpleSymmetricCommandProtocol, | |
1081 ClientClass=SimpleSymmetricCommandProtocol) | |
1082 HELLO = 'die' | |
1083 c.sendHello(HELLO).addErrback(L.append) | |
1084 p.flush() | |
1085 L.pop().trap(DeathThreat) | |
1086 c.sendHello(HELLO).addErrback(L.append) | |
1087 p.flush() | |
1088 L.pop().trap(error.ConnectionDone) | |
1089 | |
1090 | |
1091 | |
1092 def test_helloNoErrorHandling(self): | |
1093 """ | |
1094 Verify that if an unknown error type is raised, it will be relayed to | |
1095 the other end of the connection and translated into an exception, it | |
1096 will be logged, and then the connection will be dropped. | |
1097 """ | |
1098 L=[] | |
1099 c, s, p = connectedServerAndClient( | |
1100 ServerClass=SimpleSymmetricCommandProtocol, | |
1101 ClientClass=SimpleSymmetricCommandProtocol) | |
1102 HELLO = THING_I_DONT_UNDERSTAND | |
1103 c.sendHello(HELLO).addErrback(L.append) | |
1104 p.flush() | |
1105 ure = L.pop() | |
1106 ure.trap(amp.UnknownRemoteError) | |
1107 c.sendHello(HELLO).addErrback(L.append) | |
1108 cl = L.pop() | |
1109 cl.trap(error.ConnectionDone) | |
1110 # The exception should have been logged. | |
1111 self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError)) | |
1112 | |
1113 | |
1114 | |
1115 def test_lateAnswer(self): | |
1116 """ | |
1117 Verify that a command that does not get answered until after the | |
1118 connection terminates will not cause any errors. | |
1119 """ | |
1120 c, s, p = connectedServerAndClient( | |
1121 ServerClass=SimpleSymmetricCommandProtocol, | |
1122 ClientClass=SimpleSymmetricCommandProtocol) | |
1123 L = [] | |
1124 HELLO = 'world' | |
1125 c.callRemote(WaitForever).addErrback(L.append) | |
1126 p.flush() | |
1127 self.assertEquals(L, []) | |
1128 s.transport.loseConnection() | |
1129 p.flush() | |
1130 L.pop().trap(error.ConnectionDone) | |
1131 # Just make sure that it doesn't error... | |
1132 s.waiting.callback({}) | |
1133 return s.waiting | |
1134 | |
1135 | |
1136 def test_requiresNoAnswer(self): | |
1137 """ | |
1138 Verify that a command that requires no answer is run. | |
1139 """ | |
1140 L=[] | |
1141 c, s, p = connectedServerAndClient( | |
1142 ServerClass=SimpleSymmetricCommandProtocol, | |
1143 ClientClass=SimpleSymmetricCommandProtocol) | |
1144 HELLO = 'world' | |
1145 c.callRemote(NoAnswerHello, hello=HELLO) | |
1146 p.flush() | |
1147 self.failUnless(s.greeted) | |
1148 | |
1149 | |
1150 def test_requiresNoAnswerFail(self): | |
1151 """ | |
1152 Verify that commands sent after a failed no-answer request do not comple
te. | |
1153 """ | |
1154 L=[] | |
1155 c, s, p = connectedServerAndClient( | |
1156 ServerClass=SimpleSymmetricCommandProtocol, | |
1157 ClientClass=SimpleSymmetricCommandProtocol) | |
1158 HELLO = 'fuck you' | |
1159 c.callRemote(NoAnswerHello, hello=HELLO) | |
1160 p.flush() | |
1161 # This should be logged locally. | |
1162 self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError)) | |
1163 HELLO = 'world' | |
1164 c.callRemote(Hello, hello=HELLO).addErrback(L.append) | |
1165 p.flush() | |
1166 L.pop().trap(error.ConnectionDone) | |
1167 self.failIf(s.greeted) | |
1168 | |
1169 | |
1170 def test_noAnswerResponderBadAnswer(self): | |
1171 """ | |
1172 Verify that responders of requiresAnswer=False commands have to return | |
1173 a dictionary anyway. | |
1174 | |
1175 (requiresAnswer is a hint from the _client_ - the server may be called | |
1176 upon to answer commands in any case, if the client wants to know when | |
1177 they complete.) | |
1178 """ | |
1179 c, s, p = connectedServerAndClient( | |
1180 ServerClass=BadNoAnswerCommandProtocol, | |
1181 ClientClass=SimpleSymmetricCommandProtocol) | |
1182 c.callRemote(NoAnswerHello, hello="hello") | |
1183 p.flush() | |
1184 le = self.flushLoggedErrors(amp.BadLocalReturn) | |
1185 self.assertEquals(len(le), 1) | |
1186 | |
1187 | |
1188 def test_noAnswerResponderAskedForAnswer(self): | |
1189 """ | |
1190 Verify that responders with requiresAnswer=False will actually respond | |
1191 if the client sets requiresAnswer=True. In other words, verify that | |
1192 requiresAnswer is a hint honored only by the client. | |
1193 """ | |
1194 c, s, p = connectedServerAndClient( | |
1195 ServerClass=NoAnswerCommandProtocol, | |
1196 ClientClass=SimpleSymmetricCommandProtocol) | |
1197 L = [] | |
1198 c.callRemote(Hello, hello="Hello!").addCallback(L.append) | |
1199 p.flush() | |
1200 self.assertEquals(len(L), 1) | |
1201 self.assertEquals(L, [dict(hello="Hello!-noanswer", | |
1202 Print=None)]) # Optional response argument | |
1203 | |
1204 | |
1205 def test_ampListCommand(self): | |
1206 """ | |
1207 Test encoding of an argument that uses the AmpList encoding. | |
1208 """ | |
1209 c, s, p = connectedServerAndClient( | |
1210 ServerClass=SimpleSymmetricCommandProtocol, | |
1211 ClientClass=SimpleSymmetricCommandProtocol) | |
1212 L = [] | |
1213 c.callRemote(GetList, length=10).addCallback(L.append) | |
1214 p.flush() | |
1215 values = L.pop().get('body') | |
1216 self.assertEquals(values, [{'x': 1}] * 10) | |
1217 | |
1218 | |
1219 def test_failEarlyOnArgSending(self): | |
1220 """ | |
1221 Verify that if we pass an invalid argument list (omitting an argument),
an | |
1222 exception will be raised. | |
1223 """ | |
1224 okayCommand = Hello(hello="What?") | |
1225 self.assertRaises(amp.InvalidSignature, Hello) | |
1226 | |
1227 | |
1228 def test_doubleProtocolSwitch(self): | |
1229 """ | |
1230 As a debugging aid, a protocol system should raise a | |
1231 L{ProtocolSwitched} exception when asked to switch a protocol that is | |
1232 already switched. | |
1233 """ | |
1234 serverDeferred = defer.Deferred() | |
1235 serverProto = SimpleSymmetricCommandProtocol(serverDeferred) | |
1236 clientDeferred = defer.Deferred() | |
1237 clientProto = SimpleSymmetricCommandProtocol(clientDeferred) | |
1238 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, | |
1239 ClientClass=lambda: clientProto) | |
1240 def switched(result): | |
1241 self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol) | |
1242 self.testSucceeded = True | |
1243 c.switchToTestProtocol().addCallback(switched) | |
1244 p.flush() | |
1245 self.failUnless(self.testSucceeded) | |
1246 | |
1247 | |
1248 def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol, | |
1249 spuriousTraffic=False, | |
1250 spuriousError=False): | |
1251 """ | |
1252 Verify that it is possible to switch to another protocol mid-connection
and | |
1253 send data to it successfully. | |
1254 """ | |
1255 self.testSucceeded = False | |
1256 | |
1257 serverDeferred = defer.Deferred() | |
1258 serverProto = switcher(serverDeferred) | |
1259 clientDeferred = defer.Deferred() | |
1260 clientProto = switcher(clientDeferred) | |
1261 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, | |
1262 ClientClass=lambda: clientProto) | |
1263 | |
1264 if spuriousTraffic: | |
1265 wfdr = [] # remote | |
1266 wfd = c.callRemote(WaitForever).addErrback(wfdr.append) | |
1267 switchDeferred = c.switchToTestProtocol() | |
1268 if spuriousTraffic: | |
1269 self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world') | |
1270 | |
1271 def cbConnsLost(((serverSuccess, serverData), | |
1272 (clientSuccess, clientData))): | |
1273 self.failUnless(serverSuccess) | |
1274 self.failUnless(clientSuccess) | |
1275 self.assertEquals(''.join(serverData), SWITCH_CLIENT_DATA) | |
1276 self.assertEquals(''.join(clientData), SWITCH_SERVER_DATA) | |
1277 self.testSucceeded = True | |
1278 | |
1279 def cbSwitch(proto): | |
1280 return defer.DeferredList( | |
1281 [serverDeferred, clientDeferred]).addCallback(cbConnsLost) | |
1282 | |
1283 switchDeferred.addCallback(cbSwitch) | |
1284 p.flush() | |
1285 if serverProto.maybeLater is not None: | |
1286 serverProto.maybeLater.callback(serverProto.maybeLaterProto) | |
1287 p.flush() | |
1288 if spuriousTraffic: | |
1289 # switch is done here; do this here to make sure that if we're | |
1290 # going to corrupt the connection, we do it before it's closed. | |
1291 if spuriousError: | |
1292 s.waiting.errback(amp.RemoteAmpError( | |
1293 "SPURIOUS", | |
1294 "Here's some traffic in the form of an error.")) | |
1295 else: | |
1296 s.waiting.callback({}) | |
1297 p.flush() | |
1298 c.transport.loseConnection() # close it | |
1299 p.flush() | |
1300 self.failUnless(self.testSucceeded) | |
1301 | |
1302 | |
1303 def test_protocolSwitchDeferred(self): | |
1304 """ | |
1305 Verify that protocol-switching even works if the value returned from | |
1306 the command that does the switch is deferred. | |
1307 """ | |
1308 return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtoco
l) | |
1309 | |
1310 | |
1311 def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol): | |
1312 """ | |
1313 Verify that if we try to switch protocols and it fails, the connection | |
1314 stays up and we can go back to speaking AMP. | |
1315 """ | |
1316 self.testSucceeded = False | |
1317 | |
1318 serverDeferred = defer.Deferred() | |
1319 serverProto = switcher(serverDeferred) | |
1320 clientDeferred = defer.Deferred() | |
1321 clientProto = switcher(clientDeferred) | |
1322 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, | |
1323 ClientClass=lambda: clientProto) | |
1324 L = [] | |
1325 switchDeferred = c.switchToTestProtocol(fail=True).addErrback(L.append) | |
1326 p.flush() | |
1327 L.pop().trap(UnknownProtocol) | |
1328 self.failIf(self.testSucceeded) | |
1329 # It's a known error, so let's send a "hello" on the same connection; | |
1330 # it should work. | |
1331 c.sendHello('world').addCallback(L.append) | |
1332 p.flush() | |
1333 self.assertEqual(L.pop()['hello'], 'world') | |
1334 | |
1335 | |
1336 def test_trafficAfterSwitch(self): | |
1337 """ | |
1338 Verify that attempts to send traffic after a switch will not corrupt | |
1339 the nested protocol. | |
1340 """ | |
1341 return self.test_protocolSwitch(spuriousTraffic=True) | |
1342 | |
1343 | |
1344 def test_errorAfterSwitch(self): | |
1345 """ | |
1346 Returning an error after a protocol switch should record the underlying | |
1347 error. | |
1348 """ | |
1349 return self.test_protocolSwitch(spuriousTraffic=True, | |
1350 spuriousError=True) | |
1351 | |
1352 | |
1353 def test_quitBoxQuits(self): | |
1354 """ | |
1355 Verify that commands with a responseType of QuitBox will in fact | |
1356 terminate the connection. | |
1357 """ | |
1358 c, s, p = connectedServerAndClient( | |
1359 ServerClass=SimpleSymmetricCommandProtocol, | |
1360 ClientClass=SimpleSymmetricCommandProtocol) | |
1361 | |
1362 L = [] | |
1363 HELLO = 'world' | |
1364 GOODBYE = 'everyone' | |
1365 c.sendHello(HELLO).addCallback(L.append) | |
1366 p.flush() | |
1367 self.assertEquals(L.pop()['hello'], HELLO) | |
1368 c.callRemote(Goodbye).addCallback(L.append) | |
1369 p.flush() | |
1370 self.assertEquals(L.pop()['goodbye'], GOODBYE) | |
1371 c.sendHello(HELLO).addErrback(L.append) | |
1372 L.pop().trap(error.ConnectionDone) | |
1373 | |
1374 | |
1375 def test_basicLiteralEmit(self): | |
1376 """ | |
1377 Verify that the command dictionaries for a callRemoteN look correct | |
1378 after being serialized and parsed. | |
1379 """ | |
1380 c, s, p = connectedServerAndClient() | |
1381 L = [] | |
1382 s.ampBoxReceived = L.append | |
1383 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test', | |
1384 dash_arg='x', underscore_arg='y') | |
1385 p.flush() | |
1386 self.assertEquals(len(L), 1) | |
1387 for k, v in [('_command', Hello.commandName), | |
1388 ('hello', 'hello test'), | |
1389 ('mixedCase', 'mixed case arg test'), | |
1390 ('dash-arg', 'x'), | |
1391 ('underscore_arg', 'y')]: | |
1392 self.assertEquals(L[-1].pop(k), v) | |
1393 L[-1].pop('_ask') | |
1394 self.assertEquals(L[-1], {}) | |
1395 | |
1396 | |
1397 def test_basicStructuredEmit(self): | |
1398 """ | |
1399 Verify that a call similar to basicLiteralEmit's is handled properly wit
h | |
1400 high-level quoting and passing to Python methods, and that argument | |
1401 names are correctly handled. | |
1402 """ | |
1403 L = [] | |
1404 class StructuredHello(amp.AMP): | |
1405 def h(self, *a, **k): | |
1406 L.append((a, k)) | |
1407 return dict(hello='aaa') | |
1408 Hello.responder(h) | |
1409 c, s, p = connectedServerAndClient(ServerClass=StructuredHello) | |
1410 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test', | |
1411 dash_arg='x', underscore_arg='y').addCallback(L.append) | |
1412 p.flush() | |
1413 self.assertEquals(len(L), 2) | |
1414 self.assertEquals(L[0], | |
1415 ((), dict( | |
1416 hello='hello test', | |
1417 mixedCase='mixed case arg test', | |
1418 dash_arg='x', | |
1419 underscore_arg='y', | |
1420 | |
1421 # XXX - should optional arguments just not be passed? | |
1422 # passing None seems a little odd, looking at the way it | |
1423 # turns out here... -glyph | |
1424 From=('file', 'file'), | |
1425 Print=None, | |
1426 optional=None, | |
1427 ))) | |
1428 self.assertEquals(L[1], dict(Print=None, hello='aaa')) | |
1429 | |
1430 class PretendRemoteCertificateAuthority: | |
1431 def checkIsPretendRemote(self): | |
1432 return True | |
1433 | |
1434 class IOSimCert: | |
1435 verifyCount = 0 | |
1436 | |
1437 def options(self, *ign): | |
1438 return self | |
1439 | |
1440 def iosimVerify(self, otherCert): | |
1441 """ | |
1442 This isn't a real certificate, and wouldn't work on a real socket, but | |
1443 iosim specifies a different API so that we don't have to do any crypto | |
1444 math to demonstrate that the right functions get called in the right | |
1445 places. | |
1446 """ | |
1447 assert otherCert is self | |
1448 self.verifyCount += 1 | |
1449 return True | |
1450 | |
1451 class OKCert(IOSimCert): | |
1452 def options(self, x): | |
1453 assert x.checkIsPretendRemote() | |
1454 return self | |
1455 | |
1456 class GrumpyCert(IOSimCert): | |
1457 def iosimVerify(self, otherCert): | |
1458 self.verifyCount += 1 | |
1459 return False | |
1460 | |
1461 class DroppyCert(IOSimCert): | |
1462 def __init__(self, toDrop): | |
1463 self.toDrop = toDrop | |
1464 | |
1465 def iosimVerify(self, otherCert): | |
1466 self.verifyCount += 1 | |
1467 self.toDrop.loseConnection() | |
1468 return True | |
1469 | |
1470 class SecurableProto(FactoryNotifier): | |
1471 | |
1472 factory = None | |
1473 | |
1474 def verifyFactory(self): | |
1475 return [PretendRemoteCertificateAuthority()] | |
1476 | |
1477 def getTLSVars(self): | |
1478 cert = self.certFactory() | |
1479 verify = self.verifyFactory() | |
1480 return dict( | |
1481 tls_localCertificate=cert, | |
1482 tls_verifyAuthorities=verify) | |
1483 amp.StartTLS.responder(getTLSVars) | |
1484 | |
1485 | |
1486 | |
1487 class TLSTest(unittest.TestCase): | |
1488 def test_startingTLS(self): | |
1489 """ | |
1490 Verify that starting TLS and succeeding at handshaking sends all the | |
1491 notifications to all the right places. | |
1492 """ | |
1493 cli, svr, p = connectedServerAndClient( | |
1494 ServerClass=SecurableProto, | |
1495 ClientClass=SecurableProto) | |
1496 | |
1497 okc = OKCert() | |
1498 svr.certFactory = lambda : okc | |
1499 | |
1500 cli.callRemote( | |
1501 amp.StartTLS, tls_localCertificate=okc, | |
1502 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) | |
1503 | |
1504 # let's buffer something to be delivered securely | |
1505 L = [] | |
1506 d = cli.callRemote(SecuredPing).addCallback(L.append) | |
1507 p.flush() | |
1508 # once for client once for server | |
1509 self.assertEquals(okc.verifyCount, 2) | |
1510 L = [] | |
1511 d = cli.callRemote(SecuredPing).addCallback(L.append) | |
1512 p.flush() | |
1513 self.assertEqual(L[0], {'pinged': True}) | |
1514 | |
1515 | |
1516 def test_startTooManyTimes(self): | |
1517 """ | |
1518 Verify that the protocol will complain if we attempt to renegotiate TLS, | |
1519 which we don't support. | |
1520 """ | |
1521 cli, svr, p = connectedServerAndClient( | |
1522 ServerClass=SecurableProto, | |
1523 ClientClass=SecurableProto) | |
1524 | |
1525 okc = OKCert() | |
1526 svr.certFactory = lambda : okc | |
1527 | |
1528 cli.callRemote(amp.StartTLS, | |
1529 tls_localCertificate=okc, | |
1530 tls_verifyAuthorities=[PretendRemoteCertificateAuthority(
)]) | |
1531 p.flush() | |
1532 cli.noPeerCertificate = True # this is totally fake | |
1533 self.assertRaises( | |
1534 amp.OnlyOneTLS, | |
1535 cli.callRemote, | |
1536 amp.StartTLS, | |
1537 tls_localCertificate=okc, | |
1538 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) | |
1539 | |
1540 | |
1541 def test_negotiationFailed(self): | |
1542 """ | |
1543 Verify that starting TLS and failing on both sides at handshaking sends | |
1544 notifications to all the right places and terminates the connection. | |
1545 """ | |
1546 | |
1547 badCert = GrumpyCert() | |
1548 | |
1549 cli, svr, p = connectedServerAndClient( | |
1550 ServerClass=SecurableProto, | |
1551 ClientClass=SecurableProto) | |
1552 svr.certFactory = lambda : badCert | |
1553 | |
1554 cli.callRemote(amp.StartTLS, | |
1555 tls_localCertificate=badCert) | |
1556 | |
1557 p.flush() | |
1558 # once for client once for server - but both fail | |
1559 self.assertEquals(badCert.verifyCount, 2) | |
1560 d = cli.callRemote(SecuredPing) | |
1561 p.flush() | |
1562 self.assertFailure(d, iosim.OpenSSLVerifyError) | |
1563 | |
1564 | |
1565 def test_negotiationFailedByClosing(self): | |
1566 """ | |
1567 Verify that starting TLS and failing by way of a lost connection | |
1568 notices that it is probably an SSL problem. | |
1569 """ | |
1570 | |
1571 cli, svr, p = connectedServerAndClient( | |
1572 ServerClass=SecurableProto, | |
1573 ClientClass=SecurableProto) | |
1574 droppyCert = DroppyCert(svr.transport) | |
1575 svr.certFactory = lambda : droppyCert | |
1576 | |
1577 secure = cli.callRemote(amp.StartTLS, | |
1578 tls_localCertificate=droppyCert) | |
1579 | |
1580 p.flush() | |
1581 | |
1582 self.assertEquals(droppyCert.verifyCount, 2) | |
1583 | |
1584 d = cli.callRemote(SecuredPing) | |
1585 p.flush() | |
1586 | |
1587 # it might be a good idea to move this exception somewhere more | |
1588 # reasonable. | |
1589 self.assertFailure(d, error.PeerVerifyError) | |
1590 | |
1591 | |
1592 | |
1593 class InheritedError(Exception): | |
1594 """ | |
1595 This error is used to check inheritance. | |
1596 """ | |
1597 | |
1598 | |
1599 | |
1600 class OtherInheritedError(Exception): | |
1601 """ | |
1602 This is a distinct error for checking inheritance. | |
1603 """ | |
1604 | |
1605 | |
1606 | |
1607 class BaseCommand(amp.Command): | |
1608 """ | |
1609 This provides a command that will be subclassed. | |
1610 """ | |
1611 errors = {InheritedError: 'INHERITED_ERROR'} | |
1612 | |
1613 | |
1614 | |
1615 class InheritedCommand(BaseCommand): | |
1616 """ | |
1617 This is a command which subclasses another command but does not override | |
1618 anything. | |
1619 """ | |
1620 | |
1621 | |
1622 | |
1623 class AddErrorsCommand(BaseCommand): | |
1624 """ | |
1625 This is a command which subclasses another command but adds errors to the | |
1626 list. | |
1627 """ | |
1628 arguments = [('other', amp.Boolean())] | |
1629 errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'} | |
1630 | |
1631 | |
1632 | |
1633 class NormalCommandProtocol(amp.AMP): | |
1634 """ | |
1635 This is a protocol which responds to L{BaseCommand}, and is used to test | |
1636 that inheritance does not interfere with the normal handling of errors. | |
1637 """ | |
1638 def resp(self): | |
1639 raise InheritedError() | |
1640 BaseCommand.responder(resp) | |
1641 | |
1642 | |
1643 | |
1644 class InheritedCommandProtocol(amp.AMP): | |
1645 """ | |
1646 This is a protocol which responds to L{InheritedCommand}, and is used to | |
1647 test that inherited commands inherit their bases' errors if they do not | |
1648 respond to any of their own. | |
1649 """ | |
1650 def resp(self): | |
1651 raise InheritedError() | |
1652 InheritedCommand.responder(resp) | |
1653 | |
1654 | |
1655 | |
1656 class AddedCommandProtocol(amp.AMP): | |
1657 """ | |
1658 This is a protocol which responds to L{AddErrorsCommand}, and is used to | |
1659 test that inherited commands can add their own new types of errors, but | |
1660 still respond in the same way to their parents types of errors. | |
1661 """ | |
1662 def resp(self, other): | |
1663 if other: | |
1664 raise OtherInheritedError() | |
1665 else: | |
1666 raise InheritedError() | |
1667 AddErrorsCommand.responder(resp) | |
1668 | |
1669 | |
1670 | |
1671 class CommandInheritanceTests(unittest.TestCase): | |
1672 """ | |
1673 These tests verify that commands inherit error conditions properly. | |
1674 """ | |
1675 | |
1676 def errorCheck(self, err, proto, cmd, **kw): | |
1677 """ | |
1678 Check that the appropriate kind of error is raised when a given command | |
1679 is sent to a given protocol. | |
1680 """ | |
1681 c, s, p = connectedServerAndClient(ServerClass=proto, | |
1682 ClientClass=proto) | |
1683 d = c.callRemote(cmd, **kw) | |
1684 d2 = self.failUnlessFailure(d, err) | |
1685 p.flush() | |
1686 return d2 | |
1687 | |
1688 | |
1689 def test_basicErrorPropagation(self): | |
1690 """ | |
1691 Verify that errors specified in a superclass are respected normally | |
1692 even if it has subclasses. | |
1693 """ | |
1694 return self.errorCheck( | |
1695 InheritedError, NormalCommandProtocol, BaseCommand) | |
1696 | |
1697 | |
1698 def test_inheritedErrorPropagation(self): | |
1699 """ | |
1700 Verify that errors specified in a superclass command are propagated to | |
1701 its subclasses. | |
1702 """ | |
1703 return self.errorCheck( | |
1704 InheritedError, InheritedCommandProtocol, InheritedCommand) | |
1705 | |
1706 | |
1707 def test_inheritedErrorAddition(self): | |
1708 """ | |
1709 Verify that new errors specified in a subclass of an existing command | |
1710 are honored even if the superclass defines some errors. | |
1711 """ | |
1712 return self.errorCheck( | |
1713 OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=T
rue) | |
1714 | |
1715 | |
1716 def test_additionWithOriginalError(self): | |
1717 """ | |
1718 Verify that errors specified in a command's superclass are respected | |
1719 even if that command defines new errors itself. | |
1720 """ | |
1721 return self.errorCheck( | |
1722 InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False) | |
1723 | |
1724 | |
1725 def _loseAndPass(err, proto): | |
1726 # be specific, pass on the error to the client. | |
1727 err.trap(error.ConnectionLost, error.ConnectionDone) | |
1728 del proto.connectionLost | |
1729 proto.connectionLost(err) | |
1730 | |
1731 | |
1732 class LiveFireBase: | |
1733 """ | |
1734 Utility for connected reactor-using tests. | |
1735 """ | |
1736 | |
1737 def setUp(self): | |
1738 """ | |
1739 Create an amp server and connect a client to it. | |
1740 """ | |
1741 from twisted.internet import reactor | |
1742 self.serverFactory = protocol.ServerFactory() | |
1743 self.serverFactory.protocol = self.serverProto | |
1744 self.clientFactory = protocol.ClientFactory() | |
1745 self.clientFactory.protocol = self.clientProto | |
1746 self.clientFactory.onMade = defer.Deferred() | |
1747 self.serverFactory.onMade = defer.Deferred() | |
1748 self.serverPort = reactor.listenTCP(0, self.serverFactory) | |
1749 self.addCleanup(self.serverPort.stopListening) | |
1750 self.clientConn = reactor.connectTCP( | |
1751 '127.0.0.1', self.serverPort.getHost().port, | |
1752 self.clientFactory) | |
1753 self.addCleanup(self.clientConn.disconnect) | |
1754 def getProtos(rlst): | |
1755 self.cli = self.clientFactory.theProto | |
1756 self.svr = self.serverFactory.theProto | |
1757 dl = defer.DeferredList([self.clientFactory.onMade, | |
1758 self.serverFactory.onMade]) | |
1759 return dl.addCallback(getProtos) | |
1760 | |
1761 def tearDown(self): | |
1762 """ | |
1763 Cleanup client and server connections, and check the error got at | |
1764 C{connectionLost}. | |
1765 """ | |
1766 L = [] | |
1767 for conn in self.cli, self.svr: | |
1768 if conn.transport is not None: | |
1769 # depend on amp's function connection-dropping behavior | |
1770 d = defer.Deferred().addErrback(_loseAndPass, conn) | |
1771 conn.connectionLost = d.errback | |
1772 conn.transport.loseConnection() | |
1773 L.append(d) | |
1774 return defer.gatherResults(L | |
1775 ).addErrback(lambda first: first.value.subFailure) | |
1776 | |
1777 | |
1778 def show(x): | |
1779 import sys | |
1780 sys.stdout.write(x+'\n') | |
1781 sys.stdout.flush() | |
1782 | |
1783 | |
1784 def tempSelfSigned(): | |
1785 from twisted.internet import ssl | |
1786 | |
1787 sharedDN = ssl.DN(CN='shared') | |
1788 key = ssl.KeyPair.generate() | |
1789 cr = key.certificateRequest(sharedDN) | |
1790 sscrd = key.signCertificateRequest( | |
1791 sharedDN, cr, lambda dn: True, 1234567) | |
1792 cert = key.newCertificate(sscrd) | |
1793 return cert | |
1794 | |
1795 tempcert = tempSelfSigned() | |
1796 | |
1797 | |
1798 class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase): | |
1799 clientProto = SecurableProto | |
1800 serverProto = SecurableProto | |
1801 def test_liveFireCustomTLS(self): | |
1802 """ | |
1803 Using real, live TLS, actually negotiate a connection. | |
1804 | |
1805 This also looks at the 'peerCertificate' attribute's correctness, since | |
1806 that's actually loaded using OpenSSL calls, but the main purpose is to | |
1807 make sure that we didn't miss anything obvious in iosim about TLS | |
1808 negotiations. | |
1809 """ | |
1810 | |
1811 cert = tempcert | |
1812 | |
1813 self.svr.verifyFactory = lambda : [cert] | |
1814 self.svr.certFactory = lambda : cert | |
1815 # only needed on the server, we specify the client below. | |
1816 | |
1817 def secured(rslt): | |
1818 x = cert.digest() | |
1819 def pinged(rslt2): | |
1820 # Interesting. OpenSSL won't even _tell_ us about the peer | |
1821 # cert until we negotiate. we should be able to do this in | |
1822 # 'secured' instead, but it looks like we can't. I think this | |
1823 # is a bug somewhere far deeper than here. | |
1824 self.failUnlessEqual(x, self.cli.hostCertificate.digest()) | |
1825 self.failUnlessEqual(x, self.cli.peerCertificate.digest()) | |
1826 self.failUnlessEqual(x, self.svr.hostCertificate.digest()) | |
1827 self.failUnlessEqual(x, self.svr.peerCertificate.digest()) | |
1828 return self.cli.callRemote(SecuredPing).addCallback(pinged) | |
1829 return self.cli.callRemote(amp.StartTLS, | |
1830 tls_localCertificate=cert, | |
1831 tls_verifyAuthorities=[cert]).addCallback(sec
ured) | |
1832 | |
1833 | |
1834 class SlightlySmartTLS(SimpleSymmetricCommandProtocol): | |
1835 """ | |
1836 Specific implementation of server side protocol with different | |
1837 management of TLS. | |
1838 """ | |
1839 def getTLSVars(self): | |
1840 """ | |
1841 @return: the global C{tempcert} certificate as local certificate. | |
1842 """ | |
1843 return dict(tls_localCertificate=tempcert) | |
1844 amp.StartTLS.responder(getTLSVars) | |
1845 | |
1846 | |
1847 class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase): | |
1848 | |
1849 clientProto = SimpleSymmetricCommandProtocol | |
1850 serverProto = SimpleSymmetricCommandProtocol | |
1851 | |
1852 def test_liveFireDefaultTLS(self): | |
1853 """ | |
1854 Verify that out of the box, we can start TLS to at least encrypt the | |
1855 connection, even if we don't have any certificates to use. | |
1856 """ | |
1857 def secured(result): | |
1858 return self.cli.callRemote(SecuredPing) | |
1859 return self.cli.callRemote(amp.StartTLS).addCallback(secured) | |
1860 | |
1861 | |
1862 class WithServerTLSVerification(LiveFireBase, unittest.TestCase): | |
1863 clientProto = SimpleSymmetricCommandProtocol | |
1864 serverProto = SlightlySmartTLS | |
1865 | |
1866 def test_anonymousVerifyingClient(self): | |
1867 """ | |
1868 Verify that anonymous clients can verify server certificates. | |
1869 """ | |
1870 def secured(result): | |
1871 return self.cli.callRemote(SecuredPing) | |
1872 return self.cli.callRemote(amp.StartTLS, | |
1873 tls_verifyAuthorities=[tempcert] | |
1874 ).addCallback(secured) | |
1875 | |
1876 | |
1877 | |
1878 class ProtocolIncludingArgument(amp.Argument): | |
1879 """ | |
1880 An L{amp.Argument} which encodes its parser and serializer | |
1881 arguments *including the protocol* into its parsed and serialized | |
1882 forms. | |
1883 """ | |
1884 | |
1885 def fromStringProto(self, string, protocol): | |
1886 """ | |
1887 Don't decode anything; just return all possible information. | |
1888 | |
1889 @return: A two-tuple of the input string and the protocol. | |
1890 """ | |
1891 return (string, protocol) | |
1892 | |
1893 def toStringProto(self, obj, protocol): | |
1894 """ | |
1895 Encode identifying information about L{object} and protocol | |
1896 into a string for later verification. | |
1897 | |
1898 @type obj: L{object} | |
1899 @type protocol: L{amp.AMP} | |
1900 """ | |
1901 return "%s:%s" % (id(obj), id(protocol)) | |
1902 | |
1903 | |
1904 | |
1905 class ProtocolIncludingCommand(amp.Command): | |
1906 """ | |
1907 A command that has argument and response schemas which use | |
1908 L{ProtocolIncludingArgument}. | |
1909 """ | |
1910 arguments = [('weird', ProtocolIncludingArgument())] | |
1911 response = [('weird', ProtocolIncludingArgument())] | |
1912 | |
1913 | |
1914 | |
1915 class MagicSchemaCommand(amp.Command): | |
1916 """ | |
1917 A command which overrides L{parseResponse}, L{parseArguments}, and | |
1918 L{makeResponse}. | |
1919 """ | |
1920 def parseResponse(self, strings, protocol): | |
1921 """ | |
1922 Don't do any parsing, just jam the input strings and protocol | |
1923 onto the C{protocol.parseResponseArguments} attribute as a | |
1924 two-tuple. Return the original strings. | |
1925 """ | |
1926 protocol.parseResponseArguments = (strings, protocol) | |
1927 return strings | |
1928 parseResponse = classmethod(parseResponse) | |
1929 | |
1930 | |
1931 def parseArguments(cls, strings, protocol): | |
1932 """ | |
1933 Don't do any parsing, just jam the input strings and protocol | |
1934 onto the C{protocol.parseArgumentsArguments} attribute as a | |
1935 two-tuple. Return the original strings. | |
1936 """ | |
1937 protocol.parseArgumentsArguments = (strings, protocol) | |
1938 return strings | |
1939 parseArguments = classmethod(parseArguments) | |
1940 | |
1941 | |
1942 def makeArguments(cls, objects, protocol): | |
1943 """ | |
1944 Don't do any serializing, just jam the input strings and protocol | |
1945 onto the C{protocol.makeArgumentsArguments} attribute as a | |
1946 two-tuple. Return the original strings. | |
1947 """ | |
1948 protocol.makeArgumentsArguments = (objects, protocol) | |
1949 return objects | |
1950 makeArguments = classmethod(makeArguments) | |
1951 | |
1952 | |
1953 | |
1954 class NoNetworkProtocol(amp.AMP): | |
1955 """ | |
1956 An L{amp.AMP} subclass which overrides private methods to avoid | |
1957 testing the network. It also provides a responder for | |
1958 L{MagicSchemaCommand} that does nothing, so that tests can test | |
1959 aspects of the interaction of L{amp.Command}s and L{amp.AMP}. | |
1960 | |
1961 @ivar parseArgumentsArguments: Arguments that have been passed to any | |
1962 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by | |
1963 this protocol. | |
1964 | |
1965 @ivar parseResponseArguments: Responses that have been returned from a | |
1966 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by | |
1967 this protocol. | |
1968 | |
1969 @ivar makeArgumentsArguments: Arguments that have been serialized by any | |
1970 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by | |
1971 this protocol. | |
1972 """ | |
1973 def _sendBoxCommand(self, commandName, strings, requiresAnswer): | |
1974 """ | |
1975 Return a Deferred which fires with the original strings. | |
1976 """ | |
1977 return defer.succeed(strings) | |
1978 | |
1979 MagicSchemaCommand.responder(lambda s, weird: {}) | |
1980 | |
1981 | |
1982 | |
1983 class MyBox(dict): | |
1984 """ | |
1985 A unique dict subclass. | |
1986 """ | |
1987 | |
1988 | |
1989 | |
1990 class ProtocolIncludingCommandWithDifferentCommandType( | |
1991 ProtocolIncludingCommand): | |
1992 """ | |
1993 A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox} | |
1994 """ | |
1995 commandType = MyBox | |
1996 | |
1997 | |
1998 | |
1999 class CommandTestCase(unittest.TestCase): | |
2000 """ | |
2001 Tests for L{amp.Command}. | |
2002 """ | |
2003 | |
2004 def test_parseResponse(self): | |
2005 """ | |
2006 There should be a class method of Command which accepts a | |
2007 mapping of argument names to serialized forms and returns a | |
2008 similar mapping whose values have been parsed via the | |
2009 Command's response schema. | |
2010 """ | |
2011 protocol = object() | |
2012 result = 'whatever' | |
2013 strings = {'weird': result} | |
2014 self.assertEqual( | |
2015 ProtocolIncludingCommand.parseResponse(strings, protocol), | |
2016 {'weird': (result, protocol)}) | |
2017 | |
2018 | |
2019 def test_callRemoteCallsParseResponse(self): | |
2020 """ | |
2021 Making a remote call on a L{amp.Command} subclass which | |
2022 overrides the C{parseResponse} method should call that | |
2023 C{parseResponse} method to get the response. | |
2024 """ | |
2025 client = NoNetworkProtocol() | |
2026 thingy = "weeoo" | |
2027 response = client.callRemote(MagicSchemaCommand, weird=thingy) | |
2028 def gotResponse(ign): | |
2029 self.assertEquals(client.parseResponseArguments, | |
2030 ({"weird": thingy}, client)) | |
2031 response.addCallback(gotResponse) | |
2032 return response | |
2033 | |
2034 | |
2035 def test_parseArguments(self): | |
2036 """ | |
2037 There should be a class method of L{amp.Command} which accepts | |
2038 a mapping of argument names to serialized forms and returns a | |
2039 similar mapping whose values have been parsed via the | |
2040 command's argument schema. | |
2041 """ | |
2042 protocol = object() | |
2043 result = 'whatever' | |
2044 strings = {'weird': result} | |
2045 self.assertEqual( | |
2046 ProtocolIncludingCommand.parseArguments(strings, protocol), | |
2047 {'weird': (result, protocol)}) | |
2048 | |
2049 | |
2050 def test_responderCallsParseArguments(self): | |
2051 """ | |
2052 Making a remote call on a L{amp.Command} subclass which | |
2053 overrides the C{parseArguments} method should call that | |
2054 C{parseArguments} method to get the arguments. | |
2055 """ | |
2056 protocol = NoNetworkProtocol() | |
2057 responder = protocol.locateResponder(MagicSchemaCommand.commandName) | |
2058 argument = object() | |
2059 response = responder(dict(weird=argument)) | |
2060 response.addCallback( | |
2061 lambda ign: self.assertEqual(protocol.parseArgumentsArguments, | |
2062 ({"weird": argument}, protocol))) | |
2063 return response | |
2064 | |
2065 | |
2066 def test_makeArguments(self): | |
2067 """ | |
2068 There should be a class method of L{amp.Command} which accepts | |
2069 a mapping of argument names to objects and returns a similar | |
2070 mapping whose values have been serialized via the command's | |
2071 argument schema. | |
2072 """ | |
2073 protocol = object() | |
2074 argument = object() | |
2075 objects = {'weird': argument} | |
2076 self.assertEqual( | |
2077 ProtocolIncludingCommand.makeArguments(objects, protocol), | |
2078 {'weird': "%d:%d" % (id(argument), id(protocol))}) | |
2079 | |
2080 | |
2081 def test_makeArgumentsUsesCommandType(self): | |
2082 """ | |
2083 L{amp.Command.makeArguments}'s return type should be the type | |
2084 of the result of L{amp.Command.commandType}. | |
2085 """ | |
2086 protocol = object() | |
2087 objects = {"weird": "whatever"} | |
2088 | |
2089 result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments( | |
2090 objects, protocol) | |
2091 self.assertIdentical(type(result), MyBox) | |
2092 | |
2093 | |
2094 def test_callRemoteCallsMakeArguments(self): | |
2095 """ | |
2096 Making a remote call on a L{amp.Command} subclass which | |
2097 overrides the C{makeArguments} method should call that | |
2098 C{makeArguments} method to get the response. | |
2099 """ | |
2100 client = NoNetworkProtocol() | |
2101 argument = object() | |
2102 response = client.callRemote(MagicSchemaCommand, weird=argument) | |
2103 def gotResponse(ign): | |
2104 self.assertEqual(client.makeArgumentsArguments, | |
2105 ({"weird": argument}, client)) | |
2106 response.addCallback(gotResponse) | |
2107 return response | |
2108 | |
2109 if not interfaces.IReactorSSL.providedBy(reactor): | |
2110 skipMsg = 'This test case requires SSL support in the reactor' | |
2111 TLSTest.skip = skipMsg | |
2112 LiveFireTLSTestCase.skip = skipMsg | |
2113 PlainVanillaLiveFire.skip = skipMsg | |
2114 WithServerTLSVerification.skip = skipMsg | |
2115 | |
OLD | NEW |