Chromium Code Reviews| OLD | NEW |
|---|---|
| (Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "remoting/host/websocket_connection.h" | |
| 6 | |
| 7 #include <map> | |
| 8 #include <vector> | |
| 9 | |
| 10 #include "base/base64.h" | |
| 11 #include "base/compiler_specific.h" | |
| 12 #include "base/location.h" | |
| 13 #include "base/sha1.h" | |
| 14 #include "base/single_thread_task_runner.h" | |
| 15 #include "base/string_split.h" | |
| 16 #include "base/sys_byteorder.h" | |
| 17 #include "base/thread_task_runner_handle.h" | |
| 18 #include "net/base/net_errors.h" | |
| 19 #include "net/socket/stream_socket.h" | |
| 20 | |
| 21 namespace remoting { | |
| 22 | |
| 23 namespace { | |
| 24 | |
| 25 const int kReadBufferSize = 1024; | |
| 26 const char kLineSeparator[] = "\r\n"; | |
| 27 const char kHeaderEndMarker[] = "\r\n\r\n"; | |
| 28 const char kHeaderKeyValueSeparator[] = ": "; | |
| 29 const int kMaskLength = 4; | |
| 30 | |
| 31 // Fixed value specified in RFC6455. It's used to compute accept token sent to | |
| 32 // the client in Sec-WebSocket-Accept key. | |
| 33 const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | |
| 34 | |
| 35 } // namespace | |
| 36 | |
| 37 WebsocketConnection::WebsocketConnection() | |
| 38 : delegate_(NULL), | |
| 39 maximum_message_size_(0), | |
| 40 state_(READING_HEADERS), | |
| 41 receiving_message_(false), | |
| 42 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { | |
| 43 } | |
| 44 | |
| 45 WebsocketConnection::~WebsocketConnection() { | |
| 46 Close(); | |
| 47 } | |
| 48 | |
| 49 void WebsocketConnection::Start( | |
| 50 scoped_ptr<net::StreamSocket> socket, | |
| 51 ConnectedCallback connected_callback) { | |
| 52 socket_ = socket.Pass(); | |
| 53 connected_callback_ = connected_callback; | |
| 54 reader_.Init(socket_.get(), base::Bind( | |
| 55 &WebsocketConnection::OnSocketReadResult, base::Unretained(this))); | |
| 56 writer_.Init(socket_.get(), base::Bind( | |
| 57 &WebsocketConnection::OnSocketWriteError, base::Unretained(this))); | |
| 58 } | |
| 59 | |
| 60 void WebsocketConnection::Accept(Delegate* delegate) { | |
| 61 DCHECK_EQ(state_, HEADERS_READ); | |
| 62 | |
| 63 state_ = ACCEPTED; | |
| 64 delegate_ = delegate; | |
| 65 | |
| 66 std::string accept_key = | |
| 67 base::SHA1HashString(websocket_key_ + kWebsocketKeySalt); | |
| 68 std::string accept_key_base64; | |
| 69 bool result = base::Base64Encode(accept_key, &accept_key_base64); | |
| 70 DCHECK(result); | |
| 71 | |
| 72 std::string handshake; | |
| 73 handshake += "HTTP/1.1 101 Switching Protocol"; | |
| 74 handshake += kLineSeparator; | |
| 75 handshake += "Upgrade: websocket"; | |
| 76 handshake += kLineSeparator; | |
| 77 handshake += "Connection: Upgrade"; | |
| 78 handshake += kLineSeparator; | |
| 79 handshake += "Sec-WebSocket-Accept: " + accept_key_base64; | |
| 80 handshake += kHeaderEndMarker; | |
| 81 | |
| 82 scoped_refptr<net::IOBufferWithSize> buffer = | |
| 83 new net::IOBufferWithSize(handshake.size()); | |
| 84 memcpy(buffer->data(), handshake.data(), handshake.size()); | |
| 85 writer_.Write(buffer, base::Closure()); | |
| 86 } | |
| 87 | |
| 88 void WebsocketConnection::Reject() { | |
| 89 DCHECK_EQ(state_, HEADERS_READ); | |
| 90 | |
| 91 state_ = CLOSED; | |
| 92 std::string response = "HTTP/1.1 401 Unauthorized"; | |
| 93 response += kHeaderEndMarker; | |
| 94 scoped_refptr<net::IOBufferWithSize> buffer = | |
| 95 new net::IOBufferWithSize(response.size()); | |
| 96 memcpy(buffer->data(), response.data(), response.size()); | |
| 97 writer_.Write(buffer, base::Closure()); | |
| 98 } | |
| 99 | |
| 100 void WebsocketConnection::set_maximum_message_size(uint64 size) { | |
| 101 maximum_message_size_ = size; | |
| 102 } | |
| 103 | |
| 104 void WebsocketConnection::SendText(const std::string& text) { | |
| 105 SendFragment(OPCODE_TEXT_FRAME, text.data(), text.size()); | |
| 106 } | |
| 107 | |
| 108 void WebsocketConnection::Close() { | |
| 109 switch (state_) { | |
| 110 case READING_HEADERS: | |
| 111 break; | |
| 112 | |
| 113 case HEADERS_READ: | |
| 114 Reject(); | |
| 115 break; | |
| 116 | |
| 117 case ACCEPTED: | |
| 118 SendFragment(OPCODE_CLOSE, NULL, 0); | |
| 119 break; | |
| 120 | |
| 121 case CLOSED: | |
| 122 break; | |
| 123 } | |
| 124 state_ = CLOSED; | |
| 125 } | |
| 126 | |
| 127 void WebsocketConnection::CloseOnError() { | |
| 128 State old_state_ = state_; | |
| 129 Close(); | |
| 130 if (old_state_ == ACCEPTED) { | |
| 131 DCHECK(delegate_); | |
| 132 delegate_->OnWebsocketClosed(); | |
| 133 } | |
| 134 } | |
| 135 | |
| 136 void WebsocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data, | |
| 137 int result) { | |
| 138 if (result <= 0) { | |
| 139 if (result != 0) { | |
| 140 LOG(ERROR) << "Error when trying to read from WebSocket connection: " | |
| 141 << result; | |
| 142 } | |
| 143 CloseOnError(); | |
| 144 return; | |
| 145 } | |
| 146 | |
| 147 switch (state_) { | |
| 148 case READING_HEADERS: { | |
| 149 headers_.append(data->data(), data->data() + result); | |
| 150 size_t header_end_pos = headers_.find(kHeaderEndMarker); | |
| 151 if (header_end_pos != std::string::npos) { | |
| 152 bool result; | |
| 153 if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) { | |
| 154 LOG(ERROR) << "WebSocket client tried writing data before handshake " | |
| 155 "has finished."; | |
| 156 DCHECK(!connected_callback_.is_null()); | |
| 157 state_ = CLOSED; | |
| 158 result = false; | |
| 159 } else { | |
| 160 // Crop newline symbols from the end. | |
| 161 headers_.resize(header_end_pos); | |
| 162 | |
| 163 result = ParseHeaders(); | |
| 164 if (!result) { | |
| 165 state_ = CLOSED; | |
| 166 } else { | |
| 167 state_ = HEADERS_READ; | |
| 168 } | |
| 169 } | |
| 170 ConnectedCallback cb(connected_callback_); | |
| 171 connected_callback_.Reset(); | |
| 172 cb.Run(result); | |
| 173 } | |
| 174 break; | |
| 175 } | |
| 176 | |
| 177 case HEADERS_READ: | |
| 178 LOG(ERROR) << "Received unexpected data before websocket " | |
| 179 "connection is accepted."; | |
| 180 CloseOnError(); | |
| 181 break; | |
| 182 | |
| 183 case ACCEPTED: | |
| 184 DCHECK(delegate_); | |
| 185 received_data_.append(data->data(), data->data() + result); | |
| 186 ProcessData(); | |
| 187 | |
| 188 case CLOSED: | |
| 189 // Ignore anything received after connection is rejected or closed. | |
| 190 break; | |
| 191 } | |
| 192 } | |
| 193 | |
| 194 void WebsocketConnection::ProcessData() { | |
| 195 DCHECK_EQ(state_, ACCEPTED); | |
| 196 | |
| 197 if (received_data_.size() < 2) { | |
| 198 // Header hasn't been received yet. | |
| 199 return; | |
| 200 } | |
| 201 | |
| 202 bool fin_bit = (received_data_.data()[0] & 0x80) != 0; | |
| 203 | |
| 204 int rsv_bits = received_data_.data()[0] & 0x70; | |
|
Wez
2012/11/20 05:44:09
nit: Add a comment summarizing what these bits are
Sergey Ulanov
2012/11/21 01:40:24
Done.
| |
| 205 if (rsv_bits != 0) { | |
| 206 LOG(ERROR) << "Incoming has unsupported RSV bits set."; | |
| 207 CloseOnError(); | |
| 208 return; | |
| 209 } | |
| 210 | |
| 211 int opcode = received_data_.data()[0] & 0x0f; | |
| 212 | |
| 213 int mask_bit = received_data_.data()[1] & 0x80; | |
| 214 if (mask_bit == 0) { | |
| 215 LOG(ERROR) << "Incoming frame is not masked."; | |
| 216 CloseOnError(); | |
| 217 return; | |
| 218 } | |
| 219 | |
| 220 int length_field_size = 1; | |
|
Wez
2012/11/20 05:44:09
Please add a comment summarizing this length proce
Sergey Ulanov
2012/11/21 01:40:24
Done.
| |
| 221 uint64 length = received_data_.data()[1] & 0x7F; | |
| 222 if (length == 126) { | |
| 223 if (received_data_.size() < 4) { | |
| 224 // Haven't received the whole frame yet. | |
|
Wez
2012/11/20 05:44:09
nit: "Haven't received the whole frame header yet"
Sergey Ulanov
2012/11/21 01:40:24
Done.
| |
| 225 return; | |
| 226 } | |
| 227 length_field_size = 3; | |
| 228 length = base::NetToHost16( | |
| 229 *reinterpret_cast<const uint16*>(received_data_.data() + 2)); | |
| 230 } else if (length == 127) { | |
| 231 if (received_data_.size() < 10) { | |
| 232 // Haven't received the whole frame yet. | |
| 233 return; | |
| 234 } | |
| 235 length_field_size = 9; | |
| 236 length = base::NetToHost64( | |
| 237 *reinterpret_cast<const uint64*>(received_data_.data() + 2)); | |
| 238 } | |
| 239 | |
| 240 int payload_position = 1 + length_field_size + kMaskLength; | |
| 241 | |
| 242 if (maximum_message_size_ > 0 && length > maximum_message_size_) { | |
|
Wez
2012/11/20 05:44:09
nit: Add a comment explaining why we need this che
Sergey Ulanov
2012/11/21 01:40:24
Done.
| |
| 243 LOG(ERROR) << "Client tried to send a fragment that is bigger than " | |
| 244 "the maximum message size of " << maximum_message_size_; | |
| 245 CloseOnError(); | |
| 246 return; | |
| 247 } | |
| 248 | |
| 249 if (received_data_.size() < payload_position + length) { | |
| 250 // Haven't received the whole frame yet. | |
| 251 return; | |
| 252 } | |
| 253 | |
| 254 if (mask_bit) { | |
|
Wez
2012/11/20 05:44:09
nit: Add a comment to the effect of "un-mask the m
Sergey Ulanov
2012/11/21 01:40:24
Done.
| |
| 255 const char* mask = received_data_.data() + length_field_size + 1; | |
| 256 UnmaskPayload( | |
| 257 mask, | |
| 258 const_cast<char*>(received_data_.data()) + payload_position, length); | |
| 259 } | |
| 260 | |
| 261 if (opcode < 0x8) { | |
| 262 // Non-control message. | |
| 263 current_message_.append( | |
| 264 received_data_.data() + payload_position, | |
| 265 received_data_.data() + payload_position + length); | |
| 266 | |
| 267 if (maximum_message_size_ > 0 && | |
| 268 current_message_.size() > maximum_message_size_) { | |
|
Wez
2012/11/20 05:44:09
It's too late to check this here; the total size o
Sergey Ulanov
2012/11/21 01:40:24
Done, but it doesn't really matter much.
| |
| 269 LOG(ERROR) << "Client tried to send a message that is bigger than " | |
| 270 "the maximum message size of " << maximum_message_size_; | |
| 271 CloseOnError(); | |
| 272 return; | |
| 273 } | |
| 274 } else { | |
| 275 // Control message. | |
| 276 if (!fin_bit) { | |
| 277 LOG(ERROR) << "Received fragmented control message."; | |
| 278 CloseOnError(); | |
| 279 return; | |
| 280 } | |
| 281 if (length > 125) { | |
| 282 LOG(ERROR) << "Received control message that is larger than 125 bytes."; | |
| 283 CloseOnError(); | |
| 284 return; | |
| 285 } | |
| 286 } | |
| 287 | |
| 288 switch (opcode) { | |
| 289 case OPCODE_CONTINUATION: | |
| 290 if (!receiving_message_) { | |
| 291 LOG(ERROR) << "Received unexpected continuation frame."; | |
| 292 CloseOnError(); | |
| 293 return; | |
| 294 } | |
| 295 break; | |
| 296 | |
| 297 case OPCODE_TEXT_FRAME: | |
| 298 case OPCODE_BINARY_FRAME: | |
| 299 if (receiving_message_) { | |
| 300 LOG(ERROR) << "Received unexpected new start frame in a middle of " | |
| 301 "a message."; | |
| 302 CloseOnError(); | |
| 303 return; | |
| 304 } | |
| 305 break; | |
| 306 | |
| 307 case OPCODE_CLOSE: | |
| 308 Close(); | |
| 309 delegate_->OnWebsocketClosed(); | |
| 310 return; | |
| 311 | |
| 312 case OPCODE_PING: | |
| 313 SendFragment( | |
| 314 OPCODE_PONG, received_data_.data() + payload_position, length); | |
| 315 break; | |
| 316 | |
| 317 case OPCODE_PONG: | |
| 318 break; | |
| 319 | |
| 320 default: | |
| 321 LOG(ERROR) << "Received invalid opcode: " << opcode; | |
| 322 CloseOnError(); | |
| 323 return; | |
| 324 } | |
| 325 | |
| 326 // Remove the frame from |received_data_|. | |
| 327 received_data_.erase(0, payload_position + length); | |
| 328 | |
| 329 // Post a task to process the data we have left in the buffer if any. | |
|
Wez
2012/11/20 05:44:09
nit: "... left in the buffer, if any."
Sergey Ulanov
2012/11/21 01:40:24
Done.
| |
| 330 if (!received_data_.empty()) { | |
| 331 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
| 332 FROM_HERE, base::Bind(&WebsocketConnection::ProcessData, | |
| 333 weak_factory_.GetWeakPtr())); | |
| 334 } | |
| 335 | |
| 336 // Handle payload in non-control messages. Delegate can be called only at the | |
| 337 // end of this function | |
| 338 if (opcode < 0x8) { | |
| 339 if (!fin_bit) { | |
| 340 receiving_message_ = true; | |
| 341 } else { | |
| 342 receiving_message_ = false; | |
| 343 std::string msg; | |
| 344 msg.swap(current_message_); | |
| 345 delegate_->OnWebsocketMessage(msg); | |
| 346 } | |
| 347 } | |
| 348 } | |
| 349 | |
| 350 void WebsocketConnection::SendFragment( | |
| 351 WebsocketOpcode opcode, | |
| 352 const char* payload, int payload_length) { | |
| 353 DCHECK_EQ(state_, ACCEPTED); | |
| 354 | |
| 355 int length_field_size = 1; | |
| 356 if (payload_length > 65535) { | |
| 357 length_field_size = 9; | |
| 358 } else if (payload_length > 125) { | |
| 359 length_field_size = 3; | |
| 360 } | |
| 361 | |
| 362 scoped_refptr<net::IOBufferWithSize> buffer = | |
| 363 new net::IOBufferWithSize(1 + length_field_size + payload_length); | |
| 364 | |
| 365 // Always set FIN flag because we never fragment outgoing messages. | |
| 366 buffer->data()[0] = opcode | 0x80; | |
| 367 | |
| 368 if (payload_length > 65535) { | |
| 369 uint64 size = base::HostToNet64(payload_length); | |
| 370 buffer->data()[1] = 127; | |
| 371 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size)); | |
| 372 } else if (payload_length > 125) { | |
| 373 uint16 size = base::HostToNet16(payload_length); | |
| 374 buffer->data()[1] = 126; | |
| 375 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size)); | |
| 376 } else { | |
| 377 buffer->data()[1] = payload_length; | |
| 378 } | |
| 379 memcpy(buffer->data() + 1 + length_field_size, payload, payload_length); | |
| 380 | |
| 381 writer_.Write(buffer, base::Closure()); | |
| 382 } | |
| 383 | |
| 384 bool WebsocketConnection::ParseHeaders() { | |
| 385 std::vector<std::string> lines; | |
| 386 base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines); | |
| 387 | |
| 388 // Parse request line. | |
| 389 std::vector<std::string> request_parts; | |
| 390 base::SplitString(lines[0], ' ', &request_parts); | |
| 391 if (request_parts.size() != 3 || | |
| 392 request_parts[0] != "GET" || | |
| 393 request_parts[2] != "HTTP/1.1") { | |
| 394 LOG(ERROR) << "Invalid Request-Line: " << headers_[0]; | |
| 395 return false; | |
| 396 } | |
| 397 request_path_ = request_parts[1]; | |
| 398 | |
| 399 std::map<std::string, std::string> headers; | |
| 400 | |
| 401 for (size_t i = 1; i < lines.size(); ++i) { | |
| 402 std::string separator(kHeaderKeyValueSeparator); | |
| 403 size_t pos = lines[i].find(separator); | |
| 404 if (pos == std::string::npos || pos == 0) { | |
| 405 LOG(ERROR) << "Invalid header line: " << lines[i]; | |
| 406 return false; | |
| 407 } | |
| 408 std::string key = lines[i].substr(0, pos); | |
| 409 if (headers.find(key) != headers.end()) { | |
| 410 LOG(ERROR) << "Duplicate header value: " << key; | |
| 411 return false; | |
| 412 } | |
| 413 headers[key] = lines[i].substr(pos + separator.size()); | |
| 414 } | |
| 415 | |
| 416 std::map<std::string, std::string>::iterator it = headers.find("Connection"); | |
| 417 if (it == headers.end() || it->second != "Upgrade") { | |
| 418 LOG(ERROR) << "Connection header is missing or invalid."; | |
| 419 return false; | |
| 420 } | |
| 421 | |
| 422 it = headers.find("Upgrade"); | |
| 423 if (it == headers.end() || it->second != "websocket") { | |
| 424 LOG(ERROR) << "Upgrade header is missing or invalid."; | |
| 425 return false; | |
| 426 } | |
| 427 | |
| 428 it = headers.find("Host"); | |
| 429 if (it == headers.end()) { | |
| 430 LOG(ERROR) << "Host header is missing."; | |
| 431 return false; | |
| 432 } | |
| 433 request_host_ = it->second; | |
| 434 | |
| 435 it = headers.find("Sec-WebSocket-Version"); | |
| 436 if (it == headers.end()) { | |
| 437 LOG(ERROR) << "Sec-WebSocket-Version header is missing."; | |
| 438 return false; | |
| 439 } | |
| 440 if (it->second != "13") { | |
| 441 LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second; | |
| 442 return false; | |
| 443 } | |
| 444 | |
| 445 it = headers.find("Origin"); | |
| 446 if (it == headers.end()) { | |
| 447 LOG(ERROR) << "Origin header is missing."; | |
| 448 return false; | |
| 449 } | |
| 450 origin_ = it->second; | |
| 451 | |
| 452 it = headers.find("Sec-WebSocket-Key"); | |
| 453 if (it == headers.end()) { | |
| 454 LOG(ERROR) << "Sec-WebSocket-Key header is missing."; | |
| 455 return false; | |
| 456 } | |
| 457 websocket_key_ = it->second; | |
| 458 | |
| 459 return true; | |
| 460 } | |
| 461 | |
| 462 void WebsocketConnection::UnmaskPayload(const char* mask, | |
| 463 char* payload, int payload_length) { | |
| 464 for (int i = 0; i < payload_length; ++i) { | |
| 465 payload[i] = payload[i] ^ mask[i % kMaskLength]; | |
| 466 } | |
| 467 } | |
| 468 | |
| 469 void WebsocketConnection::OnSocketWriteError(int error) { | |
| 470 LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error; | |
| 471 CloseOnError(); | |
| 472 } | |
| 473 | |
| 474 } // namespace remoting | |
| OLD | NEW |