| Index: net/socket/ssl_server_socket_unittest.cc
|
| diff --git a/net/socket/ssl_server_socket_unittest.cc b/net/socket/ssl_server_socket_unittest.cc
|
| index 548f7c6c99856c583bec26f6085ad8d1cbdda87b..f88d257b3759a08e7f112c67a5ed4ebe65709338 100644
|
| --- a/net/socket/ssl_server_socket_unittest.cc
|
| +++ b/net/socket/ssl_server_socket_unittest.cc
|
| @@ -30,6 +30,7 @@
|
| #include "base/thread_task_runner_handle.h"
|
| #include "crypto/nss_util.h"
|
| #include "crypto/rsa_private_key.h"
|
| +#include "crypto/signature_creator.h"
|
| #include "net/base/address_list.h"
|
| #include "net/base/completion_callback.h"
|
| #include "net/base/host_port_pair.h"
|
| @@ -39,6 +40,7 @@
|
| #include "net/base/test_data_directory.h"
|
| #include "net/cert/cert_status_flags.h"
|
| #include "net/cert/mock_cert_verifier.h"
|
| +#include "net/cert/mock_client_cert_verifier.h"
|
| #include "net/cert/x509_certificate.h"
|
| #include "net/http/transport_security_state.h"
|
| #include "net/log/net_log.h"
|
| @@ -46,9 +48,11 @@
|
| #include "net/socket/socket_test_util.h"
|
| #include "net/socket/ssl_client_socket.h"
|
| #include "net/socket/stream_socket.h"
|
| +#include "net/ssl/ssl_cert_request_info.h"
|
| #include "net/ssl/ssl_cipher_suite_names.h"
|
| #include "net/ssl/ssl_connection_status_flags.h"
|
| #include "net/ssl/ssl_info.h"
|
| +#include "net/ssl/ssl_private_key.h"
|
| #include "net/ssl/ssl_server_config.h"
|
| #include "net/test/cert_test_util.h"
|
| #include "testing/gtest/include/gtest/gtest.h"
|
| @@ -58,6 +62,12 @@ namespace net {
|
|
|
| namespace {
|
|
|
| +const char kClientCertFileName[] = "client_1.pem";
|
| +const char kClientPrivateKeyFileName[] = "client_1.pk8";
|
| +const char kWrongClientCertFileName[] = "client_2.pem";
|
| +const char kWrongClientPrivateKeyFileName[] = "client_2.pk8";
|
| +const char kClientCertCAFileName[] = "client_1_ca.pem";
|
| +
|
| class FakeDataChannel {
|
| public:
|
| FakeDataChannel()
|
| @@ -109,13 +119,19 @@ class FakeDataChannel {
|
| // asynchronously, which is necessary to reproduce bug 127822.
|
| void Close() {
|
| closed_ = true;
|
| + data_.push(
|
| + new DrainableIOBuffer(new StringIOBuffer(std::string("0", 1)), 1));
|
| + if (!read_callback_.is_null()) {
|
| + base::MessageLoop::current()->PostTask(
|
| + FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback,
|
| + weak_factory_.GetWeakPtr()));
|
| + }
|
| }
|
|
|
| private:
|
| void DoReadCallback() {
|
| if (read_callback_.is_null() || data_.empty())
|
| return;
|
| -
|
| int copied = PropagateData(read_buf_, read_buf_len_);
|
| CompletionCallback callback = read_callback_;
|
| read_callback_.Reset();
|
| @@ -254,6 +270,64 @@ class FakeSocket : public StreamSocket {
|
| DISALLOW_COPY_AND_ASSIGN(FakeSocket);
|
| };
|
|
|
| +class TestSSLPrivateKey : public SSLPrivateKey {
|
| + public:
|
| + TestSSLPrivateKey(crypto::RSAPrivateKey* rsa_private_key)
|
| + : rsa_private_key_(rsa_private_key) {}
|
| +
|
| + ~TestSSLPrivateKey() override {}
|
| +
|
| + Type GetType() override { return SSLPrivateKey::Type::RSA; }
|
| +
|
| + std::vector<SSLPrivateKey::Hash> GetDigestPreferences() override {
|
| + static const SSLPrivateKey::Hash kHashes[] = {SSLPrivateKey::Hash::SHA256,
|
| + SSLPrivateKey::Hash::SHA1};
|
| + return std::vector<SSLPrivateKey::Hash>(kHashes,
|
| + kHashes + arraysize(kHashes));
|
| + }
|
| +
|
| + // NOTE: The following algorithm assumes the answer is a power of 2, which is
|
| + // true for the test keys in use.
|
| + size_t GetMaxSignatureLengthInBytes() override {
|
| + std::vector<uint8> public_key_info;
|
| + rsa_private_key_->ExportPublicKey(&public_key_info);
|
| + uint result = 1;
|
| + while ((result << 1) < public_key_info.size())
|
| + result <<= 1;
|
| + return result;
|
| + }
|
| +
|
| + void SignDigest(Hash hash,
|
| + const base::StringPiece& input,
|
| + const SignCallback& callback) override {
|
| + std::vector<uint8> signature;
|
| + crypto::SignatureCreator::HashAlgorithm hash_alg;
|
| + switch (hash) {
|
| + case Hash::SHA1:
|
| + hash_alg = crypto::SignatureCreator::SHA1;
|
| + break;
|
| +
|
| + case Hash::SHA256:
|
| + hash_alg = crypto::SignatureCreator::SHA256;
|
| + break;
|
| +
|
| + default:
|
| + FAIL() << "Unsupported hash function";
|
| + }
|
| + crypto::SignatureCreator::Sign(rsa_private_key_.get(), hash_alg,
|
| + reinterpret_cast<const uint8*>(input.data()),
|
| + input.size(), &signature);
|
| + base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| + FROM_HERE, base::Bind(callback, OK, signature));
|
| + }
|
| +
|
| + private:
|
| + void CompleteSignDigest(Error err, const std::vector<uint8_t>& signature) {}
|
| + scoped_ptr<crypto::RSAPrivateKey> rsa_private_key_;
|
| +
|
| + DISALLOW_COPY_AND_ASSIGN(TestSSLPrivateKey);
|
| +};
|
| +
|
| } // namespace
|
|
|
| // Verify the correctness of the test helper classes first.
|
| @@ -298,11 +372,21 @@ TEST(FakeSocketTest, DataTransfer) {
|
|
|
| class SSLServerSocketTest : public PlatformTest {
|
| public:
|
| + enum ClientCertSupply {
|
| + kNoneSupplied = 0,
|
| + kCorrectCertSupplied = 1,
|
| + kWrongCertSupplied = 2
|
| + };
|
| +
|
| + enum ClientCertExpect { kNoneExpected = 0, kCertRequired = 2 };
|
| +
|
| SSLServerSocketTest()
|
| : socket_factory_(ClientSocketFactory::GetDefaultFactory()),
|
| cert_verifier_(new MockCertVerifier()),
|
| + client_cert_verifier_(new MockClientCertVerifier()),
|
| transport_security_state_(new TransportSecurityState) {
|
| cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID);
|
| + client_cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID);
|
| }
|
|
|
| protected:
|
| @@ -313,25 +397,11 @@ class SSLServerSocketTest : public PlatformTest {
|
| scoped_ptr<StreamSocket> server_socket(
|
| new FakeSocket(&channel_2_, &channel_1_));
|
|
|
| - base::FilePath certs_dir(GetTestCertsDirectory());
|
| -
|
| - base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
|
| - std::string cert_der;
|
| - ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der));
|
| -
|
| - scoped_refptr<X509Certificate> cert =
|
| - X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
|
| -
|
| - base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
|
| - std::string key_string;
|
| - ASSERT_TRUE(base::ReadFileToString(key_path, &key_string));
|
| - std::vector<uint8> key_vector(
|
| - reinterpret_cast<const uint8*>(key_string.data()),
|
| - reinterpret_cast<const uint8*>(key_string.data() +
|
| - key_string.length()));
|
| -
|
| - scoped_ptr<crypto::RSAPrivateKey> private_key(
|
| - crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
|
| + std::string server_cert_der;
|
| + scoped_refptr<X509Certificate> server_cert(
|
| + ReadTestCert("unittest.selfsigned.der", &server_cert_der));
|
| + scoped_ptr<crypto::RSAPrivateKey> server_private_key(
|
| + ReadTestKey("unittest.key.bin"));
|
|
|
| client_ssl_config_.false_start_enabled = false;
|
| client_ssl_config_.channel_id_enabled = false;
|
| @@ -339,29 +409,97 @@ class SSLServerSocketTest : public PlatformTest {
|
| // Certificate provided by the host doesn't need authority.
|
| SSLConfig::CertAndStatus cert_and_status;
|
| cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID;
|
| - cert_and_status.der_cert = cert_der;
|
| + cert_and_status.der_cert = server_cert_der;
|
| client_ssl_config_.allowed_bad_certs.push_back(cert_and_status);
|
|
|
| HostPortPair host_and_pair("unittest", 0);
|
| SSLClientSocketContext context;
|
| context.cert_verifier = cert_verifier_.get();
|
| context.transport_security_state = transport_security_state_.get();
|
| + socket_factory_->ClearSSLSessionCache();
|
| client_socket_ = socket_factory_->CreateSSLClientSocket(
|
| client_connection.Pass(), host_and_pair, client_ssl_config_, context);
|
| - server_socket_ =
|
| - CreateSSLServerSocket(server_socket.Pass(), cert.get(),
|
| - private_key.get(), server_ssl_config_);
|
| + server_socket_ = CreateSSLServerSocket(
|
| + server_socket.Pass(), server_cert.get(), server_private_key.get(),
|
| + server_ssl_config_, server_context_);
|
| + }
|
| +
|
| + void InitializeClientCertsForClient(ClientCertSupply supply) {
|
| + scoped_refptr<X509Certificate> cert;
|
| + scoped_ptr<net::SSLPrivateKey> key;
|
| + if (supply != kNoneSupplied) {
|
| + const char* cert_file_name = supply == kCorrectCertSupplied
|
| + ? kClientCertFileName
|
| + : kWrongClientCertFileName;
|
| + const char* private_key_file_name = supply == kCorrectCertSupplied
|
| + ? kClientPrivateKeyFileName
|
| + : kWrongClientPrivateKeyFileName;
|
| + cert = ImportCertFromFile(GetTestCertsDirectory(), cert_file_name);
|
| + key.reset(new TestSSLPrivateKey(ReadTestKey(private_key_file_name)));
|
| + }
|
| + client_socket_->ForceClientCertificateAndKeyForTesting(cert, key.Pass());
|
| + }
|
| +
|
| + void InitializeClientCertsForServer(ClientCertExpect expect) {
|
| + if (expect == kNoneExpected)
|
| + return;
|
| +
|
| + server_ssl_config_.require_client_cert = true;
|
| +
|
| + if (expect == kCertRequired) {
|
| + scoped_refptr<X509Certificate> expected_client_ca_cert(
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertCAFileName));
|
| + CertificateList ca_list;
|
| + ca_list.push_back(expected_client_ca_cert);
|
| + server_ssl_config_.client_cert_ca_list = ca_list;
|
| + scoped_refptr<X509Certificate> expected_client_cert(
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName));
|
| + CertVerifyResult ignored;
|
| + ignored.verified_cert = expected_client_cert;
|
| + ignored.cert_status = 0;
|
| + client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK);
|
| +
|
| + server_context_.client_cert_verifier = client_cert_verifier_.get();
|
| + }
|
| + }
|
| +
|
| + X509Certificate* ReadTestCert(const base::StringPiece& name,
|
| + std::string* cert_der) {
|
| + base::FilePath certs_dir(GetTestCertsDirectory());
|
| + base::FilePath cert_path = certs_dir.AppendASCII(name);
|
| + std::string unneeded;
|
| + if (!cert_der)
|
| + cert_der = &unneeded;
|
| + if (!base::ReadFileToString(cert_path, cert_der))
|
| + return NULL;
|
| + return X509Certificate::CreateFromBytes(cert_der->data(), cert_der->size());
|
| + }
|
| +
|
| + crypto::RSAPrivateKey* ReadTestKey(const base::StringPiece& name) {
|
| + base::FilePath certs_dir(GetTestCertsDirectory());
|
| + base::FilePath key_path = certs_dir.AppendASCII(name);
|
| + std::string key_string;
|
| + if (!base::ReadFileToString(key_path, &key_string))
|
| + return NULL;
|
| + std::vector<uint8> key_vector(
|
| + reinterpret_cast<const uint8*>(key_string.data()),
|
| + reinterpret_cast<const uint8*>(key_string.data() +
|
| + key_string.length()));
|
| + return crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector);
|
| }
|
|
|
| FakeDataChannel channel_1_;
|
| FakeDataChannel channel_2_;
|
| SSLConfig client_ssl_config_;
|
| SSLServerConfig server_ssl_config_;
|
| + SSLServerSocketContext server_context_;
|
| scoped_ptr<SSLClientSocket> client_socket_;
|
| scoped_ptr<SSLServerSocket> server_socket_;
|
| ClientSocketFactory* socket_factory_;
|
| scoped_ptr<MockCertVerifier> cert_verifier_;
|
| + scoped_ptr<MockClientCertVerifier> client_cert_verifier_;
|
| scoped_ptr<TransportSecurityState> transport_security_state_;
|
| + CertificateList trusted_certs_;
|
| };
|
|
|
| // This test only executes creation of client and server sockets. This is to
|
| @@ -411,6 +549,122 @@ TEST_F(SSLServerSocketTest, Handshake) {
|
| EXPECT_TRUE(is_aead);
|
| }
|
|
|
| +// TODO(dougsteed) The following tests using client certificates cannot
|
| +// be performed if NSS with platform-based client auth is in use. That's because
|
| +// the tests use SSLClientSocket to make requests against the server, and on
|
| +// those builds, that class does not support supplying of a test key and cert.
|
| +// An alternative approach that would broaden the applicability of these tests
|
| +// would be to build and use the openssl flavor of SSLClientSocket, even
|
| +// on NSS platforms.
|
| +#if defined(USE_OPENSSL) || !defined(NSS_PLATFORM_CLIENT_AUTH)
|
| +
|
| +// This test executes Connect() on SSLClientSocket and Handshake() on
|
| +// SSLServerSocket to make sure handshaking between the two sockets is
|
| +// completed successfully, using client certificate.
|
| +TEST_F(SSLServerSocketTest, HandshakeWithClientCert) {
|
| + scoped_refptr<X509Certificate> client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
|
| + InitializeClientCertsForServer(kCertRequired);
|
| + Initialize();
|
| + InitializeClientCertsForClient(kCorrectCertSupplied);
|
| +
|
| + TestCompletionCallback connect_callback;
|
| + TestCompletionCallback handshake_callback;
|
| +
|
| + int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| + EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING);
|
| +
|
| + int client_ret = client_socket_->Connect(connect_callback.callback());
|
| + EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING);
|
| +
|
| + if (client_ret == ERR_IO_PENDING) {
|
| + EXPECT_EQ(OK, connect_callback.WaitForResult());
|
| + }
|
| + if (server_ret == ERR_IO_PENDING) {
|
| + EXPECT_EQ(OK, handshake_callback.WaitForResult());
|
| + }
|
| +
|
| + // Make sure the cert status is expected.
|
| + SSLInfo ssl_info;
|
| + client_socket_->GetSSLInfo(&ssl_info);
|
| + EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status);
|
| + server_socket_->GetSSLInfo(&ssl_info);
|
| + EXPECT_TRUE(ssl_info.client_cert_sent);
|
| + EXPECT_TRUE(ssl_info.client_cert_sent);
|
| + EXPECT_TRUE(ssl_info.cert.get());
|
| + EXPECT_TRUE(client_cert->Equals(ssl_info.cert.get()));
|
| +}
|
| +
|
| +TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) {
|
| + InitializeClientCertsForServer(kCertRequired);
|
| + Initialize();
|
| + // We use the default setting for the client socket. This causes the client to
|
| + // get SSL_CLIENT_AUTH_CERT_NEEDED. This code path allows us to access the
|
| + // cert_authorities from the CertificateRequest.
|
| +
|
| + TestCompletionCallback connect_callback;
|
| + TestCompletionCallback handshake_callback;
|
| + int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| + EXPECT_TRUE(server_ret == ERR_IO_PENDING);
|
| +
|
| + int client_ret = client_socket_->Connect(connect_callback.callback());
|
| + EXPECT_TRUE(client_ret == ERR_SSL_CLIENT_AUTH_CERT_NEEDED ||
|
| + client_ret == ERR_IO_PENDING);
|
| +
|
| + if (client_ret == ERR_IO_PENDING) {
|
| + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED,
|
| + connect_callback.WaitForResult());
|
| + }
|
| +
|
| + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo();
|
| + client_socket_->GetSSLCertRequestInfo(request_info.get());
|
| +
|
| + // Check that the authority name that arrived in the CertificateRequest
|
| + // handshake message is as expected.
|
| + scoped_refptr<X509Certificate> client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
|
| + EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities));
|
| +
|
| + client_socket_->Disconnect();
|
| +
|
| + if (server_ret == ERR_IO_PENDING) {
|
| + server_ret = handshake_callback.WaitForResult();
|
| + EXPECT_TRUE(server_ret == ERR_CONNECTION_CLOSED ||
|
| + server_ret == ERR_FAILED);
|
| + }
|
| +}
|
| +
|
| +TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) {
|
| + scoped_refptr<X509Certificate> client_cert =
|
| + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName);
|
| + InitializeClientCertsForServer(kCertRequired);
|
| + Initialize();
|
| + InitializeClientCertsForClient(kWrongCertSupplied);
|
| +
|
| + TestCompletionCallback connect_callback;
|
| + TestCompletionCallback handshake_callback;
|
| +
|
| + int server_ret = server_socket_->Handshake(handshake_callback.callback());
|
| + EXPECT_TRUE(server_ret == ERR_IO_PENDING);
|
| +
|
| + int client_ret = client_socket_->Connect(connect_callback.callback());
|
| + EXPECT_TRUE(client_ret == ERR_BAD_SSL_CLIENT_AUTH_CERT ||
|
| + client_ret == ERR_IO_PENDING);
|
| +
|
| + if (client_ret == ERR_IO_PENDING) {
|
| + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, connect_callback.WaitForResult());
|
| + }
|
| +
|
| + server_ret = handshake_callback.WaitForResult();
|
| + // We get a different result on NSS and OpenSSL. That's because an error
|
| + // mapping with OpenSSL makes an assumption that is true for SSLClientSocket
|
| + // but not SSLServerSocket (namely that peer cert rejection only occurs due to
|
| + // a cert change during renego).
|
| + EXPECT_TRUE(server_ret == ERR_BAD_SSL_CLIENT_AUTH_CERT ||
|
| + server_ret == ERR_SSL_SERVER_CERT_CHANGED);
|
| +}
|
| +#endif // defined(USE_OPENSSL) || !defined(NSS_PLATFORM_CLIENT_AUTH)
|
| +
|
| TEST_F(SSLServerSocketTest, DataTransfer) {
|
| Initialize();
|
|
|
|
|