OLD | NEW |
| (Empty) |
1 '''SSL with SNI-support for Python 2. | |
2 | |
3 This needs the following packages installed: | |
4 | |
5 * pyOpenSSL (tested with 0.13) | |
6 * ndg-httpsclient (tested with 0.3.2) | |
7 * pyasn1 (tested with 0.1.6) | |
8 | |
9 To activate it call :func:`~urllib3.contrib.pyopenssl.inject_into_urllib3`. | |
10 This can be done in a ``sitecustomize`` module, or at any other time before | |
11 your application begins using ``urllib3``, like this:: | |
12 | |
13 try: | |
14 import urllib3.contrib.pyopenssl | |
15 urllib3.contrib.pyopenssl.inject_into_urllib3() | |
16 except ImportError: | |
17 pass | |
18 | |
19 Now you can use :mod:`urllib3` as you normally would, and it will support SNI | |
20 when the required modules are installed. | |
21 ''' | |
22 | |
23 from ndg.httpsclient.ssl_peer_verification import SUBJ_ALT_NAME_SUPPORT | |
24 from ndg.httpsclient.subj_alt_name import SubjectAltName | |
25 import OpenSSL.SSL | |
26 from pyasn1.codec.der import decoder as der_decoder | |
27 from socket import _fileobject | |
28 import ssl | |
29 from cStringIO import StringIO | |
30 | |
31 from .. import connectionpool | |
32 from .. import util | |
33 | |
34 __all__ = ['inject_into_urllib3', 'extract_from_urllib3'] | |
35 | |
36 # SNI only *really* works if we can read the subjectAltName of certificates. | |
37 HAS_SNI = SUBJ_ALT_NAME_SUPPORT | |
38 | |
39 # Map from urllib3 to PyOpenSSL compatible parameter-values. | |
40 _openssl_versions = { | |
41 ssl.PROTOCOL_SSLv23: OpenSSL.SSL.SSLv23_METHOD, | |
42 ssl.PROTOCOL_SSLv3: OpenSSL.SSL.SSLv3_METHOD, | |
43 ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, | |
44 } | |
45 _openssl_verify = { | |
46 ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, | |
47 ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, | |
48 ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER | |
49 + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, | |
50 } | |
51 | |
52 | |
53 orig_util_HAS_SNI = util.HAS_SNI | |
54 orig_connectionpool_ssl_wrap_socket = connectionpool.ssl_wrap_socket | |
55 | |
56 | |
57 def inject_into_urllib3(): | |
58 'Monkey-patch urllib3 with PyOpenSSL-backed SSL-support.' | |
59 | |
60 connectionpool.ssl_wrap_socket = ssl_wrap_socket | |
61 util.HAS_SNI = HAS_SNI | |
62 | |
63 | |
64 def extract_from_urllib3(): | |
65 'Undo monkey-patching by :func:`inject_into_urllib3`.' | |
66 | |
67 connectionpool.ssl_wrap_socket = orig_connectionpool_ssl_wrap_socket | |
68 util.HAS_SNI = orig_util_HAS_SNI | |
69 | |
70 | |
71 ### Note: This is a slightly bug-fixed version of same from ndg-httpsclient. | |
72 def get_subj_alt_name(peer_cert): | |
73 # Search through extensions | |
74 dns_name = [] | |
75 if not SUBJ_ALT_NAME_SUPPORT: | |
76 return dns_name | |
77 | |
78 general_names = SubjectAltName() | |
79 for i in range(peer_cert.get_extension_count()): | |
80 ext = peer_cert.get_extension(i) | |
81 ext_name = ext.get_short_name() | |
82 if ext_name != 'subjectAltName': | |
83 continue | |
84 | |
85 # PyOpenSSL returns extension data in ASN.1 encoded form | |
86 ext_dat = ext.get_data() | |
87 decoded_dat = der_decoder.decode(ext_dat, | |
88 asn1Spec=general_names) | |
89 | |
90 for name in decoded_dat: | |
91 if not isinstance(name, SubjectAltName): | |
92 continue | |
93 for entry in range(len(name)): | |
94 component = name.getComponentByPosition(entry) | |
95 if component.getName() != 'dNSName': | |
96 continue | |
97 dns_name.append(str(component.getComponent())) | |
98 | |
99 return dns_name | |
100 | |
101 | |
102 class fileobject(_fileobject): | |
103 | |
104 def read(self, size=-1): | |
105 # Use max, disallow tiny reads in a loop as they are very inefficient. | |
106 # We never leave read() with any leftover data from a new recv() call | |
107 # in our internal buffer. | |
108 rbufsize = max(self._rbufsize, self.default_bufsize) | |
109 # Our use of StringIO rather than lists of string objects returned by | |
110 # recv() minimizes memory usage and fragmentation that occurs when | |
111 # rbufsize is large compared to the typical return value of recv(). | |
112 buf = self._rbuf | |
113 buf.seek(0, 2) # seek end | |
114 if size < 0: | |
115 # Read until EOF | |
116 self._rbuf = StringIO() # reset _rbuf. we consume it via buf. | |
117 while True: | |
118 try: | |
119 data = self._sock.recv(rbufsize) | |
120 except OpenSSL.SSL.WantReadError: | |
121 continue | |
122 if not data: | |
123 break | |
124 buf.write(data) | |
125 return buf.getvalue() | |
126 else: | |
127 # Read until size bytes or EOF seen, whichever comes first | |
128 buf_len = buf.tell() | |
129 if buf_len >= size: | |
130 # Already have size bytes in our buffer? Extract and return. | |
131 buf.seek(0) | |
132 rv = buf.read(size) | |
133 self._rbuf = StringIO() | |
134 self._rbuf.write(buf.read()) | |
135 return rv | |
136 | |
137 self._rbuf = StringIO() # reset _rbuf. we consume it via buf. | |
138 while True: | |
139 left = size - buf_len | |
140 # recv() will malloc the amount of memory given as its | |
141 # parameter even though it often returns much less data | |
142 # than that. The returned data string is short lived | |
143 # as we copy it into a StringIO and free it. This avoids | |
144 # fragmentation issues on many platforms. | |
145 try: | |
146 data = self._sock.recv(left) | |
147 except OpenSSL.SSL.WantReadError: | |
148 continue | |
149 if not data: | |
150 break | |
151 n = len(data) | |
152 if n == size and not buf_len: | |
153 # Shortcut. Avoid buffer data copies when: | |
154 # - We have no data in our buffer. | |
155 # AND | |
156 # - Our call to recv returned exactly the | |
157 # number of bytes we were asked to read. | |
158 return data | |
159 if n == left: | |
160 buf.write(data) | |
161 del data # explicit free | |
162 break | |
163 assert n <= left, "recv(%d) returned %d bytes" % (left, n) | |
164 buf.write(data) | |
165 buf_len += n | |
166 del data # explicit free | |
167 #assert buf_len == buf.tell() | |
168 return buf.getvalue() | |
169 | |
170 def readline(self, size=-1): | |
171 buf = self._rbuf | |
172 buf.seek(0, 2) # seek end | |
173 if buf.tell() > 0: | |
174 # check if we already have it in our buffer | |
175 buf.seek(0) | |
176 bline = buf.readline(size) | |
177 if bline.endswith('\n') or len(bline) == size: | |
178 self._rbuf = StringIO() | |
179 self._rbuf.write(buf.read()) | |
180 return bline | |
181 del bline | |
182 if size < 0: | |
183 # Read until \n or EOF, whichever comes first | |
184 if self._rbufsize <= 1: | |
185 # Speed up unbuffered case | |
186 buf.seek(0) | |
187 buffers = [buf.read()] | |
188 self._rbuf = StringIO() # reset _rbuf. we consume it via buf. | |
189 data = None | |
190 recv = self._sock.recv | |
191 while True: | |
192 try: | |
193 while data != "\n": | |
194 data = recv(1) | |
195 if not data: | |
196 break | |
197 buffers.append(data) | |
198 except OpenSSL.SSL.WantReadError: | |
199 continue | |
200 break | |
201 return "".join(buffers) | |
202 | |
203 buf.seek(0, 2) # seek end | |
204 self._rbuf = StringIO() # reset _rbuf. we consume it via buf. | |
205 while True: | |
206 try: | |
207 data = self._sock.recv(self._rbufsize) | |
208 except OpenSSL.SSL.WantReadError: | |
209 continue | |
210 if not data: | |
211 break | |
212 nl = data.find('\n') | |
213 if nl >= 0: | |
214 nl += 1 | |
215 buf.write(data[:nl]) | |
216 self._rbuf.write(data[nl:]) | |
217 del data | |
218 break | |
219 buf.write(data) | |
220 return buf.getvalue() | |
221 else: | |
222 # Read until size bytes or \n or EOF seen, whichever comes first | |
223 buf.seek(0, 2) # seek end | |
224 buf_len = buf.tell() | |
225 if buf_len >= size: | |
226 buf.seek(0) | |
227 rv = buf.read(size) | |
228 self._rbuf = StringIO() | |
229 self._rbuf.write(buf.read()) | |
230 return rv | |
231 self._rbuf = StringIO() # reset _rbuf. we consume it via buf. | |
232 while True: | |
233 try: | |
234 data = self._sock.recv(self._rbufsize) | |
235 except OpenSSL.SSL.WantReadError: | |
236 continue | |
237 if not data: | |
238 break | |
239 left = size - buf_len | |
240 # did we just receive a newline? | |
241 nl = data.find('\n', 0, left) | |
242 if nl >= 0: | |
243 nl += 1 | |
244 # save the excess data to _rbuf | |
245 self._rbuf.write(data[nl:]) | |
246 if buf_len: | |
247 buf.write(data[:nl]) | |
248 break | |
249 else: | |
250 # Shortcut. Avoid data copy through buf when returning | |
251 # a substring of our first recv(). | |
252 return data[:nl] | |
253 n = len(data) | |
254 if n == size and not buf_len: | |
255 # Shortcut. Avoid data copy through buf when | |
256 # returning exactly all of our first recv(). | |
257 return data | |
258 if n >= left: | |
259 buf.write(data[:left]) | |
260 self._rbuf.write(data[left:]) | |
261 break | |
262 buf.write(data) | |
263 buf_len += n | |
264 #assert buf_len == buf.tell() | |
265 return buf.getvalue() | |
266 | |
267 | |
268 class WrappedSocket(object): | |
269 '''API-compatibility wrapper for Python OpenSSL's Connection-class.''' | |
270 | |
271 def __init__(self, connection, socket): | |
272 self.connection = connection | |
273 self.socket = socket | |
274 | |
275 def fileno(self): | |
276 return self.socket.fileno() | |
277 | |
278 def makefile(self, mode, bufsize=-1): | |
279 return fileobject(self.connection, mode, bufsize) | |
280 | |
281 def settimeout(self, timeout): | |
282 return self.socket.settimeout(timeout) | |
283 | |
284 def sendall(self, data): | |
285 return self.connection.sendall(data) | |
286 | |
287 def close(self): | |
288 return self.connection.shutdown() | |
289 | |
290 def getpeercert(self, binary_form=False): | |
291 x509 = self.connection.get_peer_certificate() | |
292 | |
293 if not x509: | |
294 return x509 | |
295 | |
296 if binary_form: | |
297 return OpenSSL.crypto.dump_certificate( | |
298 OpenSSL.crypto.FILETYPE_ASN1, | |
299 x509) | |
300 | |
301 return { | |
302 'subject': ( | |
303 (('commonName', x509.get_subject().CN),), | |
304 ), | |
305 'subjectAltName': [ | |
306 ('DNS', value) | |
307 for value in get_subj_alt_name(x509) | |
308 ] | |
309 } | |
310 | |
311 | |
312 def _verify_callback(cnx, x509, err_no, err_depth, return_code): | |
313 return err_no == 0 | |
314 | |
315 | |
316 def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None, | |
317 ca_certs=None, server_hostname=None, | |
318 ssl_version=None): | |
319 ctx = OpenSSL.SSL.Context(_openssl_versions[ssl_version]) | |
320 if certfile: | |
321 ctx.use_certificate_file(certfile) | |
322 if keyfile: | |
323 ctx.use_privatekey_file(keyfile) | |
324 if cert_reqs != ssl.CERT_NONE: | |
325 ctx.set_verify(_openssl_verify[cert_reqs], _verify_callback) | |
326 if ca_certs: | |
327 try: | |
328 ctx.load_verify_locations(ca_certs, None) | |
329 except OpenSSL.SSL.Error as e: | |
330 raise ssl.SSLError('bad ca_certs: %r' % ca_certs, e) | |
331 | |
332 cnx = OpenSSL.SSL.Connection(ctx, sock) | |
333 cnx.set_tlsext_host_name(server_hostname) | |
334 cnx.set_connect_state() | |
335 while True: | |
336 try: | |
337 cnx.do_handshake() | |
338 except OpenSSL.SSL.WantReadError: | |
339 continue | |
340 except OpenSSL.SSL.Error as e: | |
341 raise ssl.SSLError('bad handshake', e) | |
342 break | |
343 | |
344 return WrappedSocket(cnx, sock) | |
OLD | NEW |