| OLD | NEW |
| (Empty) |
| 1 # Copyright 2014 Google Inc. All Rights Reserved. | |
| 2 # | |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
| 4 # you may not use this file except in compliance with the License. | |
| 5 # You may obtain a copy of the License at | |
| 6 # | |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 | |
| 8 # | |
| 9 # Unless required by applicable law or agreed to in writing, software | |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, | |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 12 # See the License for the specific language governing permissions and | |
| 13 # limitations under the License. | |
| 14 | |
| 15 """Test routines to generate dummy certificates.""" | |
| 16 | |
| 17 import BaseHTTPServer | |
| 18 import os | |
| 19 import shutil | |
| 20 import ssl | |
| 21 import tempfile | |
| 22 import threading | |
| 23 import unittest | |
| 24 | |
| 25 import certutils | |
| 26 | |
| 27 | |
| 28 class Server(BaseHTTPServer.HTTPServer): | |
| 29 | |
| 30 def __init__(self, https_root_ca_cert_path): | |
| 31 BaseHTTPServer.HTTPServer.__init__( | |
| 32 self, ('localhost', 0), BaseHTTPServer.BaseHTTPRequestHandler) | |
| 33 self.socket = ssl.wrap_socket( | |
| 34 self.socket, certfile=https_root_ca_cert_path, server_side=True, | |
| 35 do_handshake_on_connect=False) | |
| 36 | |
| 37 def __enter__(self): | |
| 38 thread = threading.Thread(target=self.serve_forever) | |
| 39 thread.daemon = True | |
| 40 thread.start() | |
| 41 return self | |
| 42 | |
| 43 def cleanup(self): | |
| 44 try: | |
| 45 self.shutdown() | |
| 46 except KeyboardInterrupt: | |
| 47 pass | |
| 48 | |
| 49 def __exit__(self, type_, value_, traceback_): | |
| 50 self.cleanup() | |
| 51 | |
| 52 | |
| 53 class CertutilsTest(unittest.TestCase): | |
| 54 | |
| 55 def _check_cert_file(self, cert_file_path, cert_str, key_str=None): | |
| 56 cert_load = open(cert_file_path, 'r').read() | |
| 57 if key_str: | |
| 58 expected_cert = key_str + cert_str | |
| 59 else: | |
| 60 expected_cert = cert_str | |
| 61 self.assertEqual(expected_cert, cert_load) | |
| 62 | |
| 63 def setUp(self): | |
| 64 self._temp_dir = tempfile.mkdtemp(prefix='certutils_', dir='/tmp') | |
| 65 | |
| 66 def tearDown(self): | |
| 67 if self._temp_dir: | |
| 68 shutil.rmtree(self._temp_dir) | |
| 69 | |
| 70 def test_generate_dummy_ca_cert(self): | |
| 71 subject = 'testSubject' | |
| 72 c, _ = certutils.generate_dummy_ca_cert(subject) | |
| 73 c = certutils.load_cert(c) | |
| 74 self.assertEqual(c.get_subject().commonName, subject) | |
| 75 | |
| 76 def test_get_host_cert(self): | |
| 77 ca_cert_path = os.path.join(self._temp_dir, 'rootCA.pem') | |
| 78 issuer = 'testCA' | |
| 79 certutils.write_dummy_ca_cert(*certutils.generate_dummy_ca_cert(issuer), | |
| 80 cert_path=ca_cert_path) | |
| 81 | |
| 82 with Server(ca_cert_path) as server: | |
| 83 cert_str = certutils.get_host_cert('localhost', server.server_port) | |
| 84 cert = certutils.load_cert(cert_str) | |
| 85 self.assertEqual(issuer, cert.get_subject().commonName) | |
| 86 | |
| 87 def test_get_host_cert_gives_empty_for_bad_host(self): | |
| 88 cert_str = certutils.get_host_cert('not_a_valid_host_name_2472341234234234') | |
| 89 self.assertEqual('', cert_str) | |
| 90 | |
| 91 def test_write_dummy_ca_cert(self): | |
| 92 base_path = os.path.join(self._temp_dir, 'testCA') | |
| 93 ca_cert_path = base_path + '.pem' | |
| 94 cert_path = base_path + '-cert.pem' | |
| 95 ca_cert_android = base_path + '-cert.cer' | |
| 96 ca_cert_windows = base_path + '-cert.p12' | |
| 97 | |
| 98 self.assertFalse(os.path.exists(ca_cert_path)) | |
| 99 self.assertFalse(os.path.exists(cert_path)) | |
| 100 self.assertFalse(os.path.exists(ca_cert_android)) | |
| 101 self.assertFalse(os.path.exists(ca_cert_windows)) | |
| 102 c, k = certutils.generate_dummy_ca_cert() | |
| 103 certutils.write_dummy_ca_cert(c, k, ca_cert_path) | |
| 104 | |
| 105 self._check_cert_file(ca_cert_path, c, k) | |
| 106 self._check_cert_file(cert_path, c) | |
| 107 self._check_cert_file(ca_cert_android, c) | |
| 108 self.assertTrue(os.path.exists(ca_cert_windows)) | |
| 109 | |
| 110 def test_generate_cert(self): | |
| 111 ca_cert_path = os.path.join(self._temp_dir, 'testCA.pem') | |
| 112 issuer = 'testIssuer' | |
| 113 certutils.write_dummy_ca_cert( | |
| 114 *certutils.generate_dummy_ca_cert(issuer), cert_path=ca_cert_path) | |
| 115 | |
| 116 with open(ca_cert_path, 'r') as root_file: | |
| 117 root_string = root_file.read() | |
| 118 subject = 'testSubject' | |
| 119 cert_string = certutils.generate_cert( | |
| 120 root_string, '', subject) | |
| 121 cert = certutils.load_cert(cert_string) | |
| 122 self.assertEqual(issuer, cert.get_issuer().commonName) | |
| 123 self.assertEqual(subject, cert.get_subject().commonName) | |
| 124 | |
| 125 with open(ca_cert_path, 'r') as ca_cert_file: | |
| 126 ca_cert_str = ca_cert_file.read() | |
| 127 cert_string = certutils.generate_cert(ca_cert_str, cert_string, | |
| 128 'host') | |
| 129 cert = certutils.load_cert(cert_string) | |
| 130 self.assertEqual(issuer, cert.get_issuer().commonName) | |
| 131 self.assertEqual(subject, cert.get_subject().commonName) | |
| 132 | |
| 133 | |
| 134 if __name__ == '__main__': | |
| 135 unittest.main() | |
| OLD | NEW |