| OLD | NEW |
| (Empty) |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "net/socket/web_socket_server_socket.h" | |
| 6 | |
| 7 #include <algorithm> | |
| 8 #include <deque> | |
| 9 #include <limits> | |
| 10 #include <map> | |
| 11 #include <vector> | |
| 12 | |
| 13 #include "base/basictypes.h" | |
| 14 #include "base/bind.h" | |
| 15 #include "base/bind_helpers.h" | |
| 16 #include "base/logging.h" | |
| 17 #include "base/md5.h" | |
| 18 #include "base/memory/ref_counted.h" | |
| 19 #include "base/memory/scoped_ptr.h" | |
| 20 #include "base/memory/weak_ptr.h" | |
| 21 #include "base/message_loop.h" | |
| 22 #include "base/string_util.h" | |
| 23 #include "base/sys_byteorder.h" | |
| 24 #include "googleurl/src/gurl.h" | |
| 25 #include "net/base/completion_callback.h" | |
| 26 #include "net/base/io_buffer.h" | |
| 27 #include "net/base/net_errors.h" | |
| 28 | |
| 29 namespace { | |
| 30 | |
| 31 const size_t kHandshakeLimitBytes = 1 << 14; | |
| 32 | |
| 33 const char kCrOctet = '\r'; | |
| 34 COMPILE_ASSERT(kCrOctet == '\x0d', ASCII); | |
| 35 const char kLfOctet = '\n'; | |
| 36 COMPILE_ASSERT(kLfOctet == '\x0a', ASCII); | |
| 37 const char kSpaceOctet = ' '; | |
| 38 COMPILE_ASSERT(kSpaceOctet == '\x20', ASCII); | |
| 39 const char kCommaOctet = ','; | |
| 40 COMPILE_ASSERT(kCommaOctet == '\x2c', ASCII); | |
| 41 | |
| 42 const char kCRLF[] = { kCrOctet, kLfOctet, 0 }; | |
| 43 const char kCRLFCRLF[] = { kCrOctet, kLfOctet, kCrOctet, kLfOctet, 0 }; | |
| 44 | |
| 45 const char kPlainHostFieldName[] = "Host"; | |
| 46 const char kPlainOriginFieldName[] = "Origin"; | |
| 47 const char kOriginFieldName[] = "Sec-WebSocket-Origin"; | |
| 48 const char kProtocolFieldName[] = "Sec-WebSocket-Protocol"; | |
| 49 const char kVersionFieldName[] = "Sec-WebSocket-Version"; | |
| 50 const char kLocationFieldName[] = "Sec-WebSocket-Location"; | |
| 51 const char kKey1FieldName[] = "Sec-WebSocket-Key1"; | |
| 52 const char kKey2FieldName[] = "Sec-WebSocket-Key2"; | |
| 53 | |
| 54 int CountSpaces(const std::string& s) { | |
| 55 return std::count(s.begin(), s.end(), kSpaceOctet); | |
| 56 } | |
| 57 | |
| 58 // Returns true on success. | |
| 59 bool FetchDecimalDigits(const std::string& s, uint32* result) { | |
| 60 *result = 0; | |
| 61 bool got_something = false; | |
| 62 for (size_t i = 0; i < s.size(); ++i) { | |
| 63 if (IsAsciiDigit(s[i])) { | |
| 64 got_something = true; | |
| 65 if (*result > std::numeric_limits<uint32>::max() / 10) | |
| 66 return false; | |
| 67 *result *= 10; | |
| 68 int digit = s[i] - '0'; | |
| 69 if (*result > std::numeric_limits<uint32>::max() - digit) | |
| 70 return false; | |
| 71 *result += digit; | |
| 72 } | |
| 73 } | |
| 74 return got_something; | |
| 75 } | |
| 76 | |
| 77 // Returns number of fetched subprotocols or negative error code. | |
| 78 int FetchSubprotocolList( | |
| 79 const std::string& s, std::vector<std::string>* subprotocol_list) { | |
| 80 subprotocol_list->clear(); | |
| 81 subprotocol_list->push_back(std::string()); | |
| 82 for (size_t i = 0; i < s.size(); ++i) { | |
| 83 if (s[i] > '\x20' && s[i] < '\x7f' && s[i] != kCommaOctet) | |
| 84 subprotocol_list->back() += s[i]; | |
| 85 else if (!subprotocol_list->back().empty()) { | |
| 86 if (subprotocol_list->size() < 16) | |
| 87 subprotocol_list->push_back(std::string()); | |
| 88 else | |
| 89 return net::ERR_LIMIT_VIOLATION; | |
| 90 } | |
| 91 } | |
| 92 if (subprotocol_list->back().empty()) | |
| 93 subprotocol_list->pop_back(); | |
| 94 if (subprotocol_list->empty()) | |
| 95 return net::ERR_WS_PROTOCOL_ERROR; | |
| 96 | |
| 97 { | |
| 98 std::vector<std::string> tmp(*subprotocol_list); | |
| 99 std::sort(tmp.begin(), tmp.end()); | |
| 100 if (tmp.end() != std::unique(tmp.begin(), tmp.end())) | |
| 101 return net::ERR_WS_PROTOCOL_ERROR; | |
| 102 } | |
| 103 return subprotocol_list->size(); | |
| 104 } | |
| 105 | |
| 106 class WebSocketServerSocketImpl : public net::WebSocketServerSocket { | |
| 107 public: | |
| 108 WebSocketServerSocketImpl(net::Socket* transport_socket, Delegate* delegate) | |
| 109 : phase_(PHASE_NYMPH), | |
| 110 frame_bytes_remaining_(0), | |
| 111 transport_socket_(transport_socket), | |
| 112 delegate_(delegate), | |
| 113 handshake_buf_(new net::IOBuffer(kHandshakeLimitBytes)), | |
| 114 fill_handshake_buf_(new net::DrainableIOBuffer( | |
| 115 handshake_buf_, kHandshakeLimitBytes)), | |
| 116 process_handshake_buf_(new net::DrainableIOBuffer( | |
| 117 handshake_buf_, kHandshakeLimitBytes)), | |
| 118 is_transport_read_pending_(false), | |
| 119 is_transport_write_pending_(false), | |
| 120 weak_factory_(this) { | |
| 121 DCHECK(transport_socket); | |
| 122 DCHECK(delegate); | |
| 123 } | |
| 124 | |
| 125 virtual ~WebSocketServerSocketImpl() { | |
| 126 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); | |
| 127 if (it != pending_reqs_.end() && | |
| 128 it->type == PendingReq::TYPE_READ && | |
| 129 it->io_buf != NULL && | |
| 130 it->io_buf->data() != NULL && | |
| 131 !it->callback.is_null()) { | |
| 132 it->callback.Run(0); // Report EOF. | |
| 133 } | |
| 134 } | |
| 135 | |
| 136 private: | |
| 137 enum Phase { | |
| 138 // Before Accept() is called. | |
| 139 PHASE_NYMPH, | |
| 140 | |
| 141 // After Accept() is called and until handshake success/fail. | |
| 142 PHASE_HANDSHAKE, | |
| 143 | |
| 144 // Processing data stream. | |
| 145 PHASE_FRAME_OUTSIDE, // Outside data frame. | |
| 146 PHASE_FRAME_INSIDE, // Inside text frame. | |
| 147 PHASE_FRAME_LENGTH, // Reading length of binary frame. | |
| 148 PHASE_FRAME_SKIP, // Skipping binary frame. | |
| 149 | |
| 150 // After termination. | |
| 151 PHASE_SHUT | |
| 152 }; | |
| 153 | |
| 154 struct PendingReq { | |
| 155 enum Type { | |
| 156 // Frame delimiters or handshake (as opposed to user data). | |
| 157 TYPE_METADATA = 1 << 0, | |
| 158 // Read request. | |
| 159 TYPE_READ = 1 << 1, | |
| 160 // Write request. | |
| 161 TYPE_WRITE = 1 << 2, | |
| 162 | |
| 163 TYPE_READ_METADATA = TYPE_READ | TYPE_METADATA, | |
| 164 TYPE_WRITE_METADATA = TYPE_WRITE | TYPE_METADATA | |
| 165 }; | |
| 166 | |
| 167 PendingReq(Type type, net::DrainableIOBuffer* io_buf, | |
| 168 const net::CompletionCallback& callback) | |
| 169 : type(type), | |
| 170 io_buf(io_buf), | |
| 171 callback(callback) { | |
| 172 switch (type) { | |
| 173 case PendingReq::TYPE_READ: | |
| 174 case PendingReq::TYPE_WRITE: | |
| 175 case PendingReq::TYPE_READ_METADATA: | |
| 176 case PendingReq::TYPE_WRITE_METADATA: { | |
| 177 DCHECK(io_buf); | |
| 178 break; | |
| 179 } | |
| 180 default: { | |
| 181 NOTREACHED(); | |
| 182 break; | |
| 183 } | |
| 184 } | |
| 185 } | |
| 186 | |
| 187 Type type; | |
| 188 scoped_refptr<net::DrainableIOBuffer> io_buf; | |
| 189 net::CompletionCallback callback; | |
| 190 }; | |
| 191 | |
| 192 // Socket implementation. | |
| 193 virtual int Read(net::IOBuffer* buf, int buf_len, | |
| 194 const net::CompletionCallback& callback) OVERRIDE { | |
| 195 if (buf_len == 0) | |
| 196 return 0; | |
| 197 if (buf == NULL || buf_len < 0) { | |
| 198 NOTREACHED(); | |
| 199 return net::ERR_INVALID_ARGUMENT; | |
| 200 } | |
| 201 while (int bytes_remaining = fill_handshake_buf_->BytesConsumed() - | |
| 202 process_handshake_buf_->BytesConsumed()) { | |
| 203 DCHECK(!is_transport_read_pending_); | |
| 204 DCHECK(GetPendingReq(PendingReq::TYPE_READ) == pending_reqs_.end()); | |
| 205 switch (phase_) { | |
| 206 case PHASE_FRAME_OUTSIDE: | |
| 207 case PHASE_FRAME_INSIDE: | |
| 208 case PHASE_FRAME_LENGTH: | |
| 209 case PHASE_FRAME_SKIP: { | |
| 210 int n = std::min(bytes_remaining, buf_len); | |
| 211 int rv = ProcessDataFrames( | |
| 212 process_handshake_buf_->data(), n, buf->data(), buf_len); | |
| 213 process_handshake_buf_->DidConsume(n); | |
| 214 if (rv == 0) { | |
| 215 // ProcessDataFrames may return zero for non-empty buffer if it | |
| 216 // contains only frame delimiters without real data. In this case: | |
| 217 // try again and do not just return zero (zero stands for EOF). | |
| 218 continue; | |
| 219 } | |
| 220 return rv; | |
| 221 } | |
| 222 case PHASE_SHUT: { | |
| 223 return 0; | |
| 224 } | |
| 225 case PHASE_NYMPH: | |
| 226 case PHASE_HANDSHAKE: | |
| 227 default: { | |
| 228 NOTREACHED(); | |
| 229 return net::ERR_UNEXPECTED; | |
| 230 } | |
| 231 } | |
| 232 } | |
| 233 switch (phase_) { | |
| 234 case PHASE_FRAME_OUTSIDE: | |
| 235 case PHASE_FRAME_INSIDE: | |
| 236 case PHASE_FRAME_LENGTH: | |
| 237 case PHASE_FRAME_SKIP: { | |
| 238 pending_reqs_.push_back(PendingReq( | |
| 239 PendingReq::TYPE_READ, | |
| 240 new net::DrainableIOBuffer(buf, buf_len), | |
| 241 callback)); | |
| 242 ConsiderTransportRead(); | |
| 243 break; | |
| 244 } | |
| 245 case PHASE_SHUT: { | |
| 246 return 0; | |
| 247 } | |
| 248 case PHASE_NYMPH: | |
| 249 case PHASE_HANDSHAKE: | |
| 250 default: { | |
| 251 NOTREACHED(); | |
| 252 return net::ERR_UNEXPECTED; | |
| 253 } | |
| 254 } | |
| 255 return net::ERR_IO_PENDING; | |
| 256 } | |
| 257 | |
| 258 virtual int Write(net::IOBuffer* buf, int buf_len, | |
| 259 const net::CompletionCallback& callback) OVERRIDE { | |
| 260 if (buf_len == 0) | |
| 261 return 0; | |
| 262 if (buf == NULL || buf_len < 0) { | |
| 263 NOTREACHED(); | |
| 264 return net::ERR_INVALID_ARGUMENT; | |
| 265 } | |
| 266 DCHECK_EQ(std::find(buf->data(), buf->data() + buf_len, '\xff'), | |
| 267 buf->data() + buf_len); | |
| 268 switch (phase_) { | |
| 269 case PHASE_FRAME_OUTSIDE: | |
| 270 case PHASE_FRAME_INSIDE: | |
| 271 case PHASE_FRAME_LENGTH: | |
| 272 case PHASE_FRAME_SKIP: { | |
| 273 break; | |
| 274 } | |
| 275 case PHASE_SHUT: { | |
| 276 return net::ERR_SOCKET_NOT_CONNECTED; | |
| 277 } | |
| 278 case PHASE_NYMPH: | |
| 279 case PHASE_HANDSHAKE: | |
| 280 default: { | |
| 281 NOTREACHED(); | |
| 282 return net::ERR_UNEXPECTED; | |
| 283 } | |
| 284 } | |
| 285 | |
| 286 net::IOBuffer* frame_start = new net::IOBuffer(1); | |
| 287 frame_start->data()[0] = '\x00'; | |
| 288 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, | |
| 289 new net::DrainableIOBuffer(frame_start, 1), | |
| 290 net::CompletionCallback())); | |
| 291 | |
| 292 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE, | |
| 293 new net::DrainableIOBuffer(buf, buf_len), | |
| 294 callback)); | |
| 295 | |
| 296 net::IOBuffer* frame_end = new net::IOBuffer(1); | |
| 297 frame_end->data()[0] = '\xff'; | |
| 298 pending_reqs_.push_back(PendingReq(PendingReq::TYPE_WRITE_METADATA, | |
| 299 new net::DrainableIOBuffer(frame_end, 1), | |
| 300 net::CompletionCallback())); | |
| 301 | |
| 302 ConsiderTransportWrite(); | |
| 303 return net::ERR_IO_PENDING; | |
| 304 } | |
| 305 | |
| 306 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
| 307 return transport_socket_->SetReceiveBufferSize(size); | |
| 308 } | |
| 309 | |
| 310 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
| 311 return transport_socket_->SetSendBufferSize(size); | |
| 312 } | |
| 313 | |
| 314 // WebSocketServerSocket implementation. | |
| 315 virtual int Accept(const net::CompletionCallback& callback) OVERRIDE { | |
| 316 if (phase_ != PHASE_NYMPH) | |
| 317 return net::ERR_UNEXPECTED; | |
| 318 phase_ = PHASE_HANDSHAKE; | |
| 319 pending_reqs_.push_front(PendingReq( | |
| 320 PendingReq::TYPE_READ_METADATA, fill_handshake_buf_.get(), callback)); | |
| 321 ConsiderTransportRead(); | |
| 322 return net::ERR_IO_PENDING; | |
| 323 } | |
| 324 | |
| 325 std::deque<PendingReq>::iterator GetPendingReq(PendingReq::Type type) { | |
| 326 for (std::deque<PendingReq>::iterator it = pending_reqs_.begin(); | |
| 327 it != pending_reqs_.end(); ++it) { | |
| 328 if (it->type & type) | |
| 329 return it; | |
| 330 } | |
| 331 return pending_reqs_.end(); | |
| 332 } | |
| 333 | |
| 334 void ConsiderTransportRead() { | |
| 335 if (pending_reqs_.empty()) | |
| 336 return; | |
| 337 if (is_transport_read_pending_) | |
| 338 return; | |
| 339 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); | |
| 340 if (it == pending_reqs_.end()) | |
| 341 return; | |
| 342 if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { | |
| 343 NOTREACHED(); | |
| 344 return; | |
| 345 } | |
| 346 is_transport_read_pending_ = true; | |
| 347 int rv = transport_socket_->Read( | |
| 348 it->io_buf.get(), it->io_buf->BytesRemaining(), | |
| 349 base::Bind(&WebSocketServerSocketImpl::OnRead, | |
| 350 base::Unretained(this))); | |
| 351 if (rv != net::ERR_IO_PENDING) { | |
| 352 // PostTask rather than direct call in order to: | |
| 353 // (1) guarantee calling callback after returning from Read(); | |
| 354 // (2) avoid potential stack overflow; | |
| 355 MessageLoop::current()->PostTask( | |
| 356 FROM_HERE, base::Bind(&WebSocketServerSocketImpl::OnRead, | |
| 357 weak_factory_.GetWeakPtr(), rv)); | |
| 358 } | |
| 359 } | |
| 360 | |
| 361 void ConsiderTransportWrite() { | |
| 362 if (is_transport_write_pending_) | |
| 363 return; | |
| 364 if (pending_reqs_.empty()) | |
| 365 return; | |
| 366 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); | |
| 367 if (it == pending_reqs_.end()) | |
| 368 return; | |
| 369 if (it->io_buf == NULL || it->io_buf->BytesRemaining() == 0) { | |
| 370 NOTREACHED(); | |
| 371 Shut(net::ERR_UNEXPECTED); | |
| 372 return; | |
| 373 } | |
| 374 is_transport_write_pending_ = true; | |
| 375 int rv = transport_socket_->Write( | |
| 376 it->io_buf.get(), it->io_buf->BytesRemaining(), | |
| 377 base::Bind(&WebSocketServerSocketImpl::OnWrite, | |
| 378 base::Unretained(this))); | |
| 379 if (rv != net::ERR_IO_PENDING) { | |
| 380 // PostTask rather than direct call in order to: | |
| 381 // (1) guarantee calling callback after returning from Read(); | |
| 382 // (2) avoid potential stack overflow; | |
| 383 MessageLoop::current()->PostTask( | |
| 384 FROM_HERE, base::Bind(&WebSocketServerSocketImpl::OnWrite, | |
| 385 weak_factory_.GetWeakPtr(), rv)); | |
| 386 } | |
| 387 } | |
| 388 | |
| 389 void Shut(int result) { | |
| 390 if (result > 0 || result == net::ERR_IO_PENDING) | |
| 391 result = net::ERR_UNEXPECTED; | |
| 392 if (result != 0) { | |
| 393 while (!pending_reqs_.empty()) { | |
| 394 PendingReq& req = pending_reqs_.front(); | |
| 395 if (!req.callback.is_null()) | |
| 396 req.callback.Run(result); | |
| 397 pending_reqs_.pop_front(); | |
| 398 } | |
| 399 transport_socket_.reset(); // terminate underlying connection. | |
| 400 } | |
| 401 phase_ = PHASE_SHUT; | |
| 402 } | |
| 403 | |
| 404 // Callbacks for transport socket. | |
| 405 void OnRead(int result) { | |
| 406 if (!is_transport_read_pending_) { | |
| 407 NOTREACHED(); | |
| 408 Shut(net::ERR_UNEXPECTED); | |
| 409 return; | |
| 410 } | |
| 411 is_transport_read_pending_ = false; | |
| 412 | |
| 413 if (result <= 0) { | |
| 414 Shut(result); | |
| 415 return; | |
| 416 } | |
| 417 | |
| 418 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_READ); | |
| 419 if (it == pending_reqs_.end() || | |
| 420 it->io_buf == NULL || | |
| 421 it->io_buf->data() == NULL) { | |
| 422 NOTREACHED(); | |
| 423 Shut(net::ERR_UNEXPECTED); | |
| 424 return; | |
| 425 } | |
| 426 if ((phase_ == PHASE_HANDSHAKE) == (it->type == PendingReq::TYPE_READ)) { | |
| 427 NOTREACHED(); | |
| 428 Shut(net::ERR_UNEXPECTED); | |
| 429 return; | |
| 430 } | |
| 431 | |
| 432 switch (phase_) { | |
| 433 case PHASE_HANDSHAKE: { | |
| 434 if (it != pending_reqs_.begin() || it->io_buf != fill_handshake_buf_) { | |
| 435 NOTREACHED(); | |
| 436 Shut(net::ERR_UNEXPECTED); | |
| 437 return; | |
| 438 } | |
| 439 fill_handshake_buf_->DidConsume(result); | |
| 440 // ProcessHandshake invalidates iterators for |pending_reqs_| | |
| 441 int rv = ProcessHandshake(); | |
| 442 if (rv > 0) { | |
| 443 process_handshake_buf_->DidConsume(rv); | |
| 444 phase_ = PHASE_FRAME_OUTSIDE; | |
| 445 net::CompletionCallback cb = pending_reqs_.front().callback; | |
| 446 pending_reqs_.pop_front(); | |
| 447 ConsiderTransportWrite(); // Schedule answer handshake. | |
| 448 if (!cb.is_null()) | |
| 449 cb.Run(0); | |
| 450 } else if (rv == net::ERR_IO_PENDING) { | |
| 451 if (fill_handshake_buf_->BytesRemaining() < 1) | |
| 452 Shut(net::ERR_LIMIT_VIOLATION); | |
| 453 } else if (rv < 0) { | |
| 454 Shut(rv); | |
| 455 } else { | |
| 456 Shut(net::ERR_UNEXPECTED); | |
| 457 } | |
| 458 break; | |
| 459 } | |
| 460 case PHASE_FRAME_OUTSIDE: | |
| 461 case PHASE_FRAME_INSIDE: | |
| 462 case PHASE_FRAME_LENGTH: | |
| 463 case PHASE_FRAME_SKIP: { | |
| 464 int rv = ProcessDataFrames( | |
| 465 it->io_buf->data(), result, | |
| 466 it->io_buf->data(), it->io_buf->BytesRemaining()); | |
| 467 if (rv < 0) { | |
| 468 Shut(rv); | |
| 469 return; | |
| 470 } | |
| 471 if (rv > 0 || phase_ == PHASE_SHUT) { | |
| 472 net::CompletionCallback cb = it->callback; | |
| 473 pending_reqs_.erase(it); | |
| 474 if (!cb.is_null()) | |
| 475 cb.Run(rv); | |
| 476 } | |
| 477 break; | |
| 478 } | |
| 479 case PHASE_NYMPH: | |
| 480 default: { | |
| 481 NOTREACHED(); | |
| 482 Shut(net::ERR_UNEXPECTED); | |
| 483 break; | |
| 484 } | |
| 485 } | |
| 486 ConsiderTransportRead(); | |
| 487 } | |
| 488 | |
| 489 void OnWrite(int result) { | |
| 490 if (!is_transport_write_pending_) { | |
| 491 NOTREACHED(); | |
| 492 Shut(net::ERR_UNEXPECTED); | |
| 493 return; | |
| 494 } | |
| 495 is_transport_write_pending_ = false; | |
| 496 | |
| 497 if (result < 0) { | |
| 498 Shut(result); | |
| 499 return; | |
| 500 } | |
| 501 | |
| 502 std::deque<PendingReq>::iterator it = GetPendingReq(PendingReq::TYPE_WRITE); | |
| 503 if (it == pending_reqs_.end() || | |
| 504 it->io_buf == NULL || | |
| 505 it->io_buf->data() == NULL) { | |
| 506 NOTREACHED(); | |
| 507 Shut(net::ERR_UNEXPECTED); | |
| 508 return; | |
| 509 } | |
| 510 DCHECK_LE(result, it->io_buf->BytesRemaining()); | |
| 511 it->io_buf->DidConsume(result); | |
| 512 if (it->io_buf->BytesRemaining() == 0) { | |
| 513 net::CompletionCallback cb = it->callback; | |
| 514 int bytes_written = it->io_buf->BytesConsumed(); | |
| 515 DCHECK_GT(bytes_written, 0); | |
| 516 pending_reqs_.erase(it); | |
| 517 if (!cb.is_null()) | |
| 518 cb.Run(bytes_written); | |
| 519 } | |
| 520 ConsiderTransportWrite(); | |
| 521 } | |
| 522 | |
| 523 // Returns (positive) number of consumed bytes on success. | |
| 524 // Returns ERR_IO_PENDING in case of incomplete input. | |
| 525 // Returns ERR_WS_PROTOCOL_ERROR or ERR_LIMIT_VIOLATION in case of failure to | |
| 526 // reasonably parse input. | |
| 527 int ProcessHandshake() { | |
| 528 static const char kGetPrefix[] = "GET "; | |
| 529 static const char kKeyValueDelimiter[] = ": "; | |
| 530 | |
| 531 class Fields { | |
| 532 public: | |
| 533 bool Has(const std::string& name) { | |
| 534 return map_.find(StringToLowerASCII(name)) != map_.end(); | |
| 535 } | |
| 536 | |
| 537 std::string Get(const std::string& name) { | |
| 538 return Has(name) ? map_[StringToLowerASCII(name)] : std::string(); | |
| 539 } | |
| 540 | |
| 541 void Set(const std::string& name, const std::string& value) { | |
| 542 map_[StringToLowerASCII(name)] = StringToLowerASCII(value); | |
| 543 } | |
| 544 | |
| 545 private: | |
| 546 std::map<std::string, std::string> map_; | |
| 547 } fields; | |
| 548 | |
| 549 char* buf = process_handshake_buf_->data(); | |
| 550 size_t buf_size = fill_handshake_buf_->BytesConsumed(); | |
| 551 | |
| 552 if (buf_size < 1) | |
| 553 return net::ERR_IO_PENDING; | |
| 554 if (!std::equal(buf, buf + std::min(buf_size, strlen(kGetPrefix)), | |
| 555 kGetPrefix)) { | |
| 556 // Data head does not match what is expected. | |
| 557 return net::ERR_WS_PROTOCOL_ERROR; | |
| 558 } | |
| 559 if (buf_size >= kHandshakeLimitBytes) | |
| 560 return net::ERR_LIMIT_VIOLATION; | |
| 561 char* buf_end = buf + buf_size; | |
| 562 | |
| 563 if (buf_size < strlen(kGetPrefix)) | |
| 564 return net::ERR_IO_PENDING; | |
| 565 char* resource_begin = buf + strlen(kGetPrefix); | |
| 566 char* resource_end = std::find(resource_begin, buf_end, kSpaceOctet); | |
| 567 if (resource_end == buf_end) | |
| 568 return net::ERR_IO_PENDING; | |
| 569 std::string resource(resource_begin, resource_end); | |
| 570 if (!IsStringUTF8(resource) || | |
| 571 resource.find_first_of(kCRLF) != std::string::npos) { | |
| 572 return net::ERR_WS_PROTOCOL_ERROR; | |
| 573 } | |
| 574 char* term_pos = std::search( | |
| 575 buf, buf_end, kCRLFCRLF, kCRLFCRLF + strlen(kCRLFCRLF)); | |
| 576 char key3[8]; // Notation (key3) matches websocket RFC. | |
| 577 size_t message_len = buf_end - term_pos; | |
| 578 if (message_len < sizeof(key3) + strlen(kCRLFCRLF)) | |
| 579 return net::ERR_IO_PENDING; | |
| 580 term_pos += strlen(kCRLFCRLF); | |
| 581 memcpy(key3, term_pos, sizeof(key3)); | |
| 582 term_pos += sizeof(key3); | |
| 583 // First line is "GET resource" line, so skip it. | |
| 584 char* pos = std::search(buf, term_pos, kCRLF, kCRLF + strlen(kCRLF)); | |
| 585 if (pos == term_pos) | |
| 586 return net::ERR_WS_PROTOCOL_ERROR; | |
| 587 for (;;) { | |
| 588 pos += strlen(kCRLF); | |
| 589 if (term_pos - pos < | |
| 590 static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { | |
| 591 return net::ERR_WS_PROTOCOL_ERROR; | |
| 592 } | |
| 593 if (term_pos - pos == | |
| 594 static_cast<ptrdiff_t>(sizeof(key3) + strlen(kCRLF))) { | |
| 595 break; | |
| 596 } | |
| 597 char* next_pos = std::search( | |
| 598 pos, term_pos, kKeyValueDelimiter, | |
| 599 kKeyValueDelimiter + strlen(kKeyValueDelimiter)); | |
| 600 if (next_pos == term_pos) | |
| 601 return net::ERR_WS_PROTOCOL_ERROR; | |
| 602 std::string key(pos, next_pos); | |
| 603 if (!IsStringASCII(key) || | |
| 604 key.find_first_of(kCRLF) != std::string::npos) { | |
| 605 return net::ERR_WS_PROTOCOL_ERROR; | |
| 606 } | |
| 607 pos = std::search(next_pos += strlen(kKeyValueDelimiter), term_pos, | |
| 608 kCRLF, kCRLF + strlen(kCRLF)); | |
| 609 if (pos == term_pos) | |
| 610 return net::ERR_WS_PROTOCOL_ERROR; | |
| 611 if (!key.empty()) { | |
| 612 std::string value(next_pos, pos); | |
| 613 if (!IsStringASCII(value) || | |
| 614 value.find_first_of(kCRLF) != std::string::npos) { | |
| 615 return net::ERR_WS_PROTOCOL_ERROR; | |
| 616 } | |
| 617 fields.Set(key, value); | |
| 618 } | |
| 619 } | |
| 620 | |
| 621 // Values of Upgrade and Connection fields are hardcoded in the protocol. | |
| 622 if (fields.Get("Upgrade") != "websocket" || | |
| 623 fields.Get("Connection") != "upgrade") { | |
| 624 return net::ERR_WS_PROTOCOL_ERROR; | |
| 625 } | |
| 626 if (fields.Has(kVersionFieldName)) { | |
| 627 NOTIMPLEMENTED(); // new protocol. | |
| 628 return net::ERR_NOT_IMPLEMENTED; | |
| 629 } | |
| 630 | |
| 631 if (!fields.Has(kPlainOriginFieldName)) | |
| 632 return net::ERR_CONNECTION_REFUSED; | |
| 633 // Normalize (e.g. w.r.t. leading slashes) origin. | |
| 634 GURL origin = GURL(fields.Get(kPlainOriginFieldName)).GetOrigin(); | |
| 635 if (!origin.is_valid()) | |
| 636 return net::ERR_WS_PROTOCOL_ERROR; | |
| 637 std::string normalized_origin = origin.spec(); | |
| 638 | |
| 639 if (!fields.Has(kPlainHostFieldName)) | |
| 640 return net::ERR_CONNECTION_REFUSED; | |
| 641 | |
| 642 std::vector<std::string> subprotocol_list; | |
| 643 if (fields.Has(kProtocolFieldName)) { | |
| 644 int rv = FetchSubprotocolList( | |
| 645 fields.Get(kProtocolFieldName), &subprotocol_list); | |
| 646 if (rv < 0) | |
| 647 return rv; | |
| 648 DCHECK(subprotocol_list.end() == std::find( | |
| 649 subprotocol_list.begin(), subprotocol_list.end(), "")); | |
| 650 } | |
| 651 | |
| 652 std::string location; | |
| 653 std::string subprotocol; | |
| 654 if (!delegate_->ValidateWebSocket(resource, | |
| 655 normalized_origin, | |
| 656 fields.Get(kPlainHostFieldName), | |
| 657 subprotocol_list, | |
| 658 &location, | |
| 659 &subprotocol)) { | |
| 660 return net::ERR_CONNECTION_REFUSED; | |
| 661 } | |
| 662 if (subprotocol_list.empty()) { | |
| 663 DCHECK(subprotocol.empty()); | |
| 664 } else { | |
| 665 if (!subprotocol.empty()) { | |
| 666 if (subprotocol_list.end() == std::find( | |
| 667 subprotocol_list.begin(), subprotocol_list.end(), subprotocol)) { | |
| 668 NOTREACHED() << "delegate must pick subprotocol from given list"; | |
| 669 return net::ERR_UNEXPECTED; | |
| 670 } | |
| 671 } | |
| 672 } | |
| 673 | |
| 674 uint32 key_number1 = 0; | |
| 675 uint32 key_number2 = 0; | |
| 676 if (!FetchDecimalDigits(fields.Get(kKey1FieldName), &key_number1) || | |
| 677 !FetchDecimalDigits(fields.Get(kKey2FieldName), &key_number2)) { | |
| 678 return net::ERR_WS_PROTOCOL_ERROR; | |
| 679 } | |
| 680 | |
| 681 // We limit incoming header size so following numbers shall not be too high. | |
| 682 int spaces1 = CountSpaces(fields.Get(kKey1FieldName)); | |
| 683 int spaces2 = CountSpaces(fields.Get(kKey2FieldName)); | |
| 684 if (spaces1 == 0 || | |
| 685 spaces2 == 0 || | |
| 686 key_number1 % spaces1 != 0 || | |
| 687 key_number2 % spaces2 != 0) { | |
| 688 return net::ERR_WS_PROTOCOL_ERROR; | |
| 689 } | |
| 690 | |
| 691 char challenge[4 + 4 + sizeof(key3)]; | |
| 692 int32 part1 = base::HostToNet32(key_number1 / spaces1); | |
| 693 int32 part2 = base::HostToNet32(key_number2 / spaces2); | |
| 694 memcpy(challenge, &part1, 4); | |
| 695 memcpy(challenge + 4, &part2, 4); | |
| 696 memcpy(challenge + 4 + 4, key3, sizeof(key3)); | |
| 697 base::MD5Digest challenge_response; | |
| 698 base::MD5Sum(challenge, sizeof(challenge), &challenge_response); | |
| 699 | |
| 700 // Concocting response handshake. | |
| 701 class Buffer { | |
| 702 public: | |
| 703 Buffer() | |
| 704 : io_buf_(new net::IOBuffer(kHandshakeLimitBytes)), | |
| 705 bytes_written_(0), | |
| 706 is_ok_(true) { | |
| 707 } | |
| 708 | |
| 709 bool Write(const void* p, int len) { | |
| 710 DCHECK(p); | |
| 711 DCHECK_GE(len, 0); | |
| 712 if (!is_ok_) | |
| 713 return false; | |
| 714 if (bytes_written_ + len > kHandshakeLimitBytes) { | |
| 715 NOTREACHED(); | |
| 716 is_ok_ = false; | |
| 717 return false; | |
| 718 } | |
| 719 memcpy(io_buf_->data() + bytes_written_, p, len); | |
| 720 bytes_written_ += len; | |
| 721 return true; | |
| 722 } | |
| 723 | |
| 724 bool WriteLine(const char* p) { | |
| 725 return Write(p, strlen(p)) && Write(kCRLF, strlen(kCRLF)); | |
| 726 } | |
| 727 | |
| 728 operator net::DrainableIOBuffer*() { | |
| 729 return new net::DrainableIOBuffer(io_buf_.get(), bytes_written_); | |
| 730 } | |
| 731 | |
| 732 bool is_ok() { return is_ok_; } | |
| 733 | |
| 734 private: | |
| 735 scoped_refptr<net::IOBuffer> io_buf_; | |
| 736 size_t bytes_written_; | |
| 737 bool is_ok_; | |
| 738 } buffer; | |
| 739 | |
| 740 buffer.WriteLine("HTTP/1.1 101 WebSocket Protocol Handshake"); | |
| 741 buffer.WriteLine("Upgrade: WebSocket"); | |
| 742 buffer.WriteLine("Connection: Upgrade"); | |
| 743 | |
| 744 { | |
| 745 // Take care of Location field. | |
| 746 char tmp[2048]; | |
| 747 int rv = base::snprintf(tmp, sizeof(tmp), | |
| 748 "%s: %s", | |
| 749 kLocationFieldName, | |
| 750 location.c_str()); | |
| 751 if (rv <= 0 || rv + 0u >= sizeof(tmp)) | |
| 752 return net::ERR_LIMIT_VIOLATION; | |
| 753 buffer.WriteLine(tmp); | |
| 754 } | |
| 755 { | |
| 756 // Take care of Origin field. | |
| 757 char tmp[2048]; | |
| 758 int rv = base::snprintf(tmp, sizeof(tmp), | |
| 759 "%s: %s", | |
| 760 kOriginFieldName, | |
| 761 fields.Get(kPlainOriginFieldName).c_str()); | |
| 762 if (rv <= 0 || rv + 0u >= sizeof(tmp)) | |
| 763 return net::ERR_LIMIT_VIOLATION; | |
| 764 buffer.WriteLine(tmp); | |
| 765 } | |
| 766 if (!subprotocol.empty()) { | |
| 767 char tmp[2048]; | |
| 768 int rv = base::snprintf(tmp, sizeof(tmp), | |
| 769 "%s: %s", | |
| 770 kProtocolFieldName, | |
| 771 subprotocol.c_str()); | |
| 772 if (rv <= 0 || rv + 0u >= sizeof(tmp)) | |
| 773 return net::ERR_LIMIT_VIOLATION; | |
| 774 buffer.WriteLine(tmp); | |
| 775 } | |
| 776 buffer.WriteLine(""); | |
| 777 buffer.Write(&challenge_response, sizeof(challenge_response)); | |
| 778 | |
| 779 if (!buffer.is_ok()) | |
| 780 return net::ERR_LIMIT_VIOLATION; | |
| 781 | |
| 782 pending_reqs_.push_back(PendingReq( | |
| 783 PendingReq::TYPE_WRITE_METADATA, buffer, net::CompletionCallback())); | |
| 784 DCHECK_GT(term_pos - buf, 0); | |
| 785 return term_pos - buf; | |
| 786 } | |
| 787 | |
| 788 // Removes frame delimiters and returns net number of data bytes (or error). | |
| 789 // |out| may be equal to |buf|, in that case it is in-place operation. | |
| 790 int ProcessDataFrames(char* buf, int buf_len, char* out, int out_len) { | |
| 791 if (out_len < buf_len) { | |
| 792 NOTREACHED(); | |
| 793 return net::ERR_UNEXPECTED; | |
| 794 } | |
| 795 int out_pos = 0; | |
| 796 for (char* p = buf; p < buf + buf_len; ++p) { | |
| 797 switch (phase_) { | |
| 798 case PHASE_FRAME_INSIDE: { | |
| 799 if (*p == '\x00') | |
| 800 return net::ERR_WS_PROTOCOL_ERROR; | |
| 801 if (*p == '\xff') | |
| 802 phase_ = PHASE_FRAME_OUTSIDE; | |
| 803 else | |
| 804 out[out_pos++] = *p; | |
| 805 break; | |
| 806 } | |
| 807 case PHASE_FRAME_OUTSIDE: { | |
| 808 if (*p == '\x00') { | |
| 809 phase_ = PHASE_FRAME_INSIDE; | |
| 810 } else if (*p == '\xff') { | |
| 811 phase_ = PHASE_FRAME_LENGTH; | |
| 812 frame_bytes_remaining_ = 0; | |
| 813 } | |
| 814 else { | |
| 815 return net::ERR_WS_PROTOCOL_ERROR; | |
| 816 } | |
| 817 break; | |
| 818 } | |
| 819 case PHASE_FRAME_LENGTH: { | |
| 820 static const int kValueBits = 7; | |
| 821 static const char kValueMask = (1 << kValueBits) - 1; | |
| 822 frame_bytes_remaining_ <<= kValueBits; | |
| 823 frame_bytes_remaining_ += (*p & kValueMask); | |
| 824 if (*p & ~kValueMask) { | |
| 825 // Check that next byte would not overflow. | |
| 826 if (frame_bytes_remaining_ > | |
| 827 (std::numeric_limits<int>::max() - ((1 << 7) - 1)) >> 7) { | |
| 828 return net::ERR_LIMIT_VIOLATION; | |
| 829 } | |
| 830 } else { | |
| 831 if (frame_bytes_remaining_ == 0) { | |
| 832 phase_ = PHASE_SHUT; | |
| 833 return out_pos; | |
| 834 } else { | |
| 835 phase_ = PHASE_FRAME_SKIP; | |
| 836 } | |
| 837 } | |
| 838 break; | |
| 839 } | |
| 840 case PHASE_FRAME_SKIP: { | |
| 841 DCHECK_GE(frame_bytes_remaining_, 1); | |
| 842 frame_bytes_remaining_ -= 1; | |
| 843 if (frame_bytes_remaining_ < 1) | |
| 844 phase_ = PHASE_FRAME_OUTSIDE; | |
| 845 break; | |
| 846 } | |
| 847 default: { | |
| 848 NOTREACHED(); | |
| 849 } | |
| 850 } | |
| 851 } | |
| 852 return out_pos; | |
| 853 } | |
| 854 | |
| 855 // State machinery. | |
| 856 Phase phase_; | |
| 857 | |
| 858 // Counts frame length for PHASE_FRAME_LENGTH and PHASE_FRAME_SKIP. | |
| 859 int frame_bytes_remaining_; | |
| 860 | |
| 861 // Underlying socket. | |
| 862 scoped_ptr<net::Socket> transport_socket_; | |
| 863 | |
| 864 // Validation is performed via delegate. | |
| 865 Delegate* delegate_; | |
| 866 | |
| 867 // IOBuffer used to communicate with transport at initial stage. | |
| 868 scoped_refptr<net::IOBuffer> handshake_buf_; | |
| 869 scoped_refptr<net::DrainableIOBuffer> fill_handshake_buf_; | |
| 870 scoped_refptr<net::DrainableIOBuffer> process_handshake_buf_; | |
| 871 | |
| 872 // Pending IO requests we need to complete. | |
| 873 std::deque<PendingReq> pending_reqs_; | |
| 874 | |
| 875 // Whether transport requests are pending. | |
| 876 bool is_transport_read_pending_; | |
| 877 bool is_transport_write_pending_; | |
| 878 | |
| 879 base::WeakPtrFactory<WebSocketServerSocketImpl> weak_factory_; | |
| 880 | |
| 881 DISALLOW_COPY_AND_ASSIGN(WebSocketServerSocketImpl); | |
| 882 }; | |
| 883 | |
| 884 } // namespace | |
| 885 | |
| 886 namespace net { | |
| 887 | |
| 888 WebSocketServerSocket* CreateWebSocketServerSocket( | |
| 889 Socket* transport_socket, WebSocketServerSocket::Delegate* delegate) { | |
| 890 return new WebSocketServerSocketImpl(transport_socket, delegate); | |
| 891 } | |
| 892 | |
| 893 WebSocketServerSocket::~WebSocketServerSocket() { | |
| 894 } | |
| 895 | |
| 896 } // namespace net; | |
| OLD | NEW |