| OLD | NEW |
| 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/socket/web_socket_server_socket.h" | 5 #include "net/socket/web_socket_server_socket.h" |
| 6 | 6 |
| 7 #include <stdlib.h> | 7 #include <stdlib.h> |
| 8 #include <algorithm> | 8 #include <algorithm> |
| 9 | 9 |
| 10 #include "base/callback_old.h" | 10 #include "base/callback_old.h" |
| (...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 72 | 72 |
| 73 // TODO(dilmah): consider switching to socket_test_util.h | 73 // TODO(dilmah): consider switching to socket_test_util.h |
| 74 // Simulates reading from |sample| stream; data supplied in Write() calls are | 74 // Simulates reading from |sample| stream; data supplied in Write() calls are |
| 75 // stored in |answer| buffer. | 75 // stored in |answer| buffer. |
| 76 class TestingTransportSocket : public net::Socket { | 76 class TestingTransportSocket : public net::Socket { |
| 77 public: | 77 public: |
| 78 TestingTransportSocket( | 78 TestingTransportSocket( |
| 79 net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) | 79 net::DrainableIOBuffer* sample, net::DrainableIOBuffer* answer) |
| 80 : sample_(sample), | 80 : sample_(sample), |
| 81 answer_(answer), | 81 answer_(answer), |
| 82 old_final_read_callback_(NULL), | 82 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
| 83 method_factory_(this) { | |
| 84 } | 83 } |
| 85 | 84 |
| 86 ~TestingTransportSocket() { | 85 ~TestingTransportSocket() { |
| 87 if (old_final_read_callback_) { | 86 if (!final_read_callback_.is_null()) { |
| 88 MessageLoop::current()->PostTask(FROM_HERE, | 87 MessageLoop::current()->PostTask(FROM_HERE, |
| 89 method_factory_.NewRunnableMethod( | 88 base::Bind(&TestingTransportSocket::DoReadCallback, |
| 90 &TestingTransportSocket::DoOldReadCallback, | 89 weak_factory_.GetWeakPtr(), |
| 91 old_final_read_callback_, 0)); | 90 final_read_callback_, 0)); |
| 92 } else if (!final_read_callback_.is_null()) { | |
| 93 MessageLoop::current()->PostTask( | |
| 94 FROM_HERE, | |
| 95 method_factory_.NewRunnableMethod( | |
| 96 &TestingTransportSocket::DoReadCallback, | |
| 97 final_read_callback_, 0)); | |
| 98 } | 91 } |
| 99 } | 92 } |
| 100 | 93 |
| 101 // Socket implementation. | 94 // Socket implementation. |
| 102 virtual int Read(net::IOBuffer* buf, int buf_len, | 95 virtual int Read(net::IOBuffer* buf, int buf_len, |
| 103 net::OldCompletionCallback* callback) { | |
| 104 CHECK_GT(buf_len, 0); | |
| 105 int remaining = sample_->BytesRemaining(); | |
| 106 if (remaining < 1) { | |
| 107 if (old_final_read_callback_ || !final_read_callback_.is_null()) | |
| 108 return 0; | |
| 109 old_final_read_callback_ = callback; | |
| 110 return net::ERR_IO_PENDING; | |
| 111 } | |
| 112 int lot = GetRand(1, std::min(remaining, buf_len)); | |
| 113 std::copy(sample_->data(), sample_->data() + lot, buf->data()); | |
| 114 sample_->DidConsume(lot); | |
| 115 if (GetRand(0, 1)) { | |
| 116 return lot; | |
| 117 } | |
| 118 MessageLoop::current()->PostTask(FROM_HERE, | |
| 119 method_factory_.NewRunnableMethod( | |
| 120 &TestingTransportSocket::DoOldReadCallback, callback, lot)); | |
| 121 return net::ERR_IO_PENDING; | |
| 122 } | |
| 123 virtual int Read(net::IOBuffer* buf, int buf_len, | |
| 124 const net::CompletionCallback& callback) { | 96 const net::CompletionCallback& callback) { |
| 125 CHECK_GT(buf_len, 0); | 97 CHECK_GT(buf_len, 0); |
| 126 int remaining = sample_->BytesRemaining(); | 98 int remaining = sample_->BytesRemaining(); |
| 127 if (remaining < 1) { | 99 if (remaining < 1) { |
| 128 if (old_final_read_callback_ || !final_read_callback_.is_null()) | 100 if (!final_read_callback_.is_null()) |
| 129 return 0; | 101 return 0; |
| 130 final_read_callback_ = callback; | 102 final_read_callback_ = callback; |
| 131 return net::ERR_IO_PENDING; | 103 return net::ERR_IO_PENDING; |
| 132 } | 104 } |
| 133 int lot = GetRand(1, std::min(remaining, buf_len)); | 105 int lot = GetRand(1, std::min(remaining, buf_len)); |
| 134 std::copy(sample_->data(), sample_->data() + lot, buf->data()); | 106 std::copy(sample_->data(), sample_->data() + lot, buf->data()); |
| 135 sample_->DidConsume(lot); | 107 sample_->DidConsume(lot); |
| 136 if (GetRand(0, 1)) { | 108 if (GetRand(0, 1)) { |
| 137 return lot; | 109 return lot; |
| 138 } | 110 } |
| 139 MessageLoop::current()->PostTask(FROM_HERE, | 111 MessageLoop::current()->PostTask( |
| 140 method_factory_.NewRunnableMethod( | 112 FROM_HERE, |
| 141 &TestingTransportSocket::DoReadCallback, callback, lot)); | 113 base::Bind(&TestingTransportSocket::DoReadCallback, |
| 114 weak_factory_.GetWeakPtr(), callback, lot)); |
| 142 return net::ERR_IO_PENDING; | 115 return net::ERR_IO_PENDING; |
| 143 } | 116 } |
| 144 | 117 |
| 145 virtual int Write(net::IOBuffer* buf, int buf_len, | 118 virtual int Write(net::IOBuffer* buf, int buf_len, |
| 146 net::OldCompletionCallback* callback) { | 119 const net::CompletionCallback& callback) { |
| 147 CHECK_GT(buf_len, 0); | 120 CHECK_GT(buf_len, 0); |
| 148 int remaining = answer_->BytesRemaining(); | 121 int remaining = answer_->BytesRemaining(); |
| 149 CHECK_GE(remaining, buf_len); | 122 CHECK_GE(remaining, buf_len); |
| 150 int lot = std::min(remaining, buf_len); | 123 int lot = std::min(remaining, buf_len); |
| 151 if (GetRand(0, 1)) | 124 if (GetRand(0, 1)) |
| 152 lot = GetRand(1, lot); | 125 lot = GetRand(1, lot); |
| 153 std::copy(buf->data(), buf->data() + lot, answer_->data()); | 126 std::copy(buf->data(), buf->data() + lot, answer_->data()); |
| 154 answer_->DidConsume(lot); | 127 answer_->DidConsume(lot); |
| 155 if (GetRand(0, 1)) { | 128 if (GetRand(0, 1)) { |
| 156 return lot; | 129 return lot; |
| 157 } | 130 } |
| 158 MessageLoop::current()->PostTask(FROM_HERE, | 131 MessageLoop::current()->PostTask( |
| 159 method_factory_.NewRunnableMethod( | 132 FROM_HERE, |
| 160 &TestingTransportSocket::DoWriteCallback, callback, lot)); | 133 base::Bind(&TestingTransportSocket::DoWriteCallback, |
| 134 weak_factory_.GetWeakPtr(), callback, lot)); |
| 161 return net::ERR_IO_PENDING; | 135 return net::ERR_IO_PENDING; |
| 162 } | 136 } |
| 163 | 137 |
| 164 virtual bool SetReceiveBufferSize(int32 size) { | 138 virtual bool SetReceiveBufferSize(int32 size) { |
| 165 return true; | 139 return true; |
| 166 } | 140 } |
| 167 | 141 |
| 168 virtual bool SetSendBufferSize(int32 size) { | 142 virtual bool SetSendBufferSize(int32 size) { |
| 169 return true; | 143 return true; |
| 170 } | 144 } |
| 171 | 145 |
| 172 net::DrainableIOBuffer* answer() { return answer_.get(); } | 146 net::DrainableIOBuffer* answer() { return answer_.get(); } |
| 173 | 147 |
| 174 void DoOldReadCallback(net::OldCompletionCallback* callback, int result) { | |
| 175 if (result == 0 && !is_closed_) { | |
| 176 MessageLoop::current()->PostTask(FROM_HERE, | |
| 177 method_factory_.NewRunnableMethod( | |
| 178 &TestingTransportSocket::DoOldReadCallback, callback, 0)); | |
| 179 } else { | |
| 180 if (callback) | |
| 181 callback->Run(result); | |
| 182 } | |
| 183 } | |
| 184 void DoReadCallback(const net::CompletionCallback& callback, int result) { | 148 void DoReadCallback(const net::CompletionCallback& callback, int result) { |
| 185 if (result == 0 && !is_closed_) { | 149 if (result == 0 && !is_closed_) { |
| 186 MessageLoop::current()->PostTask(FROM_HERE, | 150 MessageLoop::current()->PostTask( |
| 187 method_factory_.NewRunnableMethod( | 151 FROM_HERE, |
| 188 &TestingTransportSocket::DoReadCallback, callback, 0)); | 152 base::Bind( |
| 153 &TestingTransportSocket::DoReadCallback, |
| 154 weak_factory_.GetWeakPtr(), callback, 0)); |
| 189 } else { | 155 } else { |
| 190 if (!callback.is_null()) | 156 if (!callback.is_null()) |
| 191 callback.Run(result); | 157 callback.Run(result); |
| 192 } | 158 } |
| 193 } | 159 } |
| 194 | 160 |
| 195 void DoWriteCallback(net::OldCompletionCallback* callback, int result) { | 161 void DoWriteCallback(const net::CompletionCallback& callback, int result) { |
| 196 if (callback) | 162 if (!callback.is_null()) |
| 197 callback->Run(result); | 163 callback.Run(result); |
| 198 } | 164 } |
| 199 | 165 |
| 200 bool is_closed_; | 166 bool is_closed_; |
| 201 | 167 |
| 202 // Data to return for Read requests. | 168 // Data to return for Read requests. |
| 203 scoped_refptr<net::DrainableIOBuffer> sample_; | 169 scoped_refptr<net::DrainableIOBuffer> sample_; |
| 204 | 170 |
| 205 // Data pushed to us by server socket (using Write calls). | 171 // Data pushed to us by server socket (using Write calls). |
| 206 scoped_refptr<net::DrainableIOBuffer> answer_; | 172 scoped_refptr<net::DrainableIOBuffer> answer_; |
| 207 | 173 |
| 208 // Final read callback to report zero (zero stands for EOF). | 174 // Final read callback to report zero (zero stands for EOF). |
| 209 net::OldCompletionCallback* old_final_read_callback_; | |
| 210 net::CompletionCallback final_read_callback_; | 175 net::CompletionCallback final_read_callback_; |
| 211 | 176 |
| 212 ScopedRunnableMethodFactory<TestingTransportSocket> method_factory_; | 177 base::WeakPtrFactory<TestingTransportSocket> weak_factory_; |
| 213 }; | 178 }; |
| 214 | 179 |
| 215 class Validator : public net::WebSocketServerSocket::Delegate { | 180 class Validator : public net::WebSocketServerSocket::Delegate { |
| 216 public: | 181 public: |
| 217 Validator(const std::string& resource, | 182 Validator(const std::string& resource, |
| 218 const std::string& origin, | 183 const std::string& origin, |
| 219 const std::string& host) | 184 const std::string& host) |
| 220 : resource_(resource), origin_(origin), host_(host) { | 185 : resource_(resource), origin_(origin), host_(host) { |
| 221 } | 186 } |
| 222 | 187 |
| (...skipping 26 matching lines...) Expand all Loading... |
| 249 char ReferenceSeq(unsigned n, unsigned salt) { | 214 char ReferenceSeq(unsigned n, unsigned salt) { |
| 250 return (salt * 2 + n * 3) % ('z' - 'a') + 'a'; | 215 return (salt * 2 + n * 3) % ('z' - 'a') + 'a'; |
| 251 } | 216 } |
| 252 | 217 |
| 253 class ReadWriteTracker { | 218 class ReadWriteTracker { |
| 254 public: | 219 public: |
| 255 ReadWriteTracker( | 220 ReadWriteTracker( |
| 256 net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write) | 221 net::WebSocketServerSocket* ws, int bytes_to_read, int bytes_to_write) |
| 257 : ws_(ws), | 222 : ws_(ws), |
| 258 buf_size_(1 << 14), | 223 buf_size_(1 << 14), |
| 259 accept_callback_(NewCallback(this, &ReadWriteTracker::OnAccept)), | 224 ALLOW_THIS_IN_INITIALIZER_LIST( |
| 260 read_callback_(NewCallback(this, &ReadWriteTracker::OnRead)), | 225 accept_callback_(this, &ReadWriteTracker::OnAccept)), |
| 261 write_callback_(NewCallback(this, &ReadWriteTracker::OnWrite)), | |
| 262 read_buf_(new net::IOBuffer(buf_size_)), | 226 read_buf_(new net::IOBuffer(buf_size_)), |
| 263 write_buf_(new net::IOBuffer(buf_size_)), | 227 write_buf_(new net::IOBuffer(buf_size_)), |
| 264 bytes_remaining_to_read_(bytes_to_read), | 228 bytes_remaining_to_read_(bytes_to_read), |
| 265 bytes_remaining_to_write_(bytes_to_write), | 229 bytes_remaining_to_write_(bytes_to_write), |
| 266 read_initiated_(false), | 230 read_initiated_(false), |
| 267 write_initiated_(false), | 231 write_initiated_(false), |
| 268 got_final_zero_(false) { | 232 got_final_zero_(false) { |
| 269 int rv = ws_->Accept(accept_callback_.get()); | 233 int rv = ws_->Accept(&accept_callback_); |
| 270 if (rv != net::ERR_IO_PENDING) | 234 if (rv != net::ERR_IO_PENDING) |
| 271 OnAccept(rv); | 235 OnAccept(rv); |
| 272 } | 236 } |
| 273 | 237 |
| 274 ~ReadWriteTracker() { | 238 ~ReadWriteTracker() { |
| 275 CHECK_EQ(bytes_remaining_to_write_, 0); | 239 CHECK_EQ(bytes_remaining_to_write_, 0); |
| 276 CHECK_EQ(bytes_remaining_to_read_, 0); | 240 CHECK_EQ(bytes_remaining_to_read_, 0); |
| 277 } | 241 } |
| 278 | 242 |
| 279 void OnAccept(int result) { | 243 void OnAccept(int result) { |
| 280 ASSERT_EQ(result, 0); | 244 ASSERT_EQ(result, 0); |
| 281 if (GetRand(0, 1)) { | 245 if (GetRand(0, 1)) { |
| 282 DoRead(); | 246 DoRead(); |
| 283 DoWrite(); | 247 DoWrite(); |
| 284 } else { | 248 } else { |
| 285 DoWrite(); | 249 DoWrite(); |
| 286 DoRead(); | 250 DoRead(); |
| 287 } | 251 } |
| 288 } | 252 } |
| 289 | 253 |
| 290 void DoWrite() { | 254 void DoWrite() { |
| 291 if (bytes_remaining_to_write_ < 1) | 255 if (bytes_remaining_to_write_ < 1) |
| 292 return; | 256 return; |
| 293 int lot = GetRand(1, bytes_remaining_to_write_); | 257 int lot = GetRand(1, bytes_remaining_to_write_); |
| 294 lot = std::min(lot, buf_size_); | 258 lot = std::min(lot, buf_size_); |
| 295 for (int i = 0; i < lot; ++i) | 259 for (int i = 0; i < lot; ++i) |
| 296 write_buf_->data()[i] = ReferenceSeq( | 260 write_buf_->data()[i] = ReferenceSeq( |
| 297 bytes_remaining_to_write_ - i - 1, kWriteSalt); | 261 bytes_remaining_to_write_ - i - 1, kWriteSalt); |
| 298 int rv = ws_->Write(write_buf_, lot, write_callback_.get()); | 262 int rv = ws_->Write(write_buf_, lot, base::Bind(&ReadWriteTracker::OnWrite, |
| 263 base::Unretained(this))); |
| 299 if (rv != net::ERR_IO_PENDING) | 264 if (rv != net::ERR_IO_PENDING) |
| 300 OnWrite(rv); | 265 OnWrite(rv); |
| 301 } | 266 } |
| 302 | 267 |
| 303 void DoRead() { | 268 void DoRead() { |
| 304 int lot = GetRand(1, buf_size_); | 269 int lot = GetRand(1, buf_size_); |
| 305 if (bytes_remaining_to_read_ < 1) { | 270 if (bytes_remaining_to_read_ < 1) { |
| 306 if (got_final_zero_) | 271 if (got_final_zero_) |
| 307 return; | 272 return; |
| 308 } else { | 273 } else { |
| 309 lot = GetRand(1, bytes_remaining_to_read_); | 274 lot = GetRand(1, bytes_remaining_to_read_); |
| 310 lot = std::min(lot, buf_size_); | 275 lot = std::min(lot, buf_size_); |
| 311 } | 276 } |
| 312 int rv = ws_->Read(read_buf_, lot, read_callback_.get()); | 277 int rv = ws_->Read(read_buf_, lot, base::Bind(&ReadWriteTracker::OnRead, |
| 278 base::Unretained(this))); |
| 313 if (rv != net::ERR_IO_PENDING) | 279 if (rv != net::ERR_IO_PENDING) |
| 314 OnRead(rv); | 280 OnRead(rv); |
| 315 } | 281 } |
| 316 | 282 |
| 317 void OnWrite(int result) { | 283 void OnWrite(int result) { |
| 318 ASSERT_GT(result, 0); | 284 ASSERT_GT(result, 0); |
| 319 ASSERT_LE(result, bytes_remaining_to_write_); | 285 ASSERT_LE(result, bytes_remaining_to_write_); |
| 320 bytes_remaining_to_write_ -= result; | 286 bytes_remaining_to_write_ -= result; |
| 321 DoWrite(); | 287 DoWrite(); |
| 322 } | 288 } |
| (...skipping 10 matching lines...) Expand all Loading... |
| 333 ASSERT_EQ(read_buf_->data()[i], ReferenceSeq( | 299 ASSERT_EQ(read_buf_->data()[i], ReferenceSeq( |
| 334 bytes_remaining_to_read_ - i - 1, kReadSalt)); | 300 bytes_remaining_to_read_ - i - 1, kReadSalt)); |
| 335 } | 301 } |
| 336 bytes_remaining_to_read_ -= result; | 302 bytes_remaining_to_read_ -= result; |
| 337 DoRead(); | 303 DoRead(); |
| 338 } | 304 } |
| 339 | 305 |
| 340 private: | 306 private: |
| 341 net::WebSocketServerSocket* const ws_; | 307 net::WebSocketServerSocket* const ws_; |
| 342 int const buf_size_; | 308 int const buf_size_; |
| 343 scoped_ptr<net::OldCompletionCallback> accept_callback_; | 309 net::OldCompletionCallbackImpl<ReadWriteTracker> accept_callback_; |
| 344 scoped_ptr<net::OldCompletionCallback> read_callback_; | |
| 345 scoped_ptr<net::OldCompletionCallback> write_callback_; | |
| 346 scoped_refptr<net::IOBuffer> read_buf_; | 310 scoped_refptr<net::IOBuffer> read_buf_; |
| 347 scoped_refptr<net::IOBuffer> write_buf_; | 311 scoped_refptr<net::IOBuffer> write_buf_; |
| 348 int bytes_remaining_to_read_; | 312 int bytes_remaining_to_read_; |
| 349 int bytes_remaining_to_write_; | 313 int bytes_remaining_to_write_; |
| 350 bool read_initiated_; | 314 bool read_initiated_; |
| 351 bool write_initiated_; | 315 bool write_initiated_; |
| 352 bool got_final_zero_; | 316 bool got_final_zero_; |
| 353 }; | 317 }; |
| 354 | 318 |
| 355 } // namespace | 319 } // namespace |
| (...skipping 269 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 625 MessageLoop::current()->RunAllPending(); | 589 MessageLoop::current()->RunAllPending(); |
| 626 | 590 |
| 627 for (size_t i = kill_list.size(); i--;) | 591 for (size_t i = kill_list.size(); i--;) |
| 628 delete kill_list[i]; | 592 delete kill_list[i]; |
| 629 for (size_t i = tracker_list.size(); i--;) | 593 for (size_t i = tracker_list.size(); i--;) |
| 630 delete tracker_list[i]; | 594 delete tracker_list[i]; |
| 631 MessageLoop::current()->RunAllPending(); | 595 MessageLoop::current()->RunAllPending(); |
| 632 } | 596 } |
| 633 | 597 |
| 634 } // namespace net | 598 } // namespace net |
| OLD | NEW |