OLD | NEW |
---|---|
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/socket/ssl_client_socket.h" | 5 #include "net/socket/ssl_client_socket.h" |
6 | 6 |
7 #include "base/callback_helpers.h" | |
7 #include "base/memory/ref_counted.h" | 8 #include "base/memory/ref_counted.h" |
8 #include "net/base/address_list.h" | 9 #include "net/base/address_list.h" |
9 #include "net/base/cert_test_util.h" | 10 #include "net/base/cert_test_util.h" |
10 #include "net/base/host_resolver.h" | 11 #include "net/base/host_resolver.h" |
11 #include "net/base/io_buffer.h" | 12 #include "net/base/io_buffer.h" |
12 #include "net/base/mock_cert_verifier.h" | 13 #include "net/base/mock_cert_verifier.h" |
13 #include "net/base/net_errors.h" | 14 #include "net/base/net_errors.h" |
14 #include "net/base/net_log.h" | 15 #include "net/base/net_log.h" |
15 #include "net/base/net_log_unittest.h" | 16 #include "net/base/net_log_unittest.h" |
16 #include "net/base/ssl_cert_request_info.h" | 17 #include "net/base/ssl_cert_request_info.h" |
17 #include "net/base/ssl_config_service.h" | 18 #include "net/base/ssl_config_service.h" |
18 #include "net/base/test_completion_callback.h" | 19 #include "net/base/test_completion_callback.h" |
19 #include "net/base/test_data_directory.h" | 20 #include "net/base/test_data_directory.h" |
20 #include "net/base/test_root_certs.h" | 21 #include "net/base/test_root_certs.h" |
21 #include "net/socket/client_socket_factory.h" | 22 #include "net/socket/client_socket_factory.h" |
22 #include "net/socket/client_socket_handle.h" | 23 #include "net/socket/client_socket_handle.h" |
23 #include "net/socket/socket_test_util.h" | 24 #include "net/socket/socket_test_util.h" |
24 #include "net/socket/tcp_client_socket.h" | 25 #include "net/socket/tcp_client_socket.h" |
25 #include "net/test/test_server.h" | 26 #include "net/test/test_server.h" |
26 #include "testing/gtest/include/gtest/gtest.h" | 27 #include "testing/gtest/include/gtest/gtest.h" |
27 #include "testing/platform_test.h" | 28 #include "testing/platform_test.h" |
28 | 29 |
29 //----------------------------------------------------------------------------- | 30 //----------------------------------------------------------------------------- |
30 | 31 |
32 namespace { | |
33 | |
31 const net::SSLConfig kDefaultSSLConfig; | 34 const net::SSLConfig kDefaultSSLConfig; |
32 | 35 |
36 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that | |
37 // will ensure a certain amount of data is internally buffered before | |
38 // satisfying a Read() request. It exists to mimic OS-level internal | |
39 // buffering, but in a way to guarantee that X number of bytes will be | |
40 // returned to callers of Read(), regardless of how quickly the OS receives | |
41 // them from the TestServer. | |
42 class ReadBufferingStreamSocket : public net::StreamSocket { | |
43 public: | |
44 explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport); | |
45 virtual ~ReadBufferingStreamSocket() {} | |
46 | |
47 // Sets the internal buffer to |size|. This must not be greater than | |
48 // the largest value supplied to Read() - that is, it does not handle | |
49 // having "leftovers" at the end of Read(). | |
50 // Each call to Read() will be prevented from completion until at least | |
51 // |size| data has been read. | |
52 // Set to 0 to turn off buffering, causing Read() to transparently | |
53 // read via the underlying transport. | |
54 void SetBufferSize(int size); | |
wtc
2013/02/13 22:06:51
Should |size| have the size_t type?
Ryan Sleevi
2013/02/13 22:55:48
In theory, yes, but consistent with the entire Soc
| |
55 | |
56 // StreamSocket implementation: | |
57 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { | |
58 return transport_->Connect(callback); | |
59 } | |
60 virtual void Disconnect() OVERRIDE { | |
61 transport_->Disconnect(); | |
62 } | |
63 virtual bool IsConnected() const OVERRIDE { | |
64 return transport_->IsConnected(); | |
65 } | |
66 virtual bool IsConnectedAndIdle() const OVERRIDE { | |
67 return transport_->IsConnectedAndIdle(); | |
68 } | |
69 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { | |
70 return transport_->GetPeerAddress(address); | |
71 } | |
72 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { | |
73 return transport_->GetLocalAddress(address); | |
74 } | |
75 virtual const net::BoundNetLog& NetLog() const OVERRIDE { | |
76 return transport_->NetLog(); | |
77 } | |
78 virtual void SetSubresourceSpeculation() OVERRIDE { | |
79 transport_->SetSubresourceSpeculation(); | |
80 } | |
81 virtual void SetOmniboxSpeculation() OVERRIDE { | |
82 transport_->SetOmniboxSpeculation(); | |
83 } | |
84 virtual bool WasEverUsed() const OVERRIDE { | |
85 return transport_->WasEverUsed(); | |
86 } | |
87 virtual bool UsingTCPFastOpen() const OVERRIDE { | |
88 return transport_->UsingTCPFastOpen(); | |
89 } | |
90 virtual int64 NumBytesRead() const OVERRIDE { | |
91 return transport_->NumBytesRead(); | |
92 } | |
93 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { | |
94 return transport_->GetConnectTimeMicros(); | |
95 } | |
96 virtual bool WasNpnNegotiated() const OVERRIDE { | |
97 return transport_->WasNpnNegotiated(); | |
98 } | |
99 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { | |
100 return transport_->GetNegotiatedProtocol(); | |
101 } | |
102 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { | |
103 return transport_->GetSSLInfo(ssl_info); | |
104 } | |
105 | |
106 // Socket implementation: | |
107 virtual int Read(net::IOBuffer* buf, int buf_len, | |
108 const net::CompletionCallback& callback) OVERRIDE; | |
109 virtual int Write(net::IOBuffer* buf, int buf_len, | |
110 const net::CompletionCallback& callback) OVERRIDE { | |
111 return transport_->Write(buf, buf_len, callback); | |
112 } | |
113 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
114 return transport_->SetReceiveBufferSize(size); | |
115 } | |
116 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
117 return transport_->SetSendBufferSize(size); | |
118 } | |
119 | |
120 private: | |
121 enum State { | |
122 STATE_NONE, | |
123 STATE_READ, | |
124 STATE_READ_COMPLETE, | |
125 }; | |
126 | |
127 int DoLoop(int result); | |
128 int DoRead(); | |
129 int DoReadComplete(int result); | |
130 void OnReadCompleted(int result); | |
131 | |
132 State state_; | |
133 scoped_ptr<net::StreamSocket> transport_; | |
134 scoped_refptr<net::GrowableIOBuffer> read_buffer_; | |
135 int buffer_size_; | |
136 | |
137 scoped_refptr<net::IOBuffer> user_read_buf_; | |
138 net::CompletionCallback user_read_callback_; | |
139 }; | |
140 | |
141 ReadBufferingStreamSocket::ReadBufferingStreamSocket( | |
142 scoped_ptr<net::StreamSocket> transport) | |
143 : transport_(transport.Pass()), | |
144 read_buffer_(new net::GrowableIOBuffer()), | |
145 buffer_size_(0) { | |
146 } | |
147 | |
148 void ReadBufferingStreamSocket::SetBufferSize(int size) { | |
149 DCHECK(!user_read_buf_); | |
150 buffer_size_ = size; | |
151 read_buffer_->SetCapacity(size); | |
152 } | |
153 | |
154 int ReadBufferingStreamSocket::Read(net::IOBuffer* buf, | |
155 int buf_len, | |
156 const net::CompletionCallback& callback) { | |
157 if (buffer_size_ == 0) | |
158 return transport_->Read(buf, buf_len, callback); | |
159 | |
160 if (buf_len < buffer_size_) | |
161 return net::ERR_UNEXPECTED; | |
162 | |
163 state_ = STATE_READ; | |
164 user_read_buf_ = buf; | |
165 int result = DoLoop(net::OK); | |
166 if (result == net::ERR_IO_PENDING) | |
167 user_read_callback_ = callback; | |
168 else | |
169 user_read_buf_ = NULL; | |
170 return result; | |
171 } | |
172 | |
173 int ReadBufferingStreamSocket::DoLoop(int result) { | |
174 int rv = result; | |
175 do { | |
176 State current_state = state_; | |
177 state_ = STATE_NONE; | |
178 switch (current_state) { | |
179 case STATE_READ: | |
180 rv = DoRead(); | |
181 break; | |
182 case STATE_READ_COMPLETE: | |
183 rv = DoReadComplete(rv); | |
184 break; | |
185 case STATE_NONE: | |
186 default: | |
187 NOTREACHED() << "Unexpected state: " << current_state; | |
188 rv = net::ERR_UNEXPECTED; | |
189 break; | |
190 } | |
191 } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE); | |
192 return rv; | |
193 } | |
194 | |
195 int ReadBufferingStreamSocket::DoRead() { | |
196 state_ = STATE_READ_COMPLETE; | |
197 int rv = transport_->Read( | |
198 read_buffer_, | |
199 read_buffer_->RemainingCapacity(), | |
200 base::Bind(&ReadBufferingStreamSocket::OnReadCompleted, | |
201 base::Unretained(this))); | |
202 return rv; | |
203 } | |
204 | |
205 int ReadBufferingStreamSocket::DoReadComplete(int result) { | |
206 state_ = STATE_NONE; | |
207 if (result <= 0) | |
208 return result; | |
209 | |
210 read_buffer_->set_offset(read_buffer_->offset() + result); | |
211 if (read_buffer_->RemainingCapacity() > 0) { | |
212 state_ = STATE_READ; | |
213 return net::OK; | |
214 } | |
215 | |
216 memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(), | |
217 read_buffer_->capacity()); | |
218 read_buffer_->set_offset(0); | |
219 return read_buffer_->capacity(); | |
220 } | |
221 | |
222 void ReadBufferingStreamSocket::OnReadCompleted(int result) { | |
223 result = DoLoop(result); | |
224 if (result == net::ERR_IO_PENDING) | |
225 return; | |
226 | |
227 user_read_buf_ = NULL; | |
228 base::ResetAndReturn(&user_read_callback_).Run(result); | |
229 } | |
230 | |
231 } // namespace | |
232 | |
33 class SSLClientSocketTest : public PlatformTest { | 233 class SSLClientSocketTest : public PlatformTest { |
34 public: | 234 public: |
35 SSLClientSocketTest() | 235 SSLClientSocketTest() |
36 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), | 236 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), |
37 cert_verifier_(new net::MockCertVerifier) { | 237 cert_verifier_(new net::MockCertVerifier) { |
38 cert_verifier_->set_default_result(net::OK); | 238 cert_verifier_->set_default_result(net::OK); |
39 context_.cert_verifier = cert_verifier_.get(); | 239 context_.cert_verifier = cert_verifier_.get(); |
40 } | 240 } |
41 | 241 |
42 protected: | 242 protected: |
(...skipping 449 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
492 | 692 |
493 if (rv == net::ERR_IO_PENDING) | 693 if (rv == net::ERR_IO_PENDING) |
494 rv = callback.WaitForResult(); | 694 rv = callback.WaitForResult(); |
495 | 695 |
496 EXPECT_GE(rv, 0); | 696 EXPECT_GE(rv, 0); |
497 if (rv <= 0) | 697 if (rv <= 0) |
498 break; | 698 break; |
499 } | 699 } |
500 } | 700 } |
501 | 701 |
702 TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { | |
703 net::TestServer test_server(net::TestServer::TYPE_HTTPS, | |
Ryan Hamilton
2013/02/13 17:23:18
as you said, the proliferation of net:: is ... in
| |
704 net::TestServer::kLocalhost, | |
705 base::FilePath()); | |
706 ASSERT_TRUE(test_server.Start()); | |
707 | |
708 net::AddressList addr; | |
709 ASSERT_TRUE(test_server.GetAddressList(&addr)); | |
710 | |
711 net::TestCompletionCallback callback; | |
712 | |
713 scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( | |
714 addr, NULL, net::NetLog::Source())); | |
715 ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket( | |
716 real_transport.Pass()); | |
717 int rv = callback.GetResult(transport->Connect(callback.callback())); | |
718 ASSERT_EQ(net::OK, rv); | |
719 | |
720 scoped_ptr<net::SSLClientSocket> sock( | |
721 CreateSSLClientSocket(transport, test_server.host_port_pair(), | |
722 kDefaultSSLConfig)); | |
723 | |
724 rv = callback.GetResult(sock->Connect(callback.callback())); | |
725 ASSERT_EQ(net::OK, rv); | |
726 ASSERT_TRUE(sock->IsConnected()); | |
727 | |
728 const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; | |
729 scoped_refptr<net::IOBuffer> request_buffer( | |
730 new net::IOBuffer(arraysize(request_text) - 1)); | |
731 memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); | |
732 | |
733 rv = callback.GetResult(sock->Write( | |
734 request_buffer, arraysize(request_text) - 1, callback.callback())); | |
735 ASSERT_GT(rv, 0); | |
736 ASSERT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); | |
737 | |
738 // Note: This relies on SSLClientSocketNSS attempting to read up to 17K of | |
739 // data (the max SSL record size) at a time. This buffer size must be larger | |
740 // than the IOBuffer below, in order to ensure at least as many records | |
741 // needed to satisfy the Read request are internally buffered. | |
742 // This also relies on the TestServer writing records on 1350 byte | |
743 // boundaries. | |
744 transport->SetBufferSize(15000); | |
Ryan Hamilton
2013/02/13 17:23:18
SO this line means that the transport wrapper sock
Ryan Sleevi
2013/02/13 20:44:43
Yes, it means at least 15K must be read/buffered b
Ryan Hamilton
2013/02/13 21:50:34
Sounds great.
| |
745 | |
746 scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192)); | |
747 rv = callback.GetResult(sock->Read(buffer, 8192, callback.callback())); | |
748 ASSERT_EQ(rv, 8192); | |
749 } | |
750 | |
502 TEST_F(SSLClientSocketTest, Read_Interrupted) { | 751 TEST_F(SSLClientSocketTest, Read_Interrupted) { |
503 net::TestServer test_server(net::TestServer::TYPE_HTTPS, | 752 net::TestServer test_server(net::TestServer::TYPE_HTTPS, |
504 net::TestServer::kLocalhost, | 753 net::TestServer::kLocalhost, |
505 base::FilePath()); | 754 base::FilePath()); |
506 ASSERT_TRUE(test_server.Start()); | 755 ASSERT_TRUE(test_server.Start()); |
507 | 756 |
508 net::AddressList addr; | 757 net::AddressList addr; |
509 ASSERT_TRUE(test_server.GetAddressList(&addr)); | 758 ASSERT_TRUE(test_server.GetAddressList(&addr)); |
510 | 759 |
511 net::TestCompletionCallback callback; | 760 net::TestCompletionCallback callback; |
(...skipping 148 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
660 scoped_ptr<net::SSLClientSocket> sock( | 909 scoped_ptr<net::SSLClientSocket> sock( |
661 CreateSSLClientSocket(transport, test_server.host_port_pair(), | 910 CreateSSLClientSocket(transport, test_server.host_port_pair(), |
662 kDefaultSSLConfig)); | 911 kDefaultSSLConfig)); |
663 | 912 |
664 rv = sock->Connect(callback.callback()); | 913 rv = sock->Connect(callback.callback()); |
665 if (rv == net::ERR_IO_PENDING) | 914 if (rv == net::ERR_IO_PENDING) |
666 rv = callback.WaitForResult(); | 915 rv = callback.WaitForResult(); |
667 EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); | 916 EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); |
668 } | 917 } |
669 | 918 |
670 // TODO(rsleevi): Not implemented for Schannel. As Schannel is only used when | |
671 // performing client authentication, it will not be tested here. | |
672 TEST_F(SSLClientSocketTest, CipherSuiteDisables) { | 919 TEST_F(SSLClientSocketTest, CipherSuiteDisables) { |
673 // Rather than exhaustively disabling every RC4 ciphersuite defined at | 920 // Rather than exhaustively disabling every RC4 ciphersuite defined at |
674 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, | 921 // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, |
675 // only disabling those cipher suites that the test server actually | 922 // only disabling those cipher suites that the test server actually |
676 // implements. | 923 // implements. |
677 const uint16 kCiphersToDisable[] = { | 924 const uint16 kCiphersToDisable[] = { |
678 0x0005, // TLS_RSA_WITH_RC4_128_SHA | 925 0x0005, // TLS_RSA_WITH_RC4_128_SHA |
679 }; | 926 }; |
680 | 927 |
681 net::TestServer::SSLOptions ssl_options; | 928 net::TestServer::SSLOptions ssl_options; |
(...skipping 358 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
1040 scoped_refptr<net::SSLCertRequestInfo> request_info = | 1287 scoped_refptr<net::SSLCertRequestInfo> request_info = |
1041 GetCertRequest(ssl_options); | 1288 GetCertRequest(ssl_options); |
1042 ASSERT_TRUE(request_info); | 1289 ASSERT_TRUE(request_info); |
1043 ASSERT_EQ(2u, request_info->cert_authorities.size()); | 1290 ASSERT_EQ(2u, request_info->cert_authorities.size()); |
1044 EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), | 1291 EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), |
1045 request_info->cert_authorities[0]); | 1292 request_info->cert_authorities[0]); |
1046 EXPECT_EQ( | 1293 EXPECT_EQ( |
1047 std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), | 1294 std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), |
1048 request_info->cert_authorities[1]); | 1295 request_info->cert_authorities[1]); |
1049 } | 1296 } |
OLD | NEW |