Index: net/socket/ssl_server_socket_unittest.cc |
diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc |
index ac2d44ec4136fb88c0573646e91e0e9697eaf69e..9a2dd9cadd6ee38e6b529d970874643ce4f6a824 100644 |
--- a/net/socket/ssl_server_socket_unittest.cc |
+++ b/net/socket/ssl_server_socket_unittest.cc |
@@ -20,6 +20,7 @@ |
#include <queue> |
#include <utility> |
+#include "base/callback_helpers.h" |
#include "base/compiler_specific.h" |
#include "base/files/file_path.h" |
#include "base/files/file_util.h" |
@@ -29,8 +30,11 @@ |
#include "base/message_loop/message_loop.h" |
#include "base/single_thread_task_runner.h" |
#include "base/thread_task_runner_handle.h" |
+#include "build/build_config.h" |
#include "crypto/nss_util.h" |
#include "crypto/rsa_private_key.h" |
+#include "crypto/scoped_openssl_types.h" |
+#include "crypto/signature_creator.h" |
#include "net/base/address_list.h" |
#include "net/base/completion_callback.h" |
#include "net/base/host_port_pair.h" |
@@ -40,6 +44,7 @@ |
#include "net/base/test_data_directory.h" |
#include "net/cert/cert_status_flags.h" |
#include "net/cert/mock_cert_verifier.h" |
+#include "net/cert/mock_client_cert_verifier.h" |
#include "net/cert/x509_certificate.h" |
#include "net/http/transport_security_state.h" |
#include "net/log/net_log.h" |
@@ -47,18 +52,34 @@ |
#include "net/socket/socket_test_util.h" |
#include "net/socket/ssl_client_socket.h" |
#include "net/socket/stream_socket.h" |
+#include "net/ssl/scoped_openssl_types.h" |
+#include "net/ssl/ssl_cert_request_info.h" |
#include "net/ssl/ssl_cipher_suite_names.h" |
#include "net/ssl/ssl_connection_status_flags.h" |
#include "net/ssl/ssl_info.h" |
+#include "net/ssl/ssl_private_key.h" |
#include "net/ssl/ssl_server_config.h" |
+#include "net/ssl/test_ssl_private_key.h" |
#include "net/test/cert_test_util.h" |
#include "testing/gtest/include/gtest/gtest.h" |
#include "testing/platform_test.h" |
+#if defined(USE_OPENSSL) |
+#include <openssl/evp.h> |
+#include <openssl/ssl.h> |
+#include <openssl/x509.h> |
+#endif |
+ |
namespace net { |
namespace { |
+const char kClientCertFileName[] = "client_1.pem"; |
+const char kClientPrivateKeyFileName[] = "client_1.pk8"; |
+const char kWrongClientCertFileName[] = "client_2.pem"; |
+const char kWrongClientPrivateKeyFileName[] = "client_2.pk8"; |
+const char kClientCertCAFileName[] = "client_1_ca.pem"; |
+ |
class FakeDataChannel { |
public: |
FakeDataChannel() |
@@ -110,11 +131,24 @@ class FakeDataChannel { |
// asynchronously, which is necessary to reproduce bug 127822. |
void Close() { |
closed_ = true; |
+ if (!read_callback_.is_null()) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, |
+ weak_factory_.GetWeakPtr())); |
+ } |
} |
private: |
void DoReadCallback() { |
- if (read_callback_.is_null() || data_.empty()) |
+ if (read_callback_.is_null()) |
+ return; |
+ |
+ if (closed_) { |
+ base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED); |
+ return; |
+ } |
+ |
+ if (data_.empty()) |
return; |
int copied = PropagateData(read_buf_, read_buf_len_); |
@@ -302,8 +336,10 @@ class SSLServerSocketTest : public PlatformTest { |
SSLServerSocketTest() |
: socket_factory_(ClientSocketFactory::GetDefaultFactory()), |
cert_verifier_(new MockCertVerifier()), |
+ client_cert_verifier_(new MockClientCertVerifier()), |
transport_security_state_(new TransportSecurityState) { |
- cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID); |
+ cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID); |
+ client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID); |
} |
protected: |
@@ -314,25 +350,10 @@ class SSLServerSocketTest : public PlatformTest { |
scoped_ptr<StreamSocket> server_socket( |
new FakeSocket(&channel_2_, &channel_1_)); |
- base::FilePath certs_dir(GetTestCertsDirectory()); |
- |
- base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); |
- std::string cert_der; |
- ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); |
- |
- scoped_refptr<X509Certificate> cert = |
- X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); |
- |
- base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); |
- std::string key_string; |
- ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); |
- std::vector<uint8_t> key_vector( |
- reinterpret_cast<const uint8_t*>(key_string.data()), |
- reinterpret_cast<const uint8_t*>(key_string.data() + |
- key_string.length())); |
- |
- scoped_ptr<crypto::RSAPrivateKey> private_key( |
- crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); |
+ scoped_refptr<X509Certificate> server_cert( |
+ ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der")); |
+ scoped_ptr<crypto::RSAPrivateKey> server_private_key( |
+ ReadTestKey("unittest.key.bin")); |
client_ssl_config_.false_start_enabled = false; |
client_ssl_config_.channel_id_enabled = false; |
@@ -340,19 +361,78 @@ class SSLServerSocketTest : public PlatformTest { |
// Certificate provided by the host doesn't need authority. |
SSLConfig::CertAndStatus cert_and_status; |
cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; |
- cert_and_status.der_cert = cert_der; |
+ std::string server_cert_der; |
+ CHECK(X509Certificate::GetDEREncoded(server_cert->os_cert_handle(), |
+ &server_cert_der)); |
+ cert_and_status.der_cert = server_cert_der; |
client_ssl_config_.allowed_bad_certs.push_back(cert_and_status); |
HostPortPair host_and_pair("unittest", 0); |
SSLClientSocketContext context; |
context.cert_verifier = cert_verifier_.get(); |
context.transport_security_state = transport_security_state_.get(); |
+ socket_factory_->ClearSSLSessionCache(); |
client_socket_ = socket_factory_->CreateSSLClientSocket( |
std::move(client_connection), host_and_pair, client_ssl_config_, |
context); |
- server_socket_ = CreateSSLServerSocket(std::move(server_socket), cert.get(), |
- *private_key, server_ssl_config_); |
+ server_socket_ = |
+ CreateSSLServerSocket(std::move(server_socket), server_cert.get(), |
+ *server_private_key, server_ssl_config_); |
+ } |
+ |
+#if defined(USE_OPENSSL) |
+ void ConfigureClientCertsForClient(const char* cert_file_name, |
+ const char* private_key_file_name) { |
+ client_ssl_config_.send_client_cert = true; |
+ client_ssl_config_.client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), cert_file_name); |
+ ASSERT_TRUE(client_ssl_config_.client_cert); |
+ scoped_ptr<crypto::RSAPrivateKey> key = ReadTestKey(private_key_file_name); |
+ ASSERT_TRUE(key); |
+ client_ssl_config_.client_private_key = WrapOpenSSLPrivateKey( |
+ crypto::ScopedEVP_PKEY(EVP_PKEY_up_ref(key->key()))); |
+ } |
+ |
+ void ConfigureClientCertsForServer() { |
+ server_ssl_config_.client_cert_type = |
+ SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; |
+ |
+ ScopedX509NameStack cert_names( |
+ SSL_load_client_CA_file(GetTestCertsDirectory() |
+ .AppendASCII(kClientCertCAFileName) |
+ .MaybeAsASCII() |
+ .c_str())); |
+ ASSERT_TRUE(cert_names); |
+ for (size_t i = 0; i < sk_X509_NAME_num(cert_names.get()); ++i) { |
+ uint8_t* str = nullptr; |
+ int length = i2d_X509_NAME(sk_X509_NAME_value(cert_names.get(), i), &str); |
+ server_ssl_config_.cert_authorities_.push_back(std::string( |
+ reinterpret_cast<const char*>(str), static_cast<size_t>(length))); |
+ OPENSSL_free(str); |
+ } |
+ |
+ scoped_refptr<X509Certificate> expected_client_cert( |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName)); |
+ client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK); |
+ |
+ server_ssl_config_.client_cert_verifier = client_cert_verifier_.get(); |
+ } |
+ |
+ scoped_ptr<crypto::RSAPrivateKey> ReadTestKey(const base::StringPiece& name) { |
+ base::FilePath certs_dir(GetTestCertsDirectory()); |
+ base::FilePath key_path = certs_dir.AppendASCII(name); |
+ std::string key_string; |
+ if (!base::ReadFileToString(key_path, &key_string)) |
+ return nullptr; |
+ std::vector<uint8_t> key_vector( |
+ reinterpret_cast<const uint8_t*>(key_string.data()), |
+ reinterpret_cast<const uint8_t*>(key_string.data() + |
+ key_string.length())); |
+ scoped_ptr<crypto::RSAPrivateKey> key( |
+ crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); |
+ return key; |
} |
+#endif |
FakeDataChannel channel_1_; |
FakeDataChannel channel_2_; |
@@ -362,6 +442,7 @@ class SSLServerSocketTest : public PlatformTest { |
scoped_ptr<SSLServerSocket> server_socket_; |
ClientSocketFactory* socket_factory_; |
scoped_ptr<MockCertVerifier> cert_verifier_; |
+ scoped_ptr<MockClientCertVerifier> client_cert_verifier_; |
scoped_ptr<TransportSecurityState> transport_security_state_; |
}; |
@@ -378,21 +459,17 @@ TEST_F(SSLServerSocketTest, Initialize) { |
TEST_F(SSLServerSocketTest, Handshake) { |
Initialize(); |
- TestCompletionCallback connect_callback; |
TestCompletionCallback handshake_callback; |
- |
int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
- EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); |
+ TestCompletionCallback connect_callback; |
int client_ret = client_socket_->Connect(connect_callback.callback()); |
- EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); |
- if (client_ret == ERR_IO_PENDING) { |
- EXPECT_EQ(OK, connect_callback.WaitForResult()); |
- } |
- if (server_ret == ERR_IO_PENDING) { |
- EXPECT_EQ(OK, handshake_callback.WaitForResult()); |
- } |
+ client_ret = connect_callback.GetResult(client_ret); |
+ server_ret = handshake_callback.GetResult(server_ret); |
+ |
+ ASSERT_EQ(OK, client_ret); |
+ ASSERT_EQ(OK, server_ret); |
// Make sure the cert status is expected. |
SSLInfo ssl_info; |
@@ -412,16 +489,101 @@ TEST_F(SSLServerSocketTest, Handshake) { |
EXPECT_TRUE(is_aead); |
} |
-TEST_F(SSLServerSocketTest, DataTransfer) { |
+// NSS ports don't support client certificates. |
+#if defined(USE_OPENSSL) |
+ |
+// This test executes Connect() on SSLClientSocket and Handshake() on |
+// SSLServerSocket to make sure handshaking between the two sockets is |
+// completed successfully, using client certificate. |
+TEST_F(SSLServerSocketTest, HandshakeWithClientCert) { |
+ scoped_refptr<X509Certificate> client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); |
+ ConfigureClientCertsForClient(kClientCertFileName, kClientPrivateKeyFileName); |
+ ConfigureClientCertsForServer(); |
Initialize(); |
+ TestCompletionCallback handshake_callback; |
+ int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
+ |
TestCompletionCallback connect_callback; |
+ int client_ret = client_socket_->Connect(connect_callback.callback()); |
+ |
+ client_ret = connect_callback.GetResult(client_ret); |
+ server_ret = handshake_callback.GetResult(server_ret); |
+ |
+ ASSERT_EQ(OK, client_ret); |
+ ASSERT_EQ(OK, server_ret); |
+ |
+ // Make sure the cert status is expected. |
+ SSLInfo ssl_info; |
+ client_socket_->GetSSLInfo(&ssl_info); |
+ EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); |
+ server_socket_->GetSSLInfo(&ssl_info); |
+ EXPECT_TRUE(ssl_info.cert.get()); |
+ EXPECT_TRUE(client_cert->Equals(ssl_info.cert.get())); |
+} |
+ |
+TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) { |
+ ConfigureClientCertsForServer(); |
+ Initialize(); |
+ // Use the default setting for the client socket, which is to not send |
+ // a client certificate. This will cause the client to receive an |
+ // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the |
+ // requested cert_authorities from the CertificateRequest sent by the |
+ // server. |
+ |
TestCompletionCallback handshake_callback; |
+ int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
+ |
+ TestCompletionCallback connect_callback; |
+ EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, |
+ connect_callback.GetResult( |
+ client_socket_->Connect(connect_callback.callback()))); |
+ |
+ scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); |
+ client_socket_->GetSSLCertRequestInfo(request_info.get()); |
+ |
+ // Check that the authority name that arrived in the CertificateRequest |
+ // handshake message is as expected. |
+ scoped_refptr<X509Certificate> client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); |
+ EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities)); |
+ |
+ client_socket_->Disconnect(); |
+ |
+ EXPECT_EQ(ERR_FAILED, handshake_callback.GetResult(server_ret)); |
+} |
+ |
+TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) { |
+ scoped_refptr<X509Certificate> client_cert = |
+ ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); |
+ ConfigureClientCertsForClient(kWrongClientCertFileName, |
+ kWrongClientPrivateKeyFileName); |
+ ConfigureClientCertsForServer(); |
+ Initialize(); |
+ |
+ TestCompletionCallback handshake_callback; |
+ int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
+ |
+ TestCompletionCallback connect_callback; |
+ int client_ret = client_socket_->Connect(connect_callback.callback()); |
+ |
+ EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, |
+ connect_callback.GetResult(client_ret)); |
+ EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, |
+ handshake_callback.GetResult(server_ret)); |
+} |
+#endif // defined(USE_OPENSSL) |
+ |
+TEST_F(SSLServerSocketTest, DataTransfer) { |
+ Initialize(); |
// Establish connection. |
+ TestCompletionCallback connect_callback; |
int client_ret = client_socket_->Connect(connect_callback.callback()); |
ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); |
+ TestCompletionCallback handshake_callback; |
int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); |
@@ -499,13 +661,13 @@ TEST_F(SSLServerSocketTest, DataTransfer) { |
TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { |
Initialize(); |
- TestCompletionCallback connect_callback; |
- TestCompletionCallback handshake_callback; |
// Establish connection. |
+ TestCompletionCallback connect_callback; |
int client_ret = client_socket_->Connect(connect_callback.callback()); |
ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); |
+ TestCompletionCallback handshake_callback; |
int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); |
@@ -521,7 +683,6 @@ TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { |
// socket won't return ERR_IO_PENDING. This ensures that the client |
// will call Read() on the transport socket again. |
TestCompletionCallback write_callback; |
- |
server_ret = server_socket_->Write( |
write_buf.get(), write_buf->size(), write_callback.callback()); |
EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); |
@@ -552,11 +713,10 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { |
Initialize(); |
TestCompletionCallback connect_callback; |
- TestCompletionCallback handshake_callback; |
- |
int client_ret = client_socket_->Connect(connect_callback.callback()); |
ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); |
+ TestCompletionCallback handshake_callback; |
int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); |
@@ -616,9 +776,9 @@ TEST_F(SSLServerSocketTest, RequireEcdheFlag) { |
Initialize(); |
TestCompletionCallback connect_callback; |
- TestCompletionCallback handshake_callback; |
- |
int client_ret = client_socket_->Connect(connect_callback.callback()); |
+ |
+ TestCompletionCallback handshake_callback; |
int server_ret = server_socket_->Handshake(handshake_callback.callback()); |
client_ret = connect_callback.GetResult(client_ret); |