OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2007 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 | |
5 """ | |
6 Test HTTP support. | |
7 """ | |
8 | |
9 from urlparse import urlparse, urlunsplit, clear_cache | |
10 import string, random, urllib, cgi | |
11 | |
12 from twisted.trial import unittest | |
13 from twisted.web import http | |
14 from twisted.protocols import loopback | |
15 from twisted.internet import protocol | |
16 from twisted.test.test_protocols import StringIOWithoutClosing | |
17 | |
18 | |
19 class DateTimeTest(unittest.TestCase): | |
20 """Test date parsing functions.""" | |
21 | |
22 def testRoundtrip(self): | |
23 for i in range(10000): | |
24 time = random.randint(0, 2000000000) | |
25 timestr = http.datetimeToString(time) | |
26 time2 = http.stringToDatetime(timestr) | |
27 self.assertEquals(time, time2) | |
28 | |
29 | |
30 class OrderedDict: | |
31 | |
32 def __init__(self, dict): | |
33 self.dict = dict | |
34 self.l = dict.keys() | |
35 | |
36 def __setitem__(self, k, v): | |
37 self.l.append(k) | |
38 self.dict[k] = v | |
39 | |
40 def __getitem__(self, k): | |
41 return self.dict[k] | |
42 | |
43 def items(self): | |
44 result = [] | |
45 for i in self.l: | |
46 result.append((i, self.dict[i])) | |
47 return result | |
48 | |
49 def __getattr__(self, attr): | |
50 return getattr(self.dict, attr) | |
51 | |
52 | |
53 class DummyHTTPHandler(http.Request): | |
54 | |
55 def process(self): | |
56 self.headers = OrderedDict(self.headers) | |
57 self.content.seek(0, 0) | |
58 data = self.content.read() | |
59 length = self.getHeader('content-length') | |
60 request = "'''\n"+str(length)+"\n"+data+"'''\n" | |
61 self.setResponseCode(200) | |
62 self.setHeader("Request", self.uri) | |
63 self.setHeader("Command", self.method) | |
64 self.setHeader("Version", self.clientproto) | |
65 self.setHeader("Content-Length", len(request)) | |
66 self.write(request) | |
67 self.finish() | |
68 | |
69 | |
70 class LoopbackHTTPClient(http.HTTPClient): | |
71 | |
72 def connectionMade(self): | |
73 self.sendCommand("GET", "/foo/bar") | |
74 self.sendHeader("Content-Length", 10) | |
75 self.endHeaders() | |
76 self.transport.write("0123456789") | |
77 | |
78 | |
79 class HTTP1_0TestCase(unittest.TestCase): | |
80 | |
81 requests = '''\ | |
82 GET / HTTP/1.0 | |
83 | |
84 GET / HTTP/1.1 | |
85 Accept: text/html | |
86 | |
87 ''' | |
88 requests = string.replace(requests, '\n', '\r\n') | |
89 | |
90 expected_response = "HTTP/1.0 200 OK\015\012Request: /\015\012Command: GET\0
15\012Version: HTTP/1.0\015\012Content-length: 13\015\012\015\012'''\012None\012
'''\012" | |
91 | |
92 def test_buffer(self): | |
93 """ | |
94 Send requests over a channel and check responses match what is expected. | |
95 """ | |
96 b = StringIOWithoutClosing() | |
97 a = http.HTTPChannel() | |
98 a.requestFactory = DummyHTTPHandler | |
99 a.makeConnection(protocol.FileWrapper(b)) | |
100 # one byte at a time, to stress it. | |
101 for byte in self.requests: | |
102 a.dataReceived(byte) | |
103 a.connectionLost(IOError("all one")) | |
104 value = b.getvalue() | |
105 self.assertEquals(value, self.expected_response) | |
106 | |
107 | |
108 class HTTP1_1TestCase(HTTP1_0TestCase): | |
109 | |
110 requests = '''\ | |
111 GET / HTTP/1.1 | |
112 Accept: text/html | |
113 | |
114 POST / HTTP/1.1 | |
115 Content-Length: 10 | |
116 | |
117 0123456789POST / HTTP/1.1 | |
118 Content-Length: 10 | |
119 | |
120 0123456789HEAD / HTTP/1.1 | |
121 | |
122 ''' | |
123 requests = string.replace(requests, '\n', '\r\n') | |
124 | |
125 expected_response = "HTTP/1.1 200 OK\015\012Request: /\015\012Command: GET\0
15\012Version: HTTP/1.1\015\012Content-length: 13\015\012\015\012'''\012None\012
'''\012HTTP/1.1 200 OK\015\012Request: /\015\012Command: POST\015\012Version: HT
TP/1.1\015\012Content-length: 21\015\012\015\012'''\01210\0120123456789'''\012HT
TP/1.1 200 OK\015\012Request: /\015\012Command: POST\015\012Version: HTTP/1.1\01
5\012Content-length: 21\015\012\015\012'''\01210\0120123456789'''\012HTTP/1.1 20
0 OK\015\012Request: /\015\012Command: HEAD\015\012Version: HTTP/1.1\015\012Cont
ent-length: 13\015\012\015\012" | |
126 | |
127 class HTTP1_1_close_TestCase(HTTP1_0TestCase): | |
128 | |
129 requests = '''\ | |
130 GET / HTTP/1.1 | |
131 Accept: text/html | |
132 Connection: close | |
133 | |
134 GET / HTTP/1.0 | |
135 | |
136 ''' | |
137 | |
138 requests = string.replace(requests, '\n', '\r\n') | |
139 | |
140 expected_response = "HTTP/1.1 200 OK\015\012Connection: close\015\012Request
: /\015\012Command: GET\015\012Version: HTTP/1.1\015\012Content-length: 13\015\0
12\015\012'''\012None\012'''\012" | |
141 | |
142 | |
143 class HTTP0_9TestCase(HTTP1_0TestCase): | |
144 | |
145 requests = '''\ | |
146 GET / | |
147 ''' | |
148 requests = string.replace(requests, '\n', '\r\n') | |
149 | |
150 expected_response = "HTTP/1.1 400 Bad Request\r\n\r\n" | |
151 | |
152 | |
153 class HTTPLoopbackTestCase(unittest.TestCase): | |
154 | |
155 expectedHeaders = {'request' : '/foo/bar', | |
156 'command' : 'GET', | |
157 'version' : 'HTTP/1.0', | |
158 'content-length' : '21'} | |
159 numHeaders = 0 | |
160 gotStatus = 0 | |
161 gotResponse = 0 | |
162 gotEndHeaders = 0 | |
163 | |
164 def _handleStatus(self, version, status, message): | |
165 self.gotStatus = 1 | |
166 self.assertEquals(version, "HTTP/1.0") | |
167 self.assertEquals(status, "200") | |
168 | |
169 def _handleResponse(self, data): | |
170 self.gotResponse = 1 | |
171 self.assertEquals(data, "'''\n10\n0123456789'''\n") | |
172 | |
173 def _handleHeader(self, key, value): | |
174 self.numHeaders = self.numHeaders + 1 | |
175 self.assertEquals(self.expectedHeaders[string.lower(key)], value) | |
176 | |
177 def _handleEndHeaders(self): | |
178 self.gotEndHeaders = 1 | |
179 self.assertEquals(self.numHeaders, 4) | |
180 | |
181 def testLoopback(self): | |
182 server = http.HTTPChannel() | |
183 server.requestFactory = DummyHTTPHandler | |
184 client = LoopbackHTTPClient() | |
185 client.handleResponse = self._handleResponse | |
186 client.handleHeader = self._handleHeader | |
187 client.handleEndHeaders = self._handleEndHeaders | |
188 client.handleStatus = self._handleStatus | |
189 d = loopback.loopbackAsync(server, client) | |
190 d.addCallback(self._cbTestLoopback) | |
191 return d | |
192 | |
193 def _cbTestLoopback(self, ignored): | |
194 if not (self.gotStatus and self.gotResponse and self.gotEndHeaders): | |
195 raise RuntimeError( | |
196 "didn't got all callbacks %s" | |
197 % [self.gotStatus, self.gotResponse, self.gotEndHeaders]) | |
198 del self.gotEndHeaders | |
199 del self.gotResponse | |
200 del self.gotStatus | |
201 del self.numHeaders | |
202 | |
203 | |
204 class PRequest: | |
205 """Dummy request for persistence tests.""" | |
206 | |
207 def __init__(self, **headers): | |
208 self.received_headers = headers | |
209 self.headers = {} | |
210 | |
211 def getHeader(self, k): | |
212 return self.received_headers.get(k, '') | |
213 | |
214 def setHeader(self, k, v): | |
215 self.headers[k] = v | |
216 | |
217 | |
218 class PersistenceTestCase(unittest.TestCase): | |
219 """Tests for persistent HTTP connections.""" | |
220 | |
221 ptests = [#(PRequest(connection="Keep-Alive"), "HTTP/1.0", 1, {'connection'
: 'Keep-Alive'}), | |
222 (PRequest(), "HTTP/1.0", 0, {'connection': None}), | |
223 (PRequest(connection="close"), "HTTP/1.1", 0, {'connection' : 'clo
se'}), | |
224 (PRequest(), "HTTP/1.1", 1, {'connection': None}), | |
225 (PRequest(), "HTTP/0.9", 0, {'connection': None}), | |
226 ] | |
227 | |
228 | |
229 def testAlgorithm(self): | |
230 c = http.HTTPChannel() | |
231 for req, version, correctResult, resultHeaders in self.ptests: | |
232 result = c.checkPersistence(req, version) | |
233 self.assertEquals(result, correctResult) | |
234 for header in resultHeaders.keys(): | |
235 self.assertEquals(req.headers.get(header, None), resultHeaders[h
eader]) | |
236 | |
237 | |
238 class ChunkingTestCase(unittest.TestCase): | |
239 | |
240 strings = ["abcv", "", "fdfsd423", "Ffasfas\r\n", | |
241 "523523\n\rfsdf", "4234"] | |
242 | |
243 def testChunks(self): | |
244 for s in self.strings: | |
245 self.assertEquals((s, ''), http.fromChunk(''.join(http.toChunk(s)))) | |
246 self.assertRaises(ValueError, http.fromChunk, '-5\r\nmalformed!\r\n') | |
247 | |
248 def testConcatenatedChunks(self): | |
249 chunked = ''.join([''.join(http.toChunk(t)) for t in self.strings]) | |
250 result = [] | |
251 buffer = "" | |
252 for c in chunked: | |
253 buffer = buffer + c | |
254 try: | |
255 data, buffer = http.fromChunk(buffer) | |
256 result.append(data) | |
257 except ValueError: | |
258 pass | |
259 self.assertEquals(result, self.strings) | |
260 | |
261 | |
262 | |
263 class ParsingTestCase(unittest.TestCase): | |
264 | |
265 def runRequest(self, httpRequest, requestClass, success=1): | |
266 httpRequest = httpRequest.replace("\n", "\r\n") | |
267 b = StringIOWithoutClosing() | |
268 a = http.HTTPChannel() | |
269 a.requestFactory = requestClass | |
270 a.makeConnection(protocol.FileWrapper(b)) | |
271 # one byte at a time, to stress it. | |
272 for byte in httpRequest: | |
273 if a.transport.closed: | |
274 break | |
275 a.dataReceived(byte) | |
276 a.connectionLost(IOError("all done")) | |
277 if success: | |
278 self.assertEquals(self.didRequest, 1) | |
279 del self.didRequest | |
280 else: | |
281 self.assert_(not hasattr(self, "didRequest")) | |
282 | |
283 def testBasicAuth(self): | |
284 testcase = self | |
285 class Request(http.Request): | |
286 l = [] | |
287 def process(self): | |
288 testcase.assertEquals(self.getUser(), self.l[0]) | |
289 testcase.assertEquals(self.getPassword(), self.l[1]) | |
290 for u, p in [("foo", "bar"), ("hello", "there:z")]: | |
291 Request.l[:] = [u, p] | |
292 s = "%s:%s" % (u, p) | |
293 f = "GET / HTTP/1.0\nAuthorization: Basic %s\n\n" % (s.encode("base6
4").strip(), ) | |
294 self.runRequest(f, Request, 0) | |
295 | |
296 def testTooManyHeaders(self): | |
297 httpRequest = "GET / HTTP/1.0\n" | |
298 for i in range(502): | |
299 httpRequest += "%s: foo\n" % i | |
300 httpRequest += "\n" | |
301 class MyRequest(http.Request): | |
302 def process(self): | |
303 raise RuntimeError, "should not get called" | |
304 self.runRequest(httpRequest, MyRequest, 0) | |
305 | |
306 def testHeaders(self): | |
307 httpRequest = """\ | |
308 GET / HTTP/1.0 | |
309 Foo: bar | |
310 baz: 1 2 3 | |
311 | |
312 """ | |
313 testcase = self | |
314 | |
315 class MyRequest(http.Request): | |
316 def process(self): | |
317 testcase.assertEquals(self.getHeader('foo'), 'bar') | |
318 testcase.assertEquals(self.getHeader('Foo'), 'bar') | |
319 testcase.assertEquals(self.getHeader('bAz'), '1 2 3') | |
320 testcase.didRequest = 1 | |
321 self.finish() | |
322 | |
323 self.runRequest(httpRequest, MyRequest) | |
324 | |
325 def testCookies(self): | |
326 """ | |
327 Test cookies parsing and reading. | |
328 """ | |
329 httpRequest = '''\ | |
330 GET / HTTP/1.0 | |
331 Cookie: rabbit="eat carrot"; ninja=secret; spam="hey 1=1!" | |
332 | |
333 ''' | |
334 testcase = self | |
335 | |
336 class MyRequest(http.Request): | |
337 def process(self): | |
338 testcase.assertEquals(self.getCookie('rabbit'), '"eat carrot"') | |
339 testcase.assertEquals(self.getCookie('ninja'), 'secret') | |
340 testcase.assertEquals(self.getCookie('spam'), '"hey 1=1!"') | |
341 testcase.didRequest = 1 | |
342 self.finish() | |
343 | |
344 self.runRequest(httpRequest, MyRequest) | |
345 | |
346 def testGET(self): | |
347 httpRequest = '''\ | |
348 GET /?key=value&multiple=two+words&multiple=more%20words&empty= HTTP/1.0 | |
349 | |
350 ''' | |
351 testcase = self | |
352 class MyRequest(http.Request): | |
353 def process(self): | |
354 testcase.assertEquals(self.method, "GET") | |
355 testcase.assertEquals(self.args["key"], ["value"]) | |
356 testcase.assertEquals(self.args["empty"], [""]) | |
357 testcase.assertEquals(self.args["multiple"], ["two words", "more
words"]) | |
358 testcase.didRequest = 1 | |
359 self.finish() | |
360 | |
361 self.runRequest(httpRequest, MyRequest) | |
362 | |
363 | |
364 def test_extraQuestionMark(self): | |
365 """ | |
366 While only a single '?' is allowed in an URL, several other servers | |
367 allow several and pass all after the first through as part of the | |
368 query arguments. Test that we emulate this behavior. | |
369 """ | |
370 httpRequest = 'GET /foo?bar=?&baz=quux HTTP/1.0\n\n' | |
371 | |
372 testcase = self | |
373 class MyRequest(http.Request): | |
374 def process(self): | |
375 testcase.assertEqual(self.method, 'GET') | |
376 testcase.assertEqual(self.path, '/foo') | |
377 testcase.assertEqual(self.args['bar'], ['?']) | |
378 testcase.assertEqual(self.args['baz'], ['quux']) | |
379 testcase.didRequest = 1 | |
380 self.finish() | |
381 | |
382 self.runRequest(httpRequest, MyRequest) | |
383 | |
384 | |
385 def testPOST(self): | |
386 query = 'key=value&multiple=two+words&multiple=more%20words&empty=' | |
387 httpRequest = '''\ | |
388 POST / HTTP/1.0 | |
389 Content-Length: %d | |
390 Content-Type: application/x-www-form-urlencoded | |
391 | |
392 %s''' % (len(query), query) | |
393 | |
394 testcase = self | |
395 class MyRequest(http.Request): | |
396 def process(self): | |
397 testcase.assertEquals(self.method, "POST") | |
398 testcase.assertEquals(self.args["key"], ["value"]) | |
399 testcase.assertEquals(self.args["empty"], [""]) | |
400 testcase.assertEquals(self.args["multiple"], ["two words", "more
words"]) | |
401 testcase.didRequest = 1 | |
402 self.finish() | |
403 | |
404 self.runRequest(httpRequest, MyRequest) | |
405 | |
406 def testMissingContentDisposition(self): | |
407 req = '''\ | |
408 POST / HTTP/1.0 | |
409 Content-Type: multipart/form-data; boundary=AaB03x | |
410 Content-Length: 103 | |
411 | |
412 --AaB03x | |
413 Content-Type: text/plain | |
414 Content-Transfer-Encoding: quoted-printable | |
415 | |
416 abasdfg | |
417 --AaB03x-- | |
418 ''' | |
419 self.runRequest(req, http.Request, success=False) | |
420 | |
421 class QueryArgumentsTestCase(unittest.TestCase): | |
422 def testUnquote(self): | |
423 try: | |
424 from twisted.protocols import _c_urlarg | |
425 except ImportError: | |
426 raise unittest.SkipTest("_c_urlarg module is not available") | |
427 # work exactly like urllib.unquote, including stupid things | |
428 # % followed by a non-hexdigit in the middle and in the end | |
429 self.failUnlessEqual(urllib.unquote("%notreally%n"), | |
430 _c_urlarg.unquote("%notreally%n")) | |
431 # % followed by hexdigit, followed by non-hexdigit | |
432 self.failUnlessEqual(urllib.unquote("%1quite%1"), | |
433 _c_urlarg.unquote("%1quite%1")) | |
434 # unquoted text, followed by some quoted chars, ends in a trailing % | |
435 self.failUnlessEqual(urllib.unquote("blah%21%40%23blah%"), | |
436 _c_urlarg.unquote("blah%21%40%23blah%")) | |
437 # Empty string | |
438 self.failUnlessEqual(urllib.unquote(""), _c_urlarg.unquote("")) | |
439 | |
440 def testParseqs(self): | |
441 self.failUnlessEqual(cgi.parse_qs("a=b&d=c;+=f"), | |
442 http.parse_qs("a=b&d=c;+=f")) | |
443 self.failUnlessRaises(ValueError, http.parse_qs, "blah", | |
444 strict_parsing = 1) | |
445 self.failUnlessEqual(cgi.parse_qs("a=&b=c", keep_blank_values = 1), | |
446 http.parse_qs("a=&b=c", keep_blank_values = 1)) | |
447 self.failUnlessEqual(cgi.parse_qs("a=&b=c"), | |
448 http.parse_qs("a=&b=c")) | |
449 | |
450 | |
451 def test_urlparse(self): | |
452 """ | |
453 For a given URL, L{http.urlparse} should behave the same as | |
454 L{urlparse}, except it should always return C{str}, never C{unicode}. | |
455 """ | |
456 def urls(): | |
457 for scheme in ('http', 'https'): | |
458 for host in ('example.com',): | |
459 for port in (None, 100): | |
460 for path in ('', 'path'): | |
461 if port is not None: | |
462 host = host + ':' + str(port) | |
463 yield urlunsplit((scheme, host, path, '', '')) | |
464 | |
465 | |
466 def assertSameParsing(url, decode): | |
467 """ | |
468 Verify that C{url} is parsed into the same objects by both | |
469 L{http.urlparse} and L{urlparse}. | |
470 """ | |
471 urlToStandardImplementation = url | |
472 if decode: | |
473 urlToStandardImplementation = url.decode('ascii') | |
474 standardResult = urlparse(urlToStandardImplementation) | |
475 scheme, netloc, path, params, query, fragment = http.urlparse(url) | |
476 self.assertEqual( | |
477 (scheme, netloc, path, params, query, fragment), | |
478 standardResult) | |
479 self.assertTrue(isinstance(scheme, str)) | |
480 self.assertTrue(isinstance(netloc, str)) | |
481 self.assertTrue(isinstance(path, str)) | |
482 self.assertTrue(isinstance(params, str)) | |
483 self.assertTrue(isinstance(query, str)) | |
484 self.assertTrue(isinstance(fragment, str)) | |
485 | |
486 # With caching, unicode then str | |
487 clear_cache() | |
488 for url in urls(): | |
489 assertSameParsing(url, True) | |
490 assertSameParsing(url, False) | |
491 | |
492 # With caching, str then unicode | |
493 clear_cache() | |
494 for url in urls(): | |
495 assertSameParsing(url, False) | |
496 assertSameParsing(url, True) | |
497 | |
498 # Without caching | |
499 for url in urls(): | |
500 clear_cache() | |
501 assertSameParsing(url, True) | |
502 clear_cache() | |
503 assertSameParsing(url, False) | |
504 | |
505 | |
506 def test_urlparseRejectsUnicode(self): | |
507 """ | |
508 L{http.urlparse} should reject unicode input early. | |
509 """ | |
510 self.assertRaises(TypeError, http.urlparse, u'http://example.org/path') | |
511 | |
512 | |
513 def testEscchar(self): | |
514 try: | |
515 from twisted.protocols import _c_urlarg | |
516 except ImportError: | |
517 raise unittest.SkipTest("_c_urlarg module is not available") | |
518 self.failUnlessEqual("!@#+b", | |
519 _c_urlarg.unquote("+21+40+23+b", "+")) | |
520 | |
521 class ClientDriver(http.HTTPClient): | |
522 def handleStatus(self, version, status, message): | |
523 self.version = version | |
524 self.status = status | |
525 self.message = message | |
526 | |
527 class ClientStatusParsing(unittest.TestCase): | |
528 def testBaseline(self): | |
529 c = ClientDriver() | |
530 c.lineReceived('HTTP/1.0 201 foo') | |
531 self.failUnlessEqual(c.version, 'HTTP/1.0') | |
532 self.failUnlessEqual(c.status, '201') | |
533 self.failUnlessEqual(c.message, 'foo') | |
534 | |
535 def testNoMessage(self): | |
536 c = ClientDriver() | |
537 c.lineReceived('HTTP/1.0 201') | |
538 self.failUnlessEqual(c.version, 'HTTP/1.0') | |
539 self.failUnlessEqual(c.status, '201') | |
540 self.failUnlessEqual(c.message, '') | |
541 | |
542 def testNoMessage_trailingSpace(self): | |
543 c = ClientDriver() | |
544 c.lineReceived('HTTP/1.0 201 ') | |
545 self.failUnlessEqual(c.version, 'HTTP/1.0') | |
546 self.failUnlessEqual(c.status, '201') | |
547 self.failUnlessEqual(c.message, '') | |
548 | |
OLD | NEW |