OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """ | |
6 Test cases for twisted.protocols package. | |
7 """ | |
8 | |
9 from twisted.trial import unittest | |
10 from twisted.protocols import basic, wire, portforward | |
11 from twisted.internet import reactor, protocol, defer, task, error | |
12 from twisted.test import proto_helpers | |
13 | |
14 import struct | |
15 import StringIO | |
16 | |
17 class StringIOWithoutClosing(StringIO.StringIO): | |
18 """ | |
19 A StringIO that can't be closed. | |
20 """ | |
21 def close(self): | |
22 """ | |
23 Do nothing. | |
24 """ | |
25 | |
26 class LineTester(basic.LineReceiver): | |
27 """ | |
28 A line receiver that parses data received and make actions on some tokens. | |
29 | |
30 @type delimiter: C{str} | |
31 @ivar delimiter: character used between received lines. | |
32 @type MAX_LENGTH: C{int} | |
33 @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called. | |
34 @type clock: L{twisted.internet.task.Clock} | |
35 @ivar clock: clock simulating reactor callLater. Pass it to constructor if | |
36 you want to use the pause/rawpause functionalities. | |
37 """ | |
38 | |
39 delimiter = '\n' | |
40 MAX_LENGTH = 64 | |
41 | |
42 def __init__(self, clock=None): | |
43 """ | |
44 If given, use a clock to make callLater calls. | |
45 """ | |
46 self.clock = clock | |
47 | |
48 def connectionMade(self): | |
49 """ | |
50 Create/clean data received on connection. | |
51 """ | |
52 self.received = [] | |
53 | |
54 def lineReceived(self, line): | |
55 """ | |
56 Receive line and make some action for some tokens: pause, rawpause, | |
57 stop, len, produce, unproduce. | |
58 """ | |
59 self.received.append(line) | |
60 if line == '': | |
61 self.setRawMode() | |
62 elif line == 'pause': | |
63 self.pauseProducing() | |
64 self.clock.callLater(0, self.resumeProducing) | |
65 elif line == 'rawpause': | |
66 self.pauseProducing() | |
67 self.setRawMode() | |
68 self.received.append('') | |
69 self.clock.callLater(0, self.resumeProducing) | |
70 elif line == 'stop': | |
71 self.stopProducing() | |
72 elif line[:4] == 'len ': | |
73 self.length = int(line[4:]) | |
74 elif line.startswith('produce'): | |
75 self.transport.registerProducer(self, False) | |
76 elif line.startswith('unproduce'): | |
77 self.transport.unregisterProducer() | |
78 | |
79 def rawDataReceived(self, data): | |
80 """ | |
81 Read raw data, until the quantity specified by a previous 'len' line is | |
82 reached. | |
83 """ | |
84 data, rest = data[:self.length], data[self.length:] | |
85 self.length = self.length - len(data) | |
86 self.received[-1] = self.received[-1] + data | |
87 if self.length == 0: | |
88 self.setLineMode(rest) | |
89 | |
90 def lineLengthExceeded(self, line): | |
91 """ | |
92 Adjust line mode when long lines received. | |
93 """ | |
94 if len(line) > self.MAX_LENGTH + 1: | |
95 self.setLineMode(line[self.MAX_LENGTH + 1:]) | |
96 | |
97 | |
98 class LineOnlyTester(basic.LineOnlyReceiver): | |
99 """ | |
100 A buffering line only receiver. | |
101 """ | |
102 delimiter = '\n' | |
103 MAX_LENGTH = 64 | |
104 | |
105 def connectionMade(self): | |
106 """ | |
107 Create/clean data received on connection. | |
108 """ | |
109 self.received = [] | |
110 | |
111 def lineReceived(self, line): | |
112 """ | |
113 Save received data. | |
114 """ | |
115 self.received.append(line) | |
116 | |
117 class WireTestCase(unittest.TestCase): | |
118 """ | |
119 Test wire protocols. | |
120 """ | |
121 def testEcho(self): | |
122 """ | |
123 Test wire.Echo protocol: send some data and check it send it back. | |
124 """ | |
125 t = StringIOWithoutClosing() | |
126 a = wire.Echo() | |
127 a.makeConnection(protocol.FileWrapper(t)) | |
128 a.dataReceived("hello") | |
129 a.dataReceived("world") | |
130 a.dataReceived("how") | |
131 a.dataReceived("are") | |
132 a.dataReceived("you") | |
133 self.failUnlessEqual(t.getvalue(), "helloworldhowareyou") | |
134 | |
135 def testWho(self): | |
136 """ | |
137 Test wire.Who protocol. | |
138 """ | |
139 t = StringIOWithoutClosing() | |
140 a = wire.Who() | |
141 a.makeConnection(protocol.FileWrapper(t)) | |
142 self.failUnlessEqual(t.getvalue(), "root\r\n") | |
143 | |
144 def testQOTD(self): | |
145 """ | |
146 Test wire.QOTD protocol. | |
147 """ | |
148 t = StringIOWithoutClosing() | |
149 a = wire.QOTD() | |
150 a.makeConnection(protocol.FileWrapper(t)) | |
151 self.failUnlessEqual(t.getvalue(), | |
152 "An apple a day keeps the doctor away.\r\n") | |
153 | |
154 def testDiscard(self): | |
155 """ | |
156 Test wire.Discard protocol. | |
157 """ | |
158 t = StringIOWithoutClosing() | |
159 a = wire.Discard() | |
160 a.makeConnection(protocol.FileWrapper(t)) | |
161 a.dataReceived("hello") | |
162 a.dataReceived("world") | |
163 a.dataReceived("how") | |
164 a.dataReceived("are") | |
165 a.dataReceived("you") | |
166 self.failUnlessEqual(t.getvalue(), "") | |
167 | |
168 class LineReceiverTestCase(unittest.TestCase): | |
169 """ | |
170 Test LineReceiver, using the C{LineTester} wrapper. | |
171 """ | |
172 buffer = '''\ | |
173 len 10 | |
174 | |
175 0123456789len 5 | |
176 | |
177 1234 | |
178 len 20 | |
179 foo 123 | |
180 | |
181 0123456789 | |
182 012345678len 0 | |
183 foo 5 | |
184 | |
185 1234567890123456789012345678901234567890123456789012345678901234567890 | |
186 len 1 | |
187 | |
188 a''' | |
189 | |
190 output = ['len 10', '0123456789', 'len 5', '1234\n', | |
191 'len 20', 'foo 123', '0123456789\n012345678', | |
192 'len 0', 'foo 5', '', '67890', 'len 1', 'a'] | |
193 | |
194 def testBuffer(self): | |
195 """ | |
196 Test buffering for different packet size, checking received matches | |
197 expected data. | |
198 """ | |
199 for packet_size in range(1, 10): | |
200 t = StringIOWithoutClosing() | |
201 a = LineTester() | |
202 a.makeConnection(protocol.FileWrapper(t)) | |
203 for i in range(len(self.buffer)/packet_size + 1): | |
204 s = self.buffer[i*packet_size:(i+1)*packet_size] | |
205 a.dataReceived(s) | |
206 self.failUnlessEqual(self.output, a.received) | |
207 | |
208 | |
209 pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n' | |
210 | |
211 pause_output1 = ['twiddle1', 'twiddle2', 'pause'] | |
212 pause_output2 = pause_output1+['twiddle3'] | |
213 | |
214 def testPausing(self): | |
215 """ | |
216 Test pause inside data receiving. It uses fake clock to see if | |
217 pausing/resuming work. | |
218 """ | |
219 for packet_size in range(1, 10): | |
220 t = StringIOWithoutClosing() | |
221 clock = task.Clock() | |
222 a = LineTester(clock) | |
223 a.makeConnection(protocol.FileWrapper(t)) | |
224 for i in range(len(self.pause_buf)/packet_size + 1): | |
225 s = self.pause_buf[i*packet_size:(i+1)*packet_size] | |
226 a.dataReceived(s) | |
227 self.failUnlessEqual(self.pause_output1, a.received) | |
228 clock.advance(0) | |
229 self.failUnlessEqual(self.pause_output2, a.received) | |
230 | |
231 rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n' | |
232 | |
233 rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', ''] | |
234 rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345', | |
235 'twiddle3'] | |
236 | |
237 def testRawPausing(self): | |
238 """ | |
239 Test pause inside raw date receiving. | |
240 """ | |
241 for packet_size in range(1, 10): | |
242 t = StringIOWithoutClosing() | |
243 clock = task.Clock() | |
244 a = LineTester(clock) | |
245 a.makeConnection(protocol.FileWrapper(t)) | |
246 for i in range(len(self.rawpause_buf)/packet_size + 1): | |
247 s = self.rawpause_buf[i*packet_size:(i+1)*packet_size] | |
248 a.dataReceived(s) | |
249 self.failUnlessEqual(self.rawpause_output1, a.received) | |
250 clock.advance(0) | |
251 self.failUnlessEqual(self.rawpause_output2, a.received) | |
252 | |
253 stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n' | |
254 | |
255 stop_output = ['twiddle1', 'twiddle2', 'stop'] | |
256 | |
257 def testStopProducing(self): | |
258 """ | |
259 Test stop inside producing. | |
260 """ | |
261 for packet_size in range(1, 10): | |
262 t = StringIOWithoutClosing() | |
263 a = LineTester() | |
264 a.makeConnection(protocol.FileWrapper(t)) | |
265 for i in range(len(self.stop_buf)/packet_size + 1): | |
266 s = self.stop_buf[i*packet_size:(i+1)*packet_size] | |
267 a.dataReceived(s) | |
268 self.failUnlessEqual(self.stop_output, a.received) | |
269 | |
270 | |
271 def testLineReceiverAsProducer(self): | |
272 """ | |
273 Test produce/unproduce in receiving. | |
274 """ | |
275 a = LineTester() | |
276 t = StringIOWithoutClosing() | |
277 a.makeConnection(protocol.FileWrapper(t)) | |
278 a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n') | |
279 self.assertEquals(a.received, | |
280 ['produce', 'hello world', 'unproduce', 'goodbye']) | |
281 | |
282 | |
283 class LineOnlyReceiverTestCase(unittest.TestCase): | |
284 """ | |
285 Test line only receiveer. | |
286 """ | |
287 buffer = """foo | |
288 bleakness | |
289 desolation | |
290 plastic forks | |
291 """ | |
292 | |
293 def testBuffer(self): | |
294 """ | |
295 Test buffering over line protocol: data received should match buffer. | |
296 """ | |
297 t = StringIOWithoutClosing() | |
298 a = LineOnlyTester() | |
299 a.makeConnection(protocol.FileWrapper(t)) | |
300 for c in self.buffer: | |
301 a.dataReceived(c) | |
302 self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1]) | |
303 | |
304 def testLineTooLong(self): | |
305 """ | |
306 Test sending a line too long: it should close the connection. | |
307 """ | |
308 t = StringIOWithoutClosing() | |
309 a = LineOnlyTester() | |
310 a.makeConnection(protocol.FileWrapper(t)) | |
311 res = a.dataReceived('x'*200) | |
312 self.assertTrue(isinstance(res, error.ConnectionLost)) | |
313 | |
314 | |
315 | |
316 class TestMixin: | |
317 | |
318 def connectionMade(self): | |
319 self.received = [] | |
320 | |
321 def stringReceived(self, s): | |
322 self.received.append(s) | |
323 | |
324 MAX_LENGTH = 50 | |
325 closed = 0 | |
326 | |
327 def connectionLost(self, reason): | |
328 self.closed = 1 | |
329 | |
330 | |
331 class TestNetstring(TestMixin, basic.NetstringReceiver): | |
332 pass | |
333 | |
334 | |
335 class LPTestCaseMixin: | |
336 | |
337 illegalStrings = [] | |
338 protocol = None | |
339 | |
340 def getProtocol(self): | |
341 t = StringIOWithoutClosing() | |
342 a = self.protocol() | |
343 a.makeConnection(protocol.FileWrapper(t)) | |
344 return a | |
345 | |
346 def test_illegal(self): | |
347 """ | |
348 Assert that illegal strings cause the transport to be closed. | |
349 """ | |
350 for s in self.illegalStrings: | |
351 r = self.getProtocol() | |
352 for c in s: | |
353 r.dataReceived(c) | |
354 self.assertEquals(r.transport.closed, 1) | |
355 | |
356 | |
357 class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin): | |
358 | |
359 strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515] | |
360 | |
361 illegalStrings = [ | |
362 '9999999999999999999999', 'abc', '4:abcde', | |
363 '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',] | |
364 | |
365 protocol = TestNetstring | |
366 | |
367 def testBuffer(self): | |
368 for packet_size in range(1, 10): | |
369 t = StringIOWithoutClosing() | |
370 a = TestNetstring() | |
371 a.MAX_LENGTH = 699 | |
372 a.makeConnection(protocol.FileWrapper(t)) | |
373 for s in self.strings: | |
374 a.sendString(s) | |
375 out = t.getvalue() | |
376 for i in range(len(out)/packet_size + 1): | |
377 s = out[i*packet_size:(i+1)*packet_size] | |
378 if s: | |
379 a.dataReceived(s) | |
380 self.assertEquals(a.received, self.strings) | |
381 | |
382 | |
383 class IntNTestCaseMixin(LPTestCaseMixin): | |
384 """ | |
385 TestCase mixin for int-prefixed protocols. | |
386 """ | |
387 | |
388 protocol = None | |
389 strings = None | |
390 illegalStrings = None | |
391 partialStrings = None | |
392 | |
393 def test_receive(self): | |
394 """ | |
395 Test receiving data find the same data send. | |
396 """ | |
397 r = self.getProtocol() | |
398 for s in self.strings: | |
399 for c in struct.pack(self.protocol.structFormat,len(s)) + s: | |
400 r.dataReceived(c) | |
401 self.assertEquals(r.received, self.strings) | |
402 | |
403 def test_partial(self): | |
404 """ | |
405 Send partial data, nothing should be definitely received. | |
406 """ | |
407 for s in self.partialStrings: | |
408 r = self.getProtocol() | |
409 for c in s: | |
410 r.dataReceived(c) | |
411 self.assertEquals(r.received, []) | |
412 | |
413 def test_send(self): | |
414 """ | |
415 Test sending data over protocol. | |
416 """ | |
417 r = self.getProtocol() | |
418 r.sendString("b" * 16) | |
419 self.assertEquals(r.transport.file.getvalue(), | |
420 struct.pack(self.protocol.structFormat, 16) + "b" * 16) | |
421 | |
422 | |
423 class TestInt32(TestMixin, basic.Int32StringReceiver): | |
424 """ | |
425 A L{basic.Int32StringReceiver} storing received strings in an array. | |
426 | |
427 @ivar received: array holding received strings. | |
428 """ | |
429 | |
430 | |
431 class Int32TestCase(unittest.TestCase, IntNTestCaseMixin): | |
432 """ | |
433 Test case for int32-prefixed protocol | |
434 """ | |
435 protocol = TestInt32 | |
436 strings = ["a", "b" * 16] | |
437 illegalStrings = ["\x10\x00\x00\x00aaaaaa"] | |
438 partialStrings = ["\x00\x00\x00", "hello there", ""] | |
439 | |
440 def test_data(self): | |
441 """ | |
442 Test specific behavior of the 32-bits length. | |
443 """ | |
444 r = self.getProtocol() | |
445 r.sendString("foo") | |
446 self.assertEquals(r.transport.file.getvalue(), "\x00\x00\x00\x03foo") | |
447 r.dataReceived("\x00\x00\x00\x04ubar") | |
448 self.assertEquals(r.received, ["ubar"]) | |
449 | |
450 | |
451 class TestInt16(TestMixin, basic.Int16StringReceiver): | |
452 """ | |
453 A L{basic.Int16StringReceiver} storing received strings in an array. | |
454 | |
455 @ivar received: array holding received strings. | |
456 """ | |
457 | |
458 | |
459 class Int16TestCase(unittest.TestCase, IntNTestCaseMixin): | |
460 """ | |
461 Test case for int16-prefixed protocol | |
462 """ | |
463 protocol = TestInt16 | |
464 strings = ["a", "b" * 16] | |
465 illegalStrings = ["\x10\x00aaaaaa"] | |
466 partialStrings = ["\x00", "hello there", ""] | |
467 | |
468 def test_data(self): | |
469 """ | |
470 Test specific behavior of the 16-bits length. | |
471 """ | |
472 r = self.getProtocol() | |
473 r.sendString("foo") | |
474 self.assertEquals(r.transport.file.getvalue(), "\x00\x03foo") | |
475 r.dataReceived("\x00\x04ubar") | |
476 self.assertEquals(r.received, ["ubar"]) | |
477 | |
478 def test_tooLongSend(self): | |
479 """ | |
480 Send too much data: that should cause an error. | |
481 """ | |
482 r = self.getProtocol() | |
483 tooSend = "b" * (2**(r.prefixLength*8) + 1) | |
484 self.assertRaises(AssertionError, r.sendString, tooSend) | |
485 | |
486 | |
487 class TestInt8(TestMixin, basic.Int8StringReceiver): | |
488 """ | |
489 A L{basic.Int8StringReceiver} storing received strings in an array. | |
490 | |
491 @ivar received: array holding received strings. | |
492 """ | |
493 | |
494 | |
495 class Int8TestCase(unittest.TestCase, IntNTestCaseMixin): | |
496 """ | |
497 Test case for int8-prefixed protocol | |
498 """ | |
499 protocol = TestInt8 | |
500 strings = ["a", "b" * 16] | |
501 illegalStrings = ["\x00\x00aaaaaa"] | |
502 partialStrings = ["\x08", "dzadz", ""] | |
503 | |
504 def test_data(self): | |
505 """ | |
506 Test specific behavior of the 8-bits length. | |
507 """ | |
508 r = self.getProtocol() | |
509 r.sendString("foo") | |
510 self.assertEquals(r.transport.file.getvalue(), "\x03foo") | |
511 r.dataReceived("\x04ubar") | |
512 self.assertEquals(r.received, ["ubar"]) | |
513 | |
514 def test_tooLongSend(self): | |
515 """ | |
516 Send too much data: that should cause an error. | |
517 """ | |
518 r = self.getProtocol() | |
519 tooSend = "b" * (2**(r.prefixLength*8) + 1) | |
520 self.assertRaises(AssertionError, r.sendString, tooSend) | |
521 | |
522 | |
523 class OnlyProducerTransport(object): | |
524 # Transport which isn't really a transport, just looks like one to | |
525 # someone not looking very hard. | |
526 | |
527 paused = False | |
528 disconnecting = False | |
529 | |
530 def __init__(self): | |
531 self.data = [] | |
532 | |
533 def pauseProducing(self): | |
534 self.paused = True | |
535 | |
536 def resumeProducing(self): | |
537 self.paused = False | |
538 | |
539 def write(self, bytes): | |
540 self.data.append(bytes) | |
541 | |
542 | |
543 class ConsumingProtocol(basic.LineReceiver): | |
544 # Protocol that really, really doesn't want any more bytes. | |
545 | |
546 def lineReceived(self, line): | |
547 self.transport.write(line) | |
548 self.pauseProducing() | |
549 | |
550 | |
551 class ProducerTestCase(unittest.TestCase): | |
552 def testPauseResume(self): | |
553 p = ConsumingProtocol() | |
554 t = OnlyProducerTransport() | |
555 p.makeConnection(t) | |
556 | |
557 p.dataReceived('hello, ') | |
558 self.failIf(t.data) | |
559 self.failIf(t.paused) | |
560 self.failIf(p.paused) | |
561 | |
562 p.dataReceived('world\r\n') | |
563 | |
564 self.assertEquals(t.data, ['hello, world']) | |
565 self.failUnless(t.paused) | |
566 self.failUnless(p.paused) | |
567 | |
568 p.resumeProducing() | |
569 | |
570 self.failIf(t.paused) | |
571 self.failIf(p.paused) | |
572 | |
573 p.dataReceived('hello\r\nworld\r\n') | |
574 | |
575 self.assertEquals(t.data, ['hello, world', 'hello']) | |
576 self.failUnless(t.paused) | |
577 self.failUnless(p.paused) | |
578 | |
579 p.resumeProducing() | |
580 p.dataReceived('goodbye\r\n') | |
581 | |
582 self.assertEquals(t.data, ['hello, world', 'hello', 'world']) | |
583 self.failUnless(t.paused) | |
584 self.failUnless(p.paused) | |
585 | |
586 p.resumeProducing() | |
587 | |
588 self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye']) | |
589 self.failUnless(t.paused) | |
590 self.failUnless(p.paused) | |
591 | |
592 p.resumeProducing() | |
593 | |
594 self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye']) | |
595 self.failIf(t.paused) | |
596 self.failIf(p.paused) | |
597 | |
598 | |
599 | |
600 class TestableProxyClientFactory(portforward.ProxyClientFactory): | |
601 """ | |
602 Test proxy client factory that keeps the last created protocol instance. | |
603 | |
604 @ivar protoInstance: the last instance of the protocol. | |
605 @type protoInstance: L{portforward.ProxyClient} | |
606 """ | |
607 | |
608 def buildProtocol(self, addr): | |
609 """ | |
610 Create the protocol instance and keeps track of it. | |
611 """ | |
612 proto = portforward.ProxyClientFactory.buildProtocol(self, addr) | |
613 self.protoInstance = proto | |
614 return proto | |
615 | |
616 | |
617 | |
618 class TestableProxyFactory(portforward.ProxyFactory): | |
619 """ | |
620 Test proxy factory that keeps the last created protocol instance. | |
621 | |
622 @ivar protoInstance: the last instance of the protocol. | |
623 @type protoInstance: L{portforward.ProxyServer} | |
624 | |
625 @ivar clientFactoryInstance: client factory used by C{protoInstance} to | |
626 create forward connections. | |
627 @type clientFactoryInstance: L{TestableProxyClientFactory} | |
628 """ | |
629 | |
630 def buildProtocol(self, addr): | |
631 """ | |
632 Create the protocol instance, keeps track of it, and makes it use | |
633 C{clientFactoryInstance} as client factory. | |
634 """ | |
635 proto = portforward.ProxyFactory.buildProtocol(self, addr) | |
636 self.clientFactoryInstance = TestableProxyClientFactory() | |
637 # Force the use of this specific instance | |
638 proto.clientProtocolFactory = lambda: self.clientFactoryInstance | |
639 self.protoInstance = proto | |
640 return proto | |
641 | |
642 | |
643 | |
644 class Portforwarding(unittest.TestCase): | |
645 """ | |
646 Test port forwarding. | |
647 """ | |
648 | |
649 def setUp(self): | |
650 self.serverProtocol = wire.Echo() | |
651 self.clientProtocol = protocol.Protocol() | |
652 self.openPorts = [] | |
653 | |
654 | |
655 def tearDown(self): | |
656 try: | |
657 self.proxyServerFactory.protoInstance.transport.loseConnection() | |
658 except AttributeError: | |
659 pass | |
660 try: | |
661 self.proxyServerFactory.clientFactoryInstance.protoInstance.transpor
t.loseConnection() | |
662 except AttributeError: | |
663 pass | |
664 try: | |
665 self.clientProtocol.transport.loseConnection() | |
666 except AttributeError: | |
667 pass | |
668 try: | |
669 self.serverProtocol.transport.loseConnection() | |
670 except AttributeError: | |
671 pass | |
672 return defer.gatherResults( | |
673 [defer.maybeDeferred(p.stopListening) for p in self.openPorts]) | |
674 | |
675 | |
676 def test_portforward(self): | |
677 """ | |
678 Test port forwarding through Echo protocol. | |
679 """ | |
680 realServerFactory = protocol.ServerFactory() | |
681 realServerFactory.protocol = lambda: self.serverProtocol | |
682 realServerPort = reactor.listenTCP(0, realServerFactory, | |
683 interface='127.0.0.1') | |
684 self.openPorts.append(realServerPort) | |
685 self.proxyServerFactory = TestableProxyFactory('127.0.0.1', | |
686 realServerPort.getHost().port) | |
687 proxyServerPort = reactor.listenTCP(0, self.proxyServerFactory, | |
688 interface='127.0.0.1') | |
689 self.openPorts.append(proxyServerPort) | |
690 | |
691 nBytes = 1000 | |
692 received = [] | |
693 d = defer.Deferred() | |
694 def testDataReceived(data): | |
695 received.extend(data) | |
696 if len(received) >= nBytes: | |
697 self.assertEquals(''.join(received), 'x' * nBytes) | |
698 d.callback(None) | |
699 self.clientProtocol.dataReceived = testDataReceived | |
700 | |
701 def testConnectionMade(): | |
702 self.clientProtocol.transport.write('x' * nBytes) | |
703 self.clientProtocol.connectionMade = testConnectionMade | |
704 | |
705 clientFactory = protocol.ClientFactory() | |
706 clientFactory.protocol = lambda: self.clientProtocol | |
707 | |
708 reactor.connectTCP( | |
709 '127.0.0.1', proxyServerPort.getHost().port, clientFactory) | |
710 | |
711 return d | |
712 | |
713 | |
714 | |
715 class StringTransportTestCase(unittest.TestCase): | |
716 """ | |
717 Test L{proto_helpers.StringTransport} helper behaviour. | |
718 """ | |
719 | |
720 def test_noUnicode(self): | |
721 """ | |
722 Test that L{proto_helpers.StringTransport} doesn't accept unicode data. | |
723 """ | |
724 s = proto_helpers.StringTransport() | |
725 self.assertRaises(TypeError, s.write, u'foo') | |
OLD | NEW |