| OLD | NEW |
| (Empty) |
| 1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
| 2 # See LICENSE for details. | |
| 3 | |
| 4 from twisted.trial import unittest | |
| 5 from cStringIO import StringIO | |
| 6 | |
| 7 from twisted.web import server, resource, util | |
| 8 from twisted.internet import defer, interfaces, error, task | |
| 9 from twisted.web import http | |
| 10 from twisted.python import log | |
| 11 from twisted.internet.address import IPv4Address | |
| 12 from zope.interface import implements | |
| 13 | |
| 14 class DummyRequest: | |
| 15 uri='http://dummy/' | |
| 16 method = 'GET' | |
| 17 | |
| 18 def getHeader(self, h): | |
| 19 return None | |
| 20 | |
| 21 def registerProducer(self, prod,s): | |
| 22 self.go = 1 | |
| 23 while self.go: | |
| 24 prod.resumeProducing() | |
| 25 | |
| 26 def unregisterProducer(self): | |
| 27 self.go = 0 | |
| 28 | |
| 29 def __init__(self, postpath, session=None): | |
| 30 self.sitepath = [] | |
| 31 self.written = [] | |
| 32 self.finished = 0 | |
| 33 self.postpath = postpath | |
| 34 self.prepath = [] | |
| 35 self.session = None | |
| 36 self.protoSession = session or server.Session(0, self) | |
| 37 self.args = {} | |
| 38 self.outgoingHeaders = {} | |
| 39 | |
| 40 def setHeader(self, name, value): | |
| 41 """TODO: make this assert on write() if the header is content-length | |
| 42 """ | |
| 43 self.outgoingHeaders[name.lower()] = value | |
| 44 | |
| 45 def getSession(self): | |
| 46 if self.session: | |
| 47 return self.session | |
| 48 assert not self.written, "Session cannot be requested after data has bee
n written." | |
| 49 self.session = self.protoSession | |
| 50 return self.session | |
| 51 def write(self, data): | |
| 52 self.written.append(data) | |
| 53 def finish(self): | |
| 54 self.finished = self.finished + 1 | |
| 55 def addArg(self, name, value): | |
| 56 self.args[name] = [value] | |
| 57 def setResponseCode(self, code): | |
| 58 assert not self.written, "Response code cannot be set after data has bee
n written: %s." % "@@@@".join(self.written) | |
| 59 def setLastModified(self, when): | |
| 60 assert not self.written, "Last-Modified cannot be set after data has bee
n written: %s." % "@@@@".join(self.written) | |
| 61 def setETag(self, tag): | |
| 62 assert not self.written, "ETag cannot be set after data has been written
: %s." % "@@@@".join(self.written) | |
| 63 | |
| 64 class ResourceTestCase(unittest.TestCase): | |
| 65 def testListEntities(self): | |
| 66 r = resource.Resource() | |
| 67 self.failUnlessEqual([], r.listEntities()) | |
| 68 | |
| 69 | |
| 70 class SimpleResource(resource.Resource): | |
| 71 def render(self, request): | |
| 72 if http.CACHED in (request.setLastModified(10), | |
| 73 request.setETag('MatchingTag')): | |
| 74 return '' | |
| 75 else: | |
| 76 return "correct" | |
| 77 | |
| 78 class SiteTest(unittest.TestCase): | |
| 79 def testSimplestSite(self): | |
| 80 sres1 = SimpleResource() | |
| 81 sres2 = SimpleResource() | |
| 82 sres1.putChild("",sres2) | |
| 83 site = server.Site(sres1) | |
| 84 assert site.getResourceFor(DummyRequest([''])) is sres2, "Got the wrong
resource." | |
| 85 | |
| 86 | |
| 87 | |
| 88 class SessionTest(unittest.TestCase): | |
| 89 | |
| 90 def setUp(self): | |
| 91 """ | |
| 92 Set up a session using a simulated scheduler. Creates a | |
| 93 C{times} attribute which specifies the return values of the | |
| 94 session's C{_getTime} method. | |
| 95 """ | |
| 96 clock = self.clock = task.Clock() | |
| 97 times = self.times = [] | |
| 98 | |
| 99 class MockSession(server.Session): | |
| 100 """ | |
| 101 A mock L{server.Session} object which fakes out scheduling | |
| 102 with the C{clock} attribute and fakes out the current time | |
| 103 to be the elements of L{SessionTest}'s C{times} attribute. | |
| 104 """ | |
| 105 def loopFactory(self, *a, **kw): | |
| 106 """ | |
| 107 Create a L{task.LoopingCall} which uses | |
| 108 L{SessionTest}'s C{clock} attribute. | |
| 109 """ | |
| 110 call = task.LoopingCall(*a, **kw) | |
| 111 call.clock = clock | |
| 112 return call | |
| 113 | |
| 114 def _getTime(self): | |
| 115 return times.pop(0) | |
| 116 | |
| 117 self.site = server.Site(SimpleResource()) | |
| 118 self.site.sessionFactory = MockSession | |
| 119 | |
| 120 | |
| 121 def test_basicExpiration(self): | |
| 122 """ | |
| 123 Test session expiration: setup a session, and simulate an expiration | |
| 124 time. | |
| 125 """ | |
| 126 self.times.extend([0, server.Session.sessionTimeout + 1]) | |
| 127 session = self.site.makeSession() | |
| 128 hasExpired = [False] | |
| 129 def cbExpire(): | |
| 130 hasExpired[0] = True | |
| 131 session.notifyOnExpire(cbExpire) | |
| 132 self.clock.advance(server.Site.sessionCheckTime - 1) | |
| 133 # Looping call should not have been executed | |
| 134 self.failIf(hasExpired[0]) | |
| 135 | |
| 136 self.clock.advance(1) | |
| 137 | |
| 138 self.failUnless(hasExpired[0]) | |
| 139 | |
| 140 | |
| 141 def test_delayedCallCleanup(self): | |
| 142 """ | |
| 143 Checking to make sure Sessions do not leave extra DelayedCalls. | |
| 144 """ | |
| 145 self.times.extend([0, 100]) | |
| 146 | |
| 147 session = self.site.makeSession() | |
| 148 loop = session.checkExpiredLoop | |
| 149 session.touch() | |
| 150 self.failUnless(loop.running) | |
| 151 | |
| 152 session.expire() | |
| 153 | |
| 154 self.failIf(self.clock.calls) | |
| 155 self.failIf(loop.running) | |
| 156 | |
| 157 | |
| 158 | |
| 159 # Conditional requests: | |
| 160 # If-None-Match, If-Modified-Since | |
| 161 | |
| 162 # make conditional request: | |
| 163 # normal response if condition succeeds | |
| 164 # if condition fails: | |
| 165 # response code | |
| 166 # no body | |
| 167 | |
| 168 def httpBody(whole): | |
| 169 return whole.split('\r\n\r\n', 1)[1] | |
| 170 | |
| 171 def httpHeader(whole, key): | |
| 172 key = key.lower() | |
| 173 headers = whole.split('\r\n\r\n', 1)[0] | |
| 174 for header in headers.split('\r\n'): | |
| 175 if header.lower().startswith(key): | |
| 176 return header.split(':', 1)[1].strip() | |
| 177 return None | |
| 178 | |
| 179 def httpCode(whole): | |
| 180 l1 = whole.split('\r\n', 1)[0] | |
| 181 return int(l1.split()[1]) | |
| 182 | |
| 183 class ConditionalTest(unittest.TestCase): | |
| 184 """web.server's handling of conditional requests for cache validation.""" | |
| 185 | |
| 186 # XXX: test web.distrib. | |
| 187 | |
| 188 def setUp(self): | |
| 189 self.resrc = SimpleResource() | |
| 190 self.resrc.putChild('', self.resrc) | |
| 191 self.site = server.Site(self.resrc) | |
| 192 self.site = server.Site(self.resrc) | |
| 193 self.site.logFile = log.logfile | |
| 194 | |
| 195 # HELLLLLLLLLLP! This harness is Very Ugly. | |
| 196 self.channel = self.site.buildProtocol(None) | |
| 197 self.transport = http.StringTransport() | |
| 198 self.transport.close = lambda *a, **kw: None | |
| 199 self.transport.disconnecting = lambda *a, **kw: 0 | |
| 200 self.transport.getPeer = lambda *a, **kw: "peer" | |
| 201 self.transport.getHost = lambda *a, **kw: "host" | |
| 202 self.channel.makeConnection(self.transport) | |
| 203 for l in ["GET / HTTP/1.1", | |
| 204 "Accept: text/html"]: | |
| 205 self.channel.lineReceived(l) | |
| 206 | |
| 207 def tearDown(self): | |
| 208 self.channel.connectionLost(None) | |
| 209 | |
| 210 def test_modified(self): | |
| 211 """If-Modified-Since cache validator (positive)""" | |
| 212 self.channel.lineReceived("If-Modified-Since: %s" | |
| 213 % http.datetimeToString(1)) | |
| 214 self.channel.lineReceived('') | |
| 215 result = self.transport.getvalue() | |
| 216 self.failUnlessEqual(httpCode(result), http.OK) | |
| 217 self.failUnlessEqual(httpBody(result), "correct") | |
| 218 | |
| 219 def test_unmodified(self): | |
| 220 """If-Modified-Since cache validator (negative)""" | |
| 221 self.channel.lineReceived("If-Modified-Since: %s" | |
| 222 % http.datetimeToString(100)) | |
| 223 self.channel.lineReceived('') | |
| 224 result = self.transport.getvalue() | |
| 225 self.failUnlessEqual(httpCode(result), http.NOT_MODIFIED) | |
| 226 self.failUnlessEqual(httpBody(result), "") | |
| 227 | |
| 228 def test_etagMatchedNot(self): | |
| 229 """If-None-Match ETag cache validator (positive)""" | |
| 230 self.channel.lineReceived("If-None-Match: unmatchedTag") | |
| 231 self.channel.lineReceived('') | |
| 232 result = self.transport.getvalue() | |
| 233 self.failUnlessEqual(httpCode(result), http.OK) | |
| 234 self.failUnlessEqual(httpBody(result), "correct") | |
| 235 | |
| 236 def test_etagMatched(self): | |
| 237 """If-None-Match ETag cache validator (negative)""" | |
| 238 self.channel.lineReceived("If-None-Match: MatchingTag") | |
| 239 self.channel.lineReceived('') | |
| 240 result = self.transport.getvalue() | |
| 241 self.failUnlessEqual(httpHeader(result, "ETag"), "MatchingTag") | |
| 242 self.failUnlessEqual(httpCode(result), http.NOT_MODIFIED) | |
| 243 self.failUnlessEqual(httpBody(result), "") | |
| 244 | |
| 245 from twisted.web import google | |
| 246 class GoogleTestCase(unittest.TestCase): | |
| 247 def testCheckGoogle(self): | |
| 248 raise unittest.SkipTest("no violation of google ToS") | |
| 249 d = google.checkGoogle('site:www.twistedmatrix.com twisted') | |
| 250 d.addCallback(self.assertEquals, 'http://twistedmatrix.com/') | |
| 251 return d | |
| 252 | |
| 253 from twisted.web import static | |
| 254 from twisted.web import script | |
| 255 | |
| 256 class StaticFileTest(unittest.TestCase): | |
| 257 | |
| 258 def testStaticPaths(self): | |
| 259 import os | |
| 260 dp = os.path.join(self.mktemp(),"hello") | |
| 261 ddp = os.path.join(dp, "goodbye") | |
| 262 tp = os.path.abspath(os.path.join(dp,"world.txt")) | |
| 263 tpy = os.path.join(dp,"wyrld.rpy") | |
| 264 os.makedirs(dp) | |
| 265 f = open(tp,"wb") | |
| 266 f.write("hello world") | |
| 267 f = open(tpy, "wb") | |
| 268 f.write(""" | |
| 269 from twisted.web.static import Data | |
| 270 resource = Data('dynamic world','text/plain') | |
| 271 """) | |
| 272 f = static.File(dp) | |
| 273 f.processors = { | |
| 274 '.rpy': script.ResourceScript, | |
| 275 } | |
| 276 | |
| 277 f.indexNames = f.indexNames + ['world.txt'] | |
| 278 self.assertEquals(f.getChild('', DummyRequest([''])).path, | |
| 279 tp) | |
| 280 self.assertEquals(f.getChild('wyrld.rpy', DummyRequest(['wyrld.rpy']) | |
| 281 ).__class__, | |
| 282 static.Data) | |
| 283 f = static.File(dp) | |
| 284 wtextr = DummyRequest(['world.txt']) | |
| 285 wtext = f.getChild('world.txt', wtextr) | |
| 286 self.assertEquals(wtext.path, tp) | |
| 287 wtext.render(wtextr) | |
| 288 self.assertEquals(wtextr.outgoingHeaders.get('content-length'), | |
| 289 str(len('hello world'))) | |
| 290 self.assertNotEquals(f.getChild('', DummyRequest([''])).__class__, | |
| 291 static.File) | |
| 292 | |
| 293 def testIgnoreExt(self): | |
| 294 f = static.File(".") | |
| 295 f.ignoreExt(".foo") | |
| 296 self.assertEquals(f.ignoredExts, [".foo"]) | |
| 297 f = static.File(".") | |
| 298 self.assertEquals(f.ignoredExts, []) | |
| 299 f = static.File(".", ignoredExts=(".bar", ".baz")) | |
| 300 self.assertEquals(f.ignoredExts, [".bar", ".baz"]) | |
| 301 | |
| 302 def testIgnoredExts(self): | |
| 303 import os | |
| 304 dp = os.path.join(self.mktemp(), 'allYourBase') | |
| 305 fp = os.path.join(dp, 'AreBelong.ToUs') | |
| 306 os.makedirs(dp) | |
| 307 open(fp, 'wb').write("Take off every 'Zig'!!") | |
| 308 f = static.File(dp) | |
| 309 f.ignoreExt('.ToUs') | |
| 310 dreq = DummyRequest(['']) | |
| 311 child_without_ext = f.getChild('AreBelong', dreq) | |
| 312 self.assertNotEquals(child_without_ext, f.childNotFound) | |
| 313 | |
| 314 class DummyChannel: | |
| 315 class TCP: | |
| 316 port = 80 | |
| 317 def getPeer(self): | |
| 318 return IPv4Address("TCP", 'client.example.com', 12344) | |
| 319 def getHost(self): | |
| 320 return IPv4Address("TCP", 'example.com', self.port) | |
| 321 class SSL(TCP): | |
| 322 implements(interfaces.ISSLTransport) | |
| 323 transport = TCP() | |
| 324 site = server.Site(resource.Resource()) | |
| 325 | |
| 326 class TestRequest(unittest.TestCase): | |
| 327 | |
| 328 def testChildLink(self): | |
| 329 request = server.Request(DummyChannel(), 1) | |
| 330 request.gotLength(0) | |
| 331 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 332 self.assertEqual(request.childLink('baz'), 'bar/baz') | |
| 333 request = server.Request(DummyChannel(), 1) | |
| 334 request.gotLength(0) | |
| 335 request.requestReceived('GET', '/foo/bar/', 'HTTP/1.0') | |
| 336 self.assertEqual(request.childLink('baz'), 'baz') | |
| 337 | |
| 338 def testPrePathURLSimple(self): | |
| 339 request = server.Request(DummyChannel(), 1) | |
| 340 request.gotLength(0) | |
| 341 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 342 request.setHost('example.com', 80) | |
| 343 self.assertEqual(request.prePathURL(), 'http://example.com/foo/bar') | |
| 344 | |
| 345 def testPrePathURLNonDefault(self): | |
| 346 d = DummyChannel() | |
| 347 d.transport = DummyChannel.TCP() | |
| 348 d.transport.port = 81 | |
| 349 request = server.Request(d, 1) | |
| 350 request.setHost('example.com', 81) | |
| 351 request.gotLength(0) | |
| 352 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 353 self.assertEqual(request.prePathURL(), 'http://example.com:81/foo/bar') | |
| 354 | |
| 355 def testPrePathURLSSLPort(self): | |
| 356 d = DummyChannel() | |
| 357 d.transport = DummyChannel.TCP() | |
| 358 d.transport.port = 443 | |
| 359 request = server.Request(d, 1) | |
| 360 request.setHost('example.com', 443) | |
| 361 request.gotLength(0) | |
| 362 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 363 self.assertEqual(request.prePathURL(), 'http://example.com:443/foo/bar') | |
| 364 | |
| 365 def testPrePathURLSSLPortAndSSL(self): | |
| 366 d = DummyChannel() | |
| 367 d.transport = DummyChannel.SSL() | |
| 368 d.transport.port = 443 | |
| 369 request = server.Request(d, 1) | |
| 370 request.setHost('example.com', 443) | |
| 371 request.gotLength(0) | |
| 372 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 373 self.assertEqual(request.prePathURL(), 'https://example.com/foo/bar') | |
| 374 | |
| 375 def testPrePathURLHTTPPortAndSSL(self): | |
| 376 d = DummyChannel() | |
| 377 d.transport = DummyChannel.SSL() | |
| 378 d.transport.port = 80 | |
| 379 request = server.Request(d, 1) | |
| 380 request.setHost('example.com', 80) | |
| 381 request.gotLength(0) | |
| 382 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 383 self.assertEqual(request.prePathURL(), 'https://example.com:80/foo/bar') | |
| 384 | |
| 385 def testPrePathURLSSLNonDefault(self): | |
| 386 d = DummyChannel() | |
| 387 d.transport = DummyChannel.SSL() | |
| 388 d.transport.port = 81 | |
| 389 request = server.Request(d, 1) | |
| 390 request.setHost('example.com', 81) | |
| 391 request.gotLength(0) | |
| 392 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 393 self.assertEqual(request.prePathURL(), 'https://example.com:81/foo/bar') | |
| 394 | |
| 395 def testPrePathURLSetSSLHost(self): | |
| 396 d = DummyChannel() | |
| 397 d.transport = DummyChannel.TCP() | |
| 398 d.transport.port = 81 | |
| 399 request = server.Request(d, 1) | |
| 400 request.setHost('foo.com', 81, 1) | |
| 401 request.gotLength(0) | |
| 402 request.requestReceived('GET', '/foo/bar', 'HTTP/1.0') | |
| 403 self.assertEqual(request.prePathURL(), 'https://foo.com:81/foo/bar') | |
| 404 | |
| 405 | |
| 406 def test_prePathURLQuoting(self): | |
| 407 """ | |
| 408 L{Request.prePathURL} quotes special characters in the URL segments to | |
| 409 preserve the original meaning. | |
| 410 """ | |
| 411 d = DummyChannel() | |
| 412 request = server.Request(d, 1) | |
| 413 request.setHost('example.com', 80) | |
| 414 request.gotLength(0) | |
| 415 request.requestReceived('GET', '/foo%2Fbar', 'HTTP/1.0') | |
| 416 self.assertEqual(request.prePathURL(), 'http://example.com/foo%2Fbar') | |
| 417 | |
| 418 | |
| 419 def testNotifyFinishConnectionLost(self): | |
| 420 d = DummyChannel() | |
| 421 d.transport = DummyChannel.TCP() | |
| 422 request = server.Request(d, 1) | |
| 423 finished = request.notifyFinish() | |
| 424 request.connectionLost(error.ConnectionDone("Connection done")) | |
| 425 return self.assertFailure(finished, error.ConnectionDone) | |
| 426 | |
| 427 | |
| 428 class RootResource(resource.Resource): | |
| 429 isLeaf=0 | |
| 430 def getChildWithDefault(self, name, request): | |
| 431 request.rememberRootURL() | |
| 432 return resource.Resource.getChildWithDefault(self, name, request) | |
| 433 def render(self, request): | |
| 434 return '' | |
| 435 | |
| 436 class RememberURLTest(unittest.TestCase): | |
| 437 def createServer(self, r): | |
| 438 chan = DummyChannel() | |
| 439 chan.transport = DummyChannel.TCP() | |
| 440 chan.site = server.Site(r) | |
| 441 return chan | |
| 442 | |
| 443 def testSimple(self): | |
| 444 r = resource.Resource() | |
| 445 r.isLeaf=0 | |
| 446 rr = RootResource() | |
| 447 r.putChild('foo', rr) | |
| 448 rr.putChild('', rr) | |
| 449 rr.putChild('bar', resource.Resource()) | |
| 450 chan = self.createServer(r) | |
| 451 for url in ['/foo/', '/foo/bar', '/foo/bar/baz', '/foo/bar/']: | |
| 452 request = server.Request(chan, 1) | |
| 453 request.setHost('example.com', 81) | |
| 454 request.gotLength(0) | |
| 455 request.requestReceived('GET', url, 'HTTP/1.0') | |
| 456 self.assertEqual(request.getRootURL(), "http://example.com/foo") | |
| 457 | |
| 458 def testRoot(self): | |
| 459 rr = RootResource() | |
| 460 rr.putChild('', rr) | |
| 461 rr.putChild('bar', resource.Resource()) | |
| 462 chan = self.createServer(rr) | |
| 463 for url in ['/', '/bar', '/bar/baz', '/bar/']: | |
| 464 request = server.Request(chan, 1) | |
| 465 request.setHost('example.com', 81) | |
| 466 request.gotLength(0) | |
| 467 request.requestReceived('GET', url, 'HTTP/1.0') | |
| 468 self.assertEqual(request.getRootURL(), "http://example.com/") | |
| 469 | |
| 470 | |
| 471 class NewRenderResource(resource.Resource): | |
| 472 def render_GET(self, request): | |
| 473 return "hi hi" | |
| 474 | |
| 475 def render_HEH(self, request): | |
| 476 return "ho ho" | |
| 477 | |
| 478 | |
| 479 class NewRenderTestCase(unittest.TestCase): | |
| 480 def _getReq(self): | |
| 481 d = DummyChannel() | |
| 482 d.site.resource.putChild('newrender', NewRenderResource()) | |
| 483 d.transport = DummyChannel.TCP() | |
| 484 d.transport.port = 81 | |
| 485 request = server.Request(d, 1) | |
| 486 request.setHost('example.com', 81) | |
| 487 request.gotLength(0) | |
| 488 return request | |
| 489 | |
| 490 def testGoodMethods(self): | |
| 491 req = self._getReq() | |
| 492 req.requestReceived('GET', '/newrender', 'HTTP/1.0') | |
| 493 self.assertEquals(req.transport.getvalue().splitlines()[-1], 'hi hi') | |
| 494 | |
| 495 req = self._getReq() | |
| 496 req.requestReceived('HEH', '/newrender', 'HTTP/1.0') | |
| 497 self.assertEquals(req.transport.getvalue().splitlines()[-1], 'ho ho') | |
| 498 | |
| 499 def testBadMethods(self): | |
| 500 req = self._getReq() | |
| 501 req.requestReceived('CONNECT', '/newrender', 'HTTP/1.0') | |
| 502 self.assertEquals(req.code, 501) | |
| 503 | |
| 504 req = self._getReq() | |
| 505 req.requestReceived('hlalauguG', '/newrender', 'HTTP/1.0') | |
| 506 self.assertEquals(req.code, 501) | |
| 507 | |
| 508 def testImplicitHead(self): | |
| 509 req = self._getReq() | |
| 510 req.requestReceived('HEAD', '/newrender', 'HTTP/1.0') | |
| 511 self.assertEquals(req.code, 200) | |
| 512 self.assertEquals(-1, req.transport.getvalue().find('hi hi')) | |
| 513 | |
| 514 | |
| 515 class SDResource(resource.Resource): | |
| 516 def __init__(self,default): self.default=default | |
| 517 def getChildWithDefault(self,name,request): | |
| 518 d=defer.succeed(self.default) | |
| 519 return util.DeferredResource(d).getChildWithDefault(name, request) | |
| 520 | |
| 521 class SDTest(unittest.TestCase): | |
| 522 | |
| 523 def testDeferredResource(self): | |
| 524 r = resource.Resource() | |
| 525 r.isLeaf = 1 | |
| 526 s = SDResource(r) | |
| 527 d = DummyRequest(['foo', 'bar', 'baz']) | |
| 528 resource.getChildForRequest(s, d) | |
| 529 self.assertEqual(d.postpath, ['bar', 'baz']) | |
| 530 | |
| 531 class DummyRequestForLogTest(DummyRequest): | |
| 532 uri='/dummy' # parent class uri has "http://", which doesn't really happen | |
| 533 code = 123 | |
| 534 client = '1.2.3.4' | |
| 535 clientproto = 'HTTP/1.0' | |
| 536 sentLength = None | |
| 537 | |
| 538 def __init__(self, *a, **kw): | |
| 539 DummyRequest.__init__(self, *a, **kw) | |
| 540 self.headers = {} | |
| 541 | |
| 542 def getHeader(self, h): | |
| 543 return self.headers.get(h.lower(), None) | |
| 544 | |
| 545 def getClientIP(self): | |
| 546 return self.client | |
| 547 | |
| 548 class TestLogEscaping(unittest.TestCase): | |
| 549 def setUp(self): | |
| 550 self.site = http.HTTPFactory() | |
| 551 self.site.logFile = StringIO() | |
| 552 self.request = DummyRequestForLogTest(self.site, False) | |
| 553 | |
| 554 def testSimple(self): | |
| 555 http._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % ( | |
| 556 25, 'Oct', 2004, 12, 31, 59) | |
| 557 self.site.log(self.request) | |
| 558 self.site.logFile.seek(0) | |
| 559 self.assertEqual( | |
| 560 self.site.logFile.read(), | |
| 561 '1.2.3.4 - - [25/Oct/2004:12:31:59 +0000] "GET /dummy HTTP/1.0" 123
- "-" "-"\n') | |
| 562 | |
| 563 def testMethodQuote(self): | |
| 564 http._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % ( | |
| 565 25, 'Oct', 2004, 12, 31, 59) | |
| 566 self.request.method = 'G"T' | |
| 567 self.site.log(self.request) | |
| 568 self.site.logFile.seek(0) | |
| 569 self.assertEqual( | |
| 570 self.site.logFile.read(), | |
| 571 '1.2.3.4 - - [25/Oct/2004:12:31:59 +0000] "G\\"T /dummy HTTP/1.0" 12
3 - "-" "-"\n') | |
| 572 | |
| 573 def testRequestQuote(self): | |
| 574 http._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % ( | |
| 575 25, 'Oct', 2004, 12, 31, 59) | |
| 576 self.request.uri='/dummy"withquote' | |
| 577 self.site.log(self.request) | |
| 578 self.site.logFile.seek(0) | |
| 579 self.assertEqual( | |
| 580 self.site.logFile.read(), | |
| 581 '1.2.3.4 - - [25/Oct/2004:12:31:59 +0000] "GET /dummy\\"withquote HT
TP/1.0" 123 - "-" "-"\n') | |
| 582 | |
| 583 def testProtoQuote(self): | |
| 584 http._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % ( | |
| 585 25, 'Oct', 2004, 12, 31, 59) | |
| 586 self.request.clientproto='HT"P/1.0' | |
| 587 self.site.log(self.request) | |
| 588 self.site.logFile.seek(0) | |
| 589 self.assertEqual( | |
| 590 self.site.logFile.read(), | |
| 591 '1.2.3.4 - - [25/Oct/2004:12:31:59 +0000] "GET /dummy HT\\"P/1.0" 12
3 - "-" "-"\n') | |
| 592 | |
| 593 def testRefererQuote(self): | |
| 594 http._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % ( | |
| 595 25, 'Oct', 2004, 12, 31, 59) | |
| 596 self.request.headers['referer'] = 'http://malicious" ".website.invalid' | |
| 597 self.site.log(self.request) | |
| 598 self.site.logFile.seek(0) | |
| 599 self.assertEqual( | |
| 600 self.site.logFile.read(), | |
| 601 '1.2.3.4 - - [25/Oct/2004:12:31:59 +0000] "GET /dummy HTTP/1.0" 123
- "http://malicious\\" \\".website.invalid" "-"\n') | |
| 602 | |
| 603 def testUserAgentQuote(self): | |
| 604 http._logDateTime = "[%02d/%3s/%4d:%02d:%02d:%02d +0000]" % ( | |
| 605 25, 'Oct', 2004, 12, 31, 59) | |
| 606 self.request.headers['user-agent'] = 'Malicious Web" Evil' | |
| 607 self.site.log(self.request) | |
| 608 self.site.logFile.seek(0) | |
| 609 self.assertEqual( | |
| 610 self.site.logFile.read(), | |
| 611 '1.2.3.4 - - [25/Oct/2004:12:31:59 +0000] "GET /dummy HTTP/1.0" 123
- "-" "Malicious Web\\" Evil"\n') | |
| OLD | NEW |