| Index: remoting/host/websocket_connection.cc
|
| diff --git a/remoting/host/websocket_connection.cc b/remoting/host/websocket_connection.cc
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..bdf5e5339c03e544da3e7e3039ac4229907b3edd
|
| --- /dev/null
|
| +++ b/remoting/host/websocket_connection.cc
|
| @@ -0,0 +1,454 @@
|
| +// Copyright (c) 2012 The Chromium Authors. All rights reserved.
|
| +// Use of this source code is governed by a BSD-style license that can be
|
| +// found in the LICENSE file.
|
| +
|
| +#include "remoting/host/websocket_connection.h"
|
| +
|
| +#include <map>
|
| +#include <vector>
|
| +
|
| +#include "base/base64.h"
|
| +#include "base/compiler_specific.h"
|
| +#include "base/location.h"
|
| +#include "base/sha1.h"
|
| +#include "base/single_thread_task_runner.h"
|
| +#include "base/string_split.h"
|
| +#include "base/sys_byteorder.h"
|
| +#include "base/thread_task_runner_handle.h"
|
| +#include "net/base/net_errors.h"
|
| +#include "net/socket/stream_socket.h"
|
| +
|
| +namespace remoting {
|
| +
|
| +namespace {
|
| +
|
| +const int kReadBufferSize = 1024;
|
| +const char kLineSeparator[] = "\r\n";
|
| +const char kHeaderEndMarker[] = "\r\n\r\n";
|
| +const char kHeaderKeyValueSeparator[] = ": ";
|
| +const int kMaskLength = 4;
|
| +
|
| +// Fixed value specified in RFC6455. It's used to compute accept token sent to
|
| +// the client in Sec-WebSocket-Accept key.
|
| +const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
| +
|
| +} // namespace
|
| +
|
| +WebsocketConnection::WebsocketConnection()
|
| + : delegate_(NULL),
|
| + state_(READING_HEADERS),
|
| + receiving_message_(false),
|
| + ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
|
| +}
|
| +
|
| +WebsocketConnection::~WebsocketConnection() {
|
| + Close();
|
| +}
|
| +
|
| +void WebsocketConnection::Start(
|
| + scoped_ptr<net::StreamSocket> socket,
|
| + ConnectedCallback connected_callback) {
|
| + socket_ = socket.Pass();
|
| + connected_callback_ = connected_callback;
|
| + reader_.Init(socket_.get(), base::Bind(
|
| + &WebsocketConnection::OnSocketReadResult, base::Unretained(this)));
|
| + writer_.Init(socket_.get(), base::Bind(
|
| + &WebsocketConnection::OnSocketWriteError, base::Unretained(this)));
|
| +}
|
| +
|
| +void WebsocketConnection::Accept(Delegate* delegate) {
|
| + DCHECK_EQ(state_, HEADERS_READ);
|
| +
|
| + state_ = ACCEPTED;
|
| + delegate_ = delegate;
|
| +
|
| + std::string accept_key =
|
| + base::SHA1HashString(websocket_key_ + kWebsocketKeySalt);
|
| + std::string accept_key_base64;
|
| + bool result = base::Base64Encode(accept_key, &accept_key_base64);
|
| + DCHECK(result);
|
| +
|
| + std::string handshake;
|
| + handshake += "HTTP/1.1 101 Switching Protocol";
|
| + handshake += kLineSeparator;
|
| + handshake += "Upgrade: websocket";
|
| + handshake += kLineSeparator;
|
| + handshake += "Connection: Upgrade";
|
| + handshake += kLineSeparator;
|
| + handshake += "Sec-WebSocket-Accept: " + accept_key_base64;
|
| + handshake += kHeaderEndMarker;
|
| +
|
| + scoped_refptr<net::IOBufferWithSize> buffer =
|
| + new net::IOBufferWithSize(handshake.size());
|
| + memcpy(buffer->data(), handshake.data(), handshake.size());
|
| + writer_.Write(buffer, base::Closure());
|
| +}
|
| +
|
| +void WebsocketConnection::Reject() {
|
| + DCHECK_EQ(state_, HEADERS_READ);
|
| +
|
| + state_ = CLOSED;
|
| + std::string response = "HTTP/1.1 401 Unauthorized";
|
| + response += kHeaderEndMarker;
|
| + scoped_refptr<net::IOBufferWithSize> buffer =
|
| + new net::IOBufferWithSize(response.size());
|
| + memcpy(buffer->data(), response.data(), response.size());
|
| + writer_.Write(buffer, base::Closure());
|
| +}
|
| +
|
| +void WebsocketConnection::SendText(const std::string& text) {
|
| + SendFragment(OPCODE_TEXT_FRAME, text.data(), text.size());
|
| +}
|
| +
|
| +void WebsocketConnection::Close() {
|
| + switch (state_) {
|
| + case READING_HEADERS:
|
| + break;
|
| +
|
| + case HEADERS_READ:
|
| + Reject();
|
| + break;
|
| +
|
| + case ACCEPTED:
|
| + SendFragment(OPCODE_CLOSE, NULL, 0);
|
| + break;
|
| +
|
| + case CLOSED:
|
| + break;
|
| + }
|
| + state_ = CLOSED;
|
| +}
|
| +
|
| +void WebsocketConnection::CloseOnError() {
|
| + State old_state_ = state_;
|
| + Close();
|
| + if (old_state_ == ACCEPTED) {
|
| + DCHECK(delegate_);
|
| + delegate_->OnWebsocketClosed();
|
| + }
|
| +}
|
| +
|
| +void WebsocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data,
|
| + int result) {
|
| + if (result <= 0) {
|
| + if (result != 0) {
|
| + LOG(ERROR) << "Error when trying to read from WebSocket connection: "
|
| + << result;
|
| + }
|
| + CloseOnError();
|
| + return;
|
| + }
|
| +
|
| + switch (state_) {
|
| + case READING_HEADERS: {
|
| + headers_.append(data->data(), data->data() + result);
|
| + size_t header_end_pos = headers_.find(kHeaderEndMarker);
|
| + if (header_end_pos != std::string::npos) {
|
| + bool result;
|
| + if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) {
|
| + LOG(ERROR) << "WebSocket client tried writing data before handshake "
|
| + "has finished.";
|
| + DCHECK(!connected_callback_.is_null());
|
| + state_ = CLOSED;
|
| + result = false;
|
| + } else {
|
| + // Crop newline symbols from the end.
|
| + headers_.resize(header_end_pos);
|
| +
|
| + result = ParseHeaders();
|
| + if (!result) {
|
| + state_ = CLOSED;
|
| + } else {
|
| + state_ = HEADERS_READ;
|
| + }
|
| + }
|
| + ConnectedCallback cb(connected_callback_);
|
| + connected_callback_.Reset();
|
| + cb.Run(result);
|
| + }
|
| + break;
|
| + }
|
| +
|
| + case HEADERS_READ:
|
| + LOG(ERROR) << "Received unexpected data before websocket "
|
| + "connection is accepted.";
|
| + CloseOnError();
|
| + break;
|
| +
|
| + case ACCEPTED:
|
| + DCHECK(delegate_);
|
| + current_packet_.append(data->data(), data->data() + result);
|
| + ProcessData();
|
| +
|
| + case CLOSED:
|
| + // Ignore anything received after connection is rejected or closed.
|
| + break;
|
| + }
|
| +}
|
| +
|
| +void WebsocketConnection::ProcessData() {
|
| + DCHECK_EQ(state_, ACCEPTED);
|
| +
|
| + if (current_packet_.size() < 2) {
|
| + // Header hasn't been received yet.
|
| + return;
|
| + }
|
| +
|
| + bool fin_bit = (current_packet_.data()[0] & 0x80) != 0;
|
| +
|
| + int rsv_bits = current_packet_.data()[0] & 0x70;
|
| + if (rsv_bits != 0) {
|
| + LOG(ERROR) << "Incoming has unsupported RSV bits set.";
|
| + CloseOnError();
|
| + return;
|
| + }
|
| +
|
| + int opcode = current_packet_.data()[0] & 0x0f;
|
| +
|
| + int mask_bit = current_packet_.data()[1] & 0x80;
|
| + if (mask_bit == 0) {
|
| + LOG(ERROR) << "Incoming frame is not masked.";
|
| + CloseOnError();
|
| + return;
|
| + }
|
| +
|
| + int length_field_size = 1;
|
| + uint64 length = current_packet_.data()[1] & 0x7F;
|
| + if (length == 126) {
|
| + if (current_packet_.size() < 4) {
|
| + // Haven't received the whole frame yet.
|
| + return;
|
| + }
|
| + length_field_size = 3;
|
| + length = base::NetToHost16(
|
| + *reinterpret_cast<const uint16*>(current_packet_.data() + 2));
|
| + } else if (length == 127) {
|
| + if (current_packet_.size() < 10) {
|
| + // Haven't received the whole frame yet.
|
| + return;
|
| + }
|
| + length_field_size = 9;
|
| + length = base::NetToHost64(
|
| + *reinterpret_cast<const uint64*>(current_packet_.data() + 2));
|
| + }
|
| +
|
| + int payload_position = 1 + length_field_size + kMaskLength;
|
| +
|
| + if (current_packet_.size() < payload_position + length) {
|
| + // Haven't received the whole frame yet.
|
| + return;
|
| + }
|
| +
|
| + if (mask_bit) {
|
| + const char* mask = current_packet_.data() + length_field_size + 1;
|
| + UnmaskPayload(
|
| + mask,
|
| + const_cast<char*>(current_packet_.data()) + payload_position, length);
|
| + }
|
| +
|
| + if (opcode < 0x8) {
|
| + // Non-control message.
|
| + current_message_.append(
|
| + current_packet_.data() + payload_position,
|
| + current_packet_.data() + payload_position + length);
|
| + } else {
|
| + // Control message.
|
| + if (!fin_bit) {
|
| + LOG(ERROR) << "Received fragmented control message.";
|
| + CloseOnError();
|
| + return;
|
| + }
|
| + if (length > 125) {
|
| + LOG(ERROR) << "Received control message that is larger than 125 bytes.";
|
| + CloseOnError();
|
| + return;
|
| + }
|
| + }
|
| +
|
| + switch (opcode) {
|
| + case OPCODE_CONTINUATION:
|
| + if (!receiving_message_) {
|
| + LOG(ERROR) << "Received unexpected continuation frame.";
|
| + CloseOnError();
|
| + return;
|
| + }
|
| + break;
|
| +
|
| + case OPCODE_TEXT_FRAME:
|
| + case OPCODE_BINARY_FRAME:
|
| + if (receiving_message_) {
|
| + LOG(ERROR) << "Received unexpected new start frame in a middle of "
|
| + "a message.";
|
| + CloseOnError();
|
| + return;
|
| + }
|
| + break;
|
| +
|
| + case OPCODE_CLOSE:
|
| + Close();
|
| + delegate_->OnWebsocketClosed();
|
| + return;
|
| +
|
| + case OPCODE_PING:
|
| + SendFragment(
|
| + OPCODE_PONG, current_packet_.data() + payload_position, length);
|
| + break;
|
| +
|
| + case OPCODE_PONG:
|
| + break;
|
| +
|
| + default:
|
| + LOG(ERROR) << "Received invalid opcode: " << opcode;
|
| + CloseOnError();
|
| + return;
|
| + }
|
| +
|
| + // Remove the frame from |current_packet_|.
|
| + current_packet_.erase(0, payload_position + length);
|
| +
|
| + // Post a task to process the data we have left in the buffer if any.
|
| + if (!current_packet_.empty()) {
|
| + base::ThreadTaskRunnerHandle::Get()->PostTask(
|
| + FROM_HERE, base::Bind(&WebsocketConnection::ProcessData,
|
| + weak_factory_.GetWeakPtr()));
|
| + }
|
| +
|
| + // Handle payload in non-control messages. Delegate can be called only at the
|
| + // end of this function
|
| + if (opcode < 0x8) {
|
| + if (!fin_bit) {
|
| + receiving_message_ = true;
|
| + } else {
|
| + receiving_message_ = false;
|
| + std::string msg;
|
| + msg.swap(current_message_);
|
| + delegate_->OnWebsocketMessage(msg);
|
| + }
|
| + }
|
| +}
|
| +
|
| +void WebsocketConnection::SendFragment(
|
| + WebsocketOpcode opcode,
|
| + const char* payload, int payload_length) {
|
| + DCHECK_EQ(state_, ACCEPTED);
|
| +
|
| + int length_field_size = 1;
|
| + if (payload_length > 65535) {
|
| + length_field_size = 9;
|
| + } else if (payload_length > 125) {
|
| + length_field_size = 3;
|
| + }
|
| +
|
| + scoped_refptr<net::IOBufferWithSize> buffer =
|
| + new net::IOBufferWithSize(1 + length_field_size + payload_length);
|
| +
|
| + // Always set FIN flag because we never fragment outgoing messages.
|
| + buffer->data()[0] = opcode | 0x80;
|
| +
|
| + if (payload_length > 65535) {
|
| + uint64 size = base::HostToNet64(payload_length);
|
| + buffer->data()[1] = 127;
|
| + memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
|
| + } else if (payload_length > 125) {
|
| + uint16 size = base::HostToNet16(payload_length);
|
| + buffer->data()[1] = 126;
|
| + memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
|
| + } else {
|
| + buffer->data()[1] = payload_length;
|
| + }
|
| + memcpy(buffer->data() + 1 + length_field_size, payload, payload_length);
|
| +
|
| + writer_.Write(buffer, base::Closure());
|
| +}
|
| +
|
| +bool WebsocketConnection::ParseHeaders() {
|
| + std::vector<std::string> lines;
|
| + base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines);
|
| +
|
| + // Parse request line.
|
| + std::vector<std::string> request_parts;
|
| + base::SplitString(lines[0], ' ', &request_parts);
|
| + if (request_parts.size() != 3 ||
|
| + request_parts[0] != "GET" ||
|
| + request_parts[2] != "HTTP/1.1") {
|
| + LOG(ERROR) << "Invalid Request-Line: " << headers_[0];
|
| + return false;
|
| + }
|
| + request_path_ = request_parts[1];
|
| +
|
| + std::map<std::string, std::string> headers;
|
| +
|
| + for (size_t i = 1; i < lines.size(); ++i) {
|
| + std::string separator(kHeaderKeyValueSeparator);
|
| + size_t pos = lines[i].find(separator);
|
| + if (pos == std::string::npos || pos == 0) {
|
| + LOG(ERROR) << "Invalid header line: " << lines[i];
|
| + return false;
|
| + }
|
| + std::string key = lines[i].substr(0, pos);
|
| + if (headers.find(key) != headers.end()) {
|
| + LOG(ERROR) << "Duplicate header value: " << key;
|
| + return false;
|
| + }
|
| + headers[key] = lines[i].substr(pos + separator.size());
|
| + }
|
| +
|
| + std::map<std::string, std::string>::iterator it = headers.find("Connection");
|
| + if (it == headers.end() || it->second != "Upgrade") {
|
| + LOG(ERROR) << "Connection header is missing or invalid.";
|
| + return false;
|
| + }
|
| +
|
| + it = headers.find("Upgrade");
|
| + if (it == headers.end() || it->second != "websocket") {
|
| + LOG(ERROR) << "Upgrade header is missing or invalid.";
|
| + return false;
|
| + }
|
| +
|
| + it = headers.find("Host");
|
| + if (it == headers.end()) {
|
| + LOG(ERROR) << "Host header is missing.";
|
| + return false;
|
| + }
|
| + request_host_ = it->second;
|
| +
|
| + it = headers.find("Sec-WebSocket-Version");
|
| + if (it == headers.end()) {
|
| + LOG(ERROR) << "Sec-WebSocket-Version header is missing.";
|
| + return false;
|
| + }
|
| + if (it->second != "13") {
|
| + LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second;
|
| + return false;
|
| + }
|
| +
|
| + it = headers.find("Origin");
|
| + if (it == headers.end()) {
|
| + LOG(ERROR) << "Origin header is missing.";
|
| + return false;
|
| + }
|
| + origin_ = it->second;
|
| +
|
| + it = headers.find("Sec-WebSocket-Key");
|
| + if (it == headers.end()) {
|
| + LOG(ERROR) << "Sec-WebSocket-Key header is missing.";
|
| + return false;
|
| + }
|
| + websocket_key_ = it->second;
|
| +
|
| + return true;
|
| +}
|
| +
|
| +void WebsocketConnection::UnmaskPayload(const char* mask,
|
| + char* payload, int payload_length) {
|
| + for (int i = 0; i < payload_length; ++i) {
|
| + payload[i] = payload[i] ^ mask[i % kMaskLength];
|
| + }
|
| +}
|
| +
|
| +void WebsocketConnection::OnSocketWriteError(int error) {
|
| + LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error;
|
| + CloseOnError();
|
| +}
|
| +
|
| +} // namespace remoting
|
|
|