OLD | NEW |
1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ | 5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ |
6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ | 6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ |
7 | 7 |
8 #include <deque> | 8 #include <deque> |
9 #include <string> | 9 #include <string> |
10 #include <vector> | 10 #include <vector> |
(...skipping 56 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
67 // {async, result}. | 67 // {async, result}. |
68 typedef MockRead MockWrite; | 68 typedef MockRead MockWrite; |
69 | 69 |
70 struct MockWriteResult { | 70 struct MockWriteResult { |
71 MockWriteResult(bool async, int result) : async(async), result(result) {} | 71 MockWriteResult(bool async, int result) : async(async), result(result) {} |
72 | 72 |
73 bool async; | 73 bool async; |
74 int result; | 74 int result; |
75 }; | 75 }; |
76 | 76 |
77 class MockSocket { | 77 // The SocketDataProvider is an interface used by the MockClientSocket |
| 78 // for getting data about individual reads and writes on the socket. |
| 79 class SocketDataProvider { |
78 public: | 80 public: |
79 MockSocket() {} | 81 SocketDataProvider() {} |
80 | 82 |
81 virtual ~MockSocket() {} | 83 virtual ~SocketDataProvider() {} |
82 virtual MockRead GetNextRead() = 0; | 84 virtual MockRead GetNextRead() = 0; |
83 virtual MockWriteResult OnWrite(const std::string& data) = 0; | 85 virtual MockWriteResult OnWrite(const std::string& data) = 0; |
84 virtual void Reset() = 0; | 86 virtual void Reset() = 0; |
85 | 87 |
86 MockConnect connect_data() const { return connect_; } | 88 MockConnect connect_data() const { return connect_; } |
87 | 89 |
88 private: | 90 private: |
89 MockConnect connect_; | 91 MockConnect connect_; |
90 | 92 |
91 DISALLOW_COPY_AND_ASSIGN(MockSocket); | 93 DISALLOW_COPY_AND_ASSIGN(SocketDataProvider); |
92 }; | 94 }; |
93 | 95 |
94 // MockSocket which responds based on static tables of mock reads and writes. | 96 // SocketDataProvider which responds based on static tables of mock reads and |
95 class StaticMockSocket : public MockSocket { | 97 // writes. |
| 98 class StaticSocketDataProvider : public SocketDataProvider { |
96 public: | 99 public: |
97 StaticMockSocket() : reads_(NULL), read_index_(0), | 100 StaticSocketDataProvider() : reads_(NULL), read_index_(0), |
98 writes_(NULL), write_index_(0) {} | 101 writes_(NULL), write_index_(0) {} |
99 StaticMockSocket(MockRead* r, MockWrite* w) : reads_(r), read_index_(0), | 102 StaticSocketDataProvider(MockRead* r, MockWrite* w) : reads_(r), |
100 writes_(w), write_index_(0) {} | 103 read_index_(0), writes_(w), write_index_(0) {} |
101 | 104 |
102 // MockSocket methods: | 105 // SocketDataProvider methods: |
103 virtual MockRead GetNextRead(); | 106 virtual MockRead GetNextRead(); |
104 virtual MockWriteResult OnWrite(const std::string& data); | 107 virtual MockWriteResult OnWrite(const std::string& data); |
105 virtual void Reset(); | 108 virtual void Reset(); |
106 | 109 |
107 // If the test wishes to verify that all data is consumed, it can include | 110 // If the test wishes to verify that all data is consumed, it can include |
108 // a EOF MockRead or MockWrite, which is a zero-length Read or Write. | 111 // a EOF MockRead or MockWrite, which is a zero-length Read or Write. |
109 // The test can then call at_read_eof() or at_write_eof() to verify that | 112 // The test can then call at_read_eof() or at_write_eof() to verify that |
110 // all data has been consumed. | 113 // all data has been consumed. |
111 bool at_read_eof() const { return reads_[read_index_].data_len == 0; } | 114 bool at_read_eof() const { return reads_[read_index_].data_len == 0; } |
112 bool at_write_eof() const { return writes_[write_index_].data_len == 0; } | 115 bool at_write_eof() const { return writes_[write_index_].data_len == 0; } |
113 | 116 |
114 private: | 117 private: |
115 MockRead* reads_; | 118 MockRead* reads_; |
116 int read_index_; | 119 int read_index_; |
117 MockWrite* writes_; | 120 MockWrite* writes_; |
118 int write_index_; | 121 int write_index_; |
119 | 122 |
120 DISALLOW_COPY_AND_ASSIGN(StaticMockSocket); | 123 DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider); |
121 }; | 124 }; |
122 | 125 |
123 // MockSocket which can make decisions about next mock reads based on | 126 // SocketDataProvider which can make decisions about next mock reads based on |
124 // received writes. It can also be used to enforce order of operations, | 127 // received writes. It can also be used to enforce order of operations, for |
125 // for example that tested code must send the "Hello!" message before | 128 // example that tested code must send the "Hello!" message before receiving |
126 // receiving response. This is useful for testing conversation-like | 129 // response. This is useful for testing conversation-like protocols like FTP. |
127 // protocols like FTP. | 130 class DynamicSocketDataProvider : public SocketDataProvider { |
128 class DynamicMockSocket : public MockSocket { | |
129 public: | 131 public: |
130 DynamicMockSocket(); | 132 DynamicSocketDataProvider(); |
131 | 133 |
132 // MockSocket methods: | 134 // SocketDataProvider methods: |
133 virtual MockRead GetNextRead(); | 135 virtual MockRead GetNextRead(); |
134 virtual MockWriteResult OnWrite(const std::string& data) = 0; | 136 virtual MockWriteResult OnWrite(const std::string& data) = 0; |
135 virtual void Reset(); | 137 virtual void Reset(); |
136 | 138 |
137 int short_read_limit() const { return short_read_limit_; } | 139 int short_read_limit() const { return short_read_limit_; } |
138 void set_short_read_limit(int limit) { short_read_limit_ = limit; } | 140 void set_short_read_limit(int limit) { short_read_limit_ = limit; } |
139 | 141 |
140 void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } | 142 void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } |
141 | 143 |
142 protected: | 144 protected: |
143 // The next time there is a read from this socket, it will return |data|. | 145 // The next time there is a read from this socket, it will return |data|. |
144 // Before calling SimulateRead next time, the previous data must be consumed. | 146 // Before calling SimulateRead next time, the previous data must be consumed. |
145 void SimulateRead(const char* data); | 147 void SimulateRead(const char* data); |
146 | 148 |
147 private: | 149 private: |
148 std::deque<MockRead> reads_; | 150 std::deque<MockRead> reads_; |
149 | 151 |
150 // Max number of bytes we will read at a time. 0 means no limit. | 152 // Max number of bytes we will read at a time. 0 means no limit. |
151 int short_read_limit_; | 153 int short_read_limit_; |
152 | 154 |
153 // If true, we'll not require the client to consume all data before we | 155 // If true, we'll not require the client to consume all data before we |
154 // mock the next read. | 156 // mock the next read. |
155 bool allow_unconsumed_reads_; | 157 bool allow_unconsumed_reads_; |
156 | 158 |
157 DISALLOW_COPY_AND_ASSIGN(DynamicMockSocket); | 159 DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider); |
158 }; | 160 }; |
159 | 161 |
160 // MockSSLSockets only need to keep track of the return code from calls to | 162 // SSLSocketDataProviders only need to keep track of the return code from calls |
161 // Connect(). | 163 // to Connect(). |
162 struct MockSSLSocket { | 164 struct SSLSocketDataProvider { |
163 MockSSLSocket(bool async, int result) : connect(async, result) { } | 165 SSLSocketDataProvider(bool async, int result) : connect(async, result) { } |
164 | 166 |
165 MockConnect connect; | 167 MockConnect connect; |
166 }; | 168 }; |
167 | 169 |
168 // Holds an array of Mock{SSL,}Socket elements. As Mock{TCP,SSL}ClientSocket | 170 // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}ClientSocket |
169 // objects get instantiated, they take their data from the i'th element of this | 171 // objects get instantiated, they take their data from the i'th element of this |
170 // array. | 172 // array. |
171 template<typename T> | 173 template<typename T> |
172 class MockSocketArray { | 174 class SocketDataProviderArray { |
173 public: | 175 public: |
174 MockSocketArray() : next_index_(0) { | 176 SocketDataProviderArray() : next_index_(0) { |
175 } | 177 } |
176 | 178 |
177 T* GetNext() { | 179 T* GetNext() { |
178 DCHECK(next_index_ < sockets_.size()); | 180 DCHECK(next_index_ < data_providers_.size()); |
179 return sockets_[next_index_++]; | 181 return data_providers_[next_index_++]; |
180 } | 182 } |
181 | 183 |
182 void Add(T* socket) { | 184 void Add(T* data_provider) { |
183 DCHECK(socket); | 185 DCHECK(data_provider); |
184 sockets_.push_back(socket); | 186 data_providers_.push_back(data_provider); |
185 } | 187 } |
186 | 188 |
187 void ResetNextIndex() { | 189 void ResetNextIndex() { |
188 next_index_ = 0; | 190 next_index_ = 0; |
189 } | 191 } |
190 | 192 |
191 private: | 193 private: |
192 // Index of the next |sockets| element to use. Not an iterator because those | 194 // Index of the next |data_providers_| element to use. Not an iterator |
193 // are invalidated on vector reallocation. | 195 // because those are invalidated on vector reallocation. |
194 size_t next_index_; | 196 size_t next_index_; |
195 | 197 |
196 // Mock sockets to be returned. | 198 // SocketDataProviders to be returned. |
197 std::vector<T*> sockets_; | 199 std::vector<T*> data_providers_; |
198 }; | 200 }; |
199 | 201 |
200 class MockTCPClientSocket; | 202 class MockTCPClientSocket; |
201 class MockSSLClientSocket; | 203 class MockSSLClientSocket; |
202 | 204 |
203 // ClientSocketFactory which contains arrays of sockets of each type. | 205 // ClientSocketFactory which contains arrays of sockets of each type. |
204 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory | 206 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory |
205 // is asked to create a socket, it takes next entry from appropriate array. | 207 // is asked to create a socket, it takes next entry from appropriate array. |
206 // You can use ResetNextMockIndexes to reset that next entry index for all mock | 208 // You can use ResetNextMockIndexes to reset that next entry index for all mock |
207 // socket types. | 209 // socket types. |
208 class MockClientSocketFactory : public ClientSocketFactory { | 210 class MockClientSocketFactory : public ClientSocketFactory { |
209 public: | 211 public: |
210 void AddMockSocket(MockSocket* socket); | 212 void AddSocketDataProvider(SocketDataProvider* socket); |
211 void AddMockSSLSocket(MockSSLSocket* socket); | 213 void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); |
212 void ResetNextMockIndexes(); | 214 void ResetNextMockIndexes(); |
213 | 215 |
214 // Return |index|-th MockTCPClientSocket (starting from 0) that the factory | 216 // Return |index|-th MockTCPClientSocket (starting from 0) that the factory |
215 // created. | 217 // created. |
216 MockTCPClientSocket* GetMockTCPClientSocket(int index) const; | 218 MockTCPClientSocket* GetMockTCPClientSocket(int index) const; |
217 | 219 |
218 // Return |index|-th MockSSLClientSocket (starting from 0) that the factory | 220 // Return |index|-th MockSSLClientSocket (starting from 0) that the factory |
219 // created. | 221 // created. |
220 MockSSLClientSocket* GetMockSSLClientSocket(int index) const; | 222 MockSSLClientSocket* GetMockSSLClientSocket(int index) const; |
221 | 223 |
222 // ClientSocketFactory | 224 // ClientSocketFactory |
223 virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses); | 225 virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses); |
224 virtual SSLClientSocket* CreateSSLClientSocket( | 226 virtual SSLClientSocket* CreateSSLClientSocket( |
225 ClientSocket* transport_socket, | 227 ClientSocket* transport_socket, |
226 const std::string& hostname, | 228 const std::string& hostname, |
227 const SSLConfig& ssl_config); | 229 const SSLConfig& ssl_config); |
228 | 230 |
229 private: | 231 private: |
230 MockSocketArray<MockSocket> mock_sockets_; | 232 SocketDataProviderArray<SocketDataProvider> mock_data_; |
231 MockSocketArray<MockSSLSocket> mock_ssl_sockets_; | 233 SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; |
232 | 234 |
233 // Store pointers to handed out sockets in case the test wants to get them. | 235 // Store pointers to handed out sockets in case the test wants to get them. |
234 std::vector<MockTCPClientSocket*> tcp_client_sockets_; | 236 std::vector<MockTCPClientSocket*> tcp_client_sockets_; |
235 std::vector<MockSSLClientSocket*> ssl_client_sockets_; | 237 std::vector<MockSSLClientSocket*> ssl_client_sockets_; |
236 }; | 238 }; |
237 | 239 |
238 class MockClientSocket : public net::SSLClientSocket { | 240 class MockClientSocket : public net::SSLClientSocket { |
239 public: | 241 public: |
240 MockClientSocket(); | 242 MockClientSocket(); |
241 | 243 |
(...skipping 24 matching lines...) Expand all Loading... |
266 void RunCallbackAsync(net::CompletionCallback* callback, int result); | 268 void RunCallbackAsync(net::CompletionCallback* callback, int result); |
267 void RunCallback(net::CompletionCallback*, int result); | 269 void RunCallback(net::CompletionCallback*, int result); |
268 | 270 |
269 ScopedRunnableMethodFactory<MockClientSocket> method_factory_; | 271 ScopedRunnableMethodFactory<MockClientSocket> method_factory_; |
270 bool connected_; | 272 bool connected_; |
271 }; | 273 }; |
272 | 274 |
273 class MockTCPClientSocket : public MockClientSocket { | 275 class MockTCPClientSocket : public MockClientSocket { |
274 public: | 276 public: |
275 MockTCPClientSocket(const net::AddressList& addresses, | 277 MockTCPClientSocket(const net::AddressList& addresses, |
276 net::MockSocket* socket); | 278 net::SocketDataProvider* socket); |
277 | 279 |
278 // ClientSocket methods: | 280 // ClientSocket methods: |
279 virtual int Connect(net::CompletionCallback* callback, | 281 virtual int Connect(net::CompletionCallback* callback, |
280 LoadLog* load_log); | 282 LoadLog* load_log); |
281 | 283 |
282 // Socket methods: | 284 // Socket methods: |
283 virtual int Read(net::IOBuffer* buf, int buf_len, | 285 virtual int Read(net::IOBuffer* buf, int buf_len, |
284 net::CompletionCallback* callback); | 286 net::CompletionCallback* callback); |
285 virtual int Write(net::IOBuffer* buf, int buf_len, | 287 virtual int Write(net::IOBuffer* buf, int buf_len, |
286 net::CompletionCallback* callback); | 288 net::CompletionCallback* callback); |
287 | 289 |
288 net::AddressList addresses() const { return addresses_; } | 290 net::AddressList addresses() const { return addresses_; } |
289 | 291 |
290 private: | 292 private: |
291 net::AddressList addresses_; | 293 net::AddressList addresses_; |
292 | 294 |
293 net::MockSocket* data_; | 295 net::SocketDataProvider* data_; |
294 int read_offset_; | 296 int read_offset_; |
295 net::MockRead read_data_; | 297 net::MockRead read_data_; |
296 bool need_read_data_; | 298 bool need_read_data_; |
297 }; | 299 }; |
298 | 300 |
299 class MockSSLClientSocket : public MockClientSocket { | 301 class MockSSLClientSocket : public MockClientSocket { |
300 public: | 302 public: |
301 MockSSLClientSocket( | 303 MockSSLClientSocket( |
302 net::ClientSocket* transport_socket, | 304 net::ClientSocket* transport_socket, |
303 const std::string& hostname, | 305 const std::string& hostname, |
304 const net::SSLConfig& ssl_config, | 306 const net::SSLConfig& ssl_config, |
305 net::MockSSLSocket* socket); | 307 net::SSLSocketDataProvider* socket); |
306 ~MockSSLClientSocket(); | 308 ~MockSSLClientSocket(); |
307 | 309 |
308 virtual void GetSSLInfo(net::SSLInfo* ssl_info); | 310 virtual void GetSSLInfo(net::SSLInfo* ssl_info); |
309 | 311 |
310 virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log); | 312 virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log); |
311 virtual void Disconnect(); | 313 virtual void Disconnect(); |
312 | 314 |
313 // Socket methods: | 315 // Socket methods: |
314 virtual int Read(net::IOBuffer* buf, int buf_len, | 316 virtual int Read(net::IOBuffer* buf, int buf_len, |
315 net::CompletionCallback* callback); | 317 net::CompletionCallback* callback); |
316 virtual int Write(net::IOBuffer* buf, int buf_len, | 318 virtual int Write(net::IOBuffer* buf, int buf_len, |
317 net::CompletionCallback* callback); | 319 net::CompletionCallback* callback); |
318 | 320 |
319 private: | 321 private: |
320 class ConnectCallback; | 322 class ConnectCallback; |
321 | 323 |
322 scoped_ptr<ClientSocket> transport_; | 324 scoped_ptr<ClientSocket> transport_; |
323 net::MockSSLSocket* data_; | 325 net::SSLSocketDataProvider* data_; |
324 }; | 326 }; |
325 | 327 |
326 class TestSocketRequest : public CallbackRunner< Tuple1<int> > { | 328 class TestSocketRequest : public CallbackRunner< Tuple1<int> > { |
327 public: | 329 public: |
328 TestSocketRequest( | 330 TestSocketRequest( |
329 std::vector<TestSocketRequest*>* request_order, | 331 std::vector<TestSocketRequest*>* request_order, |
330 size_t* completion_count) | 332 size_t* completion_count) |
331 : request_order_(request_order), | 333 : request_order_(request_order), |
332 completion_count_(completion_count) { | 334 completion_count_(completion_count) { |
333 DCHECK(request_order); | 335 DCHECK(request_order); |
(...skipping 58 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
392 void ReleaseAllConnections(KeepAlive keep_alive); | 394 void ReleaseAllConnections(KeepAlive keep_alive); |
393 | 395 |
394 ScopedVector<TestSocketRequest> requests_; | 396 ScopedVector<TestSocketRequest> requests_; |
395 std::vector<TestSocketRequest*> request_order_; | 397 std::vector<TestSocketRequest*> request_order_; |
396 size_t completion_count_; | 398 size_t completion_count_; |
397 }; | 399 }; |
398 | 400 |
399 } // namespace net | 401 } // namespace net |
400 | 402 |
401 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ | 403 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ |
OLD | NEW |