Index: net/socket/ssl_client_socket_unittest.cc |
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc |
index 89356bfad20c6117f770fc8be132e70c36146e98..6270ad337bb31c944636bcac5b5d90f9bb68c65d 100644 |
--- a/net/socket/ssl_client_socket_unittest.cc |
+++ b/net/socket/ssl_client_socket_unittest.cc |
@@ -4,6 +4,7 @@ |
#include "net/socket/ssl_client_socket.h" |
+#include "base/callback_helpers.h" |
#include "base/memory/ref_counted.h" |
#include "net/base/address_list.h" |
#include "net/base/cert_test_util.h" |
@@ -28,8 +29,207 @@ |
//----------------------------------------------------------------------------- |
+namespace { |
+ |
const net::SSLConfig kDefaultSSLConfig; |
+// ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that |
+// will ensure a certain amount of data is internally buffered before |
+// satisfying a Read() request. It exists to mimic OS-level internal |
+// buffering, but in a way to guarantee that X number of bytes will be |
+// returned to callers of Read(), regardless of how quickly the OS receives |
+// them from the TestServer. |
+class ReadBufferingStreamSocket : public net::StreamSocket { |
+ public: |
+ explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport); |
+ virtual ~ReadBufferingStreamSocket() {} |
+ |
+ // Sets the internal buffer to |size|. This must not be greater than |
+ // the largest value supplied to Read() - that is, it does not handle |
+ // having "leftovers" at the end of Read(). |
+ // Each call to Read() will be prevented from completion until at least |
+ // |size| data has been read. |
+ // Set to 0 to turn off buffering, causing Read() to transparently |
+ // read via the underlying transport. |
+ void SetBufferSize(int size); |
wtc
2013/02/13 22:06:51
Should |size| have the size_t type?
Ryan Sleevi
2013/02/13 22:55:48
In theory, yes, but consistent with the entire Soc
|
+ |
+ // StreamSocket implementation: |
+ virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { |
+ return transport_->Connect(callback); |
+ } |
+ virtual void Disconnect() OVERRIDE { |
+ transport_->Disconnect(); |
+ } |
+ virtual bool IsConnected() const OVERRIDE { |
+ return transport_->IsConnected(); |
+ } |
+ virtual bool IsConnectedAndIdle() const OVERRIDE { |
+ return transport_->IsConnectedAndIdle(); |
+ } |
+ virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { |
+ return transport_->GetPeerAddress(address); |
+ } |
+ virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { |
+ return transport_->GetLocalAddress(address); |
+ } |
+ virtual const net::BoundNetLog& NetLog() const OVERRIDE { |
+ return transport_->NetLog(); |
+ } |
+ virtual void SetSubresourceSpeculation() OVERRIDE { |
+ transport_->SetSubresourceSpeculation(); |
+ } |
+ virtual void SetOmniboxSpeculation() OVERRIDE { |
+ transport_->SetOmniboxSpeculation(); |
+ } |
+ virtual bool WasEverUsed() const OVERRIDE { |
+ return transport_->WasEverUsed(); |
+ } |
+ virtual bool UsingTCPFastOpen() const OVERRIDE { |
+ return transport_->UsingTCPFastOpen(); |
+ } |
+ virtual int64 NumBytesRead() const OVERRIDE { |
+ return transport_->NumBytesRead(); |
+ } |
+ virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { |
+ return transport_->GetConnectTimeMicros(); |
+ } |
+ virtual bool WasNpnNegotiated() const OVERRIDE { |
+ return transport_->WasNpnNegotiated(); |
+ } |
+ virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { |
+ return transport_->GetNegotiatedProtocol(); |
+ } |
+ virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { |
+ return transport_->GetSSLInfo(ssl_info); |
+ } |
+ |
+ // Socket implementation: |
+ virtual int Read(net::IOBuffer* buf, int buf_len, |
+ const net::CompletionCallback& callback) OVERRIDE; |
+ virtual int Write(net::IOBuffer* buf, int buf_len, |
+ const net::CompletionCallback& callback) OVERRIDE { |
+ return transport_->Write(buf, buf_len, callback); |
+ } |
+ virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
+ return transport_->SetReceiveBufferSize(size); |
+ } |
+ virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
+ return transport_->SetSendBufferSize(size); |
+ } |
+ |
+ private: |
+ enum State { |
+ STATE_NONE, |
+ STATE_READ, |
+ STATE_READ_COMPLETE, |
+ }; |
+ |
+ int DoLoop(int result); |
+ int DoRead(); |
+ int DoReadComplete(int result); |
+ void OnReadCompleted(int result); |
+ |
+ State state_; |
+ scoped_ptr<net::StreamSocket> transport_; |
+ scoped_refptr<net::GrowableIOBuffer> read_buffer_; |
+ int buffer_size_; |
+ |
+ scoped_refptr<net::IOBuffer> user_read_buf_; |
+ net::CompletionCallback user_read_callback_; |
+}; |
+ |
+ReadBufferingStreamSocket::ReadBufferingStreamSocket( |
+ scoped_ptr<net::StreamSocket> transport) |
+ : transport_(transport.Pass()), |
+ read_buffer_(new net::GrowableIOBuffer()), |
+ buffer_size_(0) { |
+} |
+ |
+void ReadBufferingStreamSocket::SetBufferSize(int size) { |
+ DCHECK(!user_read_buf_); |
+ buffer_size_ = size; |
+ read_buffer_->SetCapacity(size); |
+} |
+ |
+int ReadBufferingStreamSocket::Read(net::IOBuffer* buf, |
+ int buf_len, |
+ const net::CompletionCallback& callback) { |
+ if (buffer_size_ == 0) |
+ return transport_->Read(buf, buf_len, callback); |
+ |
+ if (buf_len < buffer_size_) |
+ return net::ERR_UNEXPECTED; |
+ |
+ state_ = STATE_READ; |
+ user_read_buf_ = buf; |
+ int result = DoLoop(net::OK); |
+ if (result == net::ERR_IO_PENDING) |
+ user_read_callback_ = callback; |
+ else |
+ user_read_buf_ = NULL; |
+ return result; |
+} |
+ |
+int ReadBufferingStreamSocket::DoLoop(int result) { |
+ int rv = result; |
+ do { |
+ State current_state = state_; |
+ state_ = STATE_NONE; |
+ switch (current_state) { |
+ case STATE_READ: |
+ rv = DoRead(); |
+ break; |
+ case STATE_READ_COMPLETE: |
+ rv = DoReadComplete(rv); |
+ break; |
+ case STATE_NONE: |
+ default: |
+ NOTREACHED() << "Unexpected state: " << current_state; |
+ rv = net::ERR_UNEXPECTED; |
+ break; |
+ } |
+ } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE); |
+ return rv; |
+} |
+ |
+int ReadBufferingStreamSocket::DoRead() { |
+ state_ = STATE_READ_COMPLETE; |
+ int rv = transport_->Read( |
+ read_buffer_, |
+ read_buffer_->RemainingCapacity(), |
+ base::Bind(&ReadBufferingStreamSocket::OnReadCompleted, |
+ base::Unretained(this))); |
+ return rv; |
+} |
+ |
+int ReadBufferingStreamSocket::DoReadComplete(int result) { |
+ state_ = STATE_NONE; |
+ if (result <= 0) |
+ return result; |
+ |
+ read_buffer_->set_offset(read_buffer_->offset() + result); |
+ if (read_buffer_->RemainingCapacity() > 0) { |
+ state_ = STATE_READ; |
+ return net::OK; |
+ } |
+ |
+ memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(), |
+ read_buffer_->capacity()); |
+ read_buffer_->set_offset(0); |
+ return read_buffer_->capacity(); |
+} |
+ |
+void ReadBufferingStreamSocket::OnReadCompleted(int result) { |
+ result = DoLoop(result); |
+ if (result == net::ERR_IO_PENDING) |
+ return; |
+ |
+ user_read_buf_ = NULL; |
+ base::ResetAndReturn(&user_read_callback_).Run(result); |
+} |
+ |
+} // namespace |
+ |
class SSLClientSocketTest : public PlatformTest { |
public: |
SSLClientSocketTest() |
@@ -499,6 +699,55 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { |
} |
} |
+TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { |
+ net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
Ryan Hamilton
2013/02/13 17:23:18
as you said, the proliferation of net:: is ... in
|
+ net::TestServer::kLocalhost, |
+ base::FilePath()); |
+ ASSERT_TRUE(test_server.Start()); |
+ |
+ net::AddressList addr; |
+ ASSERT_TRUE(test_server.GetAddressList(&addr)); |
+ |
+ net::TestCompletionCallback callback; |
+ |
+ scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( |
+ addr, NULL, net::NetLog::Source())); |
+ ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket( |
+ real_transport.Pass()); |
+ int rv = callback.GetResult(transport->Connect(callback.callback())); |
+ ASSERT_EQ(net::OK, rv); |
+ |
+ scoped_ptr<net::SSLClientSocket> sock( |
+ CreateSSLClientSocket(transport, test_server.host_port_pair(), |
+ kDefaultSSLConfig)); |
+ |
+ rv = callback.GetResult(sock->Connect(callback.callback())); |
+ ASSERT_EQ(net::OK, rv); |
+ ASSERT_TRUE(sock->IsConnected()); |
+ |
+ const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; |
+ scoped_refptr<net::IOBuffer> request_buffer( |
+ new net::IOBuffer(arraysize(request_text) - 1)); |
+ memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); |
+ |
+ rv = callback.GetResult(sock->Write( |
+ request_buffer, arraysize(request_text) - 1, callback.callback())); |
+ ASSERT_GT(rv, 0); |
+ ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); |
+ |
+ // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of |
+ // data (the max SSL record size) at a time. This buffer size must be larger |
+ // than the IOBuffer below, in order to ensure at least as many records |
+ // needed to satisfy the Read request are internally buffered. |
+ // This also relies on the TestServer writing records on 1350 byte |
+ // boundaries. |
+ transport->SetBufferSize(15000); |
Ryan Hamilton
2013/02/13 17:23:18
SO this line means that the transport wrapper sock
Ryan Sleevi
2013/02/13 20:44:43
Yes, it means at least 15K must be read/buffered b
Ryan Hamilton
2013/02/13 21:50:34
Sounds great.
|
+ |
+ scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192)); |
+ rv = callback.GetResult(sock->Read(buffer, 8192, callback.callback())); |
+ ASSERT_EQ(rv, 8192); |
+} |
+ |
TEST_F(SSLClientSocketTest, Read_Interrupted) { |
net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
net::TestServer::kLocalhost, |
@@ -667,8 +916,6 @@ TEST_F(SSLClientSocketTest, PrematureApplicationData) { |
EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); |
} |
-// TODO(rsleevi): Not implemented for Schannel. As Schannel is only used when |
-// performing client authentication, it will not be tested here. |
TEST_F(SSLClientSocketTest, CipherSuiteDisables) { |
// Rather than exhaustively disabling every RC4 ciphersuite defined at |
// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, |