| Index: net/socket/ssl_client_socket_unittest.cc
|
| diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc
|
| index 89356bfad20c6117f770fc8be132e70c36146e98..265eabcfa84db8ce56435eb51222ac49fe599d9f 100644
|
| --- a/net/socket/ssl_client_socket_unittest.cc
|
| +++ b/net/socket/ssl_client_socket_unittest.cc
|
| @@ -4,6 +4,7 @@
|
|
|
| #include "net/socket/ssl_client_socket.h"
|
|
|
| +#include "base/callback_helpers.h"
|
| #include "base/memory/ref_counted.h"
|
| #include "net/base/address_list.h"
|
| #include "net/base/cert_test_util.h"
|
| @@ -28,8 +29,207 @@
|
|
|
| //-----------------------------------------------------------------------------
|
|
|
| +namespace {
|
| +
|
| const net::SSLConfig kDefaultSSLConfig;
|
|
|
| +// ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that
|
| +// will ensure a certain amount of data is internally buffered before
|
| +// satisfying a Read() request. It exists to mimic OS-level internal
|
| +// buffering, but in a way to guarantee that X number of bytes will be
|
| +// returned to callers of Read(), regardless of how quickly the OS receives
|
| +// them from the TestServer.
|
| +class ReadBufferingStreamSocket : public net::StreamSocket {
|
| + public:
|
| + explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport);
|
| + virtual ~ReadBufferingStreamSocket() {}
|
| +
|
| + // Sets the internal buffer to |size|. This must not be greater than
|
| + // the largest value supplied to Read() - that is, it does not handle
|
| + // having "leftovers" at the end of Read().
|
| + // Each call to Read() will be prevented from completion until at least
|
| + // |size| data has been read.
|
| + // Set to 0 to turn off buffering, causing Read() to transparently
|
| + // read via the underlying transport.
|
| + void SetBufferSize(int size);
|
| +
|
| + // StreamSocket implementation:
|
| + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE {
|
| + return transport_->Connect(callback);
|
| + }
|
| + virtual void Disconnect() OVERRIDE {
|
| + transport_->Disconnect();
|
| + }
|
| + virtual bool IsConnected() const OVERRIDE {
|
| + return transport_->IsConnected();
|
| + }
|
| + virtual bool IsConnectedAndIdle() const OVERRIDE {
|
| + return transport_->IsConnectedAndIdle();
|
| + }
|
| + virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE {
|
| + return transport_->GetPeerAddress(address);
|
| + }
|
| + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE {
|
| + return transport_->GetLocalAddress(address);
|
| + }
|
| + virtual const net::BoundNetLog& NetLog() const OVERRIDE {
|
| + return transport_->NetLog();
|
| + }
|
| + virtual void SetSubresourceSpeculation() OVERRIDE {
|
| + transport_->SetSubresourceSpeculation();
|
| + }
|
| + virtual void SetOmniboxSpeculation() OVERRIDE {
|
| + transport_->SetOmniboxSpeculation();
|
| + }
|
| + virtual bool WasEverUsed() const OVERRIDE {
|
| + return transport_->WasEverUsed();
|
| + }
|
| + virtual bool UsingTCPFastOpen() const OVERRIDE {
|
| + return transport_->UsingTCPFastOpen();
|
| + }
|
| + virtual int64 NumBytesRead() const OVERRIDE {
|
| + return transport_->NumBytesRead();
|
| + }
|
| + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE {
|
| + return transport_->GetConnectTimeMicros();
|
| + }
|
| + virtual bool WasNpnNegotiated() const OVERRIDE {
|
| + return transport_->WasNpnNegotiated();
|
| + }
|
| + virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE {
|
| + return transport_->GetNegotiatedProtocol();
|
| + }
|
| + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE {
|
| + return transport_->GetSSLInfo(ssl_info);
|
| + }
|
| +
|
| + // Socket implementation:
|
| + virtual int Read(net::IOBuffer* buf, int buf_len,
|
| + const net::CompletionCallback& callback) OVERRIDE;
|
| + virtual int Write(net::IOBuffer* buf, int buf_len,
|
| + const net::CompletionCallback& callback) OVERRIDE {
|
| + return transport_->Write(buf, buf_len, callback);
|
| + }
|
| + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
|
| + return transport_->SetReceiveBufferSize(size);
|
| + }
|
| + virtual bool SetSendBufferSize(int32 size) OVERRIDE {
|
| + return transport_->SetSendBufferSize(size);
|
| + }
|
| +
|
| + private:
|
| + enum State {
|
| + STATE_NONE,
|
| + STATE_READ,
|
| + STATE_READ_COMPLETE,
|
| + };
|
| +
|
| + int DoLoop(int result);
|
| + int DoRead();
|
| + int DoReadComplete(int result);
|
| + void OnReadCompleted(int result);
|
| +
|
| + State state_;
|
| + scoped_ptr<net::StreamSocket> transport_;
|
| + scoped_refptr<net::GrowableIOBuffer> read_buffer_;
|
| + int buffer_size_;
|
| +
|
| + scoped_refptr<net::IOBuffer> user_read_buf_;
|
| + net::CompletionCallback user_read_callback_;
|
| +};
|
| +
|
| +ReadBufferingStreamSocket::ReadBufferingStreamSocket(
|
| + scoped_ptr<net::StreamSocket> transport)
|
| + : transport_(transport.Pass()),
|
| + read_buffer_(new net::GrowableIOBuffer()),
|
| + buffer_size_(0) {
|
| +}
|
| +
|
| +void ReadBufferingStreamSocket::SetBufferSize(int size) {
|
| + DCHECK(!user_read_buf_);
|
| + buffer_size_ = size;
|
| + read_buffer_->SetCapacity(size);
|
| +}
|
| +
|
| +int ReadBufferingStreamSocket::Read(net::IOBuffer* buf,
|
| + int buf_len,
|
| + const net::CompletionCallback& callback) {
|
| + if (buffer_size_ == 0)
|
| + return transport_->Read(buf, buf_len, callback);
|
| +
|
| + if (buf_len < buffer_size_)
|
| + return net::ERR_UNEXPECTED;
|
| +
|
| + state_ = STATE_READ;
|
| + user_read_buf_ = buf;
|
| + int result = DoLoop(net::OK);
|
| + if (result == net::ERR_IO_PENDING)
|
| + user_read_callback_ = callback;
|
| + else
|
| + user_read_buf_ = NULL;
|
| + return result;
|
| +}
|
| +
|
| +int ReadBufferingStreamSocket::DoLoop(int result) {
|
| + int rv = result;
|
| + do {
|
| + State current_state = state_;
|
| + state_ = STATE_NONE;
|
| + switch (current_state) {
|
| + case STATE_READ:
|
| + rv = DoRead();
|
| + break;
|
| + case STATE_READ_COMPLETE:
|
| + rv = DoReadComplete(rv);
|
| + break;
|
| + case STATE_NONE:
|
| + default:
|
| + NOTREACHED() << "Unexpected state: " << current_state;
|
| + rv = net::ERR_UNEXPECTED;
|
| + break;
|
| + }
|
| + } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE);
|
| + return rv;
|
| +}
|
| +
|
| +int ReadBufferingStreamSocket::DoRead() {
|
| + state_ = STATE_READ_COMPLETE;
|
| + int rv = transport_->Read(
|
| + read_buffer_,
|
| + read_buffer_->RemainingCapacity(),
|
| + base::Bind(&ReadBufferingStreamSocket::OnReadCompleted,
|
| + base::Unretained(this)));
|
| + return rv;
|
| +}
|
| +
|
| +int ReadBufferingStreamSocket::DoReadComplete(int result) {
|
| + state_ = STATE_NONE;
|
| + if (result <= 0)
|
| + return result;
|
| +
|
| + read_buffer_->set_offset(read_buffer_->offset() + result);
|
| + if (read_buffer_->RemainingCapacity() > 0) {
|
| + state_ = STATE_READ;
|
| + return net::OK;
|
| + }
|
| +
|
| + memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(),
|
| + read_buffer_->capacity());
|
| + read_buffer_->set_offset(0);
|
| + return read_buffer_->capacity();
|
| +}
|
| +
|
| +void ReadBufferingStreamSocket::OnReadCompleted(int result) {
|
| + result = DoLoop(result);
|
| + if (result == net::ERR_IO_PENDING)
|
| + return;
|
| +
|
| + user_read_buf_ = NULL;
|
| + base::ResetAndReturn(&user_read_callback_).Run(result);
|
| +}
|
| +
|
| +} // namespace
|
| +
|
| class SSLClientSocketTest : public PlatformTest {
|
| public:
|
| SSLClientSocketTest()
|
| @@ -499,6 +699,58 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) {
|
| }
|
| }
|
|
|
| +TEST_F(SSLClientSocketTest, Read_ManySmallRecords) {
|
| + net::TestServer test_server(net::TestServer::TYPE_HTTPS,
|
| + net::TestServer::kLocalhost,
|
| + base::FilePath());
|
| + ASSERT_TRUE(test_server.Start());
|
| +
|
| + net::AddressList addr;
|
| + ASSERT_TRUE(test_server.GetAddressList(&addr));
|
| +
|
| + net::TestCompletionCallback callback;
|
| +
|
| + scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket(
|
| + addr, NULL, net::NetLog::Source()));
|
| + ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket(
|
| + real_transport.Pass());
|
| + int rv = callback.GetResult(transport->Connect(callback.callback()));
|
| + ASSERT_EQ(net::OK, rv);
|
| +
|
| + scoped_ptr<net::SSLClientSocket> sock(
|
| + CreateSSLClientSocket(transport, test_server.host_port_pair(),
|
| + kDefaultSSLConfig));
|
| +
|
| + rv = callback.GetResult(sock->Connect(callback.callback()));
|
| + ASSERT_EQ(net::OK, rv);
|
| + ASSERT_TRUE(sock->IsConnected());
|
| +
|
| + const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n";
|
| + scoped_refptr<net::IOBuffer> request_buffer(
|
| + new net::IOBuffer(arraysize(request_text) - 1));
|
| + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
|
| +
|
| + rv = callback.GetResult(sock->Write(
|
| + request_buffer, arraysize(request_text) - 1, callback.callback()));
|
| + ASSERT_GT(rv, 0);
|
| + ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv);
|
| +
|
| + // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of
|
| + // data (the max SSL record size) at a time. Ensure that at least 15K worth
|
| + // of SSL data is buffered first. The 15K of buffered data is made up of
|
| + // many smaller SSL records (the TestServer writes along 1350 byte
|
| + // plaintext boundaries), although there may also be a few records that are
|
| + // smaller or larger, due to timing and SSL False Start.
|
| + // 15K was chosen because 15K is smaller than the 17K (max) read issued by
|
| + // the SSLClientSocket implementation, and larger than the minimum amount
|
| + // of ciphertext necessary to contain the 8K of plaintext requested below.
|
| + transport->SetBufferSize(15000);
|
| +
|
| + scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192));
|
| + rv = callback.GetResult(sock->Read(buffer, 8192, callback.callback()));
|
| + ASSERT_EQ(rv, 8192);
|
| +}
|
| +
|
| TEST_F(SSLClientSocketTest, Read_Interrupted) {
|
| net::TestServer test_server(net::TestServer::TYPE_HTTPS,
|
| net::TestServer::kLocalhost,
|
| @@ -667,8 +919,6 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) {
|
| EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv);
|
| }
|
|
|
| -// TODO(rsleevi): Not implemented for Schannel. As Schannel is only used when
|
| -// performing client authentication, it will not be tested here.
|
| TEST_F(SSLClientSocketTest, CipherSuiteDisables) {
|
| // Rather than exhaustively disabling every RC4 ciphersuite defined at
|
| // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml,
|
|
|