| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
| 2 # See LICENSE for details. | |
| 3 | |
| 4 | |
| 5 """ | |
| 6 Test cases for twisted.protocols package. | |
| 7 """ | |
| 8 | |
| 9 from twisted.trial import unittest | |
| 10 from twisted.protocols import basic, wire, portforward | |
| 11 from twisted.internet import reactor, protocol, defer, task, error | |
| 12 from twisted.test import proto_helpers | |
| 13 | |
| 14 import struct | |
| 15 import StringIO | |
| 16 | |
| 17 class StringIOWithoutClosing(StringIO.StringIO): | |
| 18 """ | |
| 19 A StringIO that can't be closed. | |
| 20 """ | |
| 21 def close(self): | |
| 22 """ | |
| 23 Do nothing. | |
| 24 """ | |
| 25 | |
| 26 class LineTester(basic.LineReceiver): | |
| 27 """ | |
| 28 A line receiver that parses data received and make actions on some tokens. | |
| 29 | |
| 30 @type delimiter: C{str} | |
| 31 @ivar delimiter: character used between received lines. | |
| 32 @type MAX_LENGTH: C{int} | |
| 33 @ivar MAX_LENGTH: size of a line when C{lineLengthExceeded} will be called. | |
| 34 @type clock: L{twisted.internet.task.Clock} | |
| 35 @ivar clock: clock simulating reactor callLater. Pass it to constructor if | |
| 36 you want to use the pause/rawpause functionalities. | |
| 37 """ | |
| 38 | |
| 39 delimiter = '\n' | |
| 40 MAX_LENGTH = 64 | |
| 41 | |
| 42 def __init__(self, clock=None): | |
| 43 """ | |
| 44 If given, use a clock to make callLater calls. | |
| 45 """ | |
| 46 self.clock = clock | |
| 47 | |
| 48 def connectionMade(self): | |
| 49 """ | |
| 50 Create/clean data received on connection. | |
| 51 """ | |
| 52 self.received = [] | |
| 53 | |
| 54 def lineReceived(self, line): | |
| 55 """ | |
| 56 Receive line and make some action for some tokens: pause, rawpause, | |
| 57 stop, len, produce, unproduce. | |
| 58 """ | |
| 59 self.received.append(line) | |
| 60 if line == '': | |
| 61 self.setRawMode() | |
| 62 elif line == 'pause': | |
| 63 self.pauseProducing() | |
| 64 self.clock.callLater(0, self.resumeProducing) | |
| 65 elif line == 'rawpause': | |
| 66 self.pauseProducing() | |
| 67 self.setRawMode() | |
| 68 self.received.append('') | |
| 69 self.clock.callLater(0, self.resumeProducing) | |
| 70 elif line == 'stop': | |
| 71 self.stopProducing() | |
| 72 elif line[:4] == 'len ': | |
| 73 self.length = int(line[4:]) | |
| 74 elif line.startswith('produce'): | |
| 75 self.transport.registerProducer(self, False) | |
| 76 elif line.startswith('unproduce'): | |
| 77 self.transport.unregisterProducer() | |
| 78 | |
| 79 def rawDataReceived(self, data): | |
| 80 """ | |
| 81 Read raw data, until the quantity specified by a previous 'len' line is | |
| 82 reached. | |
| 83 """ | |
| 84 data, rest = data[:self.length], data[self.length:] | |
| 85 self.length = self.length - len(data) | |
| 86 self.received[-1] = self.received[-1] + data | |
| 87 if self.length == 0: | |
| 88 self.setLineMode(rest) | |
| 89 | |
| 90 def lineLengthExceeded(self, line): | |
| 91 """ | |
| 92 Adjust line mode when long lines received. | |
| 93 """ | |
| 94 if len(line) > self.MAX_LENGTH + 1: | |
| 95 self.setLineMode(line[self.MAX_LENGTH + 1:]) | |
| 96 | |
| 97 | |
| 98 class LineOnlyTester(basic.LineOnlyReceiver): | |
| 99 """ | |
| 100 A buffering line only receiver. | |
| 101 """ | |
| 102 delimiter = '\n' | |
| 103 MAX_LENGTH = 64 | |
| 104 | |
| 105 def connectionMade(self): | |
| 106 """ | |
| 107 Create/clean data received on connection. | |
| 108 """ | |
| 109 self.received = [] | |
| 110 | |
| 111 def lineReceived(self, line): | |
| 112 """ | |
| 113 Save received data. | |
| 114 """ | |
| 115 self.received.append(line) | |
| 116 | |
| 117 class WireTestCase(unittest.TestCase): | |
| 118 """ | |
| 119 Test wire protocols. | |
| 120 """ | |
| 121 def testEcho(self): | |
| 122 """ | |
| 123 Test wire.Echo protocol: send some data and check it send it back. | |
| 124 """ | |
| 125 t = StringIOWithoutClosing() | |
| 126 a = wire.Echo() | |
| 127 a.makeConnection(protocol.FileWrapper(t)) | |
| 128 a.dataReceived("hello") | |
| 129 a.dataReceived("world") | |
| 130 a.dataReceived("how") | |
| 131 a.dataReceived("are") | |
| 132 a.dataReceived("you") | |
| 133 self.failUnlessEqual(t.getvalue(), "helloworldhowareyou") | |
| 134 | |
| 135 def testWho(self): | |
| 136 """ | |
| 137 Test wire.Who protocol. | |
| 138 """ | |
| 139 t = StringIOWithoutClosing() | |
| 140 a = wire.Who() | |
| 141 a.makeConnection(protocol.FileWrapper(t)) | |
| 142 self.failUnlessEqual(t.getvalue(), "root\r\n") | |
| 143 | |
| 144 def testQOTD(self): | |
| 145 """ | |
| 146 Test wire.QOTD protocol. | |
| 147 """ | |
| 148 t = StringIOWithoutClosing() | |
| 149 a = wire.QOTD() | |
| 150 a.makeConnection(protocol.FileWrapper(t)) | |
| 151 self.failUnlessEqual(t.getvalue(), | |
| 152 "An apple a day keeps the doctor away.\r\n") | |
| 153 | |
| 154 def testDiscard(self): | |
| 155 """ | |
| 156 Test wire.Discard protocol. | |
| 157 """ | |
| 158 t = StringIOWithoutClosing() | |
| 159 a = wire.Discard() | |
| 160 a.makeConnection(protocol.FileWrapper(t)) | |
| 161 a.dataReceived("hello") | |
| 162 a.dataReceived("world") | |
| 163 a.dataReceived("how") | |
| 164 a.dataReceived("are") | |
| 165 a.dataReceived("you") | |
| 166 self.failUnlessEqual(t.getvalue(), "") | |
| 167 | |
| 168 class LineReceiverTestCase(unittest.TestCase): | |
| 169 """ | |
| 170 Test LineReceiver, using the C{LineTester} wrapper. | |
| 171 """ | |
| 172 buffer = '''\ | |
| 173 len 10 | |
| 174 | |
| 175 0123456789len 5 | |
| 176 | |
| 177 1234 | |
| 178 len 20 | |
| 179 foo 123 | |
| 180 | |
| 181 0123456789 | |
| 182 012345678len 0 | |
| 183 foo 5 | |
| 184 | |
| 185 1234567890123456789012345678901234567890123456789012345678901234567890 | |
| 186 len 1 | |
| 187 | |
| 188 a''' | |
| 189 | |
| 190 output = ['len 10', '0123456789', 'len 5', '1234\n', | |
| 191 'len 20', 'foo 123', '0123456789\n012345678', | |
| 192 'len 0', 'foo 5', '', '67890', 'len 1', 'a'] | |
| 193 | |
| 194 def testBuffer(self): | |
| 195 """ | |
| 196 Test buffering for different packet size, checking received matches | |
| 197 expected data. | |
| 198 """ | |
| 199 for packet_size in range(1, 10): | |
| 200 t = StringIOWithoutClosing() | |
| 201 a = LineTester() | |
| 202 a.makeConnection(protocol.FileWrapper(t)) | |
| 203 for i in range(len(self.buffer)/packet_size + 1): | |
| 204 s = self.buffer[i*packet_size:(i+1)*packet_size] | |
| 205 a.dataReceived(s) | |
| 206 self.failUnlessEqual(self.output, a.received) | |
| 207 | |
| 208 | |
| 209 pause_buf = 'twiddle1\ntwiddle2\npause\ntwiddle3\n' | |
| 210 | |
| 211 pause_output1 = ['twiddle1', 'twiddle2', 'pause'] | |
| 212 pause_output2 = pause_output1+['twiddle3'] | |
| 213 | |
| 214 def testPausing(self): | |
| 215 """ | |
| 216 Test pause inside data receiving. It uses fake clock to see if | |
| 217 pausing/resuming work. | |
| 218 """ | |
| 219 for packet_size in range(1, 10): | |
| 220 t = StringIOWithoutClosing() | |
| 221 clock = task.Clock() | |
| 222 a = LineTester(clock) | |
| 223 a.makeConnection(protocol.FileWrapper(t)) | |
| 224 for i in range(len(self.pause_buf)/packet_size + 1): | |
| 225 s = self.pause_buf[i*packet_size:(i+1)*packet_size] | |
| 226 a.dataReceived(s) | |
| 227 self.failUnlessEqual(self.pause_output1, a.received) | |
| 228 clock.advance(0) | |
| 229 self.failUnlessEqual(self.pause_output2, a.received) | |
| 230 | |
| 231 rawpause_buf = 'twiddle1\ntwiddle2\nlen 5\nrawpause\n12345twiddle3\n' | |
| 232 | |
| 233 rawpause_output1 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', ''] | |
| 234 rawpause_output2 = ['twiddle1', 'twiddle2', 'len 5', 'rawpause', '12345', | |
| 235 'twiddle3'] | |
| 236 | |
| 237 def testRawPausing(self): | |
| 238 """ | |
| 239 Test pause inside raw date receiving. | |
| 240 """ | |
| 241 for packet_size in range(1, 10): | |
| 242 t = StringIOWithoutClosing() | |
| 243 clock = task.Clock() | |
| 244 a = LineTester(clock) | |
| 245 a.makeConnection(protocol.FileWrapper(t)) | |
| 246 for i in range(len(self.rawpause_buf)/packet_size + 1): | |
| 247 s = self.rawpause_buf[i*packet_size:(i+1)*packet_size] | |
| 248 a.dataReceived(s) | |
| 249 self.failUnlessEqual(self.rawpause_output1, a.received) | |
| 250 clock.advance(0) | |
| 251 self.failUnlessEqual(self.rawpause_output2, a.received) | |
| 252 | |
| 253 stop_buf = 'twiddle1\ntwiddle2\nstop\nmore\nstuff\n' | |
| 254 | |
| 255 stop_output = ['twiddle1', 'twiddle2', 'stop'] | |
| 256 | |
| 257 def testStopProducing(self): | |
| 258 """ | |
| 259 Test stop inside producing. | |
| 260 """ | |
| 261 for packet_size in range(1, 10): | |
| 262 t = StringIOWithoutClosing() | |
| 263 a = LineTester() | |
| 264 a.makeConnection(protocol.FileWrapper(t)) | |
| 265 for i in range(len(self.stop_buf)/packet_size + 1): | |
| 266 s = self.stop_buf[i*packet_size:(i+1)*packet_size] | |
| 267 a.dataReceived(s) | |
| 268 self.failUnlessEqual(self.stop_output, a.received) | |
| 269 | |
| 270 | |
| 271 def testLineReceiverAsProducer(self): | |
| 272 """ | |
| 273 Test produce/unproduce in receiving. | |
| 274 """ | |
| 275 a = LineTester() | |
| 276 t = StringIOWithoutClosing() | |
| 277 a.makeConnection(protocol.FileWrapper(t)) | |
| 278 a.dataReceived('produce\nhello world\nunproduce\ngoodbye\n') | |
| 279 self.assertEquals(a.received, | |
| 280 ['produce', 'hello world', 'unproduce', 'goodbye']) | |
| 281 | |
| 282 | |
| 283 class LineOnlyReceiverTestCase(unittest.TestCase): | |
| 284 """ | |
| 285 Test line only receiveer. | |
| 286 """ | |
| 287 buffer = """foo | |
| 288 bleakness | |
| 289 desolation | |
| 290 plastic forks | |
| 291 """ | |
| 292 | |
| 293 def testBuffer(self): | |
| 294 """ | |
| 295 Test buffering over line protocol: data received should match buffer. | |
| 296 """ | |
| 297 t = StringIOWithoutClosing() | |
| 298 a = LineOnlyTester() | |
| 299 a.makeConnection(protocol.FileWrapper(t)) | |
| 300 for c in self.buffer: | |
| 301 a.dataReceived(c) | |
| 302 self.failUnlessEqual(a.received, self.buffer.split('\n')[:-1]) | |
| 303 | |
| 304 def testLineTooLong(self): | |
| 305 """ | |
| 306 Test sending a line too long: it should close the connection. | |
| 307 """ | |
| 308 t = StringIOWithoutClosing() | |
| 309 a = LineOnlyTester() | |
| 310 a.makeConnection(protocol.FileWrapper(t)) | |
| 311 res = a.dataReceived('x'*200) | |
| 312 self.assertTrue(isinstance(res, error.ConnectionLost)) | |
| 313 | |
| 314 | |
| 315 | |
| 316 class TestMixin: | |
| 317 | |
| 318 def connectionMade(self): | |
| 319 self.received = [] | |
| 320 | |
| 321 def stringReceived(self, s): | |
| 322 self.received.append(s) | |
| 323 | |
| 324 MAX_LENGTH = 50 | |
| 325 closed = 0 | |
| 326 | |
| 327 def connectionLost(self, reason): | |
| 328 self.closed = 1 | |
| 329 | |
| 330 | |
| 331 class TestNetstring(TestMixin, basic.NetstringReceiver): | |
| 332 pass | |
| 333 | |
| 334 | |
| 335 class LPTestCaseMixin: | |
| 336 | |
| 337 illegalStrings = [] | |
| 338 protocol = None | |
| 339 | |
| 340 def getProtocol(self): | |
| 341 t = StringIOWithoutClosing() | |
| 342 a = self.protocol() | |
| 343 a.makeConnection(protocol.FileWrapper(t)) | |
| 344 return a | |
| 345 | |
| 346 def test_illegal(self): | |
| 347 """ | |
| 348 Assert that illegal strings cause the transport to be closed. | |
| 349 """ | |
| 350 for s in self.illegalStrings: | |
| 351 r = self.getProtocol() | |
| 352 for c in s: | |
| 353 r.dataReceived(c) | |
| 354 self.assertEquals(r.transport.closed, 1) | |
| 355 | |
| 356 | |
| 357 class NetstringReceiverTestCase(unittest.TestCase, LPTestCaseMixin): | |
| 358 | |
| 359 strings = ['hello', 'world', 'how', 'are', 'you123', ':today', "a"*515] | |
| 360 | |
| 361 illegalStrings = [ | |
| 362 '9999999999999999999999', 'abc', '4:abcde', | |
| 363 '51:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab,',] | |
| 364 | |
| 365 protocol = TestNetstring | |
| 366 | |
| 367 def testBuffer(self): | |
| 368 for packet_size in range(1, 10): | |
| 369 t = StringIOWithoutClosing() | |
| 370 a = TestNetstring() | |
| 371 a.MAX_LENGTH = 699 | |
| 372 a.makeConnection(protocol.FileWrapper(t)) | |
| 373 for s in self.strings: | |
| 374 a.sendString(s) | |
| 375 out = t.getvalue() | |
| 376 for i in range(len(out)/packet_size + 1): | |
| 377 s = out[i*packet_size:(i+1)*packet_size] | |
| 378 if s: | |
| 379 a.dataReceived(s) | |
| 380 self.assertEquals(a.received, self.strings) | |
| 381 | |
| 382 | |
| 383 class IntNTestCaseMixin(LPTestCaseMixin): | |
| 384 """ | |
| 385 TestCase mixin for int-prefixed protocols. | |
| 386 """ | |
| 387 | |
| 388 protocol = None | |
| 389 strings = None | |
| 390 illegalStrings = None | |
| 391 partialStrings = None | |
| 392 | |
| 393 def test_receive(self): | |
| 394 """ | |
| 395 Test receiving data find the same data send. | |
| 396 """ | |
| 397 r = self.getProtocol() | |
| 398 for s in self.strings: | |
| 399 for c in struct.pack(self.protocol.structFormat,len(s)) + s: | |
| 400 r.dataReceived(c) | |
| 401 self.assertEquals(r.received, self.strings) | |
| 402 | |
| 403 def test_partial(self): | |
| 404 """ | |
| 405 Send partial data, nothing should be definitely received. | |
| 406 """ | |
| 407 for s in self.partialStrings: | |
| 408 r = self.getProtocol() | |
| 409 for c in s: | |
| 410 r.dataReceived(c) | |
| 411 self.assertEquals(r.received, []) | |
| 412 | |
| 413 def test_send(self): | |
| 414 """ | |
| 415 Test sending data over protocol. | |
| 416 """ | |
| 417 r = self.getProtocol() | |
| 418 r.sendString("b" * 16) | |
| 419 self.assertEquals(r.transport.file.getvalue(), | |
| 420 struct.pack(self.protocol.structFormat, 16) + "b" * 16) | |
| 421 | |
| 422 | |
| 423 class TestInt32(TestMixin, basic.Int32StringReceiver): | |
| 424 """ | |
| 425 A L{basic.Int32StringReceiver} storing received strings in an array. | |
| 426 | |
| 427 @ivar received: array holding received strings. | |
| 428 """ | |
| 429 | |
| 430 | |
| 431 class Int32TestCase(unittest.TestCase, IntNTestCaseMixin): | |
| 432 """ | |
| 433 Test case for int32-prefixed protocol | |
| 434 """ | |
| 435 protocol = TestInt32 | |
| 436 strings = ["a", "b" * 16] | |
| 437 illegalStrings = ["\x10\x00\x00\x00aaaaaa"] | |
| 438 partialStrings = ["\x00\x00\x00", "hello there", ""] | |
| 439 | |
| 440 def test_data(self): | |
| 441 """ | |
| 442 Test specific behavior of the 32-bits length. | |
| 443 """ | |
| 444 r = self.getProtocol() | |
| 445 r.sendString("foo") | |
| 446 self.assertEquals(r.transport.file.getvalue(), "\x00\x00\x00\x03foo") | |
| 447 r.dataReceived("\x00\x00\x00\x04ubar") | |
| 448 self.assertEquals(r.received, ["ubar"]) | |
| 449 | |
| 450 | |
| 451 class TestInt16(TestMixin, basic.Int16StringReceiver): | |
| 452 """ | |
| 453 A L{basic.Int16StringReceiver} storing received strings in an array. | |
| 454 | |
| 455 @ivar received: array holding received strings. | |
| 456 """ | |
| 457 | |
| 458 | |
| 459 class Int16TestCase(unittest.TestCase, IntNTestCaseMixin): | |
| 460 """ | |
| 461 Test case for int16-prefixed protocol | |
| 462 """ | |
| 463 protocol = TestInt16 | |
| 464 strings = ["a", "b" * 16] | |
| 465 illegalStrings = ["\x10\x00aaaaaa"] | |
| 466 partialStrings = ["\x00", "hello there", ""] | |
| 467 | |
| 468 def test_data(self): | |
| 469 """ | |
| 470 Test specific behavior of the 16-bits length. | |
| 471 """ | |
| 472 r = self.getProtocol() | |
| 473 r.sendString("foo") | |
| 474 self.assertEquals(r.transport.file.getvalue(), "\x00\x03foo") | |
| 475 r.dataReceived("\x00\x04ubar") | |
| 476 self.assertEquals(r.received, ["ubar"]) | |
| 477 | |
| 478 def test_tooLongSend(self): | |
| 479 """ | |
| 480 Send too much data: that should cause an error. | |
| 481 """ | |
| 482 r = self.getProtocol() | |
| 483 tooSend = "b" * (2**(r.prefixLength*8) + 1) | |
| 484 self.assertRaises(AssertionError, r.sendString, tooSend) | |
| 485 | |
| 486 | |
| 487 class TestInt8(TestMixin, basic.Int8StringReceiver): | |
| 488 """ | |
| 489 A L{basic.Int8StringReceiver} storing received strings in an array. | |
| 490 | |
| 491 @ivar received: array holding received strings. | |
| 492 """ | |
| 493 | |
| 494 | |
| 495 class Int8TestCase(unittest.TestCase, IntNTestCaseMixin): | |
| 496 """ | |
| 497 Test case for int8-prefixed protocol | |
| 498 """ | |
| 499 protocol = TestInt8 | |
| 500 strings = ["a", "b" * 16] | |
| 501 illegalStrings = ["\x00\x00aaaaaa"] | |
| 502 partialStrings = ["\x08", "dzadz", ""] | |
| 503 | |
| 504 def test_data(self): | |
| 505 """ | |
| 506 Test specific behavior of the 8-bits length. | |
| 507 """ | |
| 508 r = self.getProtocol() | |
| 509 r.sendString("foo") | |
| 510 self.assertEquals(r.transport.file.getvalue(), "\x03foo") | |
| 511 r.dataReceived("\x04ubar") | |
| 512 self.assertEquals(r.received, ["ubar"]) | |
| 513 | |
| 514 def test_tooLongSend(self): | |
| 515 """ | |
| 516 Send too much data: that should cause an error. | |
| 517 """ | |
| 518 r = self.getProtocol() | |
| 519 tooSend = "b" * (2**(r.prefixLength*8) + 1) | |
| 520 self.assertRaises(AssertionError, r.sendString, tooSend) | |
| 521 | |
| 522 | |
| 523 class OnlyProducerTransport(object): | |
| 524 # Transport which isn't really a transport, just looks like one to | |
| 525 # someone not looking very hard. | |
| 526 | |
| 527 paused = False | |
| 528 disconnecting = False | |
| 529 | |
| 530 def __init__(self): | |
| 531 self.data = [] | |
| 532 | |
| 533 def pauseProducing(self): | |
| 534 self.paused = True | |
| 535 | |
| 536 def resumeProducing(self): | |
| 537 self.paused = False | |
| 538 | |
| 539 def write(self, bytes): | |
| 540 self.data.append(bytes) | |
| 541 | |
| 542 | |
| 543 class ConsumingProtocol(basic.LineReceiver): | |
| 544 # Protocol that really, really doesn't want any more bytes. | |
| 545 | |
| 546 def lineReceived(self, line): | |
| 547 self.transport.write(line) | |
| 548 self.pauseProducing() | |
| 549 | |
| 550 | |
| 551 class ProducerTestCase(unittest.TestCase): | |
| 552 def testPauseResume(self): | |
| 553 p = ConsumingProtocol() | |
| 554 t = OnlyProducerTransport() | |
| 555 p.makeConnection(t) | |
| 556 | |
| 557 p.dataReceived('hello, ') | |
| 558 self.failIf(t.data) | |
| 559 self.failIf(t.paused) | |
| 560 self.failIf(p.paused) | |
| 561 | |
| 562 p.dataReceived('world\r\n') | |
| 563 | |
| 564 self.assertEquals(t.data, ['hello, world']) | |
| 565 self.failUnless(t.paused) | |
| 566 self.failUnless(p.paused) | |
| 567 | |
| 568 p.resumeProducing() | |
| 569 | |
| 570 self.failIf(t.paused) | |
| 571 self.failIf(p.paused) | |
| 572 | |
| 573 p.dataReceived('hello\r\nworld\r\n') | |
| 574 | |
| 575 self.assertEquals(t.data, ['hello, world', 'hello']) | |
| 576 self.failUnless(t.paused) | |
| 577 self.failUnless(p.paused) | |
| 578 | |
| 579 p.resumeProducing() | |
| 580 p.dataReceived('goodbye\r\n') | |
| 581 | |
| 582 self.assertEquals(t.data, ['hello, world', 'hello', 'world']) | |
| 583 self.failUnless(t.paused) | |
| 584 self.failUnless(p.paused) | |
| 585 | |
| 586 p.resumeProducing() | |
| 587 | |
| 588 self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye']) | |
| 589 self.failUnless(t.paused) | |
| 590 self.failUnless(p.paused) | |
| 591 | |
| 592 p.resumeProducing() | |
| 593 | |
| 594 self.assertEquals(t.data, ['hello, world', 'hello', 'world', 'goodbye']) | |
| 595 self.failIf(t.paused) | |
| 596 self.failIf(p.paused) | |
| 597 | |
| 598 | |
| 599 | |
| 600 class TestableProxyClientFactory(portforward.ProxyClientFactory): | |
| 601 """ | |
| 602 Test proxy client factory that keeps the last created protocol instance. | |
| 603 | |
| 604 @ivar protoInstance: the last instance of the protocol. | |
| 605 @type protoInstance: L{portforward.ProxyClient} | |
| 606 """ | |
| 607 | |
| 608 def buildProtocol(self, addr): | |
| 609 """ | |
| 610 Create the protocol instance and keeps track of it. | |
| 611 """ | |
| 612 proto = portforward.ProxyClientFactory.buildProtocol(self, addr) | |
| 613 self.protoInstance = proto | |
| 614 return proto | |
| 615 | |
| 616 | |
| 617 | |
| 618 class TestableProxyFactory(portforward.ProxyFactory): | |
| 619 """ | |
| 620 Test proxy factory that keeps the last created protocol instance. | |
| 621 | |
| 622 @ivar protoInstance: the last instance of the protocol. | |
| 623 @type protoInstance: L{portforward.ProxyServer} | |
| 624 | |
| 625 @ivar clientFactoryInstance: client factory used by C{protoInstance} to | |
| 626 create forward connections. | |
| 627 @type clientFactoryInstance: L{TestableProxyClientFactory} | |
| 628 """ | |
| 629 | |
| 630 def buildProtocol(self, addr): | |
| 631 """ | |
| 632 Create the protocol instance, keeps track of it, and makes it use | |
| 633 C{clientFactoryInstance} as client factory. | |
| 634 """ | |
| 635 proto = portforward.ProxyFactory.buildProtocol(self, addr) | |
| 636 self.clientFactoryInstance = TestableProxyClientFactory() | |
| 637 # Force the use of this specific instance | |
| 638 proto.clientProtocolFactory = lambda: self.clientFactoryInstance | |
| 639 self.protoInstance = proto | |
| 640 return proto | |
| 641 | |
| 642 | |
| 643 | |
| 644 class Portforwarding(unittest.TestCase): | |
| 645 """ | |
| 646 Test port forwarding. | |
| 647 """ | |
| 648 | |
| 649 def setUp(self): | |
| 650 self.serverProtocol = wire.Echo() | |
| 651 self.clientProtocol = protocol.Protocol() | |
| 652 self.openPorts = [] | |
| 653 | |
| 654 | |
| 655 def tearDown(self): | |
| 656 try: | |
| 657 self.proxyServerFactory.protoInstance.transport.loseConnection() | |
| 658 except AttributeError: | |
| 659 pass | |
| 660 try: | |
| 661 self.proxyServerFactory.clientFactoryInstance.protoInstance.transpor
t.loseConnection() | |
| 662 except AttributeError: | |
| 663 pass | |
| 664 try: | |
| 665 self.clientProtocol.transport.loseConnection() | |
| 666 except AttributeError: | |
| 667 pass | |
| 668 try: | |
| 669 self.serverProtocol.transport.loseConnection() | |
| 670 except AttributeError: | |
| 671 pass | |
| 672 return defer.gatherResults( | |
| 673 [defer.maybeDeferred(p.stopListening) for p in self.openPorts]) | |
| 674 | |
| 675 | |
| 676 def test_portforward(self): | |
| 677 """ | |
| 678 Test port forwarding through Echo protocol. | |
| 679 """ | |
| 680 realServerFactory = protocol.ServerFactory() | |
| 681 realServerFactory.protocol = lambda: self.serverProtocol | |
| 682 realServerPort = reactor.listenTCP(0, realServerFactory, | |
| 683 interface='127.0.0.1') | |
| 684 self.openPorts.append(realServerPort) | |
| 685 self.proxyServerFactory = TestableProxyFactory('127.0.0.1', | |
| 686 realServerPort.getHost().port) | |
| 687 proxyServerPort = reactor.listenTCP(0, self.proxyServerFactory, | |
| 688 interface='127.0.0.1') | |
| 689 self.openPorts.append(proxyServerPort) | |
| 690 | |
| 691 nBytes = 1000 | |
| 692 received = [] | |
| 693 d = defer.Deferred() | |
| 694 def testDataReceived(data): | |
| 695 received.extend(data) | |
| 696 if len(received) >= nBytes: | |
| 697 self.assertEquals(''.join(received), 'x' * nBytes) | |
| 698 d.callback(None) | |
| 699 self.clientProtocol.dataReceived = testDataReceived | |
| 700 | |
| 701 def testConnectionMade(): | |
| 702 self.clientProtocol.transport.write('x' * nBytes) | |
| 703 self.clientProtocol.connectionMade = testConnectionMade | |
| 704 | |
| 705 clientFactory = protocol.ClientFactory() | |
| 706 clientFactory.protocol = lambda: self.clientProtocol | |
| 707 | |
| 708 reactor.connectTCP( | |
| 709 '127.0.0.1', proxyServerPort.getHost().port, clientFactory) | |
| 710 | |
| 711 return d | |
| 712 | |
| 713 | |
| 714 | |
| 715 class StringTransportTestCase(unittest.TestCase): | |
| 716 """ | |
| 717 Test L{proto_helpers.StringTransport} helper behaviour. | |
| 718 """ | |
| 719 | |
| 720 def test_noUnicode(self): | |
| 721 """ | |
| 722 Test that L{proto_helpers.StringTransport} doesn't accept unicode data. | |
| 723 """ | |
| 724 s = proto_helpers.StringTransport() | |
| 725 self.assertRaises(TypeError, s.write, u'foo') | |
| OLD | NEW |