| Index: net/socket/ssl_server_socket_unittest.cc
|
| diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc
|
| index ac2d44ec4136fb88c0573646e91e0e9697eaf69e..9a2dd9cadd6ee38e6b529d970874643ce4f6a824 100644
|
| --- a/net/socket/ssl_server_socket_unittest.cc
|
| +++ b/net/socket/ssl_server_socket_unittest.cc
|
| @@ -20,6 +20,7 @@
|
| #include <queue>
|
| #include <utility>
|
|
|
| +#include "base/callback_helpers.h"
|
| #include "base/compiler_specific.h"
|
| #include "base/files/file_path.h"
|
| #include "base/files/file_util.h"
|
| @@ -29,8 +30,11 @@
|
| #include "base/message_loop/message_loop.h"
|
| #include "base/single_thread_task_runner.h"
|
| #include "base/thread_task_runner_handle.h"
|
| +#include "build/build_config.h"
|
| #include "crypto/nss_util.h"
|
| #include "crypto/rsa_private_key.h"
|
| +#include "crypto/scoped_openssl_types.h"
|
| +#include "crypto/signature_creator.h"
|
| #include "net/base/address_list.h"
|
| #include "net/base/completion_callback.h"
|
| #include "net/base/host_port_pair.h"
|
| @@ -40,6 +44,7 @@
|
| #include "net/base/test_data_directory.h"
|
| #include "net/cert/cert_status_flags.h"
|
| #include "net/cert/mock_cert_verifier.h"
|
| +#include "net/cert/mock_client_cert_verifier.h"
|
| #include "net/cert/x509_certificate.h"
|
| #include "net/http/transport_security_state.h"
|
| #include "net/log/net_log.h"
|
| @@ -47,18 +52,34 @@
|
| #include "net/socket/socket_test_util.h"
|
| #include "net/socket/ssl_client_socket.h"
|
| #include "net/socket/stream_socket.h"
|
| +#include "net/ssl/scoped_openssl_types.h"
|
| +#include "net/ssl/ssl_cert_request_info.h"
|
| #include "net/ssl/ssl_cipher_suite_names.h"
|
| #include "net/ssl/ssl_connection_status_flags.h"
|
| #include "net/ssl/ssl_info.h"
|
| +#include "net/ssl/ssl_private_key.h"
|
| #include "net/ssl/ssl_server_config.h"
|
| +#include "net/ssl/test_ssl_private_key.h"
|
| #include "net/test/cert_test_util.h"
|
| #include "testing/gtest/include/gtest/gtest.h"
|
| #include "testing/platform_test.h"
|
|
|
| +#if defined(USE_OPENSSL)
|
| +#include <openssl/evp.h>
|
| +#include <openssl/ssl.h>
|
| +#include <openssl/x509.h>
|
| +#endif
|
| +
|
| namespace net {
|
|
|
| namespace {
|
|
|
| +const char kClientCertFileName[] = "client_1.pem";
|
| +const char kClientPrivateKeyFileName[] = "client_1.pk8";
|
| +const char kWrongClientCertFileName[] = "client_2.pem";
|
| +const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
|
| +const char kClientCertCAFileName[] = "client_1_ca.pem";
|
| +
|
| class FakeDataChannel {
|
| public:
|
| FakeDataChannel()
|
| @@ -110,11 +131,24 @@ class FakeDataChannel {
|
| // asynchronously, which is necessary to reproduce bug 127822.
|
| void Close() {
|
| closed_ = true;
|
| + if (!read_callback_.is_null()) {
|
| + base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| + FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
|
| + weak_factory_.GetWeakPtr()));
|
| + }
|
| }
|
|
|
| private:
|
| void DoReadCallback() {
|
| - if (read_callback_.is_null() || data_.empty())
|
| + if (read_callback_.is_null())
|
| + return;
|
| +
|
| + if (closed_) {
|
| + base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
|
| + return;
|
| + }
|
| +
|
| + if (data_.empty())
|
| return;
|
|
|
| int copied = PropagateData(read_buf_, read_buf_len_);
|
| @@ -302,8 +336,10 @@ class SSLServerSocketTest : public PlatformTest {
|
| SSLServerSocketTest()
|
| : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
|
| cert_verifier_(new MockCertVerifier()),
|
| + client_cert_verifier_(new MockClientCertVerifier()),
|
| transport_security_state_(new TransportSecurityState) {
|
| - cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID);
|
| + cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
|
| + client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID);
|
| }
|
|
|
| protected:
|
| @@ -314,25 +350,10 @@ class SSLServerSocketTest : public PlatformTest {
|
| scoped_ptr<StreamSocket> server_socket(
|
| new FakeSocket(&channel_2_, &channel_1_));
|
|
|
| - base::FilePath certs_dir(GetTestCertsDirectory());
|
| -
|
| - base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
|
| - std::string cert_der;
|
| - ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
|
| -
|
| - scoped_refptr<X509Certificate> cert =
|
| - X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
|
| -
|
| - base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
|
| - std::string key_string;
|
| - ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
|
| - std::vector<uint8_t> key_vector(
|
| - reinterpret_cast<const uint8_t*>(key_string.data()),
|
| - reinterpret_cast<const uint8_t*>(key_string.data() +
|
| - key_string.length()));
|
| -
|
| - scoped_ptr<crypto::RSAPrivateKey> private_key(
|
| - crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
|
| + scoped_refptr<X509Certificate> server_cert(
|
| + ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"));
|
| + scoped_ptr<crypto::RSAPrivateKey> server_private_key(
|
| + ReadTestKey("unittest.key.bin"));
|
|
|
| client_ssl_config_.false_start_enabled = false;
|
| client_ssl_config_.channel_id_enabled = false;
|
| @@ -340,19 +361,78 @@ class SSLServerSocketTest : public PlatformTest {
|
| // Certificate provided by the host doesn't need authority.
|
| SSLConfig::CertAndStatus cert_and_status;
|
| cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
|
| - cert_and_status.der_cert = cert_der;
|
| + std::string server_cert_der;
|
| + CHECK(X509Certificate::GetDEREncoded(server_cert->os_cert_handle(),
|
| + &server_cert_der));
|
| + cert_and_status.der_cert = server_cert_der;
|
| client_ssl_config_.allowed_bad_certs.push_back(cert_and_status);
|
|
|
| HostPortPair host_and_pair("unittest", 0);
|
| SSLClientSocketContext context;
|
| context.cert_verifier = cert_verifier_.get();
|
| context.transport_security_state = transport_security_state_.get();
|
| + socket_factory_->ClearSSLSessionCache();
|
| client_socket_ = socket_factory_->CreateSSLClientSocket(
|
| std::move(client_connection), host_and_pair, client_ssl_config_,
|
| context);
|
| - server_socket_ = CreateSSLServerSocket(std::move(server_socket), cert.get(),
|
| - *private_key, server_ssl_config_);
|
| + server_socket_ =
|
| + CreateSSLServerSocket(std::move(server_socket), server_cert.get(),
|
| + *server_private_key, server_ssl_config_);
|
| + }
|
| +
|
| +#if defined(USE_OPENSSL)
|
| + void ConfigureClientCertsForClient(const char* cert_file_name,
|
| + const char* private_key_file_name) {
|
| + client_ssl_config_.send_client_cert = true;
|
| + client_ssl_config_.client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
|
| + ASSERT_TRUE(client_ssl_config_.client_cert);
|
| + scoped_ptr<crypto::RSAPrivateKey> key = ReadTestKey(private_key_file_name);
|
| + ASSERT_TRUE(key);
|
| + client_ssl_config_.client_private_key = WrapOpenSSLPrivateKey(
|
| + crypto::ScopedEVP_PKEY(EVP_PKEY_up_ref(key->key())));
|
| + }
|
| +
|
| + void ConfigureClientCertsForServer() {
|
| + server_ssl_config_.client_cert_type =
|
| + SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT;
|
| +
|
| + ScopedX509NameStack cert_names(
|
| + SSL_load_client_CA_file(GetTestCertsDirectory()
|
| + .AppendASCII(kClientCertCAFileName)
|
| + .MaybeAsASCII()
|
| + .c_str()));
|
| + ASSERT_TRUE(cert_names);
|
| + for (size_t i = 0; i < sk_X509_NAME_num(cert_names.get()); ++i) {
|
| + uint8_t* str = nullptr;
|
| + int length = i2d_X509_NAME(sk_X509_NAME_value(cert_names.get(), i), &str);
|
| + server_ssl_config_.cert_authorities_.push_back(std::string(
|
| + reinterpret_cast<const char*>(str), static_cast<size_t>(length)));
|
| + OPENSSL_free(str);
|
| + }
|
| +
|
| + scoped_refptr<X509Certificate> expected_client_cert(
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
|
| + client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
|
| +
|
| + server_ssl_config_.client_cert_verifier = client_cert_verifier_.get();
|
| + }
|
| +
|
| + scoped_ptr<crypto::RSAPrivateKey> ReadTestKey(const base::StringPiece& name) {
|
| + base::FilePath certs_dir(GetTestCertsDirectory());
|
| + base::FilePath key_path = certs_dir.AppendASCII(name);
|
| + std::string key_string;
|
| + if (!base::ReadFileToString(key_path, &key_string))
|
| + return nullptr;
|
| + std::vector<uint8_t> key_vector(
|
| + reinterpret_cast<const uint8_t*>(key_string.data()),
|
| + reinterpret_cast<const uint8_t*>(key_string.data() +
|
| + key_string.length()));
|
| + scoped_ptr<crypto::RSAPrivateKey> key(
|
| + crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
|
| + return key;
|
| }
|
| +#endif
|
|
|
| FakeDataChannel channel_1_;
|
| FakeDataChannel channel_2_;
|
| @@ -362,6 +442,7 @@ class SSLServerSocketTest : public PlatformTest {
|
| scoped_ptr<SSLServerSocket> server_socket_;
|
| ClientSocketFactory* socket_factory_;
|
| scoped_ptr<MockCertVerifier> cert_verifier_;
|
| + scoped_ptr<MockClientCertVerifier> client_cert_verifier_;
|
| scoped_ptr<TransportSecurityState> transport_security_state_;
|
| };
|
|
|
| @@ -378,21 +459,17 @@ TEST_F(SSLServerSocketTest, Initialize) {
|
| TEST_F(SSLServerSocketTest, Handshake) {
|
| Initialize();
|
|
|
| - TestCompletionCallback connect_callback;
|
| TestCompletionCallback handshake_callback;
|
| -
|
| int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| - EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
|
|
|
| + TestCompletionCallback connect_callback;
|
| int client_ret = client_socket_->Connect(connect_callback.callback());
|
| - EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
|
|
|
| - if (client_ret == ERR_IO_PENDING) {
|
| - EXPECT_EQ(OK, connect_callback.WaitForResult());
|
| - }
|
| - if (server_ret == ERR_IO_PENDING) {
|
| - EXPECT_EQ(OK, handshake_callback.WaitForResult());
|
| - }
|
| + client_ret = connect_callback.GetResult(client_ret);
|
| + server_ret = handshake_callback.GetResult(server_ret);
|
| +
|
| + ASSERT_EQ(OK, client_ret);
|
| + ASSERT_EQ(OK, server_ret);
|
|
|
| // Make sure the cert status is expected.
|
| SSLInfo ssl_info;
|
| @@ -412,16 +489,101 @@ TEST_F(SSLServerSocketTest, Handshake) {
|
| EXPECT_TRUE(is_aead);
|
| }
|
|
|
| -TEST_F(SSLServerSocketTest, DataTransfer) {
|
| +// NSS ports don't support client certificates.
|
| +#if defined(USE_OPENSSL)
|
| +
|
| +// This test executes Connect() on SSLClientSocket and Handshake() on
|
| +// SSLServerSocket to make sure handshaking between the two sockets is
|
| +// completed successfully, using client certificate.
|
| +TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
|
| + scoped_refptr<X509Certificate> client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
|
| + ConfigureClientCertsForClient(kClientCertFileName, kClientPrivateKeyFileName);
|
| + ConfigureClientCertsForServer();
|
| Initialize();
|
|
|
| + TestCompletionCallback handshake_callback;
|
| + int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| +
|
| TestCompletionCallback connect_callback;
|
| + int client_ret = client_socket_->Connect(connect_callback.callback());
|
| +
|
| + client_ret = connect_callback.GetResult(client_ret);
|
| + server_ret = handshake_callback.GetResult(server_ret);
|
| +
|
| + ASSERT_EQ(OK, client_ret);
|
| + ASSERT_EQ(OK, server_ret);
|
| +
|
| + // Make sure the cert status is expected.
|
| + SSLInfo ssl_info;
|
| + client_socket_->GetSSLInfo(&ssl_info);
|
| + EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
|
| + server_socket_->GetSSLInfo(&ssl_info);
|
| + EXPECT_TRUE(ssl_info.cert.get());
|
| + EXPECT_TRUE(client_cert->Equals(ssl_info.cert.get()));
|
| +}
|
| +
|
| +TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
|
| + ConfigureClientCertsForServer();
|
| + Initialize();
|
| + // Use the default setting for the client socket, which is to not send
|
| + // a client certificate. This will cause the client to receive an
|
| + // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the
|
| + // requested cert_authorities from the CertificateRequest sent by the
|
| + // server.
|
| +
|
| TestCompletionCallback handshake_callback;
|
| + int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| +
|
| + TestCompletionCallback connect_callback;
|
| + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
|
| + connect_callback.GetResult(
|
| + client_socket_->Connect(connect_callback.callback())));
|
| +
|
| + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo();
|
| + client_socket_->GetSSLCertRequestInfo(request_info.get());
|
| +
|
| + // Check that the authority name that arrived in the CertificateRequest
|
| + // handshake message is as expected.
|
| + scoped_refptr<X509Certificate> client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
|
| + EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
|
| +
|
| + client_socket_->Disconnect();
|
| +
|
| + EXPECT_EQ(ERR_FAILED, handshake_callback.GetResult(server_ret));
|
| +}
|
| +
|
| +TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
|
| + scoped_refptr<X509Certificate> client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
|
| + ConfigureClientCertsForClient(kWrongClientCertFileName,
|
| + kWrongClientPrivateKeyFileName);
|
| + ConfigureClientCertsForServer();
|
| + Initialize();
|
| +
|
| + TestCompletionCallback handshake_callback;
|
| + int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| +
|
| + TestCompletionCallback connect_callback;
|
| + int client_ret = client_socket_->Connect(connect_callback.callback());
|
| +
|
| + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
|
| + connect_callback.GetResult(client_ret));
|
| + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT,
|
| + handshake_callback.GetResult(server_ret));
|
| +}
|
| +#endif // defined(USE_OPENSSL)
|
| +
|
| +TEST_F(SSLServerSocketTest, DataTransfer) {
|
| + Initialize();
|
|
|
| // Establish connection.
|
| + TestCompletionCallback connect_callback;
|
| int client_ret = client_socket_->Connect(connect_callback.callback());
|
| ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
|
|
|
| + TestCompletionCallback handshake_callback;
|
| int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
|
|
|
| @@ -499,13 +661,13 @@ TEST_F(SSLServerSocketTest, DataTransfer) {
|
| TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
|
| Initialize();
|
|
|
| - TestCompletionCallback connect_callback;
|
| - TestCompletionCallback handshake_callback;
|
|
|
| // Establish connection.
|
| + TestCompletionCallback connect_callback;
|
| int client_ret = client_socket_->Connect(connect_callback.callback());
|
| ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
|
|
|
| + TestCompletionCallback handshake_callback;
|
| int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
|
|
|
| @@ -521,7 +683,6 @@ TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) {
|
| // socket won't return ERR_IO_PENDING. This ensures that the client
|
| // will call Read() on the transport socket again.
|
| TestCompletionCallback write_callback;
|
| -
|
| server_ret = server_socket_->Write(
|
| write_buf.get(), write_buf->size(), write_callback.callback());
|
| EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING);
|
| @@ -552,11 +713,10 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) {
|
| Initialize();
|
|
|
| TestCompletionCallback connect_callback;
|
| - TestCompletionCallback handshake_callback;
|
| -
|
| int client_ret = client_socket_->Connect(connect_callback.callback());
|
| ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
|
|
|
| + TestCompletionCallback handshake_callback;
|
| int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
|
|
|
| @@ -616,9 +776,9 @@ TEST_F(SSLServerSocketTest, RequireEcdheFlag) {
|
| Initialize();
|
|
|
| TestCompletionCallback connect_callback;
|
| - TestCompletionCallback handshake_callback;
|
| -
|
| int client_ret = client_socket_->Connect(connect_callback.callback());
|
| +
|
| + TestCompletionCallback handshake_callback;
|
| int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
|
|
| client_ret = connect_callback.GetResult(client_ret);
|
|
|