| OLD | NEW |
| (Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "remoting/host/websocket_connection.h" |
| 6 |
| 7 #include <map> |
| 8 #include <vector> |
| 9 |
| 10 #include "base/base64.h" |
| 11 #include "base/compiler_specific.h" |
| 12 #include "base/location.h" |
| 13 #include "base/sha1.h" |
| 14 #include "base/single_thread_task_runner.h" |
| 15 #include "base/string_split.h" |
| 16 #include "base/sys_byteorder.h" |
| 17 #include "base/thread_task_runner_handle.h" |
| 18 #include "net/base/net_errors.h" |
| 19 #include "net/socket/stream_socket.h" |
| 20 |
| 21 namespace remoting { |
| 22 |
| 23 namespace { |
| 24 |
| 25 const int kReadBufferSize = 1024; |
| 26 const char kLineSeparator[] = "\r\n"; |
| 27 const char kHeaderEndMarker[] = "\r\n\r\n"; |
| 28 const char kHeaderKeyValueSeparator[] = ": "; |
| 29 const int kMaskLength = 4; |
| 30 |
| 31 // Fixed value specified in RFC6455. It's used to compute accept token sent to |
| 32 // the client in Sec-WebSocket-Accept key. |
| 33 const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
| 34 |
| 35 } // namespace |
| 36 |
| 37 WebsocketConnection::WebsocketConnection() |
| 38 : delegate_(NULL), |
| 39 state_(READING_HEADERS), |
| 40 receiving_message_(false), |
| 41 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
| 42 } |
| 43 |
| 44 WebsocketConnection::~WebsocketConnection() { |
| 45 Close(); |
| 46 } |
| 47 |
| 48 void WebsocketConnection::Start( |
| 49 scoped_ptr<net::StreamSocket> socket, |
| 50 ConnectedCallback connected_callback) { |
| 51 socket_ = socket.Pass(); |
| 52 connected_callback_ = connected_callback; |
| 53 reader_.Init(socket_.get(), base::Bind( |
| 54 &WebsocketConnection::OnSocketReadResult, base::Unretained(this))); |
| 55 writer_.Init(socket_.get(), base::Bind( |
| 56 &WebsocketConnection::OnSocketWriteError, base::Unretained(this))); |
| 57 } |
| 58 |
| 59 void WebsocketConnection::Accept(Delegate* delegate) { |
| 60 DCHECK_EQ(state_, HEADERS_READ); |
| 61 |
| 62 state_ = ACCEPTED; |
| 63 delegate_ = delegate; |
| 64 |
| 65 std::string accept_key = |
| 66 base::SHA1HashString(websocket_key_ + kWebsocketKeySalt); |
| 67 std::string accept_key_base64; |
| 68 bool result = base::Base64Encode(accept_key, &accept_key_base64); |
| 69 DCHECK(result); |
| 70 |
| 71 std::string handshake; |
| 72 handshake += "HTTP/1.1 101 Switching Protocol"; |
| 73 handshake += kLineSeparator; |
| 74 handshake += "Upgrade: websocket"; |
| 75 handshake += kLineSeparator; |
| 76 handshake += "Connection: Upgrade"; |
| 77 handshake += kLineSeparator; |
| 78 handshake += "Sec-WebSocket-Accept: " + accept_key_base64; |
| 79 handshake += kHeaderEndMarker; |
| 80 |
| 81 scoped_refptr<net::IOBufferWithSize> buffer = |
| 82 new net::IOBufferWithSize(handshake.size()); |
| 83 memcpy(buffer->data(), handshake.data(), handshake.size()); |
| 84 writer_.Write(buffer, base::Closure()); |
| 85 } |
| 86 |
| 87 void WebsocketConnection::Reject() { |
| 88 DCHECK_EQ(state_, HEADERS_READ); |
| 89 |
| 90 state_ = CLOSED; |
| 91 std::string response = "HTTP/1.1 401 Unauthorized"; |
| 92 response += kHeaderEndMarker; |
| 93 scoped_refptr<net::IOBufferWithSize> buffer = |
| 94 new net::IOBufferWithSize(response.size()); |
| 95 memcpy(buffer->data(), response.data(), response.size()); |
| 96 writer_.Write(buffer, base::Closure()); |
| 97 } |
| 98 |
| 99 void WebsocketConnection::SendText(const std::string& text) { |
| 100 SendFragment(OPCODE_TEXT_FRAME, text.data(), text.size()); |
| 101 } |
| 102 |
| 103 void WebsocketConnection::Close() { |
| 104 switch (state_) { |
| 105 case READING_HEADERS: |
| 106 break; |
| 107 |
| 108 case HEADERS_READ: |
| 109 Reject(); |
| 110 break; |
| 111 |
| 112 case ACCEPTED: |
| 113 SendFragment(OPCODE_CLOSE, NULL, 0); |
| 114 break; |
| 115 |
| 116 case CLOSED: |
| 117 break; |
| 118 } |
| 119 state_ = CLOSED; |
| 120 } |
| 121 |
| 122 void WebsocketConnection::CloseOnError() { |
| 123 State old_state_ = state_; |
| 124 Close(); |
| 125 if (old_state_ == ACCEPTED) { |
| 126 DCHECK(delegate_); |
| 127 delegate_->OnWebsocketClosed(); |
| 128 } |
| 129 } |
| 130 |
| 131 void WebsocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data, |
| 132 int result) { |
| 133 if (result <= 0) { |
| 134 if (result != 0) { |
| 135 LOG(ERROR) << "Error when trying to read from WebSocket connection: " |
| 136 << result; |
| 137 } |
| 138 CloseOnError(); |
| 139 return; |
| 140 } |
| 141 |
| 142 switch (state_) { |
| 143 case READING_HEADERS: { |
| 144 headers_.append(data->data(), data->data() + result); |
| 145 size_t header_end_pos = headers_.find(kHeaderEndMarker); |
| 146 if (header_end_pos != std::string::npos) { |
| 147 bool result; |
| 148 if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) { |
| 149 LOG(ERROR) << "WebSocket client tried writing data before handshake " |
| 150 "has finished."; |
| 151 DCHECK(!connected_callback_.is_null()); |
| 152 state_ = CLOSED; |
| 153 result = false; |
| 154 } else { |
| 155 // Crop newline symbols from the end. |
| 156 headers_.resize(header_end_pos); |
| 157 |
| 158 result = ParseHeaders(); |
| 159 if (!result) { |
| 160 state_ = CLOSED; |
| 161 } else { |
| 162 state_ = HEADERS_READ; |
| 163 } |
| 164 } |
| 165 ConnectedCallback cb(connected_callback_); |
| 166 connected_callback_.Reset(); |
| 167 cb.Run(result); |
| 168 } |
| 169 break; |
| 170 } |
| 171 |
| 172 case HEADERS_READ: |
| 173 LOG(ERROR) << "Received unexpected data before websocket " |
| 174 "connection is accepted."; |
| 175 CloseOnError(); |
| 176 break; |
| 177 |
| 178 case ACCEPTED: |
| 179 DCHECK(delegate_); |
| 180 current_packet_.append(data->data(), data->data() + result); |
| 181 ProcessData(); |
| 182 |
| 183 case CLOSED: |
| 184 // Ignore anything received after connection is rejected or closed. |
| 185 break; |
| 186 } |
| 187 } |
| 188 |
| 189 void WebsocketConnection::ProcessData() { |
| 190 DCHECK_EQ(state_, ACCEPTED); |
| 191 |
| 192 if (current_packet_.size() < 2) { |
| 193 // Header hasn't been received yet. |
| 194 return; |
| 195 } |
| 196 |
| 197 bool fin_bit = (current_packet_.data()[0] & 0x80) != 0; |
| 198 |
| 199 int rsv_bits = current_packet_.data()[0] & 0x70; |
| 200 if (rsv_bits != 0) { |
| 201 LOG(ERROR) << "Incoming has unsupported RSV bits set."; |
| 202 CloseOnError(); |
| 203 return; |
| 204 } |
| 205 |
| 206 int opcode = current_packet_.data()[0] & 0x0f; |
| 207 |
| 208 int mask_bit = current_packet_.data()[1] & 0x80; |
| 209 if (mask_bit == 0) { |
| 210 LOG(ERROR) << "Incoming frame is not masked."; |
| 211 CloseOnError(); |
| 212 return; |
| 213 } |
| 214 |
| 215 int length_field_size = 1; |
| 216 uint64 length = current_packet_.data()[1] & 0x7F; |
| 217 if (length == 126) { |
| 218 if (current_packet_.size() < 4) { |
| 219 // Haven't received the whole frame yet. |
| 220 return; |
| 221 } |
| 222 length_field_size = 3; |
| 223 length = base::NetToHost16( |
| 224 *reinterpret_cast<const uint16*>(current_packet_.data() + 2)); |
| 225 } else if (length == 127) { |
| 226 if (current_packet_.size() < 10) { |
| 227 // Haven't received the whole frame yet. |
| 228 return; |
| 229 } |
| 230 length_field_size = 9; |
| 231 length = base::NetToHost64( |
| 232 *reinterpret_cast<const uint64*>(current_packet_.data() + 2)); |
| 233 } |
| 234 |
| 235 int payload_position = 1 + length_field_size + kMaskLength; |
| 236 |
| 237 if (current_packet_.size() < payload_position + length) { |
| 238 // Haven't received the whole frame yet. |
| 239 return; |
| 240 } |
| 241 |
| 242 if (mask_bit) { |
| 243 const char* mask = current_packet_.data() + length_field_size + 1; |
| 244 UnmaskPayload( |
| 245 mask, |
| 246 const_cast<char*>(current_packet_.data()) + payload_position, length); |
| 247 } |
| 248 |
| 249 if (opcode < 0x8) { |
| 250 // Non-control message. |
| 251 current_message_.append( |
| 252 current_packet_.data() + payload_position, |
| 253 current_packet_.data() + payload_position + length); |
| 254 } else { |
| 255 // Control message. |
| 256 if (!fin_bit) { |
| 257 LOG(ERROR) << "Received fragmented control message."; |
| 258 CloseOnError(); |
| 259 return; |
| 260 } |
| 261 if (length > 125) { |
| 262 LOG(ERROR) << "Received control message that is larger than 125 bytes."; |
| 263 CloseOnError(); |
| 264 return; |
| 265 } |
| 266 } |
| 267 |
| 268 switch (opcode) { |
| 269 case OPCODE_CONTINUATION: |
| 270 if (!receiving_message_) { |
| 271 LOG(ERROR) << "Received unexpected continuation frame."; |
| 272 CloseOnError(); |
| 273 return; |
| 274 } |
| 275 break; |
| 276 |
| 277 case OPCODE_TEXT_FRAME: |
| 278 case OPCODE_BINARY_FRAME: |
| 279 if (receiving_message_) { |
| 280 LOG(ERROR) << "Received unexpected new start frame in a middle of " |
| 281 "a message."; |
| 282 CloseOnError(); |
| 283 return; |
| 284 } |
| 285 break; |
| 286 |
| 287 case OPCODE_CLOSE: |
| 288 Close(); |
| 289 delegate_->OnWebsocketClosed(); |
| 290 return; |
| 291 |
| 292 case OPCODE_PING: |
| 293 SendFragment( |
| 294 OPCODE_PONG, current_packet_.data() + payload_position, length); |
| 295 break; |
| 296 |
| 297 case OPCODE_PONG: |
| 298 break; |
| 299 |
| 300 default: |
| 301 LOG(ERROR) << "Received invalid opcode: " << opcode; |
| 302 CloseOnError(); |
| 303 return; |
| 304 } |
| 305 |
| 306 // Remove the frame from |current_packet_|. |
| 307 current_packet_.erase(0, payload_position + length); |
| 308 |
| 309 // Post a task to process the data we have left in the buffer if any. |
| 310 if (!current_packet_.empty()) { |
| 311 base::ThreadTaskRunnerHandle::Get()->PostTask( |
| 312 FROM_HERE, base::Bind(&WebsocketConnection::ProcessData, |
| 313 weak_factory_.GetWeakPtr())); |
| 314 } |
| 315 |
| 316 // Handle payload in non-control messages. Delegate can be called only at the |
| 317 // end of this function |
| 318 if (opcode < 0x8) { |
| 319 if (!fin_bit) { |
| 320 receiving_message_ = true; |
| 321 } else { |
| 322 receiving_message_ = false; |
| 323 std::string msg; |
| 324 msg.swap(current_message_); |
| 325 delegate_->OnWebsocketMessage(msg); |
| 326 } |
| 327 } |
| 328 } |
| 329 |
| 330 void WebsocketConnection::SendFragment( |
| 331 WebsocketOpcode opcode, |
| 332 const char* payload, int payload_length) { |
| 333 DCHECK_EQ(state_, ACCEPTED); |
| 334 |
| 335 int length_field_size = 1; |
| 336 if (payload_length > 65535) { |
| 337 length_field_size = 9; |
| 338 } else if (payload_length > 125) { |
| 339 length_field_size = 3; |
| 340 } |
| 341 |
| 342 scoped_refptr<net::IOBufferWithSize> buffer = |
| 343 new net::IOBufferWithSize(1 + length_field_size + payload_length); |
| 344 |
| 345 // Always set FIN flag because we never fragment outgoing messages. |
| 346 buffer->data()[0] = opcode | 0x80; |
| 347 |
| 348 if (payload_length > 65535) { |
| 349 uint64 size = base::HostToNet64(payload_length); |
| 350 buffer->data()[1] = 127; |
| 351 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size)); |
| 352 } else if (payload_length > 125) { |
| 353 uint16 size = base::HostToNet16(payload_length); |
| 354 buffer->data()[1] = 126; |
| 355 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size)); |
| 356 } else { |
| 357 buffer->data()[1] = payload_length; |
| 358 } |
| 359 memcpy(buffer->data() + 1 + length_field_size, payload, payload_length); |
| 360 |
| 361 writer_.Write(buffer, base::Closure()); |
| 362 } |
| 363 |
| 364 bool WebsocketConnection::ParseHeaders() { |
| 365 std::vector<std::string> lines; |
| 366 base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines); |
| 367 |
| 368 // Parse request line. |
| 369 std::vector<std::string> request_parts; |
| 370 base::SplitString(lines[0], ' ', &request_parts); |
| 371 if (request_parts.size() != 3 || |
| 372 request_parts[0] != "GET" || |
| 373 request_parts[2] != "HTTP/1.1") { |
| 374 LOG(ERROR) << "Invalid Request-Line: " << headers_[0]; |
| 375 return false; |
| 376 } |
| 377 request_path_ = request_parts[1]; |
| 378 |
| 379 std::map<std::string, std::string> headers; |
| 380 |
| 381 for (size_t i = 1; i < lines.size(); ++i) { |
| 382 std::string separator(kHeaderKeyValueSeparator); |
| 383 size_t pos = lines[i].find(separator); |
| 384 if (pos == std::string::npos || pos == 0) { |
| 385 LOG(ERROR) << "Invalid header line: " << lines[i]; |
| 386 return false; |
| 387 } |
| 388 std::string key = lines[i].substr(0, pos); |
| 389 if (headers.find(key) != headers.end()) { |
| 390 LOG(ERROR) << "Duplicate header value: " << key; |
| 391 return false; |
| 392 } |
| 393 headers[key] = lines[i].substr(pos + separator.size()); |
| 394 } |
| 395 |
| 396 std::map<std::string, std::string>::iterator it = headers.find("Connection"); |
| 397 if (it == headers.end() || it->second != "Upgrade") { |
| 398 LOG(ERROR) << "Connection header is missing or invalid."; |
| 399 return false; |
| 400 } |
| 401 |
| 402 it = headers.find("Upgrade"); |
| 403 if (it == headers.end() || it->second != "websocket") { |
| 404 LOG(ERROR) << "Upgrade header is missing or invalid."; |
| 405 return false; |
| 406 } |
| 407 |
| 408 it = headers.find("Host"); |
| 409 if (it == headers.end()) { |
| 410 LOG(ERROR) << "Host header is missing."; |
| 411 return false; |
| 412 } |
| 413 request_host_ = it->second; |
| 414 |
| 415 it = headers.find("Sec-WebSocket-Version"); |
| 416 if (it == headers.end()) { |
| 417 LOG(ERROR) << "Sec-WebSocket-Version header is missing."; |
| 418 return false; |
| 419 } |
| 420 if (it->second != "13") { |
| 421 LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second; |
| 422 return false; |
| 423 } |
| 424 |
| 425 it = headers.find("Origin"); |
| 426 if (it == headers.end()) { |
| 427 LOG(ERROR) << "Origin header is missing."; |
| 428 return false; |
| 429 } |
| 430 origin_ = it->second; |
| 431 |
| 432 it = headers.find("Sec-WebSocket-Key"); |
| 433 if (it == headers.end()) { |
| 434 LOG(ERROR) << "Sec-WebSocket-Key header is missing."; |
| 435 return false; |
| 436 } |
| 437 websocket_key_ = it->second; |
| 438 |
| 439 return true; |
| 440 } |
| 441 |
| 442 void WebsocketConnection::UnmaskPayload(const char* mask, |
| 443 char* payload, int payload_length) { |
| 444 for (int i = 0; i < payload_length; ++i) { |
| 445 payload[i] = payload[i] ^ mask[i % kMaskLength]; |
| 446 } |
| 447 } |
| 448 |
| 449 void WebsocketConnection::OnSocketWriteError(int error) { |
| 450 LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error; |
| 451 CloseOnError(); |
| 452 } |
| 453 |
| 454 } // namespace remoting |
| OLD | NEW |