OLD | NEW |
| (Empty) |
1 # -*- test-case-name: twisted.web.test.test_webclient -*- | |
2 # Copyright (c) 2001-2008 Twisted Matrix Laboratories. | |
3 # See LICENSE for details. | |
4 | |
5 """ | |
6 HTTP client. | |
7 """ | |
8 | |
9 import os, types | |
10 from urlparse import urlunparse | |
11 | |
12 from twisted.web import http | |
13 from twisted.internet import defer, protocol, reactor | |
14 from twisted.python import failure | |
15 from twisted.python.util import InsensitiveDict | |
16 from twisted.web import error | |
17 | |
18 | |
19 class PartialDownloadError(error.Error): | |
20 """Page was only partially downloaded, we got disconnected in middle. | |
21 | |
22 The bit that was downloaded is in the response attribute. | |
23 """ | |
24 | |
25 | |
26 class HTTPPageGetter(http.HTTPClient): | |
27 | |
28 quietLoss = 0 | |
29 followRedirect = 1 | |
30 failed = 0 | |
31 | |
32 def connectionMade(self): | |
33 method = getattr(self.factory, 'method', 'GET') | |
34 self.sendCommand(method, self.factory.path) | |
35 self.sendHeader('Host', self.factory.headers.get("host", self.factory.ho
st)) | |
36 self.sendHeader('User-Agent', self.factory.agent) | |
37 if self.factory.cookies: | |
38 l=[] | |
39 for cookie, cookval in self.factory.cookies.items(): | |
40 l.append('%s=%s' % (cookie, cookval)) | |
41 self.sendHeader('Cookie', '; '.join(l)) | |
42 data = getattr(self.factory, 'postdata', None) | |
43 if data is not None: | |
44 self.sendHeader("Content-Length", str(len(data))) | |
45 for (key, value) in self.factory.headers.items(): | |
46 if key.lower() != "content-length": | |
47 # we calculated it on our own | |
48 self.sendHeader(key, value) | |
49 self.endHeaders() | |
50 self.headers = {} | |
51 | |
52 if data is not None: | |
53 self.transport.write(data) | |
54 | |
55 def handleHeader(self, key, value): | |
56 key = key.lower() | |
57 l = self.headers[key] = self.headers.get(key, []) | |
58 l.append(value) | |
59 | |
60 def handleStatus(self, version, status, message): | |
61 self.version, self.status, self.message = version, status, message | |
62 self.factory.gotStatus(version, status, message) | |
63 | |
64 def handleEndHeaders(self): | |
65 self.factory.gotHeaders(self.headers) | |
66 m = getattr(self, 'handleStatus_'+self.status, self.handleStatusDefault) | |
67 m() | |
68 | |
69 def handleStatus_200(self): | |
70 pass | |
71 | |
72 handleStatus_201 = lambda self: self.handleStatus_200() | |
73 handleStatus_202 = lambda self: self.handleStatus_200() | |
74 | |
75 def handleStatusDefault(self): | |
76 self.failed = 1 | |
77 | |
78 def handleStatus_301(self): | |
79 l = self.headers.get('location') | |
80 if not l: | |
81 self.handleStatusDefault() | |
82 return | |
83 url = l[0] | |
84 if self.followRedirect: | |
85 scheme, host, port, path = \ | |
86 _parse(url, defaultPort=self.transport.getPeer().port) | |
87 self.factory.setURL(url) | |
88 | |
89 if self.factory.scheme == 'https': | |
90 from twisted.internet import ssl | |
91 contextFactory = ssl.ClientContextFactory() | |
92 reactor.connectSSL(self.factory.host, self.factory.port, | |
93 self.factory, contextFactory) | |
94 else: | |
95 reactor.connectTCP(self.factory.host, self.factory.port, | |
96 self.factory) | |
97 else: | |
98 self.handleStatusDefault() | |
99 self.factory.noPage( | |
100 failure.Failure( | |
101 error.PageRedirect( | |
102 self.status, self.message, location = url))) | |
103 self.quietLoss = 1 | |
104 self.transport.loseConnection() | |
105 | |
106 handleStatus_302 = lambda self: self.handleStatus_301() | |
107 | |
108 def handleStatus_303(self): | |
109 self.factory.method = 'GET' | |
110 self.handleStatus_301() | |
111 | |
112 def connectionLost(self, reason): | |
113 if not self.quietLoss: | |
114 http.HTTPClient.connectionLost(self, reason) | |
115 self.factory.noPage(reason) | |
116 | |
117 def handleResponse(self, response): | |
118 if self.quietLoss: | |
119 return | |
120 if self.failed: | |
121 self.factory.noPage( | |
122 failure.Failure( | |
123 error.Error( | |
124 self.status, self.message, response))) | |
125 if self.factory.method.upper() == 'HEAD': | |
126 # Callback with empty string, since there is never a response | |
127 # body for HEAD requests. | |
128 self.factory.page('') | |
129 elif self.length != None and self.length != 0: | |
130 self.factory.noPage(failure.Failure( | |
131 PartialDownloadError(self.status, self.message, response))) | |
132 else: | |
133 self.factory.page(response) | |
134 # server might be stupid and not close connection. admittedly | |
135 # the fact we do only one request per connection is also | |
136 # stupid... | |
137 self.transport.loseConnection() | |
138 | |
139 def timeout(self): | |
140 self.quietLoss = True | |
141 self.transport.loseConnection() | |
142 self.factory.noPage(defer.TimeoutError("Getting %s took longer than %s s
econds." % (self.factory.url, self.factory.timeout))) | |
143 | |
144 | |
145 class HTTPPageDownloader(HTTPPageGetter): | |
146 | |
147 transmittingPage = 0 | |
148 | |
149 def handleStatus_200(self, partialContent=0): | |
150 HTTPPageGetter.handleStatus_200(self) | |
151 self.transmittingPage = 1 | |
152 self.factory.pageStart(partialContent) | |
153 | |
154 def handleStatus_206(self): | |
155 self.handleStatus_200(partialContent=1) | |
156 | |
157 def handleResponsePart(self, data): | |
158 if self.transmittingPage: | |
159 self.factory.pagePart(data) | |
160 | |
161 def handleResponseEnd(self): | |
162 if self.transmittingPage: | |
163 self.factory.pageEnd() | |
164 self.transmittingPage = 0 | |
165 if self.failed: | |
166 self.factory.noPage( | |
167 failure.Failure( | |
168 error.Error( | |
169 self.status, self.message, None))) | |
170 self.transport.loseConnection() | |
171 | |
172 | |
173 class HTTPClientFactory(protocol.ClientFactory): | |
174 """Download a given URL. | |
175 | |
176 @type deferred: Deferred | |
177 @ivar deferred: A Deferred that will fire when the content has | |
178 been retrieved. Once this is fired, the ivars `status', `version', | |
179 and `message' will be set. | |
180 | |
181 @type status: str | |
182 @ivar status: The status of the response. | |
183 | |
184 @type version: str | |
185 @ivar version: The version of the response. | |
186 | |
187 @type message: str | |
188 @ivar message: The text message returned with the status. | |
189 | |
190 @type response_headers: dict | |
191 @ivar response_headers: The headers that were specified in the | |
192 response from the server. | |
193 """ | |
194 | |
195 protocol = HTTPPageGetter | |
196 | |
197 url = None | |
198 scheme = None | |
199 host = '' | |
200 port = None | |
201 path = None | |
202 | |
203 def __init__(self, url, method='GET', postdata=None, headers=None, | |
204 agent="Twisted PageGetter", timeout=0, cookies=None, | |
205 followRedirect=1): | |
206 self.protocol.followRedirect = followRedirect | |
207 self.timeout = timeout | |
208 self.agent = agent | |
209 | |
210 if cookies is None: | |
211 cookies = {} | |
212 self.cookies = cookies | |
213 if headers is not None: | |
214 self.headers = InsensitiveDict(headers) | |
215 else: | |
216 self.headers = InsensitiveDict() | |
217 if postdata is not None: | |
218 self.headers.setdefault('Content-Length', len(postdata)) | |
219 # just in case a broken http/1.1 decides to keep connection alive | |
220 self.headers.setdefault("connection", "close") | |
221 self.postdata = postdata | |
222 self.method = method | |
223 | |
224 self.setURL(url) | |
225 | |
226 self.waiting = 1 | |
227 self.deferred = defer.Deferred() | |
228 self.response_headers = None | |
229 | |
230 def __repr__(self): | |
231 return "<%s: %s>" % (self.__class__.__name__, self.url) | |
232 | |
233 def setURL(self, url): | |
234 self.url = url | |
235 scheme, host, port, path = _parse(url) | |
236 if scheme and host: | |
237 self.scheme = scheme | |
238 self.host = host | |
239 self.port = port | |
240 self.path = path | |
241 | |
242 def buildProtocol(self, addr): | |
243 p = protocol.ClientFactory.buildProtocol(self, addr) | |
244 if self.timeout: | |
245 timeoutCall = reactor.callLater(self.timeout, p.timeout) | |
246 self.deferred.addBoth(self._cancelTimeout, timeoutCall) | |
247 return p | |
248 | |
249 def _cancelTimeout(self, result, timeoutCall): | |
250 if timeoutCall.active(): | |
251 timeoutCall.cancel() | |
252 return result | |
253 | |
254 def gotHeaders(self, headers): | |
255 self.response_headers = headers | |
256 if headers.has_key('set-cookie'): | |
257 for cookie in headers['set-cookie']: | |
258 cookparts = cookie.split(';') | |
259 cook = cookparts[0] | |
260 cook.lstrip() | |
261 k, v = cook.split('=', 1) | |
262 self.cookies[k.lstrip()] = v.lstrip() | |
263 | |
264 def gotStatus(self, version, status, message): | |
265 self.version, self.status, self.message = version, status, message | |
266 | |
267 def page(self, page): | |
268 if self.waiting: | |
269 self.waiting = 0 | |
270 self.deferred.callback(page) | |
271 | |
272 def noPage(self, reason): | |
273 if self.waiting: | |
274 self.waiting = 0 | |
275 self.deferred.errback(reason) | |
276 | |
277 def clientConnectionFailed(self, _, reason): | |
278 if self.waiting: | |
279 self.waiting = 0 | |
280 self.deferred.errback(reason) | |
281 | |
282 | |
283 class HTTPDownloader(HTTPClientFactory): | |
284 """Download to a file.""" | |
285 | |
286 protocol = HTTPPageDownloader | |
287 value = None | |
288 | |
289 def __init__(self, url, fileOrName, | |
290 method='GET', postdata=None, headers=None, | |
291 agent="Twisted client", supportPartial=0): | |
292 self.requestedPartial = 0 | |
293 if isinstance(fileOrName, types.StringTypes): | |
294 self.fileName = fileOrName | |
295 self.file = None | |
296 if supportPartial and os.path.exists(self.fileName): | |
297 fileLength = os.path.getsize(self.fileName) | |
298 if fileLength: | |
299 self.requestedPartial = fileLength | |
300 if headers == None: | |
301 headers = {} | |
302 headers["range"] = "bytes=%d-" % fileLength | |
303 else: | |
304 self.file = fileOrName | |
305 HTTPClientFactory.__init__(self, url, method=method, postdata=postdata,
headers=headers, agent=agent) | |
306 self.deferred = defer.Deferred() | |
307 self.waiting = 1 | |
308 | |
309 def gotHeaders(self, headers): | |
310 if self.requestedPartial: | |
311 contentRange = headers.get("content-range", None) | |
312 if not contentRange: | |
313 # server doesn't support partial requests, oh well | |
314 self.requestedPartial = 0 | |
315 return | |
316 start, end, realLength = http.parseContentRange(contentRange[0]) | |
317 if start != self.requestedPartial: | |
318 # server is acting wierdly | |
319 self.requestedPartial = 0 | |
320 | |
321 def openFile(self, partialContent): | |
322 if partialContent: | |
323 file = open(self.fileName, 'rb+') | |
324 file.seek(0, 2) | |
325 else: | |
326 file = open(self.fileName, 'wb') | |
327 return file | |
328 | |
329 def pageStart(self, partialContent): | |
330 """Called on page download start. | |
331 | |
332 @param partialContent: tells us if the download is partial download we r
equested. | |
333 """ | |
334 if partialContent and not self.requestedPartial: | |
335 raise ValueError, "we shouldn't get partial content response if we d
idn't want it!" | |
336 if self.waiting: | |
337 self.waiting = 0 | |
338 try: | |
339 if not self.file: | |
340 self.file = self.openFile(partialContent) | |
341 except IOError: | |
342 #raise | |
343 self.deferred.errback(failure.Failure()) | |
344 | |
345 def pagePart(self, data): | |
346 if not self.file: | |
347 return | |
348 try: | |
349 self.file.write(data) | |
350 except IOError: | |
351 #raise | |
352 self.file = None | |
353 self.deferred.errback(failure.Failure()) | |
354 | |
355 def pageEnd(self): | |
356 if not self.file: | |
357 return | |
358 try: | |
359 self.file.close() | |
360 except IOError: | |
361 self.deferred.errback(failure.Failure()) | |
362 return | |
363 self.deferred.callback(self.value) | |
364 | |
365 | |
366 def _parse(url, defaultPort=None): | |
367 """ | |
368 Split the given URL into the scheme, host, port, and path. | |
369 | |
370 @type url: C{str} | |
371 @param url: An URL to parse. | |
372 | |
373 @type defaultPort: C{int} or C{None} | |
374 @param defaultPort: An alternate value to use as the port if the URL does | |
375 not include one. | |
376 | |
377 @return: A four-tuple of the scheme, host, port, and path of the URL. All | |
378 of these are C{str} instances except for port, which is an C{int}. | |
379 """ | |
380 url = url.strip() | |
381 parsed = http.urlparse(url) | |
382 scheme = parsed[0] | |
383 path = urlunparse(('','')+parsed[2:]) | |
384 if defaultPort is None: | |
385 if scheme == 'https': | |
386 defaultPort = 443 | |
387 else: | |
388 defaultPort = 80 | |
389 host, port = parsed[1], defaultPort | |
390 if ':' in host: | |
391 host, port = host.split(':') | |
392 port = int(port) | |
393 if path == "": | |
394 path = "/" | |
395 return scheme, host, port, path | |
396 | |
397 | |
398 def getPage(url, contextFactory=None, *args, **kwargs): | |
399 """Download a web page as a string. | |
400 | |
401 Download a page. Return a deferred, which will callback with a | |
402 page (as a string) or errback with a description of the error. | |
403 | |
404 See HTTPClientFactory to see what extra args can be passed. | |
405 """ | |
406 scheme, host, port, path = _parse(url) | |
407 factory = HTTPClientFactory(url, *args, **kwargs) | |
408 if scheme == 'https': | |
409 from twisted.internet import ssl | |
410 if contextFactory is None: | |
411 contextFactory = ssl.ClientContextFactory() | |
412 reactor.connectSSL(host, port, factory, contextFactory) | |
413 else: | |
414 reactor.connectTCP(host, port, factory) | |
415 return factory.deferred | |
416 | |
417 | |
418 def downloadPage(url, file, contextFactory=None, *args, **kwargs): | |
419 """Download a web page to a file. | |
420 | |
421 @param file: path to file on filesystem, or file-like object. | |
422 | |
423 See HTTPDownloader to see what extra args can be passed. | |
424 """ | |
425 scheme, host, port, path = _parse(url) | |
426 factory = HTTPDownloader(url, file, *args, **kwargs) | |
427 if scheme == 'https': | |
428 from twisted.internet import ssl | |
429 if contextFactory is None: | |
430 contextFactory = ssl.ClientContextFactory() | |
431 reactor.connectSSL(host, port, factory, contextFactory) | |
432 else: | |
433 reactor.connectTCP(host, port, factory) | |
434 return factory.deferred | |
OLD | NEW |