| Index: net/socket/ssl_client_socket_unittest.cc
|
| diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc
|
| index d7d28bc556e857420d05c5470be0ff84fa6c6d71..daa12b42d8da2df9a479feb091f3b0837bda4fb7 100644
|
| --- a/net/socket/ssl_client_socket_unittest.cc
|
| +++ b/net/socket/ssl_client_socket_unittest.cc
|
| @@ -23,8 +23,11 @@
|
| #include "base/threading/thread_task_runner_handle.h"
|
| #include "base/time/time.h"
|
| #include "base/values.h"
|
| +#include "crypto/rsa_private_key.h"
|
| #include "net/base/address_list.h"
|
| #include "net/base/io_buffer.h"
|
| +#include "net/base/ip_address.h"
|
| +#include "net/base/ip_endpoint.h"
|
| #include "net/base/net_errors.h"
|
| #include "net/base/test_completion_callback.h"
|
| #include "net/cert/asn1_util.h"
|
| @@ -48,14 +51,17 @@
|
| #include "net/socket/client_socket_factory.h"
|
| #include "net/socket/client_socket_handle.h"
|
| #include "net/socket/socket_test_util.h"
|
| +#include "net/socket/ssl_server_socket.h"
|
| #include "net/socket/stream_socket.h"
|
| #include "net/socket/tcp_client_socket.h"
|
| +#include "net/socket/tcp_server_socket.h"
|
| #include "net/ssl/channel_id_service.h"
|
| #include "net/ssl/default_channel_id_store.h"
|
| #include "net/ssl/ssl_cert_request_info.h"
|
| #include "net/ssl/ssl_config_service.h"
|
| #include "net/ssl/ssl_connection_status_flags.h"
|
| #include "net/ssl/ssl_info.h"
|
| +#include "net/ssl/ssl_server_config.h"
|
| #include "net/ssl/test_ssl_private_key.h"
|
| #include "net/test/cert_test_util.h"
|
| #include "net/test/gtest_util.h"
|
| @@ -3205,11 +3211,7 @@ TEST_F(SSLClientSocketTest, AlpnClientDisabled) {
|
|
|
| namespace {
|
|
|
| -// Loads a PEM-encoded private key file into a SSLPrivateKey object.
|
| -// |filepath| is the private key file path.
|
| -// Returns the new SSLPrivateKey.
|
| -scoped_refptr<SSLPrivateKey> LoadPrivateKeyOpenSSL(
|
| - const base::FilePath& filepath) {
|
| +bssl::UniquePtr<EVP_PKEY> LoadEVP_PKEY(const base::FilePath& filepath) {
|
| std::string data;
|
| if (!base::ReadFileToString(filepath, &data)) {
|
| LOG(ERROR) << "Could not read private key file: " << filepath.value();
|
| @@ -3227,7 +3229,18 @@ scoped_refptr<SSLPrivateKey> LoadPrivateKeyOpenSSL(
|
| LOG(ERROR) << "Could not decode private key file: " << filepath.value();
|
| return nullptr;
|
| }
|
| - return WrapOpenSSLPrivateKey(std::move(result));
|
| + return result;
|
| +}
|
| +
|
| +// Loads a PEM-encoded private key file into a SSLPrivateKey object.
|
| +// |filepath| is the private key file path.
|
| +// Returns the new SSLPrivateKey.
|
| +scoped_refptr<SSLPrivateKey> LoadPrivateKeyOpenSSL(
|
| + const base::FilePath& filepath) {
|
| + bssl::UniquePtr<EVP_PKEY> key = LoadEVP_PKEY(filepath);
|
| + if (!key)
|
| + return nullptr;
|
| + return WrapOpenSSLPrivateKey(std::move(key));
|
| }
|
|
|
| } // namespace
|
| @@ -3793,4 +3806,80 @@ TEST_P(SSLClientSocketReadTest, DumpMemoryStats) {
|
| EXPECT_LT(0u, stats2.cert_size);
|
| }
|
|
|
| +TEST_P(SSLClientSocketReadTest, IdleAfterRead) {
|
| + // Set up a TCP server.
|
| + TCPServerSocket server_listener(NULL, NetLogSource());
|
| + ASSERT_THAT(
|
| + server_listener.Listen(IPEndPoint(IPAddress::IPv4Localhost(), 0), 1),
|
| + IsOk());
|
| + IPEndPoint server_address;
|
| + ASSERT_THAT(server_listener.GetLocalAddress(&server_address), IsOk());
|
| +
|
| + // Connect a TCP client and server socket.
|
| + TestCompletionCallback server_callback;
|
| + std::unique_ptr<StreamSocket> server_transport;
|
| + int server_rv =
|
| + server_listener.Accept(&server_transport, server_callback.callback());
|
| +
|
| + TestCompletionCallback client_callback;
|
| + std::unique_ptr<TCPClientSocket> client_transport(new TCPClientSocket(
|
| + AddressList(server_address), NULL, NULL, NetLogSource()));
|
| + int client_rv = client_transport->Connect(client_callback.callback());
|
| +
|
| + EXPECT_THAT(server_callback.GetResult(server_rv), IsOk());
|
| + EXPECT_THAT(client_callback.GetResult(client_rv), IsOk());
|
| +
|
| + // Set up an SSL server.
|
| + base::FilePath certs_dir = GetTestCertsDirectory();
|
| + scoped_refptr<net::X509Certificate> cert =
|
| + ImportCertFromFile(certs_dir, "ok_cert.pem");
|
| + ASSERT_TRUE(cert);
|
| + bssl::UniquePtr<EVP_PKEY> pkey =
|
| + LoadEVP_PKEY(certs_dir.AppendASCII("ok_cert.pem"));
|
| + ASSERT_TRUE(pkey);
|
| + std::unique_ptr<crypto::RSAPrivateKey> key =
|
| + crypto::RSAPrivateKey::CreateFromKey(pkey.get());
|
| + ASSERT_TRUE(key);
|
| + std::unique_ptr<SSLServerContext> server_context =
|
| + CreateSSLServerContext(cert.get(), *key.get(), SSLServerConfig());
|
| +
|
| + // Complete the SSL handshake on both sides.
|
| + std::unique_ptr<SSLClientSocket> client(CreateSSLClientSocket(
|
| + std::move(client_transport), HostPortPair::FromIPEndPoint(server_address),
|
| + SSLConfig()));
|
| + std::unique_ptr<SSLServerSocket> server(
|
| + server_context->CreateSSLServerSocket(std::move(server_transport)));
|
| +
|
| + server_rv = server->Handshake(server_callback.callback());
|
| + client_rv = client->Connect(client_callback.callback());
|
| +
|
| + EXPECT_THAT(server_callback.GetResult(server_rv), IsOk());
|
| + EXPECT_THAT(client_callback.GetResult(client_rv), IsOk());
|
| +
|
| + // Write a single record on the server.
|
| + scoped_refptr<IOBuffer> write_buf(new StringIOBuffer("a"));
|
| + server_rv = server->Write(write_buf.get(), 1, server_callback.callback());
|
| +
|
| + // Read that record on the server, but with a much larger buffer than
|
| + // necessary.
|
| + scoped_refptr<IOBuffer> read_buf(new IOBuffer(1024));
|
| + client_rv =
|
| + Read(client.get(), read_buf.get(), 1024, client_callback.callback());
|
| +
|
| + EXPECT_EQ(1, server_callback.GetResult(server_rv));
|
| + EXPECT_EQ(1, WaitForReadCompletion(client.get(), read_buf.get(), 1024,
|
| + &client_callback, client_rv));
|
| +
|
| + // At this point the client socket should be idle.
|
| + EXPECT_TRUE(client->IsConnectedAndIdle());
|
| +
|
| + // The read buffer should be released.
|
| + StreamSocket::SocketMemoryStats stats;
|
| + client->DumpMemoryStats(&stats);
|
| + EXPECT_EQ(0u, stats.buffer_size);
|
| + EXPECT_EQ(1u, stats.cert_count);
|
| + EXPECT_LT(0u, stats.cert_size);
|
| + EXPECT_EQ(stats.cert_size, stats.total_size);
|
| +}
|
| +
|
| } // namespace net
|
|
|