| OLD | NEW |
| (Empty) |
| 1 # Copyright 2014 Google Inc. All Rights Reserved. | |
| 2 # | |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
| 4 # you may not use this file except in compliance with the License. | |
| 5 # You may obtain a copy of the License at | |
| 6 # | |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 | |
| 8 # | |
| 9 # Unless required by applicable law or agreed to in writing, software | |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, | |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 12 # See the License for the specific language governing permissions and | |
| 13 # limitations under the License. | |
| 14 | |
| 15 """Test routines to generate dummy certificates.""" | |
| 16 | |
| 17 import BaseHTTPServer | |
| 18 import shutil | |
| 19 import signal | |
| 20 import socket | |
| 21 import tempfile | |
| 22 import threading | |
| 23 import time | |
| 24 import unittest | |
| 25 | |
| 26 import certutils | |
| 27 import sslproxy | |
| 28 | |
| 29 | |
| 30 class Client(object): | |
| 31 | |
| 32 def __init__(self, ca_cert_path, verify_cb, port, host_name='foo.com', | |
| 33 host='localhost'): | |
| 34 self.host_name = host_name | |
| 35 self.verify_cb = verify_cb | |
| 36 self.ca_cert_path = ca_cert_path | |
| 37 self.port = port | |
| 38 self.host_name = host_name | |
| 39 self.host = host | |
| 40 self.connection = None | |
| 41 | |
| 42 def run_request(self): | |
| 43 context = certutils.get_ssl_context() | |
| 44 context.set_verify(certutils.VERIFY_PEER, self.verify_cb) # Demand a cert | |
| 45 context.use_certificate_file(self.ca_cert_path) | |
| 46 context.load_verify_locations(self.ca_cert_path) | |
| 47 | |
| 48 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| 49 self.connection = certutils.get_ssl_connection(context, s) | |
| 50 self.connection.connect((self.host, self.port)) | |
| 51 self.connection.set_tlsext_host_name(self.host_name) | |
| 52 | |
| 53 try: | |
| 54 self.connection.send('\r\n\r\n') | |
| 55 finally: | |
| 56 self.connection.shutdown() | |
| 57 self.connection.close() | |
| 58 | |
| 59 | |
| 60 class Handler(BaseHTTPServer.BaseHTTPRequestHandler): | |
| 61 protocol_version = 'HTTP/1.1' # override BaseHTTPServer setting | |
| 62 | |
| 63 def handle_one_request(self): | |
| 64 """Handle a single HTTP request.""" | |
| 65 self.raw_requestline = self.rfile.readline(65537) | |
| 66 | |
| 67 | |
| 68 class WrappedErrorHandler(Handler): | |
| 69 """Wraps handler to verify expected sslproxy errors are being raised.""" | |
| 70 | |
| 71 def setup(self): | |
| 72 Handler.setup(self) | |
| 73 try: | |
| 74 sslproxy._SetUpUsingDummyCert(self) | |
| 75 except certutils.Error: | |
| 76 self.server.error_function = certutils.Error | |
| 77 | |
| 78 def finish(self): | |
| 79 Handler.finish(self) | |
| 80 self.connection.shutdown() | |
| 81 self.connection.close() | |
| 82 | |
| 83 | |
| 84 class DummyArchive(object): | |
| 85 | |
| 86 def __init__(self): | |
| 87 pass | |
| 88 | |
| 89 | |
| 90 class DummyFetch(object): | |
| 91 | |
| 92 def __init__(self): | |
| 93 self.http_archive = DummyArchive() | |
| 94 | |
| 95 | |
| 96 class Server(BaseHTTPServer.HTTPServer): | |
| 97 """SSL server.""" | |
| 98 | |
| 99 def __init__(self, ca_cert_path, use_error_handler=False, port=0, | |
| 100 host='localhost'): | |
| 101 self.ca_cert_path = ca_cert_path | |
| 102 with open(ca_cert_path, 'r') as ca_file: | |
| 103 self.ca_cert_str = ca_file.read() | |
| 104 self.http_archive_fetch = DummyFetch() | |
| 105 if use_error_handler: | |
| 106 self.HANDLER = WrappedErrorHandler | |
| 107 else: | |
| 108 self.HANDLER = sslproxy.wrap_handler(Handler) | |
| 109 try: | |
| 110 BaseHTTPServer.HTTPServer.__init__(self, (host, port), self.HANDLER) | |
| 111 except Exception, e: | |
| 112 raise RuntimeError('Could not start HTTPSServer on port %d: %s' | |
| 113 % (port, e)) | |
| 114 | |
| 115 def __enter__(self): | |
| 116 thread = threading.Thread(target=self.serve_forever) | |
| 117 thread.daemon = True | |
| 118 thread.start() | |
| 119 return self | |
| 120 | |
| 121 def cleanup(self): | |
| 122 try: | |
| 123 self.shutdown() | |
| 124 except KeyboardInterrupt: | |
| 125 pass | |
| 126 | |
| 127 def __exit__(self, type_, value_, traceback_): | |
| 128 self.cleanup() | |
| 129 | |
| 130 def get_certificate(self, host): | |
| 131 return certutils.generate_cert(self.ca_cert_str, '', host) | |
| 132 | |
| 133 | |
| 134 class TestClient(unittest.TestCase): | |
| 135 _temp_dir = None | |
| 136 | |
| 137 def setUp(self): | |
| 138 self._temp_dir = tempfile.mkdtemp(prefix='sslproxy_', dir='/tmp') | |
| 139 self.ca_cert_path = self._temp_dir + 'testCA.pem' | |
| 140 self.cert_path = self._temp_dir + 'testCA-cert.cer' | |
| 141 self.wrong_ca_cert_path = self._temp_dir + 'wrong.pem' | |
| 142 self.wrong_cert_path = self._temp_dir + 'wrong-cert.cer' | |
| 143 | |
| 144 # Write both pem and cer files for certificates | |
| 145 certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(), | |
| 146 cert_path=self.ca_cert_path) | |
| 147 certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(), | |
| 148 cert_path=self.ca_cert_path) | |
| 149 | |
| 150 def tearDown(self): | |
| 151 if self._temp_dir: | |
| 152 shutil.rmtree(self._temp_dir) | |
| 153 | |
| 154 def verify_cb(self, conn, cert, errnum, depth, ok): | |
| 155 """A callback that verifies the certificate authentication worked. | |
| 156 | |
| 157 Args: | |
| 158 conn: Connection object | |
| 159 cert: x509 object | |
| 160 errnum: possible error number | |
| 161 depth: error depth | |
| 162 ok: 1 if the authentication worked 0 if it didnt. | |
| 163 Returns: | |
| 164 1 or 0 depending on if the verification worked | |
| 165 """ | |
| 166 self.assertFalse(cert.has_expired()) | |
| 167 self.assertGreater(time.strftime('%Y%m%d%H%M%SZ', time.gmtime()), | |
| 168 cert.get_notBefore()) | |
| 169 return ok | |
| 170 | |
| 171 def test_no_host(self): | |
| 172 with Server(self.ca_cert_path) as server: | |
| 173 c = Client(self.cert_path, self.verify_cb, server.server_port, '') | |
| 174 self.assertRaises(certutils.Error, c.run_request) | |
| 175 | |
| 176 def test_client_connection(self): | |
| 177 with Server(self.ca_cert_path) as server: | |
| 178 c = Client(self.cert_path, self.verify_cb, server.server_port, 'foo.com') | |
| 179 c.run_request() | |
| 180 | |
| 181 c = Client(self.cert_path, self.verify_cb, server.server_port, | |
| 182 'random.host') | |
| 183 c.run_request() | |
| 184 | |
| 185 def test_wrong_cert(self): | |
| 186 with Server(self.ca_cert_path, True) as server: | |
| 187 c = Client(self.wrong_cert_path, self.verify_cb, server.server_port, | |
| 188 'foo.com') | |
| 189 self.assertRaises(certutils.Error, c.run_request) | |
| 190 | |
| 191 | |
| 192 if __name__ == '__main__': | |
| 193 signal.signal(signal.SIGINT, signal.SIG_DFL) # Exit on Ctrl-C | |
| 194 unittest.main() | |
| OLD | NEW |