OLD | NEW |
| (Empty) |
1 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
2 # See LICENSE for details. | |
3 | |
4 """ | |
5 Tests for L{twisted.web.client}. | |
6 """ | |
7 | |
8 import os | |
9 | |
10 from urlparse import urlparse | |
11 | |
12 from twisted.trial import unittest | |
13 from twisted.web import server, static, client, error, util, resource | |
14 from twisted.internet import reactor, defer, interfaces | |
15 from twisted.python.filepath import FilePath | |
16 | |
17 try: | |
18 from twisted.internet import ssl | |
19 except: | |
20 ssl = None | |
21 | |
22 serverCallID = None | |
23 | |
24 class LongTimeTakingResource(resource.Resource): | |
25 def render(self, request): | |
26 global serverCallID | |
27 serverCallID = reactor.callLater(1, self.writeIt, request) | |
28 return server.NOT_DONE_YET | |
29 | |
30 def writeIt(self, request): | |
31 request.write("hello!!!") | |
32 request.finish() | |
33 | |
34 class CookieMirrorResource(resource.Resource): | |
35 def render(self, request): | |
36 l = [] | |
37 for k,v in request.received_cookies.items(): | |
38 l.append((k, v)) | |
39 l.sort() | |
40 return repr(l) | |
41 | |
42 class RawCookieMirrorResource(resource.Resource): | |
43 def render(self, request): | |
44 return repr(request.getHeader('cookie')) | |
45 | |
46 class ErrorResource(resource.Resource): | |
47 | |
48 def render(self, request): | |
49 request.setResponseCode(401) | |
50 if request.args.get("showlength"): | |
51 request.setHeader("content-length", "0") | |
52 return "" | |
53 | |
54 class NoLengthResource(resource.Resource): | |
55 | |
56 def render(self, request): | |
57 return "nolength" | |
58 | |
59 class HostHeaderResource(resource.Resource): | |
60 | |
61 def render(self, request): | |
62 return request.received_headers["host"] | |
63 | |
64 class PayloadResource(resource.Resource): | |
65 | |
66 def render(self, request): | |
67 data = request.content.read() | |
68 if len(data) != 100 or int(request.received_headers["content-length"]) !
= 100: | |
69 return "ERROR" | |
70 return data | |
71 | |
72 class BrokenDownloadResource(resource.Resource): | |
73 | |
74 def render(self, request): | |
75 # only sends 3 bytes even though it claims to send 5 | |
76 request.setHeader("content-length", "5") | |
77 request.write('abc') | |
78 return '' | |
79 | |
80 | |
81 | |
82 class ParseUrlTestCase(unittest.TestCase): | |
83 """ | |
84 Test URL parsing facility and defaults values. | |
85 """ | |
86 | |
87 def testParse(self): | |
88 scheme, host, port, path = client._parse("http://127.0.0.1/") | |
89 self.assertEquals(path, "/") | |
90 self.assertEquals(port, 80) | |
91 scheme, host, port, path = client._parse("https://127.0.0.1/") | |
92 self.assertEquals(path, "/") | |
93 self.assertEquals(port, 443) | |
94 scheme, host, port, path = client._parse("http://spam:12345/") | |
95 self.assertEquals(port, 12345) | |
96 scheme, host, port, path = client._parse("http://foo ") | |
97 self.assertEquals(host, "foo") | |
98 self.assertEquals(path, "/") | |
99 scheme, host, port, path = client._parse("http://egg:7890") | |
100 self.assertEquals(port, 7890) | |
101 self.assertEquals(host, "egg") | |
102 self.assertEquals(path, "/") | |
103 | |
104 | |
105 def test_externalUnicodeInterference(self): | |
106 """ | |
107 L{client._parse} should return C{str} for the scheme, host, and path | |
108 elements of its return tuple, even when passed an URL which has | |
109 previously been passed to L{urlparse} as a C{unicode} string. | |
110 """ | |
111 badInput = u'http://example.com/path' | |
112 goodInput = badInput.encode('ascii') | |
113 urlparse(badInput) | |
114 scheme, host, port, path = client._parse(goodInput) | |
115 self.assertTrue(isinstance(scheme, str)) | |
116 self.assertTrue(isinstance(host, str)) | |
117 self.assertTrue(isinstance(path, str)) | |
118 | |
119 | |
120 | |
121 class WebClientTestCase(unittest.TestCase): | |
122 def _listen(self, site): | |
123 return reactor.listenTCP(0, site, interface="127.0.0.1") | |
124 | |
125 def setUp(self): | |
126 name = self.mktemp() | |
127 os.mkdir(name) | |
128 FilePath(name).child("file").setContent("0123456789") | |
129 r = static.File(name) | |
130 r.putChild("redirect", util.Redirect("/file")) | |
131 r.putChild("wait", LongTimeTakingResource()) | |
132 r.putChild("error", ErrorResource()) | |
133 r.putChild("nolength", NoLengthResource()) | |
134 r.putChild("host", HostHeaderResource()) | |
135 r.putChild("payload", PayloadResource()) | |
136 r.putChild("broken", BrokenDownloadResource()) | |
137 site = server.Site(r, timeout=None) | |
138 self.port = self._listen(site) | |
139 self.portno = self.port.getHost().port | |
140 | |
141 def tearDown(self): | |
142 if serverCallID and serverCallID.active(): | |
143 serverCallID.cancel() | |
144 return self.port.stopListening() | |
145 | |
146 def getURL(self, path): | |
147 return "http://127.0.0.1:%d/%s" % (self.portno, path) | |
148 | |
149 def testPayload(self): | |
150 s = "0123456789" * 10 | |
151 return client.getPage(self.getURL("payload"), postdata=s | |
152 ).addCallback(self.assertEquals, s | |
153 ) | |
154 | |
155 def testBrokenDownload(self): | |
156 # test what happens when download gets disconnected in the middle | |
157 d = client.getPage(self.getURL("broken")) | |
158 d = self.assertFailure(d, client.PartialDownloadError) | |
159 d.addCallback(lambda exc: self.assertEquals(exc.response, "abc")) | |
160 return d | |
161 | |
162 def testHostHeader(self): | |
163 # if we pass Host header explicitly, it should be used, otherwise | |
164 # it should extract from url | |
165 return defer.gatherResults([ | |
166 client.getPage(self.getURL("host")).addCallback(self.assertEquals, "
127.0.0.1"), | |
167 client.getPage(self.getURL("host"), headers={"Host": "www.example.co
m"}).addCallback(self.assertEquals, "www.example.com")]) | |
168 | |
169 | |
170 def test_getPage(self): | |
171 """ | |
172 L{client.getPage} returns a L{Deferred} which is called back with | |
173 the body of the response if the default method B{GET} is used. | |
174 """ | |
175 d = client.getPage(self.getURL("file")) | |
176 d.addCallback(self.assertEquals, "0123456789") | |
177 return d | |
178 | |
179 | |
180 def test_getPageHead(self): | |
181 """ | |
182 L{client.getPage} returns a L{Deferred} which is called back with | |
183 the empty string if the method is C{HEAD} and there is a successful | |
184 response code. | |
185 """ | |
186 def getPage(method): | |
187 return client.getPage(self.getURL("file"), method=method) | |
188 return defer.gatherResults([ | |
189 getPage("head").addCallback(self.assertEqual, ""), | |
190 getPage("HEAD").addCallback(self.assertEqual, "")]) | |
191 | |
192 | |
193 def testTimeoutNotTriggering(self): | |
194 # Test that when the timeout doesn't trigger, things work as expected. | |
195 d = client.getPage(self.getURL("wait"), timeout=100) | |
196 d.addCallback(self.assertEquals, "hello!!!") | |
197 return d | |
198 | |
199 def testTimeoutTriggering(self): | |
200 # Test that when the timeout does trigger, we get a defer.TimeoutError. | |
201 return self.assertFailure( | |
202 client.getPage(self.getURL("wait"), timeout=0.5), | |
203 defer.TimeoutError) | |
204 | |
205 def testDownloadPage(self): | |
206 downloads = [] | |
207 downloadData = [("file", self.mktemp(), "0123456789"), | |
208 ("nolength", self.mktemp(), "nolength")] | |
209 | |
210 for (url, name, data) in downloadData: | |
211 d = client.downloadPage(self.getURL(url), name) | |
212 d.addCallback(self._cbDownloadPageTest, data, name) | |
213 downloads.append(d) | |
214 return defer.gatherResults(downloads) | |
215 | |
216 def _cbDownloadPageTest(self, ignored, data, name): | |
217 bytes = file(name, "rb").read() | |
218 self.assertEquals(bytes, data) | |
219 | |
220 def testDownloadPageError1(self): | |
221 class errorfile: | |
222 def write(self, data): | |
223 raise IOError, "badness happened during write" | |
224 def close(self): | |
225 pass | |
226 ef = errorfile() | |
227 return self.assertFailure( | |
228 client.downloadPage(self.getURL("file"), ef), | |
229 IOError) | |
230 | |
231 def testDownloadPageError2(self): | |
232 class errorfile: | |
233 def write(self, data): | |
234 pass | |
235 def close(self): | |
236 raise IOError, "badness happened during close" | |
237 ef = errorfile() | |
238 return self.assertFailure( | |
239 client.downloadPage(self.getURL("file"), ef), | |
240 IOError) | |
241 | |
242 def testDownloadPageError3(self): | |
243 # make sure failures in open() are caught too. This is tricky. | |
244 # Might only work on posix. | |
245 tmpfile = open("unwritable", "wb") | |
246 tmpfile.close() | |
247 os.chmod("unwritable", 0) # make it unwritable (to us) | |
248 d = self.assertFailure( | |
249 client.downloadPage(self.getURL("file"), "unwritable"), | |
250 IOError) | |
251 d.addBoth(self._cleanupDownloadPageError3) | |
252 return d | |
253 | |
254 def _cleanupDownloadPageError3(self, ignored): | |
255 os.chmod("unwritable", 0700) | |
256 os.unlink("unwritable") | |
257 return ignored | |
258 | |
259 def _downloadTest(self, method): | |
260 dl = [] | |
261 for (url, code) in [("nosuchfile", "404"), ("error", "401"), | |
262 ("error?showlength=1", "401")]: | |
263 d = method(url) | |
264 d = self.assertFailure(d, error.Error) | |
265 d.addCallback(lambda exc, code=code: self.assertEquals(exc.args[0],
code)) | |
266 dl.append(d) | |
267 return defer.DeferredList(dl, fireOnOneErrback=True) | |
268 | |
269 def testServerError(self): | |
270 return self._downloadTest(lambda url: client.getPage(self.getURL(url))) | |
271 | |
272 def testDownloadServerError(self): | |
273 return self._downloadTest(lambda url: client.downloadPage(self.getURL(ur
l), url.split('?')[0])) | |
274 | |
275 def testFactoryInfo(self): | |
276 url = self.getURL('file') | |
277 scheme, host, port, path = client._parse(url) | |
278 factory = client.HTTPClientFactory(url) | |
279 reactor.connectTCP(host, port, factory) | |
280 return factory.deferred.addCallback(self._cbFactoryInfo, factory) | |
281 | |
282 def _cbFactoryInfo(self, ignoredResult, factory): | |
283 self.assertEquals(factory.status, '200') | |
284 self.assert_(factory.version.startswith('HTTP/')) | |
285 self.assertEquals(factory.message, 'OK') | |
286 self.assertEquals(factory.response_headers['content-length'][0], '10') | |
287 | |
288 | |
289 def testRedirect(self): | |
290 return client.getPage(self.getURL("redirect")).addCallback(self._cbRedir
ect) | |
291 | |
292 def _cbRedirect(self, pageData): | |
293 self.assertEquals(pageData, "0123456789") | |
294 d = self.assertFailure( | |
295 client.getPage(self.getURL("redirect"), followRedirect=0), | |
296 error.PageRedirect) | |
297 d.addCallback(self._cbCheckLocation) | |
298 return d | |
299 | |
300 def _cbCheckLocation(self, exc): | |
301 self.assertEquals(exc.location, "/file") | |
302 | |
303 def testPartial(self): | |
304 name = self.mktemp() | |
305 f = open(name, "wb") | |
306 f.write("abcd") | |
307 f.close() | |
308 | |
309 downloads = [] | |
310 partialDownload = [(True, "abcd456789"), | |
311 (True, "abcd456789"), | |
312 (False, "0123456789")] | |
313 | |
314 d = defer.succeed(None) | |
315 for (partial, expectedData) in partialDownload: | |
316 d.addCallback(self._cbRunPartial, name, partial) | |
317 d.addCallback(self._cbPartialTest, expectedData, name) | |
318 | |
319 return d | |
320 | |
321 testPartial.skip = "Cannot test until webserver can serve partial data prope
rly" | |
322 | |
323 def _cbRunPartial(self, ignored, name, partial): | |
324 return client.downloadPage(self.getURL("file"), name, supportPartial=par
tial) | |
325 | |
326 def _cbPartialTest(self, ignored, expectedData, filename): | |
327 bytes = file(filename, "rb").read() | |
328 self.assertEquals(bytes, expectedData) | |
329 | |
330 class WebClientSSLTestCase(WebClientTestCase): | |
331 def _listen(self, site): | |
332 from twisted import test | |
333 return reactor.listenSSL(0, site, | |
334 contextFactory=ssl.DefaultOpenSSLContextFactory
( | |
335 FilePath(test.__file__).sibling('server.pem').path, | |
336 FilePath(test.__file__).sibling('server.pem').path, | |
337 ), | |
338 interface="127.0.0.1") | |
339 | |
340 def getURL(self, path): | |
341 return "https://127.0.0.1:%d/%s" % (self.portno, path) | |
342 | |
343 def testFactoryInfo(self): | |
344 url = self.getURL('file') | |
345 scheme, host, port, path = client._parse(url) | |
346 factory = client.HTTPClientFactory(url) | |
347 reactor.connectSSL(host, port, factory, ssl.ClientContextFactory()) | |
348 # The base class defines _cbFactoryInfo correctly for this | |
349 return factory.deferred.addCallback(self._cbFactoryInfo, factory) | |
350 | |
351 class WebClientRedirectBetweenSSLandPlainText(unittest.TestCase): | |
352 def getHTTPS(self, path): | |
353 return "https://127.0.0.1:%d/%s" % (self.tlsPortno, path) | |
354 | |
355 def getHTTP(self, path): | |
356 return "http://127.0.0.1:%d/%s" % (self.plainPortno, path) | |
357 | |
358 def setUp(self): | |
359 plainRoot = static.Data('not me', 'text/plain') | |
360 tlsRoot = static.Data('me neither', 'text/plain') | |
361 | |
362 plainSite = server.Site(plainRoot, timeout=None) | |
363 tlsSite = server.Site(tlsRoot, timeout=None) | |
364 | |
365 from twisted import test | |
366 self.tlsPort = reactor.listenSSL(0, tlsSite, | |
367 contextFactory=ssl.DefaultOpenSSLContex
tFactory( | |
368 FilePath(test.__file__).sibling('server.pem').path, | |
369 FilePath(test.__file__).sibling('server.pem').path, | |
370 ), | |
371 interface="127.0.0.1") | |
372 self.plainPort = reactor.listenTCP(0, plainSite, interface="127.0.0.1") | |
373 | |
374 self.plainPortno = self.plainPort.getHost().port | |
375 self.tlsPortno = self.tlsPort.getHost().port | |
376 | |
377 plainRoot.putChild('one', util.Redirect(self.getHTTPS('two'))) | |
378 tlsRoot.putChild('two', util.Redirect(self.getHTTP('three'))) | |
379 plainRoot.putChild('three', util.Redirect(self.getHTTPS('four'))) | |
380 tlsRoot.putChild('four', static.Data('FOUND IT!', 'text/plain')) | |
381 | |
382 def tearDown(self): | |
383 ds = map(defer.maybeDeferred, | |
384 [self.plainPort.stopListening, self.tlsPort.stopListening]) | |
385 return defer.gatherResults(ds) | |
386 | |
387 def testHoppingAround(self): | |
388 return client.getPage(self.getHTTP("one") | |
389 ).addCallback(self.assertEquals, "FOUND IT!" | |
390 ) | |
391 | |
392 class FakeTransport: | |
393 disconnecting = False | |
394 def __init__(self): | |
395 self.data = [] | |
396 def write(self, stuff): | |
397 self.data.append(stuff) | |
398 | |
399 class CookieTestCase(unittest.TestCase): | |
400 def _listen(self, site): | |
401 return reactor.listenTCP(0, site, interface="127.0.0.1") | |
402 | |
403 def setUp(self): | |
404 root = static.Data('El toro!', 'text/plain') | |
405 root.putChild("cookiemirror", CookieMirrorResource()) | |
406 root.putChild("rawcookiemirror", RawCookieMirrorResource()) | |
407 site = server.Site(root, timeout=None) | |
408 self.port = self._listen(site) | |
409 self.portno = self.port.getHost().port | |
410 | |
411 def tearDown(self): | |
412 return self.port.stopListening() | |
413 | |
414 def getHTTP(self, path): | |
415 return "http://127.0.0.1:%d/%s" % (self.portno, path) | |
416 | |
417 def testNoCookies(self): | |
418 return client.getPage(self.getHTTP("cookiemirror") | |
419 ).addCallback(self.assertEquals, "[]" | |
420 ) | |
421 | |
422 def testSomeCookies(self): | |
423 cookies = {'foo': 'bar', 'baz': 'quux'} | |
424 return client.getPage(self.getHTTP("cookiemirror"), cookies=cookies | |
425 ).addCallback(self.assertEquals, "[('baz', 'quux'), ('foo', 'bar')]" | |
426 ) | |
427 | |
428 def testRawNoCookies(self): | |
429 return client.getPage(self.getHTTP("rawcookiemirror") | |
430 ).addCallback(self.assertEquals, "None" | |
431 ) | |
432 | |
433 def testRawSomeCookies(self): | |
434 cookies = {'foo': 'bar', 'baz': 'quux'} | |
435 return client.getPage(self.getHTTP("rawcookiemirror"), cookies=cookies | |
436 ).addCallback(self.assertEquals, "'foo=bar; baz=quux'" | |
437 ) | |
438 | |
439 def testCookieHeaderParsing(self): | |
440 d = defer.Deferred() | |
441 factory = client.HTTPClientFactory('http://foo.example.com/') | |
442 proto = factory.buildProtocol('127.42.42.42') | |
443 proto.transport = FakeTransport() | |
444 proto.connectionMade() | |
445 for line in [ | |
446 '200 Ok', | |
447 'Squash: yes', | |
448 'Hands: stolen', | |
449 'Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/; expires=Wednesday, 09-N
ov-99 23:12:40 GMT', | |
450 'Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/', | |
451 'Set-Cookie: SHIPPING=FEDEX; path=/foo', | |
452 '', | |
453 'body', | |
454 'more body', | |
455 ]: | |
456 proto.dataReceived(line + '\r\n') | |
457 self.assertEquals(proto.transport.data, | |
458 ['GET / HTTP/1.0\r\n', | |
459 'Host: foo.example.com\r\n', | |
460 'User-Agent: Twisted PageGetter\r\n', | |
461 '\r\n']) | |
462 self.assertEquals(factory.cookies, | |
463 { | |
464 'CUSTOMER': 'WILE_E_COYOTE', | |
465 'PART_NUMBER': 'ROCKET_LAUNCHER_0001', | |
466 'SHIPPING': 'FEDEX', | |
467 }) | |
468 | |
469 if ssl is None or not hasattr(ssl, 'DefaultOpenSSLContextFactory'): | |
470 for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]: | |
471 case.skip = "OpenSSL not present" | |
472 | |
473 if not interfaces.IReactorSSL(reactor, None): | |
474 for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]: | |
475 case.skip = "Reactor doesn't support SSL" | |
OLD | NEW |