Chromium Code Reviews| Index: net/socket/ssl_client_socket_unittest.cc |
| diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc |
| index 89356bfad20c6117f770fc8be132e70c36146e98..265eabcfa84db8ce56435eb51222ac49fe599d9f 100644 |
| --- a/net/socket/ssl_client_socket_unittest.cc |
| +++ b/net/socket/ssl_client_socket_unittest.cc |
| @@ -4,6 +4,7 @@ |
| #include "net/socket/ssl_client_socket.h" |
| +#include "base/callback_helpers.h" |
| #include "base/memory/ref_counted.h" |
| #include "net/base/address_list.h" |
| #include "net/base/cert_test_util.h" |
| @@ -28,8 +29,207 @@ |
| //----------------------------------------------------------------------------- |
| +namespace { |
| + |
| const net::SSLConfig kDefaultSSLConfig; |
| +// ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that |
| +// will ensure a certain amount of data is internally buffered before |
| +// satisfying a Read() request. It exists to mimic OS-level internal |
| +// buffering, but in a way to guarantee that X number of bytes will be |
| +// returned to callers of Read(), regardless of how quickly the OS receives |
| +// them from the TestServer. |
| +class ReadBufferingStreamSocket : public net::StreamSocket { |
| + public: |
| + explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport); |
| + virtual ~ReadBufferingStreamSocket() {} |
| + |
| + // Sets the internal buffer to |size|. This must not be greater than |
| + // the largest value supplied to Read() - that is, it does not handle |
| + // having "leftovers" at the end of Read(). |
| + // Each call to Read() will be prevented from completion until at least |
| + // |size| data has been read. |
| + // Set to 0 to turn off buffering, causing Read() to transparently |
| + // read via the underlying transport. |
| + void SetBufferSize(int size); |
| + |
| + // StreamSocket implementation: |
| + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { |
| + return transport_->Connect(callback); |
| + } |
| + virtual void Disconnect() OVERRIDE { |
| + transport_->Disconnect(); |
| + } |
| + virtual bool IsConnected() const OVERRIDE { |
| + return transport_->IsConnected(); |
| + } |
| + virtual bool IsConnectedAndIdle() const OVERRIDE { |
| + return transport_->IsConnectedAndIdle(); |
| + } |
| + virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { |
| + return transport_->GetPeerAddress(address); |
| + } |
| + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { |
| + return transport_->GetLocalAddress(address); |
| + } |
| + virtual const net::BoundNetLog& NetLog() const OVERRIDE { |
| + return transport_->NetLog(); |
| + } |
| + virtual void SetSubresourceSpeculation() OVERRIDE { |
| + transport_->SetSubresourceSpeculation(); |
| + } |
| + virtual void SetOmniboxSpeculation() OVERRIDE { |
| + transport_->SetOmniboxSpeculation(); |
| + } |
| + virtual bool WasEverUsed() const OVERRIDE { |
| + return transport_->WasEverUsed(); |
| + } |
| + virtual bool UsingTCPFastOpen() const OVERRIDE { |
| + return transport_->UsingTCPFastOpen(); |
| + } |
| + virtual int64 NumBytesRead() const OVERRIDE { |
| + return transport_->NumBytesRead(); |
| + } |
| + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { |
| + return transport_->GetConnectTimeMicros(); |
| + } |
| + virtual bool WasNpnNegotiated() const OVERRIDE { |
| + return transport_->WasNpnNegotiated(); |
| + } |
| + virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { |
| + return transport_->GetNegotiatedProtocol(); |
| + } |
| + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { |
| + return transport_->GetSSLInfo(ssl_info); |
| + } |
| + |
| + // Socket implementation: |
| + virtual int Read(net::IOBuffer* buf, int buf_len, |
| + const net::CompletionCallback& callback) OVERRIDE; |
| + virtual int Write(net::IOBuffer* buf, int buf_len, |
| + const net::CompletionCallback& callback) OVERRIDE { |
| + return transport_->Write(buf, buf_len, callback); |
| + } |
| + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
| + return transport_->SetReceiveBufferSize(size); |
| + } |
| + virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
| + return transport_->SetSendBufferSize(size); |
| + } |
| + |
| + private: |
| + enum State { |
| + STATE_NONE, |
| + STATE_READ, |
| + STATE_READ_COMPLETE, |
| + }; |
| + |
| + int DoLoop(int result); |
| + int DoRead(); |
| + int DoReadComplete(int result); |
| + void OnReadCompleted(int result); |
| + |
| + State state_; |
| + scoped_ptr<net::StreamSocket> transport_; |
| + scoped_refptr<net::GrowableIOBuffer> read_buffer_; |
| + int buffer_size_; |
| + |
| + scoped_refptr<net::IOBuffer> user_read_buf_; |
| + net::CompletionCallback user_read_callback_; |
| +}; |
| + |
| +ReadBufferingStreamSocket::ReadBufferingStreamSocket( |
| + scoped_ptr<net::StreamSocket> transport) |
| + : transport_(transport.Pass()), |
| + read_buffer_(new net::GrowableIOBuffer()), |
| + buffer_size_(0) { |
| +} |
| + |
| +void ReadBufferingStreamSocket::SetBufferSize(int size) { |
| + DCHECK(!user_read_buf_); |
| + buffer_size_ = size; |
| + read_buffer_->SetCapacity(size); |
| +} |
| + |
| +int ReadBufferingStreamSocket::Read(net::IOBuffer* buf, |
| + int buf_len, |
| + const net::CompletionCallback& callback) { |
| + if (buffer_size_ == 0) |
| + return transport_->Read(buf, buf_len, callback); |
| + |
| + if (buf_len < buffer_size_) |
| + return net::ERR_UNEXPECTED; |
| + |
| + state_ = STATE_READ; |
| + user_read_buf_ = buf; |
| + int result = DoLoop(net::OK); |
| + if (result == net::ERR_IO_PENDING) |
| + user_read_callback_ = callback; |
| + else |
| + user_read_buf_ = NULL; |
| + return result; |
| +} |
| + |
| +int ReadBufferingStreamSocket::DoLoop(int result) { |
| + int rv = result; |
| + do { |
| + State current_state = state_; |
| + state_ = STATE_NONE; |
| + switch (current_state) { |
| + case STATE_READ: |
| + rv = DoRead(); |
| + break; |
| + case STATE_READ_COMPLETE: |
| + rv = DoReadComplete(rv); |
| + break; |
| + case STATE_NONE: |
| + default: |
| + NOTREACHED() << "Unexpected state: " << current_state; |
| + rv = net::ERR_UNEXPECTED; |
| + break; |
| + } |
| + } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE); |
| + return rv; |
| +} |
| + |
| +int ReadBufferingStreamSocket::DoRead() { |
| + state_ = STATE_READ_COMPLETE; |
| + int rv = transport_->Read( |
| + read_buffer_, |
| + read_buffer_->RemainingCapacity(), |
| + base::Bind(&ReadBufferingStreamSocket::OnReadCompleted, |
| + base::Unretained(this))); |
| + return rv; |
| +} |
| + |
| +int ReadBufferingStreamSocket::DoReadComplete(int result) { |
| + state_ = STATE_NONE; |
| + if (result <= 0) |
| + return result; |
| + |
| + read_buffer_->set_offset(read_buffer_->offset() + result); |
| + if (read_buffer_->RemainingCapacity() > 0) { |
| + state_ = STATE_READ; |
| + return net::OK; |
| + } |
| + |
| + memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(), |
| + read_buffer_->capacity()); |
| + read_buffer_->set_offset(0); |
| + return read_buffer_->capacity(); |
| +} |
| + |
| +void ReadBufferingStreamSocket::OnReadCompleted(int result) { |
| + result = DoLoop(result); |
| + if (result == net::ERR_IO_PENDING) |
| + return; |
| + |
| + user_read_buf_ = NULL; |
| + base::ResetAndReturn(&user_read_callback_).Run(result); |
| +} |
| + |
| +} // namespace |
| + |
| class SSLClientSocketTest : public PlatformTest { |
| public: |
| SSLClientSocketTest() |
| @@ -499,6 +699,58 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { |
| } |
| } |
| +TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { |
| + net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
| + net::TestServer::kLocalhost, |
| + base::FilePath()); |
| + ASSERT_TRUE(test_server.Start()); |
| + |
| + net::AddressList addr; |
| + ASSERT_TRUE(test_server.GetAddressList(&addr)); |
| + |
| + net::TestCompletionCallback callback; |
| + |
| + scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( |
| + addr, NULL, net::NetLog::Source())); |
| + ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket( |
| + real_transport.Pass()); |
| + int rv = callback.GetResult(transport->Connect(callback.callback())); |
| + ASSERT_EQ(net::OK, rv); |
| + |
| + scoped_ptr<net::SSLClientSocket> sock( |
| + CreateSSLClientSocket(transport, test_server.host_port_pair(), |
| + kDefaultSSLConfig)); |
| + |
| + rv = callback.GetResult(sock->Connect(callback.callback())); |
| + ASSERT_EQ(net::OK, rv); |
| + ASSERT_TRUE(sock->IsConnected()); |
| + |
| + const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; |
| + scoped_refptr<net::IOBuffer> request_buffer( |
| + new net::IOBuffer(arraysize(request_text) - 1)); |
| + memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); |
| + |
| + rv = callback.GetResult(sock->Write( |
| + request_buffer, arraysize(request_text) - 1, callback.callback())); |
| + ASSERT_GT(rv, 0); |
| + ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); |
| + |
| + // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of |
| + // data (the max SSL record size) at a time. Ensure that at least 15K worth |
| + // of SSL data is buffered first. The 15K of buffered data is made up of |
| + // many smaller SSL records (the TestServer writes along 1350 byte |
| + // plaintext boundaries), although there may also be a few records that are |
| + // smaller or larger, due to timing and SSL False Start. |
| + // 15K was chosen because 15K is smaller than the 17K (max) read issued by |
| + // the SSLClientSocket implementation, and larger than the minimum amount |
| + // of ciphertext necessary to contain the 8K of plaintext requested below. |
|
Ryan Hamilton
2013/02/13 23:17:43
nice comment. very clear now.
|
| + transport->SetBufferSize(15000); |
| + |
| + scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192)); |
| + rv = callback.GetResult(sock->Read(buffer, 8192, callback.callback())); |
| + ASSERT_EQ(rv, 8192); |
| +} |
| + |
| TEST_F(SSLClientSocketTest, Read_Interrupted) { |
| net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
| net::TestServer::kLocalhost, |
| @@ -667,8 +919,6 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { |
| EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); |
| } |
| -// TODO(rsleevi): Not implemented for Schannel. As Schannel is only used when |
| -// performing client authentication, it will not be tested here. |
| TEST_F(SSLClientSocketTest, CipherSuiteDisables) { |
| // Rather than exhaustively disabling every RC4 ciphersuite defined at |
| // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, |