OLD | NEW |
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/socket/web_socket_server_socket.h" | 5 #include "net/socket/web_socket_server_socket.h" |
6 | 6 |
7 #include <stdlib.h> | 7 #include <stdlib.h> |
8 #include <algorithm> | 8 #include <algorithm> |
9 | 9 |
10 #include "base/callback_old.h" | 10 #include "base/callback_old.h" |
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
72 | 72 |
73 // TODO(dilmah): consider switching to socket_test_util.h | 73 // TODO(dilmah): consider switching to socket_test_util.h |
74 // Simulates reading from |sample| stream; data supplied in Write() calls are | 74 // Simulates reading from |sample| stream; data supplied in Write() calls are |
75 // stored in |answer| buffer. | 75 // stored in |answer| buffer. |
76 class TestingTransportSocket : public net::Socket { | 76 class TestingTransportSocket : public net::Socket { |
77 public: | 77 public: |
78 TestingTransportSocket( | 78 TestingTransportSocket( |
79 net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) | 79 net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) |
80 : sample_(sample), | 80 : sample_(sample), |
81 answer_(answer), | 81 answer_(answer), |
82 old_final_read_callback_(NULL), | 82 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
83 method_factory_(this) { | |
84 } | 83 } |
85 | 84 |
86 ~TestingTransportSocket() { | 85 ~TestingTransportSocket() { |
87 if (old_final_read_callback_) { | 86 if (!final_read_callback_.is_null()) { |
88 MessageLoop::current()->PostTask(FROM_HERE, | 87 MessageLoop::current()->PostTask(FROM_HERE, |
89 method_factory_.NewRunnableMethod( | 88 base::Bind(&TestingTransportSocket::DoReadCallback, |
90 &TestingTransportSocket::DoOldReadCallback, | 89 weak_factory_.GetWeakPtr(), |
91 old_final_read_callback_, 0)); | 90 final_read_callback_, 0)); |
92 } else if (!final_read_callback_.is_null()) { | |
93 MessageLoop::current()->PostTask( | |
94 FROM_HERE, | |
95 method_factory_.NewRunnableMethod( | |
96 &TestingTransportSocket::DoReadCallback, | |
97 final_read_callback_, 0)); | |
98 } | 91 } |
99 } | 92 } |
100 | 93 |
101 // Socket implementation. | 94 // Socket implementation. |
102 virtual int Read(net::IOBuffer* buf, int buf_len, | 95 virtual int Read(net::IOBuffer* buf, int buf_len, |
103 net::OldCompletionCallback* callback) { | |
104 CHECK_GT(buf_len, 0); | |
105 int remaining = sample_->BytesRemaining(); | |
106 if (remaining < 1) { | |
107 if (old_final_read_callback_ || !final_read_callback_.is_null()) | |
108 return 0; | |
109 old_final_read_callback_ = callback; | |
110 return net::ERR_IO_PENDING; | |
111 } | |
112 int lot = GetRand(1, std::min(remaining, buf_len)); | |
113 std::copy(sample_->data(), sample_->data() + lot, buf->data()); | |
114 sample_->DidConsume(lot); | |
115 if (GetRand(0, 1)) { | |
116 return lot; | |
117 } | |
118 MessageLoop::current()->PostTask(FROM_HERE, | |
119 method_factory_.NewRunnableMethod( | |
120 &TestingTransportSocket::DoOldReadCallback, callback, lot)); | |
121 return net::ERR_IO_PENDING; | |
122 } | |
123 virtual int Read(net::IOBuffer* buf, int buf_len, | |
124 const net::CompletionCallback& callback) { | 96 const net::CompletionCallback& callback) { |
125 CHECK_GT(buf_len, 0); | 97 CHECK_GT(buf_len, 0); |
126 int remaining = sample_->BytesRemaining(); | 98 int remaining = sample_->BytesRemaining(); |
127 if (remaining < 1) { | 99 if (remaining < 1) { |
128 if (old_final_read_callback_ || !final_read_callback_.is_null()) | 100 if (!final_read_callback_.is_null()) |
129 return 0; | 101 return 0; |
130 final_read_callback_ = callback; | 102 final_read_callback_ = callback; |
131 return net::ERR_IO_PENDING; | 103 return net::ERR_IO_PENDING; |
132 } | 104 } |
133 int lot = GetRand(1, std::min(remaining, buf_len)); | 105 int lot = GetRand(1, std::min(remaining, buf_len)); |
134 std::copy(sample_->data(), sample_->data() + lot, buf->data()); | 106 std::copy(sample_->data(), sample_->data() + lot, buf->data()); |
135 sample_->DidConsume(lot); | 107 sample_->DidConsume(lot); |
136 if (GetRand(0, 1)) { | 108 if (GetRand(0, 1)) { |
137 return lot; | 109 return lot; |
138 } | 110 } |
139 MessageLoop::current()->PostTask(FROM_HERE, | 111 MessageLoop::current()->PostTask( |
140 method_factory_.NewRunnableMethod( | 112 FROM_HERE, |
141 &TestingTransportSocket::DoReadCallback, callback, lot)); | 113 base::Bind(&TestingTransportSocket::DoReadCallback, |
| 114 weak_factory_.GetWeakPtr(), callback, lot)); |
142 return net::ERR_IO_PENDING; | 115 return net::ERR_IO_PENDING; |
143 } | 116 } |
144 | 117 |
145 virtual int Write(net::IOBuffer* buf, int buf_len, | 118 virtual int Write(net::IOBuffer* buf, int buf_len, |
146 net::OldCompletionCallback* callback) { | 119 const net::CompletionCallback& callback) { |
147 CHECK_GT(buf_len, 0); | 120 CHECK_GT(buf_len, 0); |
148 int remaining = answer_->BytesRemaining(); | 121 int remaining = answer_->BytesRemaining(); |
149 CHECK_GE(remaining, buf_len); | 122 CHECK_GE(remaining, buf_len); |
150 int lot = std::min(remaining, buf_len); | 123 int lot = std::min(remaining, buf_len); |
151 if (GetRand(0, 1)) | 124 if (GetRand(0, 1)) |
152 lot = GetRand(1, lot); | 125 lot = GetRand(1, lot); |
153 std::copy(buf->data(), buf->data() + lot, answer_->data()); | 126 std::copy(buf->data(), buf->data() + lot, answer_->data()); |
154 answer_->DidConsume(lot); | 127 answer_->DidConsume(lot); |
155 if (GetRand(0, 1)) { | 128 if (GetRand(0, 1)) { |
156 return lot; | 129 return lot; |
157 } | 130 } |
158 MessageLoop::current()->PostTask(FROM_HERE, | 131 MessageLoop::current()->PostTask( |
159 method_factory_.NewRunnableMethod( | 132 FROM_HERE, |
160 &TestingTransportSocket::DoWriteCallback, callback, lot)); | 133 base::Bind(&TestingTransportSocket::DoWriteCallback, |
| 134 weak_factory_.GetWeakPtr(), callback, lot)); |
161 return net::ERR_IO_PENDING; | 135 return net::ERR_IO_PENDING; |
162 } | 136 } |
163 | 137 |
164 virtual bool SetReceiveBufferSize(int32 size) { | 138 virtual bool SetReceiveBufferSize(int32 size) { |
165 return true; | 139 return true; |
166 } | 140 } |
167 | 141 |
168 virtual bool SetSendBufferSize(int32 size) { | 142 virtual bool SetSendBufferSize(int32 size) { |
169 return true; | 143 return true; |
170 } | 144 } |
171 | 145 |
172 net::DrainableIOBuffer* answer() { return answer_.get(); } | 146 net::DrainableIOBuffer* answer() { return answer_.get(); } |
173 | 147 |
174 void DoOldReadCallback(net::OldCompletionCallback* callback, int result) { | |
175 if (result == 0 && !is_closed_) { | |
176 MessageLoop::current()->PostTask(FROM_HERE, | |
177 method_factory_.NewRunnableMethod( | |
178 &TestingTransportSocket::DoOldReadCallback, callback, 0)); | |
179 } else { | |
180 if (callback) | |
181 callback->Run(result); | |
182 } | |
183 } | |
184 void DoReadCallback(const net::CompletionCallback& callback, int result) { | 148 void DoReadCallback(const net::CompletionCallback& callback, int result) { |
185 if (result == 0 && !is_closed_) { | 149 if (result == 0 && !is_closed_) { |
186 MessageLoop::current()->PostTask(FROM_HERE, | 150 MessageLoop::current()->PostTask( |
187 method_factory_.NewRunnableMethod( | 151 FROM_HERE, |
188 &TestingTransportSocket::DoReadCallback, callback, 0)); | 152 base::Bind( |
| 153 &TestingTransportSocket::DoReadCallback, |
| 154 weak_factory_.GetWeakPtr(), callback, 0)); |
189 } else { | 155 } else { |
190 if (!callback.is_null()) | 156 if (!callback.is_null()) |
191 callback.Run(result); | 157 callback.Run(result); |
192 } | 158 } |
193 } | 159 } |
194 | 160 |
195 void DoWriteCallback(net::OldCompletionCallback* callback, int result) { | 161 void DoWriteCallback(const net::CompletionCallback& callback, int result) { |
196 if (callback) | 162 if (!callback.is_null()) |
197 callback->Run(result); | 163 callback.Run(result); |
198 } | 164 } |
199 | 165 |
200 bool is_closed_; | 166 bool is_closed_; |
201 | 167 |
202 // Data to return for Read requests. | 168 // Data to return for Read requests. |
203 scoped_refptr<net::DrainableIOBuffer> sample_; | 169 scoped_refptr<net::DrainableIOBuffer> sample_; |
204 | 170 |
205 // Data pushed to us by server socket (using Write calls). | 171 // Data pushed to us by server socket (using Write calls). |
206 scoped_refptr<net::DrainableIOBuffer> answer_; | 172 scoped_refptr<net::DrainableIOBuffer> answer_; |
207 | 173 |
208 // Final read callback to report zero (zero stands for EOF). | 174 // Final read callback to report zero (zero stands for EOF). |
209 net::OldCompletionCallback* old_final_read_callback_; | |
210 net::CompletionCallback final_read_callback_; | 175 net::CompletionCallback final_read_callback_; |
211 | 176 |
212 ScopedRunnableMethodFactory<TestingTransportSocket> method_factory_; | 177 base::WeakPtrFactory<TestingTransportSocket> weak_factory_; |
213 }; | 178 }; |
214 | 179 |
215 class Validator : public net::WebSocketServerSocket::Delegate { | 180 class Validator : public net::WebSocketServerSocket::Delegate { |
216 public: | 181 public: |
217 Validator(const std::string& resource, | 182 Validator(const std::string& resource, |
218 const std::string& origin, | 183 const std::string& origin, |
219 const std::string& host) | 184 const std::string& host) |
220 : resource_(resource), origin_(origin), host_(host) { | 185 : resource_(resource), origin_(origin), host_(host) { |
221 } | 186 } |
222 | 187 |
(...skipping 26 matching lines...) Expand all Loading... |
249 char ReferenceSeq(unsigned n, unsigned salt) { | 214 char ReferenceSeq(unsigned n, unsigned salt) { |
250 return (salt * 2 + n * 3) % ('z' - 'a') + 'a'; | 215 return (salt * 2 + n * 3) % ('z' - 'a') + 'a'; |
251 } | 216 } |
252 | 217 |
253 class ReadWriteTracker { | 218 class ReadWriteTracker { |
254 public: | 219 public: |
255 ReadWriteTracker( | 220 ReadWriteTracker( |
256 net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write) | 221 net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write) |
257 : ws_(ws), | 222 : ws_(ws), |
258 buf_size_(1 << 14), | 223 buf_size_(1 << 14), |
259 accept_callback_(NewCallback(this, &ReadWriteTracker::OnAccept)), | 224 ALLOW_THIS_IN_INITIALIZER_LIST( |
260 read_callback_(NewCallback(this, &ReadWriteTracker::OnRead)), | 225 accept_callback_(this, &ReadWriteTracker::OnAccept)), |
261 write_callback_(NewCallback(this, &ReadWriteTracker::OnWrite)), | |
262 read_buf_(new net::IOBuffer(buf_size_)), | 226 read_buf_(new net::IOBuffer(buf_size_)), |
263 write_buf_(new net::IOBuffer(buf_size_)), | 227 write_buf_(new net::IOBuffer(buf_size_)), |
264 bytes_remaining_to_read_(bytes_to_read), | 228 bytes_remaining_to_read_(bytes_to_read), |
265 bytes_remaining_to_write_(bytes_to_write), | 229 bytes_remaining_to_write_(bytes_to_write), |
266 read_initiated_(false), | 230 read_initiated_(false), |
267 write_initiated_(false), | 231 write_initiated_(false), |
268 got_final_zero_(false) { | 232 got_final_zero_(false) { |
269 int rv = ws_->Accept(accept_callback_.get()); | 233 int rv = ws_->Accept(&accept_callback_); |
270 if (rv != net::ERR_IO_PENDING) | 234 if (rv != net::ERR_IO_PENDING) |
271 OnAccept(rv); | 235 OnAccept(rv); |
272 } | 236 } |
273 | 237 |
274 ~ReadWriteTracker() { | 238 ~ReadWriteTracker() { |
275 CHECK_EQ(bytes_remaining_to_write_, 0); | 239 CHECK_EQ(bytes_remaining_to_write_, 0); |
276 CHECK_EQ(bytes_remaining_to_read_, 0); | 240 CHECK_EQ(bytes_remaining_to_read_, 0); |
277 } | 241 } |
278 | 242 |
279 void OnAccept(int result) { | 243 void OnAccept(int result) { |
280 ASSERT_EQ(result, 0); | 244 ASSERT_EQ(result, 0); |
281 if (GetRand(0, 1)) { | 245 if (GetRand(0, 1)) { |
282 DoRead(); | 246 DoRead(); |
283 DoWrite(); | 247 DoWrite(); |
284 } else { | 248 } else { |
285 DoWrite(); | 249 DoWrite(); |
286 DoRead(); | 250 DoRead(); |
287 } | 251 } |
288 } | 252 } |
289 | 253 |
290 void DoWrite() { | 254 void DoWrite() { |
291 if (bytes_remaining_to_write_ < 1) | 255 if (bytes_remaining_to_write_ < 1) |
292 return; | 256 return; |
293 int lot = GetRand(1, bytes_remaining_to_write_); | 257 int lot = GetRand(1, bytes_remaining_to_write_); |
294 lot = std::min(lot, buf_size_); | 258 lot = std::min(lot, buf_size_); |
295 for (int i = 0; i < lot; ++i) | 259 for (int i = 0; i < lot; ++i) |
296 write_buf_->data()[i] = ReferenceSeq( | 260 write_buf_->data()[i] = ReferenceSeq( |
297 bytes_remaining_to_write_ - i - 1, kWriteSalt); | 261 bytes_remaining_to_write_ - i - 1, kWriteSalt); |
298 int rv = ws_->Write(write_buf_, lot, write_callback_.get()); | 262 int rv = ws_->Write(write_buf_, lot, base::Bind(&ReadWriteTracker::OnWrite, |
| 263 base::Unretained(this))); |
299 if (rv != net::ERR_IO_PENDING) | 264 if (rv != net::ERR_IO_PENDING) |
300 OnWrite(rv); | 265 OnWrite(rv); |
301 } | 266 } |
302 | 267 |
303 void DoRead() { | 268 void DoRead() { |
304 int lot = GetRand(1, buf_size_); | 269 int lot = GetRand(1, buf_size_); |
305 if (bytes_remaining_to_read_ < 1) { | 270 if (bytes_remaining_to_read_ < 1) { |
306 if (got_final_zero_) | 271 if (got_final_zero_) |
307 return; | 272 return; |
308 } else { | 273 } else { |
309 lot = GetRand(1, bytes_remaining_to_read_); | 274 lot = GetRand(1, bytes_remaining_to_read_); |
310 lot = std::min(lot, buf_size_); | 275 lot = std::min(lot, buf_size_); |
311 } | 276 } |
312 int rv = ws_->Read(read_buf_, lot, read_callback_.get()); | 277 int rv = ws_->Read(read_buf_, lot, base::Bind(&ReadWriteTracker::OnRead, |
| 278 base::Unretained(this))); |
313 if (rv != net::ERR_IO_PENDING) | 279 if (rv != net::ERR_IO_PENDING) |
314 OnRead(rv); | 280 OnRead(rv); |
315 } | 281 } |
316 | 282 |
317 void OnWrite(int result) { | 283 void OnWrite(int result) { |
318 ASSERT_GT(result, 0); | 284 ASSERT_GT(result, 0); |
319 ASSERT_LE(result, bytes_remaining_to_write_); | 285 ASSERT_LE(result, bytes_remaining_to_write_); |
320 bytes_remaining_to_write_ -= result; | 286 bytes_remaining_to_write_ -= result; |
321 DoWrite(); | 287 DoWrite(); |
322 } | 288 } |
(...skipping 10 matching lines...) Expand all Loading... |
333 ASSERT_EQ(read_buf_->data()[i], ReferenceSeq( | 299 ASSERT_EQ(read_buf_->data()[i], ReferenceSeq( |
334 bytes_remaining_to_read_ - i - 1, kReadSalt)); | 300 bytes_remaining_to_read_ - i - 1, kReadSalt)); |
335 } | 301 } |
336 bytes_remaining_to_read_ -= result; | 302 bytes_remaining_to_read_ -= result; |
337 DoRead(); | 303 DoRead(); |
338 } | 304 } |
339 | 305 |
340 private: | 306 private: |
341 net::WebSocketServerSocket* const ws_; | 307 net::WebSocketServerSocket* const ws_; |
342 int const buf_size_; | 308 int const buf_size_; |
343 scoped_ptr<net::OldCompletionCallback> accept_callback_; | 309 net::OldCompletionCallbackImpl<ReadWriteTracker> accept_callback_; |
344 scoped_ptr<net::OldCompletionCallback> read_callback_; | |
345 scoped_ptr<net::OldCompletionCallback> write_callback_; | |
346 scoped_refptr<net::IOBuffer> read_buf_; | 310 scoped_refptr<net::IOBuffer> read_buf_; |
347 scoped_refptr<net::IOBuffer> write_buf_; | 311 scoped_refptr<net::IOBuffer> write_buf_; |
348 int bytes_remaining_to_read_; | 312 int bytes_remaining_to_read_; |
349 int bytes_remaining_to_write_; | 313 int bytes_remaining_to_write_; |
350 bool read_initiated_; | 314 bool read_initiated_; |
351 bool write_initiated_; | 315 bool write_initiated_; |
352 bool got_final_zero_; | 316 bool got_final_zero_; |
353 }; | 317 }; |
354 | 318 |
355 } // namespace | 319 } // namespace |
(...skipping 269 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
625 MessageLoop::current()->RunAllPending(); | 589 MessageLoop::current()->RunAllPending(); |
626 | 590 |
627 for (size_t i = kill_list.size(); i--;) | 591 for (size_t i = kill_list.size(); i--;) |
628 delete kill_list[i]; | 592 delete kill_list[i]; |
629 for (size_t i = tracker_list.size(); i--;) | 593 for (size_t i = tracker_list.size(); i--;) |
630 delete tracker_list[i]; | 594 delete tracker_list[i]; |
631 MessageLoop::current()->RunAllPending(); | 595 MessageLoop::current()->RunAllPending(); |
632 } | 596 } |
633 | 597 |
634 } // namespace net | 598 } // namespace net |
OLD | NEW |