OLD | NEW |
| (Empty) |
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "net/base/socket_test_util.h" | |
6 | |
7 #include "base/basictypes.h" | |
8 #include "base/compiler_specific.h" | |
9 #include "base/message_loop.h" | |
10 #include "net/base/io_buffer.h" | |
11 #include "net/base/socket.h" | |
12 #include "net/base/ssl_client_socket.h" | |
13 #include "net/base/ssl_info.h" | |
14 #include "testing/gtest/include/gtest/gtest.h" | |
15 | |
16 namespace { | |
17 | |
18 class MockClientSocket : public net::SSLClientSocket { | |
19 public: | |
20 MockClientSocket(); | |
21 | |
22 // ClientSocket methods: | |
23 virtual int Connect(net::CompletionCallback* callback) = 0; | |
24 | |
25 // SSLClientSocket methods: | |
26 virtual void GetSSLInfo(net::SSLInfo* ssl_info); | |
27 virtual void GetSSLCertRequestInfo( | |
28 net::SSLCertRequestInfo* cert_request_info); | |
29 virtual void Disconnect(); | |
30 virtual bool IsConnected() const; | |
31 virtual bool IsConnectedAndIdle() const; | |
32 | |
33 // Socket methods: | |
34 virtual int Read(net::IOBuffer* buf, int buf_len, | |
35 net::CompletionCallback* callback) = 0; | |
36 virtual int Write(net::IOBuffer* buf, int buf_len, | |
37 net::CompletionCallback* callback) = 0; | |
38 | |
39 #if defined(OS_LINUX) | |
40 virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen); | |
41 #endif | |
42 | |
43 protected: | |
44 void RunCallbackAsync(net::CompletionCallback* callback, int result); | |
45 void RunCallback(int result); | |
46 | |
47 ScopedRunnableMethodFactory<MockClientSocket> method_factory_; | |
48 net::CompletionCallback* callback_; | |
49 bool connected_; | |
50 }; | |
51 | |
52 class MockTCPClientSocket : public MockClientSocket { | |
53 public: | |
54 MockTCPClientSocket(const net::AddressList& addresses, | |
55 net::MockSocket* socket); | |
56 | |
57 // ClientSocket methods: | |
58 virtual int Connect(net::CompletionCallback* callback); | |
59 | |
60 // Socket methods: | |
61 virtual int Read(net::IOBuffer* buf, int buf_len, | |
62 net::CompletionCallback* callback); | |
63 virtual int Write(net::IOBuffer* buf, int buf_len, | |
64 net::CompletionCallback* callback); | |
65 | |
66 private: | |
67 net::MockSocket* data_; | |
68 int read_offset_; | |
69 net::MockRead* read_data_; | |
70 bool need_read_data_; | |
71 }; | |
72 | |
73 class MockSSLClientSocket : public MockClientSocket { | |
74 public: | |
75 MockSSLClientSocket( | |
76 net::ClientSocket* transport_socket, | |
77 const std::string& hostname, | |
78 const net::SSLConfig& ssl_config, | |
79 net::MockSSLSocket* socket); | |
80 ~MockSSLClientSocket(); | |
81 | |
82 virtual void GetSSLInfo(net::SSLInfo* ssl_info); | |
83 | |
84 virtual int Connect(net::CompletionCallback* callback); | |
85 virtual void Disconnect(); | |
86 | |
87 // Socket methods: | |
88 virtual int Read(net::IOBuffer* buf, int buf_len, | |
89 net::CompletionCallback* callback); | |
90 virtual int Write(net::IOBuffer* buf, int buf_len, | |
91 net::CompletionCallback* callback); | |
92 | |
93 private: | |
94 class ConnectCallback; | |
95 | |
96 scoped_ptr<ClientSocket> transport_; | |
97 net::MockSSLSocket* data_; | |
98 }; | |
99 | |
100 MockClientSocket::MockClientSocket() | |
101 : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), | |
102 callback_(NULL), | |
103 connected_(false) { | |
104 } | |
105 | |
106 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { | |
107 NOTREACHED(); | |
108 } | |
109 | |
110 void MockClientSocket::GetSSLCertRequestInfo( | |
111 net::SSLCertRequestInfo* cert_request_info) { | |
112 NOTREACHED(); | |
113 } | |
114 | |
115 void MockClientSocket::Disconnect() { | |
116 connected_ = false; | |
117 callback_ = NULL; | |
118 } | |
119 | |
120 bool MockClientSocket::IsConnected() const { | |
121 return connected_; | |
122 } | |
123 | |
124 bool MockClientSocket::IsConnectedAndIdle() const { | |
125 return connected_; | |
126 } | |
127 | |
128 #if defined(OS_LINUX) | |
129 int MockClientSocket::GetPeerName(struct sockaddr *name, socklen_t *namelen) { | |
130 memset(reinterpret_cast<char *>(name), 0, *namelen); | |
131 return net::OK; | |
132 } | |
133 #endif // defined(OS_LINUX) | |
134 | |
135 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback, | |
136 int result) { | |
137 callback_ = callback; | |
138 MessageLoop::current()->PostTask(FROM_HERE, | |
139 method_factory_.NewRunnableMethod( | |
140 &MockClientSocket::RunCallback, result)); | |
141 } | |
142 | |
143 void MockClientSocket::RunCallback(int result) { | |
144 net::CompletionCallback* c = callback_; | |
145 callback_ = NULL; | |
146 if (c) | |
147 c->Run(result); | |
148 } | |
149 | |
150 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, | |
151 net::MockSocket* socket) | |
152 : data_(socket), | |
153 read_offset_(0), | |
154 read_data_(NULL), | |
155 need_read_data_(true) { | |
156 DCHECK(data_); | |
157 data_->Reset(); | |
158 } | |
159 | |
160 int MockTCPClientSocket::Connect(net::CompletionCallback* callback) { | |
161 DCHECK(!callback_); | |
162 if (connected_) | |
163 return net::OK; | |
164 connected_ = true; | |
165 if (data_->connect_data().async) { | |
166 RunCallbackAsync(callback, data_->connect_data().result); | |
167 return net::ERR_IO_PENDING; | |
168 } | |
169 return data_->connect_data().result; | |
170 } | |
171 | |
172 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, | |
173 net::CompletionCallback* callback) { | |
174 DCHECK(!callback_); | |
175 if (need_read_data_) { | |
176 read_data_ = data_->GetNextRead(); | |
177 need_read_data_ = false; | |
178 } | |
179 int result = read_data_->result; | |
180 if (read_data_->data) { | |
181 if (read_data_->data_len - read_offset_ > 0) { | |
182 result = std::min(buf_len, read_data_->data_len - read_offset_); | |
183 memcpy(buf->data(), read_data_->data + read_offset_, result); | |
184 read_offset_ += result; | |
185 if (read_offset_ == read_data_->data_len) { | |
186 need_read_data_ = true; | |
187 read_offset_ = 0; | |
188 } | |
189 } else { | |
190 result = 0; // EOF | |
191 } | |
192 } | |
193 if (read_data_->async) { | |
194 RunCallbackAsync(callback, result); | |
195 return net::ERR_IO_PENDING; | |
196 } | |
197 return result; | |
198 } | |
199 | |
200 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, | |
201 net::CompletionCallback* callback) { | |
202 DCHECK(buf); | |
203 DCHECK(buf_len > 0); | |
204 DCHECK(!callback_); | |
205 | |
206 std::string data(buf->data(), buf_len); | |
207 net::MockWriteResult write_result = data_->OnWrite(data); | |
208 | |
209 if (write_result.async) { | |
210 RunCallbackAsync(callback, write_result.result); | |
211 return net::ERR_IO_PENDING; | |
212 } | |
213 return write_result.result; | |
214 } | |
215 | |
216 class MockSSLClientSocket::ConnectCallback : | |
217 public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { | |
218 public: | |
219 ConnectCallback(MockSSLClientSocket *ssl_client_socket, | |
220 net::CompletionCallback* user_callback, | |
221 int rv) | |
222 : ALLOW_THIS_IN_INITIALIZER_LIST( | |
223 net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( | |
224 this, &ConnectCallback::Wrapper)), | |
225 ssl_client_socket_(ssl_client_socket), | |
226 user_callback_(user_callback), | |
227 rv_(rv) { | |
228 } | |
229 | |
230 private: | |
231 void Wrapper(int rv) { | |
232 if (rv_ == net::OK) | |
233 ssl_client_socket_->connected_ = true; | |
234 user_callback_->Run(rv_); | |
235 delete this; | |
236 } | |
237 | |
238 MockSSLClientSocket* ssl_client_socket_; | |
239 net::CompletionCallback* user_callback_; | |
240 int rv_; | |
241 }; | |
242 | |
243 MockSSLClientSocket::MockSSLClientSocket( | |
244 net::ClientSocket* transport_socket, | |
245 const std::string& hostname, | |
246 const net::SSLConfig& ssl_config, | |
247 net::MockSSLSocket* socket) | |
248 : transport_(transport_socket), | |
249 data_(socket) { | |
250 DCHECK(data_); | |
251 } | |
252 | |
253 MockSSLClientSocket::~MockSSLClientSocket() { | |
254 Disconnect(); | |
255 } | |
256 | |
257 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { | |
258 ssl_info->Reset(); | |
259 } | |
260 | |
261 int MockSSLClientSocket::Connect(net::CompletionCallback* callback) { | |
262 DCHECK(!callback_); | |
263 ConnectCallback* connect_callback = new ConnectCallback( | |
264 this, callback, data_->connect.result); | |
265 int rv = transport_->Connect(connect_callback); | |
266 if (rv == net::OK) { | |
267 delete connect_callback; | |
268 if (data_->connect.async) { | |
269 RunCallbackAsync(callback, data_->connect.result); | |
270 return net::ERR_IO_PENDING; | |
271 } | |
272 if (data_->connect.result == net::OK) | |
273 connected_ = true; | |
274 return data_->connect.result; | |
275 } | |
276 return rv; | |
277 } | |
278 | |
279 void MockSSLClientSocket::Disconnect() { | |
280 MockClientSocket::Disconnect(); | |
281 if (transport_ != NULL) | |
282 transport_->Disconnect(); | |
283 } | |
284 | |
285 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, | |
286 net::CompletionCallback* callback) { | |
287 DCHECK(!callback_); | |
288 return transport_->Read(buf, buf_len, callback); | |
289 } | |
290 | |
291 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, | |
292 net::CompletionCallback* callback) { | |
293 DCHECK(!callback_); | |
294 return transport_->Write(buf, buf_len, callback); | |
295 } | |
296 | |
297 } // namespace | |
298 | |
299 namespace net { | |
300 | |
301 MockRead* StaticMockSocket::GetNextRead() { | |
302 return &reads_[read_index_++]; | |
303 } | |
304 | |
305 MockWriteResult StaticMockSocket::OnWrite(const std::string& data) { | |
306 if (!writes_) { | |
307 // Not using mock writes; succeed synchronously. | |
308 return MockWriteResult(false, data.length()); | |
309 } | |
310 | |
311 // Check that what we are writing matches the expectation. | |
312 // Then give the mocked return value. | |
313 net::MockWrite* w = &writes_[write_index_++]; | |
314 int result = w->result; | |
315 if (w->data) { | |
316 std::string expected_data(w->data, w->data_len); | |
317 EXPECT_EQ(expected_data, data); | |
318 if (expected_data != data) | |
319 return MockWriteResult(false, net::ERR_UNEXPECTED); | |
320 if (result == net::OK) | |
321 result = w->data_len; | |
322 } | |
323 return MockWriteResult(w->async, result); | |
324 } | |
325 | |
326 void StaticMockSocket::Reset() { | |
327 read_index_ = 0; | |
328 write_index_ = 0; | |
329 } | |
330 | |
331 DynamicMockSocket::DynamicMockSocket() | |
332 : read_(false, ERR_UNEXPECTED), | |
333 has_read_(false) { | |
334 } | |
335 | |
336 MockRead* DynamicMockSocket::GetNextRead() { | |
337 if (!has_read_) | |
338 return unexpected_read(); | |
339 has_read_ = false; | |
340 return &read_; | |
341 } | |
342 | |
343 void DynamicMockSocket::Reset() { | |
344 has_read_ = false; | |
345 } | |
346 | |
347 void DynamicMockSocket::SimulateRead(const char* data) { | |
348 EXPECT_FALSE(has_read_) << "Unconsumed read: " << read_.data; | |
349 read_ = MockRead(data); | |
350 has_read_ = true; | |
351 } | |
352 | |
353 void MockClientSocketFactory::AddMockSocket(MockSocket* socket) { | |
354 mock_sockets_.Add(socket); | |
355 } | |
356 | |
357 void MockClientSocketFactory::AddMockSSLSocket(MockSSLSocket* socket) { | |
358 mock_ssl_sockets_.Add(socket); | |
359 } | |
360 | |
361 void MockClientSocketFactory::ResetNextMockIndexes() { | |
362 mock_sockets_.ResetNextIndex(); | |
363 mock_ssl_sockets_.ResetNextIndex(); | |
364 } | |
365 | |
366 ClientSocket* MockClientSocketFactory::CreateTCPClientSocket( | |
367 const AddressList& addresses) { | |
368 return new MockTCPClientSocket(addresses, mock_sockets_.GetNext()); | |
369 } | |
370 | |
371 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( | |
372 ClientSocket* transport_socket, | |
373 const std::string& hostname, | |
374 const SSLConfig& ssl_config) { | |
375 return new MockSSLClientSocket(transport_socket, hostname, ssl_config, | |
376 mock_ssl_sockets_.GetNext()); | |
377 } | |
378 | |
379 } // namespace net | |
OLD | NEW |