| OLD | NEW |
| (Empty) |
| 1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "net/base/socket_test_util.h" | |
| 6 | |
| 7 #include "base/basictypes.h" | |
| 8 #include "base/compiler_specific.h" | |
| 9 #include "base/message_loop.h" | |
| 10 #include "net/base/io_buffer.h" | |
| 11 #include "net/base/socket.h" | |
| 12 #include "net/base/ssl_client_socket.h" | |
| 13 #include "net/base/ssl_info.h" | |
| 14 #include "testing/gtest/include/gtest/gtest.h" | |
| 15 | |
| 16 namespace { | |
| 17 | |
| 18 class MockClientSocket : public net::SSLClientSocket { | |
| 19 public: | |
| 20 MockClientSocket(); | |
| 21 | |
| 22 // ClientSocket methods: | |
| 23 virtual int Connect(net::CompletionCallback* callback) = 0; | |
| 24 | |
| 25 // SSLClientSocket methods: | |
| 26 virtual void GetSSLInfo(net::SSLInfo* ssl_info); | |
| 27 virtual void GetSSLCertRequestInfo( | |
| 28 net::SSLCertRequestInfo* cert_request_info); | |
| 29 virtual void Disconnect(); | |
| 30 virtual bool IsConnected() const; | |
| 31 virtual bool IsConnectedAndIdle() const; | |
| 32 | |
| 33 // Socket methods: | |
| 34 virtual int Read(net::IOBuffer* buf, int buf_len, | |
| 35 net::CompletionCallback* callback) = 0; | |
| 36 virtual int Write(net::IOBuffer* buf, int buf_len, | |
| 37 net::CompletionCallback* callback) = 0; | |
| 38 | |
| 39 #if defined(OS_LINUX) | |
| 40 virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); | |
| 41 #endif | |
| 42 | |
| 43 protected: | |
| 44 void RunCallbackAsync(net::CompletionCallback* callback, int result); | |
| 45 void RunCallback(int result); | |
| 46 | |
| 47 ScopedRunnableMethodFactory<MockClientSocket> method_factory_; | |
| 48 net::CompletionCallback* callback_; | |
| 49 bool connected_; | |
| 50 }; | |
| 51 | |
| 52 class MockTCPClientSocket : public MockClientSocket { | |
| 53 public: | |
| 54 MockTCPClientSocket(const net::AddressList& addresses, | |
| 55 net::MockSocket* socket); | |
| 56 | |
| 57 // ClientSocket methods: | |
| 58 virtual int Connect(net::CompletionCallback* callback); | |
| 59 | |
| 60 // Socket methods: | |
| 61 virtual int Read(net::IOBuffer* buf, int buf_len, | |
| 62 net::CompletionCallback* callback); | |
| 63 virtual int Write(net::IOBuffer* buf, int buf_len, | |
| 64 net::CompletionCallback* callback); | |
| 65 | |
| 66 private: | |
| 67 net::MockSocket* data_; | |
| 68 int read_offset_; | |
| 69 net::MockRead* read_data_; | |
| 70 bool need_read_data_; | |
| 71 }; | |
| 72 | |
| 73 class MockSSLClientSocket : public MockClientSocket { | |
| 74 public: | |
| 75 MockSSLClientSocket( | |
| 76 net::ClientSocket* transport_socket, | |
| 77 const std::string& hostname, | |
| 78 const net::SSLConfig& ssl_config, | |
| 79 net::MockSSLSocket* socket); | |
| 80 ~MockSSLClientSocket(); | |
| 81 | |
| 82 virtual void GetSSLInfo(net::SSLInfo* ssl_info); | |
| 83 | |
| 84 virtual int Connect(net::CompletionCallback* callback); | |
| 85 virtual void Disconnect(); | |
| 86 | |
| 87 // Socket methods: | |
| 88 virtual int Read(net::IOBuffer* buf, int buf_len, | |
| 89 net::CompletionCallback* callback); | |
| 90 virtual int Write(net::IOBuffer* buf, int buf_len, | |
| 91 net::CompletionCallback* callback); | |
| 92 | |
| 93 private: | |
| 94 class ConnectCallback; | |
| 95 | |
| 96 scoped_ptr<ClientSocket> transport_; | |
| 97 net::MockSSLSocket* data_; | |
| 98 }; | |
| 99 | |
| 100 MockClientSocket::MockClientSocket() | |
| 101 : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), | |
| 102 callback_(NULL), | |
| 103 connected_(false) { | |
| 104 } | |
| 105 | |
| 106 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { | |
| 107 NOTREACHED(); | |
| 108 } | |
| 109 | |
| 110 void MockClientSocket::GetSSLCertRequestInfo( | |
| 111 net::SSLCertRequestInfo* cert_request_info) { | |
| 112 NOTREACHED(); | |
| 113 } | |
| 114 | |
| 115 void MockClientSocket::Disconnect() { | |
| 116 connected_ = false; | |
| 117 callback_ = NULL; | |
| 118 } | |
| 119 | |
| 120 bool MockClientSocket::IsConnected() const { | |
| 121 return connected_; | |
| 122 } | |
| 123 | |
| 124 bool MockClientSocket::IsConnectedAndIdle() const { | |
| 125 return connected_; | |
| 126 } | |
| 127 | |
| 128 #if defined(OS_LINUX) | |
| 129 int MockClientSocket::GetPeerName(struct sockaddr *name, socklen_t *namelen) { | |
| 130 memset(reinterpret_cast<char *>(name), 0, *namelen); | |
| 131 return net::OK; | |
| 132 } | |
| 133 #endif // defined(OS_LINUX) | |
| 134 | |
| 135 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback, | |
| 136 int result) { | |
| 137 callback_ = callback; | |
| 138 MessageLoop::current()->PostTask(FROM_HERE, | |
| 139 method_factory_.NewRunnableMethod( | |
| 140 &MockClientSocket::RunCallback, result)); | |
| 141 } | |
| 142 | |
| 143 void MockClientSocket::RunCallback(int result) { | |
| 144 net::CompletionCallback* c = callback_; | |
| 145 callback_ = NULL; | |
| 146 if (c) | |
| 147 c->Run(result); | |
| 148 } | |
| 149 | |
| 150 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, | |
| 151 net::MockSocket* socket) | |
| 152 : data_(socket), | |
| 153 read_offset_(0), | |
| 154 read_data_(NULL), | |
| 155 need_read_data_(true) { | |
| 156 DCHECK(data_); | |
| 157 data_->Reset(); | |
| 158 } | |
| 159 | |
| 160 int MockTCPClientSocket::Connect(net::CompletionCallback* callback) { | |
| 161 DCHECK(!callback_); | |
| 162 if (connected_) | |
| 163 return net::OK; | |
| 164 connected_ = true; | |
| 165 if (data_->connect_data().async) { | |
| 166 RunCallbackAsync(callback, data_->connect_data().result); | |
| 167 return net::ERR_IO_PENDING; | |
| 168 } | |
| 169 return data_->connect_data().result; | |
| 170 } | |
| 171 | |
| 172 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, | |
| 173 net::CompletionCallback* callback) { | |
| 174 DCHECK(!callback_); | |
| 175 if (need_read_data_) { | |
| 176 read_data_ = data_->GetNextRead(); | |
| 177 need_read_data_ = false; | |
| 178 } | |
| 179 int result = read_data_->result; | |
| 180 if (read_data_->data) { | |
| 181 if (read_data_->data_len - read_offset_ > 0) { | |
| 182 result = std::min(buf_len, read_data_->data_len - read_offset_); | |
| 183 memcpy(buf->data(), read_data_->data + read_offset_, result); | |
| 184 read_offset_ += result; | |
| 185 if (read_offset_ == read_data_->data_len) { | |
| 186 need_read_data_ = true; | |
| 187 read_offset_ = 0; | |
| 188 } | |
| 189 } else { | |
| 190 result = 0; // EOF | |
| 191 } | |
| 192 } | |
| 193 if (read_data_->async) { | |
| 194 RunCallbackAsync(callback, result); | |
| 195 return net::ERR_IO_PENDING; | |
| 196 } | |
| 197 return result; | |
| 198 } | |
| 199 | |
| 200 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, | |
| 201 net::CompletionCallback* callback) { | |
| 202 DCHECK(buf); | |
| 203 DCHECK(buf_len > 0); | |
| 204 DCHECK(!callback_); | |
| 205 | |
| 206 std::string data(buf->data(), buf_len); | |
| 207 net::MockWriteResult write_result = data_->OnWrite(data); | |
| 208 | |
| 209 if (write_result.async) { | |
| 210 RunCallbackAsync(callback, write_result.result); | |
| 211 return net::ERR_IO_PENDING; | |
| 212 } | |
| 213 return write_result.result; | |
| 214 } | |
| 215 | |
| 216 class MockSSLClientSocket::ConnectCallback : | |
| 217 public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { | |
| 218 public: | |
| 219 ConnectCallback(MockSSLClientSocket *ssl_client_socket, | |
| 220 net::CompletionCallback* user_callback, | |
| 221 int rv) | |
| 222 : ALLOW_THIS_IN_INITIALIZER_LIST( | |
| 223 net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( | |
| 224 this, &ConnectCallback::Wrapper)), | |
| 225 ssl_client_socket_(ssl_client_socket), | |
| 226 user_callback_(user_callback), | |
| 227 rv_(rv) { | |
| 228 } | |
| 229 | |
| 230 private: | |
| 231 void Wrapper(int rv) { | |
| 232 if (rv_ == net::OK) | |
| 233 ssl_client_socket_->connected_ = true; | |
| 234 user_callback_->Run(rv_); | |
| 235 delete this; | |
| 236 } | |
| 237 | |
| 238 MockSSLClientSocket* ssl_client_socket_; | |
| 239 net::CompletionCallback* user_callback_; | |
| 240 int rv_; | |
| 241 }; | |
| 242 | |
| 243 MockSSLClientSocket::MockSSLClientSocket( | |
| 244 net::ClientSocket* transport_socket, | |
| 245 const std::string& hostname, | |
| 246 const net::SSLConfig& ssl_config, | |
| 247 net::MockSSLSocket* socket) | |
| 248 : transport_(transport_socket), | |
| 249 data_(socket) { | |
| 250 DCHECK(data_); | |
| 251 } | |
| 252 | |
| 253 MockSSLClientSocket::~MockSSLClientSocket() { | |
| 254 Disconnect(); | |
| 255 } | |
| 256 | |
| 257 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { | |
| 258 ssl_info->Reset(); | |
| 259 } | |
| 260 | |
| 261 int MockSSLClientSocket::Connect(net::CompletionCallback* callback) { | |
| 262 DCHECK(!callback_); | |
| 263 ConnectCallback* connect_callback = new ConnectCallback( | |
| 264 this, callback, data_->connect.result); | |
| 265 int rv = transport_->Connect(connect_callback); | |
| 266 if (rv == net::OK) { | |
| 267 delete connect_callback; | |
| 268 if (data_->connect.async) { | |
| 269 RunCallbackAsync(callback, data_->connect.result); | |
| 270 return net::ERR_IO_PENDING; | |
| 271 } | |
| 272 if (data_->connect.result == net::OK) | |
| 273 connected_ = true; | |
| 274 return data_->connect.result; | |
| 275 } | |
| 276 return rv; | |
| 277 } | |
| 278 | |
| 279 void MockSSLClientSocket::Disconnect() { | |
| 280 MockClientSocket::Disconnect(); | |
| 281 if (transport_ != NULL) | |
| 282 transport_->Disconnect(); | |
| 283 } | |
| 284 | |
| 285 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, | |
| 286 net::CompletionCallback* callback) { | |
| 287 DCHECK(!callback_); | |
| 288 return transport_->Read(buf, buf_len, callback); | |
| 289 } | |
| 290 | |
| 291 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, | |
| 292 net::CompletionCallback* callback) { | |
| 293 DCHECK(!callback_); | |
| 294 return transport_->Write(buf, buf_len, callback); | |
| 295 } | |
| 296 | |
| 297 } // namespace | |
| 298 | |
| 299 namespace net { | |
| 300 | |
| 301 MockRead* StaticMockSocket::GetNextRead() { | |
| 302 return &reads_[read_index_++]; | |
| 303 } | |
| 304 | |
| 305 MockWriteResult StaticMockSocket::OnWrite(const std::string& data) { | |
| 306 if (!writes_) { | |
| 307 // Not using mock writes; succeed synchronously. | |
| 308 return MockWriteResult(false, data.length()); | |
| 309 } | |
| 310 | |
| 311 // Check that what we are writing matches the expectation. | |
| 312 // Then give the mocked return value. | |
| 313 net::MockWrite* w = &writes_[write_index_++]; | |
| 314 int result = w->result; | |
| 315 if (w->data) { | |
| 316 std::string expected_data(w->data, w->data_len); | |
| 317 EXPECT_EQ(expected_data, data); | |
| 318 if (expected_data != data) | |
| 319 return MockWriteResult(false, net::ERR_UNEXPECTED); | |
| 320 if (result == net::OK) | |
| 321 result = w->data_len; | |
| 322 } | |
| 323 return MockWriteResult(w->async, result); | |
| 324 } | |
| 325 | |
| 326 void StaticMockSocket::Reset() { | |
| 327 read_index_ = 0; | |
| 328 write_index_ = 0; | |
| 329 } | |
| 330 | |
| 331 DynamicMockSocket::DynamicMockSocket() | |
| 332 : read_(false, ERR_UNEXPECTED), | |
| 333 has_read_(false) { | |
| 334 } | |
| 335 | |
| 336 MockRead* DynamicMockSocket::GetNextRead() { | |
| 337 if (!has_read_) | |
| 338 return unexpected_read(); | |
| 339 has_read_ = false; | |
| 340 return &read_; | |
| 341 } | |
| 342 | |
| 343 void DynamicMockSocket::Reset() { | |
| 344 has_read_ = false; | |
| 345 } | |
| 346 | |
| 347 void DynamicMockSocket::SimulateRead(const char* data) { | |
| 348 EXPECT_FALSE(has_read_) << "Unconsumed read: " << read_.data; | |
| 349 read_ = MockRead(data); | |
| 350 has_read_ = true; | |
| 351 } | |
| 352 | |
| 353 void MockClientSocketFactory::AddMockSocket(MockSocket* socket) { | |
| 354 mock_sockets_.Add(socket); | |
| 355 } | |
| 356 | |
| 357 void MockClientSocketFactory::AddMockSSLSocket(MockSSLSocket* socket) { | |
| 358 mock_ssl_sockets_.Add(socket); | |
| 359 } | |
| 360 | |
| 361 void MockClientSocketFactory::ResetNextMockIndexes() { | |
| 362 mock_sockets_.ResetNextIndex(); | |
| 363 mock_ssl_sockets_.ResetNextIndex(); | |
| 364 } | |
| 365 | |
| 366 ClientSocket* MockClientSocketFactory::CreateTCPClientSocket( | |
| 367 const AddressList& addresses) { | |
| 368 return new MockTCPClientSocket(addresses, mock_sockets_.GetNext()); | |
| 369 } | |
| 370 | |
| 371 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( | |
| 372 ClientSocket* transport_socket, | |
| 373 const std::string& hostname, | |
| 374 const SSLConfig& ssl_config) { | |
| 375 return new MockSSLClientSocket(transport_socket, hostname, ssl_config, | |
| 376 mock_ssl_sockets_.GetNext()); | |
| 377 } | |
| 378 | |
| 379 } // namespace net | |
| OLD | NEW |