Index: net/socket/ssl_server_socket_unittest.cc |
diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc |
index 548f7c6c99856c583bec26f6085ad8d1cbdda87b..905014316dc6a732572d6950d18f8a84e96184ed 100644 |
--- a/net/socket/ssl_server_socket_unittest.cc |
+++ b/net/socket/ssl_server_socket_unittest.cc |
@@ -30,6 +30,7 @@ |
#include "base/thread_task_runner_handle.h" |
#include "crypto/nss_util.h" |
#include "crypto/rsa_private_key.h" |
+#include "crypto/signature_creator.h" |
#include "net/base/address_list.h" |
#include "net/base/completion_callback.h" |
#include "net/base/host_port_pair.h" |
@@ -46,9 +47,11 @@ |
#include "net/socket/socket_test_util.h" |
#include "net/socket/ssl_client_socket.h" |
#include "net/socket/stream_socket.h" |
+#include "net/ssl/ssl_cert_request_info.h" |
#include "net/ssl/ssl_cipher_suite_names.h" |
#include "net/ssl/ssl_connection_status_flags.h" |
#include "net/ssl/ssl_info.h" |
+#include "net/ssl/ssl_private_key.h" |
#include "net/ssl/ssl_server_config.h" |
#include "net/test/cert_test_util.h" |
#include "testing/gtest/include/gtest/gtest.h" |
@@ -109,13 +112,17 @@ class FakeDataChannel { |
// asynchronously, which is necessary to reproduce bug 127822. |
void Close() { |
closed_ = true; |
+ if (!read_callback_.is_null()) { |
+ base::MessageLoop::current()->PostTask( |
+ FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, |
+ weak_factory_.GetWeakPtr())); |
+ } |
} |
private: |
void DoReadCallback() { |
if (read_callback_.is_null() || data_.empty()) |
return; |
- |
int copied = PropagateData(read_buf_, read_buf_len_); |
CompletionCallback callback = read_callback_; |
read_callback_.Reset(); |
@@ -254,6 +261,64 @@ class FakeSocket : public StreamSocket { |
DISALLOW_COPY_AND_ASSIGN(FakeSocket); |
}; |
+class TestSSLPrivateKey : public SSLPrivateKey { |
+ public: |
+ TestSSLPrivateKey(crypto::RSAPrivateKey* rsa_private_key) |
+ : rsa_private_key_(rsa_private_key) {} |
+ |
+ ~TestSSLPrivateKey() override {} |
+ |
+ Type GetType() override { return SSLPrivateKey::Type::RSA; } |
+ |
+ std::vector<SSLPrivateKey::Hash> GetDigestPreferences() override { |
+ static const SSLPrivateKey::Hash kHashes[] = {SSLPrivateKey::Hash::SHA256, |
+ SSLPrivateKey::Hash::SHA1}; |
+ return std::vector<SSLPrivateKey::Hash>(kHashes, |
+ kHashes + arraysize(kHashes)); |
+ } |
+ |
+ // NOTE: The following algorithm assumes the answer is a power of 2, which is |
+ // true for the test keys in use. |
+ size_t GetMaxSignatureLengthInBytes() override { |
+ std::vector<uint8> public_key_info; |
+ rsa_private_key_->ExportPublicKey(&public_key_info); |
+ uint result = 1; |
+ while ((result << 1) < public_key_info.size()) |
+ result <<= 1; |
+ return result; |
+ } |
+ |
+ void SignDigest(Hash hash, |
+ const base::StringPiece& input, |
+ const SignCallback& callback) override { |
+ std::vector<uint8> signature; |
+ crypto::SignatureCreator::HashAlgorithm hash_alg; |
+ switch (hash) { |
+ case Hash::SHA1: |
+ hash_alg = crypto::SignatureCreator::SHA1; |
+ break; |
+ |
+ case Hash::SHA256: |
+ hash_alg = crypto::SignatureCreator::SHA256; |
+ break; |
+ |
+ default: |
+ FAIL() << "Unsupported hash function"; |
+ } |
+ crypto::SignatureCreator::Sign(rsa_private_key_.get(), hash_alg, |
+ reinterpret_cast<const uint8*>(input.data()), |
+ input.size(), &signature); |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback, OK, signature)); |
+ } |
+ |
+ private: |
+ void CompleteSignDigest(Error err, const std::vector<uint8_t>& signature) {} |
+ scoped_ptr<crypto::RSAPrivateKey> rsa_private_key_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(TestSSLPrivateKey); |
+}; |
+ |
} // namespace |
// Verify the correctness of the test helper classes first. |
@@ -298,11 +363,21 @@ TEST(FakeSocketTest, DataTransfer) { |
class SSLServerSocketTest : public PlatformTest { |
public: |
+ enum ClientCertSupply { |
+ kNoneSupplied = 0, |
+ kCorrectCertSupplied = 1, |
+ kWrongCertSupplied = 2 |
+ }; |
+ |
+ enum ClientCertExpect { kNoneExpected = 0, kCertRequired = 2 }; |
+ |
SSLServerSocketTest() |
: socket_factory_(ClientSocketFactory::GetDefaultFactory()), |
cert_verifier_(new MockCertVerifier()), |
+ client_cert_verifier_(new MockCertVerifier()), |
transport_security_state_(new TransportSecurityState) { |
cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID); |
+ client_cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID); |
} |
protected: |
@@ -313,25 +388,11 @@ class SSLServerSocketTest : public PlatformTest { |
scoped_ptr<StreamSocket> server_socket( |
new FakeSocket(&channel_2_, &channel_1_)); |
- base::FilePath certs_dir(GetTestCertsDirectory()); |
- |
- base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); |
- std::string cert_der; |
- ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); |
- |
- scoped_refptr<X509Certificate> cert = |
- X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); |
- |
- base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); |
- std::string key_string; |
- ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); |
- std::vector<uint8> key_vector( |
- reinterpret_cast<const uint8*>(key_string.data()), |
- reinterpret_cast<const uint8*>(key_string.data() + |
- key_string.length())); |
- |
- scoped_ptr<crypto::RSAPrivateKey> private_key( |
- crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); |
+ std::string server_cert_der; |
+ scoped_refptr<X509Certificate> server_cert( |
+ ReadTestCert("unittest.selfsigned.der", &server_cert_der)); |
+ scoped_ptr<crypto::RSAPrivateKey> server_private_key( |
+ ReadTestKey("unittest.key.bin")); |
client_ssl_config_.false_start_enabled = false; |
client_ssl_config_.channel_id_enabled = false; |
@@ -339,18 +400,85 @@ class SSLServerSocketTest : public PlatformTest { |
// Certificate provided by the host doesn't need authority. |
SSLConfig::CertAndStatus cert_and_status; |
cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; |
- cert_and_status.der_cert = cert_der; |
+ cert_and_status.der_cert = server_cert_der; |
client_ssl_config_.allowed_bad_certs.push_back(cert_and_status); |
+ SSLConfig ssl_server_config; |
HostPortPair host_and_pair("unittest", 0); |
SSLClientSocketContext context; |
context.cert_verifier = cert_verifier_.get(); |
context.transport_security_state = transport_security_state_.get(); |
+ socket_factory_->ClearSSLSessionCache(); |
client_socket_ = socket_factory_->CreateSSLClientSocket( |
client_connection.Pass(), host_and_pair, client_ssl_config_, context); |
server_socket_ = |
- CreateSSLServerSocket(server_socket.Pass(), cert.get(), |
- private_key.get(), server_ssl_config_); |
+ CreateSSLServerSocket(server_socket.Pass(), server_cert.get(), |
+ server_private_key.get(), server_ssl_config_); |
+ } |
+ |
+ void InitializeClientCertsForClient(ClientCertSupply supply) { |
+ scoped_refptr<X509Certificate> cert; |
+ scoped_ptr<net::SSLPrivateKey> key; |
+ if (supply != kNoneSupplied) { |
+ const char* cert_file_name = supply == kCorrectCertSupplied |
+ ? kClientCertFileName |
+ : kWrongClientCertFileName; |
+ const char* private_key_file_name = supply == kCorrectCertSupplied |
+ ? kClientPrivateKeyFileName |
+ : kWrongClientPrivateKeyFileName; |
+ cert = ImportCertFromFile(GetTestCertsDirectory(), cert_file_name); |
+ key.reset(new TestSSLPrivateKey(ReadTestKey(private_key_file_name))); |
+ } |
+ client_socket_->ForceClientCertificateAndKeyForTest(cert, key.Pass()); |
+ } |
+ |
+ void InitializeClientCertsForServer(ClientCertExpect expect) { |
+ if (expect == kNoneExpected) |
+ return; |
+ |
+ server_socket_->SetRequireClientCert(true); |
+ |
+ if (expect == kCertRequired) { |
+ scoped_refptr<X509Certificate> expected_client_ca_cert( |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertCAFileName)); |
+ CertificateList ca_list; |
+ ca_list.push_back(expected_client_ca_cert); |
+ server_socket_->SetClientCertCAList(ca_list); |
+ scoped_refptr<X509Certificate> expected_client_cert( |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName)); |
+ CertVerifyResult ignored; |
+ ignored.verified_cert = expected_client_cert; |
+ ignored.cert_status = 0; |
+ client_cert_verifier_->AddResultForCert(expected_client_cert.get(), |
+ ignored, OK); |
+ server_socket_->SetClientCertVerifier(client_cert_verifier_.get()); |
+ } |
+ } |
+ |
+ X509Certificate* ReadTestCert(const base::StringPiece& name, |
+ std::string* cert_der) { |
+ base::FilePath certs_dir(GetTestCertsDirectory()); |
+ base::FilePath cert_path = certs_dir.AppendASCII(name); |
+ std::string unneeded; |
+ if (!cert_der) { |
+ cert_der = &unneeded; |
+ } |
+ if (!base::ReadFileToString(cert_path, cert_der)) |
+ return NULL; |
+ return X509Certificate::CreateFromBytes(cert_der->data(), cert_der->size()); |
+ } |
+ |
+ crypto::RSAPrivateKey* ReadTestKey(const base::StringPiece& name) { |
+ base::FilePath certs_dir(GetTestCertsDirectory()); |
+ base::FilePath key_path = certs_dir.AppendASCII(name); |
+ std::string key_string; |
+ if (!base::ReadFileToString(key_path, &key_string)) |
+ return NULL; |
+ std::vector<uint8> key_vector( |
+ reinterpret_cast<const uint8*>(key_string.data()), |
+ reinterpret_cast<const uint8*>(key_string.data() + |
+ key_string.length())); |
+ return crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector); |
} |
FakeDataChannel channel_1_; |
@@ -361,7 +489,15 @@ class SSLServerSocketTest : public PlatformTest { |
scoped_ptr<SSLServerSocket> server_socket_; |
ClientSocketFactory* socket_factory_; |
scoped_ptr<MockCertVerifier> cert_verifier_; |
+ scoped_ptr<MockCertVerifier> client_cert_verifier_; |
scoped_ptr<TransportSecurityState> transport_security_state_; |
+ CertificateList trusted_certs_; |
+ |
+ const char* kClientCertFileName = "client_1.pem"; |
+ const char* kClientPrivateKeyFileName = "client_1.pk8"; |
+ const char* kWrongClientCertFileName = "client_2.pem"; |
+ const char* kWrongClientPrivateKeyFileName = "client_2.pk8"; |
+ const char* kClientCertCAFileName = "client_1_ca.pem"; |
}; |
// This test only executes creation of client and server sockets. This is to |
@@ -411,6 +547,123 @@ TEST_F(SSLServerSocketTest, Handshake) { |
EXPECT_TRUE(is_aead); |
} |
+// TODO(dougsteed). The following tests using client certificates cannot |
+// be performed if NSS with platform-based client auth is in use. That's because |
+// the tests use SSLClientSocket to make requests against the server, and on |
+// those builds, that class does not support supplying of a test key and cert. |
+// An alternative approach that would broaden the applicability of these tests |
+// would be to build and use the openssl flavor of SSLClientSocket, even |
+// on NSS platforms. |
+#if !defined(USE_NSS) || !defined(NSS_PLATFORM_CLIENT_AUTH) |
+ |
+// This test executes Connect() on SSLClientSocket and Handshake() on |
+// SSLServerSocket to make sure handshaking between the two sockets is |
+// completed successfully, using client certificate. |
+TEST_F(SSLServerSocketTest, HandshakeWithClientCert) { |
+ scoped_refptr<X509Certificate> client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); |
+ Initialize(); |
+ InitializeClientCertsForServer(kCertRequired); |
+ InitializeClientCertsForClient(kCorrectCertSupplied); |
+ |
+ TestCompletionCallback connect_callback; |
+ TestCompletionCallback handshake_callback; |
+ |
+ int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
+ EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); |
+ |
+ int client_ret = client_socket_->Connect(connect_callback.callback()); |
+ EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); |
+ |
+ if (client_ret == ERR_IO_PENDING) { |
+ EXPECT_EQ(OK, connect_callback.WaitForResult()); |
+ } |
+ if (server_ret == ERR_IO_PENDING) { |
+ EXPECT_EQ(OK, handshake_callback.WaitForResult()); |
+ } |
+ |
+ // Make sure the cert status is expected. |
+ SSLInfo ssl_info; |
+ client_socket_->GetSSLInfo(&ssl_info); |
+ EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); |
+ server_socket_->GetSSLInfo(&ssl_info); |
+ EXPECT_TRUE(ssl_info.client_cert_sent); |
+ EXPECT_TRUE(ssl_info.client_cert_sent); |
+ EXPECT_TRUE(ssl_info.cert.get()); |
+ EXPECT_TRUE(client_cert->Equals(ssl_info.cert.get())); |
+} |
+ |
+TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) { |
+ Initialize(); |
+ InitializeClientCertsForServer(kCertRequired); |
+ // We use the default setting for the client socket. This causes the client to |
+ // get SSL_CLIENT_AUTH_CERT_NEEDED. This code path allows us to access the |
+ // cert_authorities from the CertificateRequest. |
+ |
+ TestCompletionCallback connect_callback; |
+ TestCompletionCallback handshake_callback; |
+ |
+ int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
+ EXPECT_TRUE(server_ret == ERR_IO_PENDING); |
+ |
+ int client_ret = client_socket_->Connect(connect_callback.callback()); |
+ EXPECT_TRUE(client_ret == ERR_SSL_CLIENT_AUTH_CERT_NEEDED || |
+ client_ret == ERR_IO_PENDING); |
+ |
+ if (client_ret == ERR_IO_PENDING) { |
+ EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, |
+ connect_callback.WaitForResult()); |
+ } |
+ |
+ scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); |
+ client_socket_->GetSSLCertRequestInfo(request_info.get()); |
+ |
+ // Check that the authority name that arrived in the CertificateRequest |
+ // handshake message is as expected. |
+ scoped_refptr<X509Certificate> client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); |
+ EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities)); |
+ |
+ client_socket_->Disconnect(); |
+ |
+ if (server_ret == ERR_IO_PENDING) { |
+ server_ret = handshake_callback.WaitForResult(); |
+ EXPECT_TRUE(server_ret == ERR_CONNECTION_CLOSED || |
+ server_ret == ERR_FAILED); |
+ } |
+} |
+ |
+TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) { |
+ scoped_refptr<X509Certificate> client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); |
+ Initialize(); |
+ InitializeClientCertsForServer(kCertRequired); |
+ InitializeClientCertsForClient(kWrongCertSupplied); |
+ |
+ TestCompletionCallback connect_callback; |
+ TestCompletionCallback handshake_callback; |
+ |
+ int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
+ EXPECT_TRUE(server_ret == ERR_IO_PENDING); |
+ |
+ int client_ret = client_socket_->Connect(connect_callback.callback()); |
+ EXPECT_TRUE(client_ret == ERR_BAD_SSL_CLIENT_AUTH_CERT || |
+ client_ret == ERR_IO_PENDING); |
+ |
+ if (client_ret == ERR_IO_PENDING) { |
+ EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, connect_callback.WaitForResult()); |
+ } |
+ |
+ server_ret = handshake_callback.WaitForResult(); |
+ // We get a different result on NSS and OpenSSL. That's because an error |
+ // mapping with OpenSSL makes an assumption that is true for SSLClientSocket |
+ // but not SSLServerSocket (namely that peer cert rejection only occurs due to |
+ // a cert change during renego). |
+ EXPECT_TRUE(server_ret == ERR_BAD_SSL_CLIENT_AUTH_CERT || |
+ server_ret == ERR_SSL_SERVER_CERT_CHANGED); |
+} |
+#endif //! defined(USE_NSS) || !defined(NSS_PLATFORM_CLIENT_AUTH) |
+ |
TEST_F(SSLServerSocketTest, DataTransfer) { |
Initialize(); |