| OLD | NEW |
| (Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "remoting/host/websocket_connection.h" |
| 6 |
| 7 #include <map> |
| 8 #include <vector> |
| 9 |
| 10 #include "base/base64.h" |
| 11 #include "base/compiler_specific.h" |
| 12 #include "base/location.h" |
| 13 #include "base/sha1.h" |
| 14 #include "base/single_thread_task_runner.h" |
| 15 #include "base/string_split.h" |
| 16 #include "base/sys_byteorder.h" |
| 17 #include "base/thread_task_runner_handle.h" |
| 18 #include "net/base/net_errors.h" |
| 19 #include "net/socket/stream_socket.h" |
| 20 |
| 21 namespace remoting { |
| 22 |
| 23 namespace { |
| 24 |
| 25 const int kReadBufferSize = 1024; |
| 26 const char kLineSeparator[] = "\r\n"; |
| 27 const char kHeaderEndMarker[] = "\r\n\r\n"; |
| 28 const char kHeaderKeyValueSeparator[] = ": "; |
| 29 const int kMaskLength = 4; |
| 30 |
| 31 // Maximum frame length that can be encoded without extended length filed. |
| 32 const uint32 kMaxNotExtendedLength = 125; |
| 33 |
| 34 // Maximum frame length that can be encoded in 16 bits. |
| 35 const uint32 kMax16BitLength = 65535; |
| 36 |
| 37 // Special values of the length field used to extend frame length to 16 or 64 |
| 38 // bits. |
| 39 const uint32 kLength16BitMarker = 126; |
| 40 const uint32 kLength64BitMarker = 127; |
| 41 |
| 42 // Fixed value specified in RFC6455. It's used to compute accept token sent to |
| 43 // the client in Sec-WebSocket-Accept key. |
| 44 const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
| 45 |
| 46 } // namespace |
| 47 |
| 48 WebSocketConnection::WebSocketConnection() |
| 49 : delegate_(NULL), |
| 50 maximum_message_size_(0), |
| 51 state_(READING_HEADERS), |
| 52 receiving_message_(false), |
| 53 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
| 54 } |
| 55 |
| 56 WebSocketConnection::~WebSocketConnection() { |
| 57 Close(); |
| 58 } |
| 59 |
| 60 void WebSocketConnection::Start( |
| 61 scoped_ptr<net::StreamSocket> socket, |
| 62 ConnectedCallback connected_callback) { |
| 63 socket_ = socket.Pass(); |
| 64 connected_callback_ = connected_callback; |
| 65 reader_.Init(socket_.get(), base::Bind( |
| 66 &WebSocketConnection::OnSocketReadResult, base::Unretained(this))); |
| 67 writer_.Init(socket_.get(), base::Bind( |
| 68 &WebSocketConnection::OnSocketWriteError, base::Unretained(this))); |
| 69 } |
| 70 |
| 71 void WebSocketConnection::Accept(Delegate* delegate) { |
| 72 DCHECK_EQ(state_, HEADERS_READ); |
| 73 |
| 74 state_ = ACCEPTED; |
| 75 delegate_ = delegate; |
| 76 |
| 77 std::string accept_key = |
| 78 base::SHA1HashString(websocket_key_ + kWebsocketKeySalt); |
| 79 std::string accept_key_base64; |
| 80 bool result = base::Base64Encode(accept_key, &accept_key_base64); |
| 81 DCHECK(result); |
| 82 |
| 83 std::string handshake; |
| 84 handshake += "HTTP/1.1 101 Switching Protocol"; |
| 85 handshake += kLineSeparator; |
| 86 handshake += "Upgrade: websocket"; |
| 87 handshake += kLineSeparator; |
| 88 handshake += "Connection: Upgrade"; |
| 89 handshake += kLineSeparator; |
| 90 handshake += "Sec-WebSocket-Accept: " + accept_key_base64; |
| 91 handshake += kHeaderEndMarker; |
| 92 |
| 93 scoped_refptr<net::IOBufferWithSize> buffer = |
| 94 new net::IOBufferWithSize(handshake.size()); |
| 95 memcpy(buffer->data(), handshake.data(), handshake.size()); |
| 96 writer_.Write(buffer, base::Closure()); |
| 97 } |
| 98 |
| 99 void WebSocketConnection::Reject() { |
| 100 DCHECK_EQ(state_, HEADERS_READ); |
| 101 |
| 102 state_ = CLOSED; |
| 103 std::string response = "HTTP/1.1 401 Unauthorized"; |
| 104 response += kHeaderEndMarker; |
| 105 scoped_refptr<net::IOBufferWithSize> buffer = |
| 106 new net::IOBufferWithSize(response.size()); |
| 107 memcpy(buffer->data(), response.data(), response.size()); |
| 108 writer_.Write(buffer, base::Closure()); |
| 109 } |
| 110 |
| 111 void WebSocketConnection::set_maximum_message_size(uint64 size) { |
| 112 maximum_message_size_ = size; |
| 113 } |
| 114 |
| 115 void WebSocketConnection::SendText(const std::string& text) { |
| 116 SendFragment(OPCODE_TEXT_FRAME, text); |
| 117 } |
| 118 |
| 119 void WebSocketConnection::Close() { |
| 120 switch (state_) { |
| 121 case READING_HEADERS: |
| 122 break; |
| 123 |
| 124 case HEADERS_READ: |
| 125 Reject(); |
| 126 break; |
| 127 |
| 128 case ACCEPTED: |
| 129 SendFragment(OPCODE_CLOSE, std::string()); |
| 130 break; |
| 131 |
| 132 case CLOSED: |
| 133 break; |
| 134 } |
| 135 state_ = CLOSED; |
| 136 } |
| 137 |
| 138 void WebSocketConnection::CloseOnError() { |
| 139 State old_state_ = state_; |
| 140 Close(); |
| 141 if (old_state_ == ACCEPTED) { |
| 142 DCHECK(delegate_); |
| 143 delegate_->OnWebSocketClosed(); |
| 144 } |
| 145 } |
| 146 |
| 147 void WebSocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data, |
| 148 int result) { |
| 149 if (result <= 0) { |
| 150 if (result != 0) { |
| 151 LOG(ERROR) << "Error when trying to read from WebSocket connection: " |
| 152 << result; |
| 153 } |
| 154 CloseOnError(); |
| 155 return; |
| 156 } |
| 157 |
| 158 switch (state_) { |
| 159 case READING_HEADERS: { |
| 160 headers_.append(data->data(), data->data() + result); |
| 161 size_t header_end_pos = headers_.find(kHeaderEndMarker); |
| 162 if (header_end_pos != std::string::npos) { |
| 163 bool result; |
| 164 if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) { |
| 165 LOG(ERROR) << "WebSocket client tried writing data before handshake " |
| 166 "has finished."; |
| 167 DCHECK(!connected_callback_.is_null()); |
| 168 state_ = CLOSED; |
| 169 result = false; |
| 170 } else { |
| 171 // Crop newline symbols from the end. |
| 172 headers_.resize(header_end_pos); |
| 173 |
| 174 result = ParseHeaders(); |
| 175 if (!result) { |
| 176 state_ = CLOSED; |
| 177 } else { |
| 178 state_ = HEADERS_READ; |
| 179 } |
| 180 } |
| 181 ConnectedCallback cb(connected_callback_); |
| 182 connected_callback_.Reset(); |
| 183 cb.Run(result); |
| 184 } |
| 185 break; |
| 186 } |
| 187 |
| 188 case HEADERS_READ: |
| 189 LOG(ERROR) << "Received unexpected data before websocket " |
| 190 "connection is accepted."; |
| 191 CloseOnError(); |
| 192 break; |
| 193 |
| 194 case ACCEPTED: |
| 195 DCHECK(delegate_); |
| 196 received_data_.append(data->data(), data->data() + result); |
| 197 ProcessData(); |
| 198 |
| 199 case CLOSED: |
| 200 // Ignore anything received after connection is rejected or closed. |
| 201 break; |
| 202 } |
| 203 } |
| 204 |
| 205 void WebSocketConnection::ProcessData() { |
| 206 DCHECK_EQ(state_, ACCEPTED); |
| 207 |
| 208 if (received_data_.size() < 2) { |
| 209 // Header hasn't been received yet. |
| 210 return; |
| 211 } |
| 212 |
| 213 bool fin_bit = (received_data_.data()[0] & 0x80) != 0; |
| 214 |
| 215 // 3 bits after FIN are reserved for WebSocket extensions. RFC6455 requires |
| 216 // that endpoint fails connection if any of these bits is set while no |
| 217 // extension that uses these bits was negotiated. |
| 218 int rsv_bits = received_data_.data()[0] & 0x70; |
| 219 if (rsv_bits != 0) { |
| 220 LOG(ERROR) << "Incoming has unsupported RSV bits set."; |
| 221 CloseOnError(); |
| 222 return; |
| 223 } |
| 224 |
| 225 int opcode = received_data_.data()[0] & 0x0f; |
| 226 |
| 227 int mask_bit = received_data_.data()[1] & 0x80; |
| 228 if (mask_bit == 0) { |
| 229 LOG(ERROR) << "Incoming frame is not masked."; |
| 230 CloseOnError(); |
| 231 return; |
| 232 } |
| 233 |
| 234 // Length field has variable size in each WebSocket frame - it's either 1, 3 |
| 235 // or 9 bytes with the first bit always reserved for MASK flag. The first byte |
| 236 // is set to 126 or 127 for 16 and 64 bit extensions respectively. Code below |
| 237 // extracts |length| value and sets |length_field_size| accordingly. |
| 238 int length_field_size = 1; |
| 239 uint64 length = received_data_.data()[1] & 0x7F; |
| 240 if (length == kLength16BitMarker) { |
| 241 if (received_data_.size() < 4) { |
| 242 // Haven't received the whole frame header yet. |
| 243 return; |
| 244 } |
| 245 length_field_size = 3; |
| 246 length = base::NetToHost16( |
| 247 *reinterpret_cast<const uint16*>(received_data_.data() + 2)); |
| 248 } else if (length == kLength64BitMarker) { |
| 249 if (received_data_.size() < 10) { |
| 250 // Haven't received the whole frame header yet. |
| 251 return; |
| 252 } |
| 253 length_field_size = 9; |
| 254 length = base::NetToHost64( |
| 255 *reinterpret_cast<const uint64*>(received_data_.data() + 2)); |
| 256 } |
| 257 |
| 258 int payload_position = 1 + length_field_size + kMaskLength; |
| 259 |
| 260 // Check that the size of the frame is below the limit. It needs to be done |
| 261 // before we read the payload to avoid allocating buffer for a bogus frame |
| 262 // that is too big. |
| 263 if (maximum_message_size_ > 0 && length > maximum_message_size_) { |
| 264 LOG(ERROR) << "Client tried to send a fragment that is bigger than " |
| 265 "the maximum message size of " << maximum_message_size_; |
| 266 CloseOnError(); |
| 267 return; |
| 268 } |
| 269 |
| 270 if (received_data_.size() < payload_position + length) { |
| 271 // Haven't received the whole frame yet. |
| 272 return; |
| 273 } |
| 274 |
| 275 // Unmask the payload. |
| 276 if (mask_bit) { |
| 277 const char* mask = received_data_.data() + length_field_size + 1; |
| 278 UnmaskPayload( |
| 279 mask, |
| 280 const_cast<char*>(received_data_.data()) + payload_position, length); |
| 281 } |
| 282 |
| 283 const char* payload = received_data_.data() + payload_position; |
| 284 |
| 285 if (opcode < 0x8) { |
| 286 if (maximum_message_size_ > 0 && |
| 287 current_message_.size() + length > maximum_message_size_) { |
| 288 LOG(ERROR) << "Client tried to send a message that is bigger than " |
| 289 "the maximum message size of " << maximum_message_size_; |
| 290 CloseOnError(); |
| 291 return; |
| 292 } |
| 293 |
| 294 // Non-control message. |
| 295 current_message_.append(payload, payload + length); |
| 296 } else { |
| 297 // Control message. |
| 298 if (!fin_bit) { |
| 299 LOG(ERROR) << "Received fragmented control message."; |
| 300 CloseOnError(); |
| 301 return; |
| 302 } |
| 303 if (length > kMaxNotExtendedLength) { |
| 304 LOG(ERROR) << "Received control message that is larger than 125 bytes."; |
| 305 CloseOnError(); |
| 306 return; |
| 307 } |
| 308 } |
| 309 |
| 310 switch (opcode) { |
| 311 case OPCODE_CONTINUATION: |
| 312 if (!receiving_message_) { |
| 313 LOG(ERROR) << "Received unexpected continuation frame."; |
| 314 CloseOnError(); |
| 315 return; |
| 316 } |
| 317 break; |
| 318 |
| 319 case OPCODE_TEXT_FRAME: |
| 320 case OPCODE_BINARY_FRAME: |
| 321 if (receiving_message_) { |
| 322 LOG(ERROR) << "Received unexpected new start frame in a middle of " |
| 323 "a message."; |
| 324 CloseOnError(); |
| 325 return; |
| 326 } |
| 327 break; |
| 328 |
| 329 case OPCODE_CLOSE: |
| 330 Close(); |
| 331 delegate_->OnWebSocketClosed(); |
| 332 return; |
| 333 |
| 334 case OPCODE_PING: |
| 335 SendFragment(OPCODE_PONG, std::string(payload, payload + length)); |
| 336 break; |
| 337 |
| 338 case OPCODE_PONG: |
| 339 break; |
| 340 |
| 341 default: |
| 342 LOG(ERROR) << "Received invalid opcode: " << opcode; |
| 343 CloseOnError(); |
| 344 return; |
| 345 } |
| 346 |
| 347 // Remove the frame from |received_data_|. |
| 348 received_data_.erase(0, payload_position + length); |
| 349 |
| 350 // Post a task to process the data left in the buffer, if any. |
| 351 if (!received_data_.empty()) { |
| 352 base::ThreadTaskRunnerHandle::Get()->PostTask( |
| 353 FROM_HERE, base::Bind(&WebSocketConnection::ProcessData, |
| 354 weak_factory_.GetWeakPtr())); |
| 355 } |
| 356 |
| 357 // Handle payload in non-control messages. Delegate can be called only at the |
| 358 // end of this function |
| 359 if (opcode < 0x8) { |
| 360 if (!fin_bit) { |
| 361 receiving_message_ = true; |
| 362 } else { |
| 363 receiving_message_ = false; |
| 364 std::string msg; |
| 365 msg.swap(current_message_); |
| 366 delegate_->OnWebSocketMessage(msg); |
| 367 } |
| 368 } |
| 369 } |
| 370 |
| 371 void WebSocketConnection::SendFragment(WebsocketOpcode opcode, |
| 372 const std::string& payload) { |
| 373 DCHECK_EQ(state_, ACCEPTED); |
| 374 |
| 375 int length_field_size = 1; |
| 376 if (payload.size() > kMax16BitLength) { |
| 377 length_field_size = 9; |
| 378 } else if (payload.size() > kMaxNotExtendedLength) { |
| 379 length_field_size = 3; |
| 380 } |
| 381 |
| 382 scoped_refptr<net::IOBufferWithSize> buffer = |
| 383 new net::IOBufferWithSize(1 + length_field_size + payload.size()); |
| 384 |
| 385 // Always set FIN flag because we never fragment outgoing messages. |
| 386 buffer->data()[0] = opcode | 0x80; |
| 387 |
| 388 if (payload.size() > kMax16BitLength) { |
| 389 uint64 size = base::HostToNet64(payload.size()); |
| 390 buffer->data()[1] = kLength64BitMarker; |
| 391 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size)); |
| 392 } else if (payload.size() > kMaxNotExtendedLength) { |
| 393 uint16 size = base::HostToNet16(payload.size()); |
| 394 buffer->data()[1] = kLength16BitMarker; |
| 395 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size)); |
| 396 } else { |
| 397 buffer->data()[1] = payload.size(); |
| 398 } |
| 399 memcpy(buffer->data() + 1 + length_field_size, |
| 400 payload.data(), payload.size()); |
| 401 |
| 402 writer_.Write(buffer, base::Closure()); |
| 403 } |
| 404 |
| 405 bool WebSocketConnection::ParseHeaders() { |
| 406 std::vector<std::string> lines; |
| 407 base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines); |
| 408 |
| 409 // Parse request line. |
| 410 std::vector<std::string> request_parts; |
| 411 base::SplitString(lines[0], ' ', &request_parts); |
| 412 if (request_parts.size() != 3 || |
| 413 request_parts[0] != "GET" || |
| 414 request_parts[2] != "HTTP/1.1") { |
| 415 LOG(ERROR) << "Invalid Request-Line: " << headers_[0]; |
| 416 return false; |
| 417 } |
| 418 request_path_ = request_parts[1]; |
| 419 |
| 420 std::map<std::string, std::string> headers; |
| 421 |
| 422 for (size_t i = 1; i < lines.size(); ++i) { |
| 423 std::string separator(kHeaderKeyValueSeparator); |
| 424 size_t pos = lines[i].find(separator); |
| 425 if (pos == std::string::npos || pos == 0) { |
| 426 LOG(ERROR) << "Invalid header line: " << lines[i]; |
| 427 return false; |
| 428 } |
| 429 std::string key = lines[i].substr(0, pos); |
| 430 if (headers.find(key) != headers.end()) { |
| 431 LOG(ERROR) << "Duplicate header value: " << key; |
| 432 return false; |
| 433 } |
| 434 headers[key] = lines[i].substr(pos + separator.size()); |
| 435 } |
| 436 |
| 437 std::map<std::string, std::string>::iterator it = headers.find("Connection"); |
| 438 if (it == headers.end() || it->second != "Upgrade") { |
| 439 LOG(ERROR) << "Connection header is missing or invalid."; |
| 440 return false; |
| 441 } |
| 442 |
| 443 it = headers.find("Upgrade"); |
| 444 if (it == headers.end() || it->second != "websocket") { |
| 445 LOG(ERROR) << "Upgrade header is missing or invalid."; |
| 446 return false; |
| 447 } |
| 448 |
| 449 it = headers.find("Host"); |
| 450 if (it == headers.end()) { |
| 451 LOG(ERROR) << "Host header is missing."; |
| 452 return false; |
| 453 } |
| 454 request_host_ = it->second; |
| 455 |
| 456 it = headers.find("Sec-WebSocket-Version"); |
| 457 if (it == headers.end()) { |
| 458 LOG(ERROR) << "Sec-WebSocket-Version header is missing."; |
| 459 return false; |
| 460 } |
| 461 if (it->second != "13") { |
| 462 LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second; |
| 463 return false; |
| 464 } |
| 465 |
| 466 it = headers.find("Origin"); |
| 467 if (it == headers.end()) { |
| 468 LOG(ERROR) << "Origin header is missing."; |
| 469 return false; |
| 470 } |
| 471 origin_ = it->second; |
| 472 |
| 473 it = headers.find("Sec-WebSocket-Key"); |
| 474 if (it == headers.end()) { |
| 475 LOG(ERROR) << "Sec-WebSocket-Key header is missing."; |
| 476 return false; |
| 477 } |
| 478 websocket_key_ = it->second; |
| 479 |
| 480 return true; |
| 481 } |
| 482 |
| 483 void WebSocketConnection::UnmaskPayload(const char* mask, |
| 484 char* payload, int payload_length) { |
| 485 for (int i = 0; i < payload_length; ++i) { |
| 486 payload[i] = payload[i] ^ mask[i % kMaskLength]; |
| 487 } |
| 488 } |
| 489 |
| 490 void WebSocketConnection::OnSocketWriteError(int error) { |
| 491 LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error; |
| 492 CloseOnError(); |
| 493 } |
| 494 |
| 495 } // namespace remoting |
| OLD | NEW |