Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(741)

Unified Diff: remoting/host/websocket_connection.cc

Issue 11358190: Add simple WebSocket server implementation. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Created 8 years, 1 month ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: remoting/host/websocket_connection.cc
diff --git a/remoting/host/websocket_connection.cc b/remoting/host/websocket_connection.cc
new file mode 100644
index 0000000000000000000000000000000000000000..897804dc1be02b84e29c38610237acef800f0908
--- /dev/null
+++ b/remoting/host/websocket_connection.cc
@@ -0,0 +1,474 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "remoting/host/websocket_connection.h"
+
+#include <map>
+#include <vector>
+
+#include "base/base64.h"
+#include "base/compiler_specific.h"
+#include "base/location.h"
+#include "base/sha1.h"
+#include "base/single_thread_task_runner.h"
+#include "base/string_split.h"
+#include "base/sys_byteorder.h"
+#include "base/thread_task_runner_handle.h"
+#include "net/base/net_errors.h"
+#include "net/socket/stream_socket.h"
+
+namespace remoting {
+
+namespace {
+
+const int kReadBufferSize = 1024;
+const char kLineSeparator[] = "\r\n";
+const char kHeaderEndMarker[] = "\r\n\r\n";
+const char kHeaderKeyValueSeparator[] = ": ";
+const int kMaskLength = 4;
+
+// Fixed value specified in RFC6455. It's used to compute accept token sent to
+// the client in Sec-WebSocket-Accept key.
+const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+} // namespace
+
+WebsocketConnection::WebsocketConnection()
+ : delegate_(NULL),
+ maximum_message_size_(0),
+ state_(READING_HEADERS),
+ receiving_message_(false),
+ ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
+}
+
+WebsocketConnection::~WebsocketConnection() {
+ Close();
+}
+
+void WebsocketConnection::Start(
+ scoped_ptr<net::StreamSocket> socket,
+ ConnectedCallback connected_callback) {
+ socket_ = socket.Pass();
+ connected_callback_ = connected_callback;
+ reader_.Init(socket_.get(), base::Bind(
+ &WebsocketConnection::OnSocketReadResult, base::Unretained(this)));
+ writer_.Init(socket_.get(), base::Bind(
+ &WebsocketConnection::OnSocketWriteError, base::Unretained(this)));
+}
+
+void WebsocketConnection::Accept(Delegate* delegate) {
+ DCHECK_EQ(state_, HEADERS_READ);
+
+ state_ = ACCEPTED;
+ delegate_ = delegate;
+
+ std::string accept_key =
+ base::SHA1HashString(websocket_key_ + kWebsocketKeySalt);
+ std::string accept_key_base64;
+ bool result = base::Base64Encode(accept_key, &accept_key_base64);
+ DCHECK(result);
+
+ std::string handshake;
+ handshake += "HTTP/1.1 101 Switching Protocol";
+ handshake += kLineSeparator;
+ handshake += "Upgrade: websocket";
+ handshake += kLineSeparator;
+ handshake += "Connection: Upgrade";
+ handshake += kLineSeparator;
+ handshake += "Sec-WebSocket-Accept: " + accept_key_base64;
+ handshake += kHeaderEndMarker;
+
+ scoped_refptr<net::IOBufferWithSize> buffer =
+ new net::IOBufferWithSize(handshake.size());
+ memcpy(buffer->data(), handshake.data(), handshake.size());
+ writer_.Write(buffer, base::Closure());
+}
+
+void WebsocketConnection::Reject() {
+ DCHECK_EQ(state_, HEADERS_READ);
+
+ state_ = CLOSED;
+ std::string response = "HTTP/1.1 401 Unauthorized";
+ response += kHeaderEndMarker;
+ scoped_refptr<net::IOBufferWithSize> buffer =
+ new net::IOBufferWithSize(response.size());
+ memcpy(buffer->data(), response.data(), response.size());
+ writer_.Write(buffer, base::Closure());
+}
+
+void WebsocketConnection::set_maximum_message_size(uint64 size) {
+ maximum_message_size_ = size;
+}
+
+void WebsocketConnection::SendText(const std::string& text) {
+ SendFragment(OPCODE_TEXT_FRAME, text.data(), text.size());
+}
+
+void WebsocketConnection::Close() {
+ switch (state_) {
+ case READING_HEADERS:
+ break;
+
+ case HEADERS_READ:
+ Reject();
+ break;
+
+ case ACCEPTED:
+ SendFragment(OPCODE_CLOSE, NULL, 0);
+ break;
+
+ case CLOSED:
+ break;
+ }
+ state_ = CLOSED;
+}
+
+void WebsocketConnection::CloseOnError() {
+ State old_state_ = state_;
+ Close();
+ if (old_state_ == ACCEPTED) {
+ DCHECK(delegate_);
+ delegate_->OnWebsocketClosed();
+ }
+}
+
+void WebsocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data,
+ int result) {
+ if (result <= 0) {
+ if (result != 0) {
+ LOG(ERROR) << "Error when trying to read from WebSocket connection: "
+ << result;
+ }
+ CloseOnError();
+ return;
+ }
+
+ switch (state_) {
+ case READING_HEADERS: {
+ headers_.append(data->data(), data->data() + result);
+ size_t header_end_pos = headers_.find(kHeaderEndMarker);
+ if (header_end_pos != std::string::npos) {
+ bool result;
+ if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) {
+ LOG(ERROR) << "WebSocket client tried writing data before handshake "
+ "has finished.";
+ DCHECK(!connected_callback_.is_null());
+ state_ = CLOSED;
+ result = false;
+ } else {
+ // Crop newline symbols from the end.
+ headers_.resize(header_end_pos);
+
+ result = ParseHeaders();
+ if (!result) {
+ state_ = CLOSED;
+ } else {
+ state_ = HEADERS_READ;
+ }
+ }
+ ConnectedCallback cb(connected_callback_);
+ connected_callback_.Reset();
+ cb.Run(result);
+ }
+ break;
+ }
+
+ case HEADERS_READ:
+ LOG(ERROR) << "Received unexpected data before websocket "
+ "connection is accepted.";
+ CloseOnError();
+ break;
+
+ case ACCEPTED:
+ DCHECK(delegate_);
+ received_data_.append(data->data(), data->data() + result);
+ ProcessData();
+
+ case CLOSED:
+ // Ignore anything received after connection is rejected or closed.
+ break;
+ }
+}
+
+void WebsocketConnection::ProcessData() {
+ DCHECK_EQ(state_, ACCEPTED);
+
+ if (received_data_.size() < 2) {
+ // Header hasn't been received yet.
+ return;
+ }
+
+ bool fin_bit = (received_data_.data()[0] & 0x80) != 0;
+
+ int rsv_bits = received_data_.data()[0] & 0x70;
Wez 2012/11/20 05:44:09 nit: Add a comment summarizing what these bits are
Sergey Ulanov 2012/11/21 01:40:24 Done.
+ if (rsv_bits != 0) {
+ LOG(ERROR) << "Incoming has unsupported RSV bits set.";
+ CloseOnError();
+ return;
+ }
+
+ int opcode = received_data_.data()[0] & 0x0f;
+
+ int mask_bit = received_data_.data()[1] & 0x80;
+ if (mask_bit == 0) {
+ LOG(ERROR) << "Incoming frame is not masked.";
+ CloseOnError();
+ return;
+ }
+
+ int length_field_size = 1;
Wez 2012/11/20 05:44:09 Please add a comment summarizing this length proce
Sergey Ulanov 2012/11/21 01:40:24 Done.
+ uint64 length = received_data_.data()[1] & 0x7F;
+ if (length == 126) {
+ if (received_data_.size() < 4) {
+ // Haven't received the whole frame yet.
Wez 2012/11/20 05:44:09 nit: "Haven't received the whole frame header yet"
Sergey Ulanov 2012/11/21 01:40:24 Done.
+ return;
+ }
+ length_field_size = 3;
+ length = base::NetToHost16(
+ *reinterpret_cast<const uint16*>(received_data_.data() + 2));
+ } else if (length == 127) {
+ if (received_data_.size() < 10) {
+ // Haven't received the whole frame yet.
+ return;
+ }
+ length_field_size = 9;
+ length = base::NetToHost64(
+ *reinterpret_cast<const uint64*>(received_data_.data() + 2));
+ }
+
+ int payload_position = 1 + length_field_size + kMaskLength;
+
+ if (maximum_message_size_ > 0 && length > maximum_message_size_) {
Wez 2012/11/20 05:44:09 nit: Add a comment explaining why we need this che
Sergey Ulanov 2012/11/21 01:40:24 Done.
+ LOG(ERROR) << "Client tried to send a fragment that is bigger than "
+ "the maximum message size of " << maximum_message_size_;
+ CloseOnError();
+ return;
+ }
+
+ if (received_data_.size() < payload_position + length) {
+ // Haven't received the whole frame yet.
+ return;
+ }
+
+ if (mask_bit) {
Wez 2012/11/20 05:44:09 nit: Add a comment to the effect of "un-mask the m
Sergey Ulanov 2012/11/21 01:40:24 Done.
+ const char* mask = received_data_.data() + length_field_size + 1;
+ UnmaskPayload(
+ mask,
+ const_cast<char*>(received_data_.data()) + payload_position, length);
+ }
+
+ if (opcode < 0x8) {
+ // Non-control message.
+ current_message_.append(
+ received_data_.data() + payload_position,
+ received_data_.data() + payload_position + length);
+
+ if (maximum_message_size_ > 0 &&
+ current_message_.size() > maximum_message_size_) {
Wez 2012/11/20 05:44:09 It's too late to check this here; the total size o
Sergey Ulanov 2012/11/21 01:40:24 Done, but it doesn't really matter much.
+ LOG(ERROR) << "Client tried to send a message that is bigger than "
+ "the maximum message size of " << maximum_message_size_;
+ CloseOnError();
+ return;
+ }
+ } else {
+ // Control message.
+ if (!fin_bit) {
+ LOG(ERROR) << "Received fragmented control message.";
+ CloseOnError();
+ return;
+ }
+ if (length > 125) {
+ LOG(ERROR) << "Received control message that is larger than 125 bytes.";
+ CloseOnError();
+ return;
+ }
+ }
+
+ switch (opcode) {
+ case OPCODE_CONTINUATION:
+ if (!receiving_message_) {
+ LOG(ERROR) << "Received unexpected continuation frame.";
+ CloseOnError();
+ return;
+ }
+ break;
+
+ case OPCODE_TEXT_FRAME:
+ case OPCODE_BINARY_FRAME:
+ if (receiving_message_) {
+ LOG(ERROR) << "Received unexpected new start frame in a middle of "
+ "a message.";
+ CloseOnError();
+ return;
+ }
+ break;
+
+ case OPCODE_CLOSE:
+ Close();
+ delegate_->OnWebsocketClosed();
+ return;
+
+ case OPCODE_PING:
+ SendFragment(
+ OPCODE_PONG, received_data_.data() + payload_position, length);
+ break;
+
+ case OPCODE_PONG:
+ break;
+
+ default:
+ LOG(ERROR) << "Received invalid opcode: " << opcode;
+ CloseOnError();
+ return;
+ }
+
+ // Remove the frame from |received_data_|.
+ received_data_.erase(0, payload_position + length);
+
+ // Post a task to process the data we have left in the buffer if any.
Wez 2012/11/20 05:44:09 nit: "... left in the buffer, if any."
Sergey Ulanov 2012/11/21 01:40:24 Done.
+ if (!received_data_.empty()) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE, base::Bind(&WebsocketConnection::ProcessData,
+ weak_factory_.GetWeakPtr()));
+ }
+
+ // Handle payload in non-control messages. Delegate can be called only at the
+ // end of this function
+ if (opcode < 0x8) {
+ if (!fin_bit) {
+ receiving_message_ = true;
+ } else {
+ receiving_message_ = false;
+ std::string msg;
+ msg.swap(current_message_);
+ delegate_->OnWebsocketMessage(msg);
+ }
+ }
+}
+
+void WebsocketConnection::SendFragment(
+ WebsocketOpcode opcode,
+ const char* payload, int payload_length) {
+ DCHECK_EQ(state_, ACCEPTED);
+
+ int length_field_size = 1;
+ if (payload_length > 65535) {
+ length_field_size = 9;
+ } else if (payload_length > 125) {
+ length_field_size = 3;
+ }
+
+ scoped_refptr<net::IOBufferWithSize> buffer =
+ new net::IOBufferWithSize(1 + length_field_size + payload_length);
+
+ // Always set FIN flag because we never fragment outgoing messages.
+ buffer->data()[0] = opcode | 0x80;
+
+ if (payload_length > 65535) {
+ uint64 size = base::HostToNet64(payload_length);
+ buffer->data()[1] = 127;
+ memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
+ } else if (payload_length > 125) {
+ uint16 size = base::HostToNet16(payload_length);
+ buffer->data()[1] = 126;
+ memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
+ } else {
+ buffer->data()[1] = payload_length;
+ }
+ memcpy(buffer->data() + 1 + length_field_size, payload, payload_length);
+
+ writer_.Write(buffer, base::Closure());
+}
+
+bool WebsocketConnection::ParseHeaders() {
+ std::vector<std::string> lines;
+ base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines);
+
+ // Parse request line.
+ std::vector<std::string> request_parts;
+ base::SplitString(lines[0], ' ', &request_parts);
+ if (request_parts.size() != 3 ||
+ request_parts[0] != "GET" ||
+ request_parts[2] != "HTTP/1.1") {
+ LOG(ERROR) << "Invalid Request-Line: " << headers_[0];
+ return false;
+ }
+ request_path_ = request_parts[1];
+
+ std::map<std::string, std::string> headers;
+
+ for (size_t i = 1; i < lines.size(); ++i) {
+ std::string separator(kHeaderKeyValueSeparator);
+ size_t pos = lines[i].find(separator);
+ if (pos == std::string::npos || pos == 0) {
+ LOG(ERROR) << "Invalid header line: " << lines[i];
+ return false;
+ }
+ std::string key = lines[i].substr(0, pos);
+ if (headers.find(key) != headers.end()) {
+ LOG(ERROR) << "Duplicate header value: " << key;
+ return false;
+ }
+ headers[key] = lines[i].substr(pos + separator.size());
+ }
+
+ std::map<std::string, std::string>::iterator it = headers.find("Connection");
+ if (it == headers.end() || it->second != "Upgrade") {
+ LOG(ERROR) << "Connection header is missing or invalid.";
+ return false;
+ }
+
+ it = headers.find("Upgrade");
+ if (it == headers.end() || it->second != "websocket") {
+ LOG(ERROR) << "Upgrade header is missing or invalid.";
+ return false;
+ }
+
+ it = headers.find("Host");
+ if (it == headers.end()) {
+ LOG(ERROR) << "Host header is missing.";
+ return false;
+ }
+ request_host_ = it->second;
+
+ it = headers.find("Sec-WebSocket-Version");
+ if (it == headers.end()) {
+ LOG(ERROR) << "Sec-WebSocket-Version header is missing.";
+ return false;
+ }
+ if (it->second != "13") {
+ LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second;
+ return false;
+ }
+
+ it = headers.find("Origin");
+ if (it == headers.end()) {
+ LOG(ERROR) << "Origin header is missing.";
+ return false;
+ }
+ origin_ = it->second;
+
+ it = headers.find("Sec-WebSocket-Key");
+ if (it == headers.end()) {
+ LOG(ERROR) << "Sec-WebSocket-Key header is missing.";
+ return false;
+ }
+ websocket_key_ = it->second;
+
+ return true;
+}
+
+void WebsocketConnection::UnmaskPayload(const char* mask,
+ char* payload, int payload_length) {
+ for (int i = 0; i < payload_length; ++i) {
+ payload[i] = payload[i] ^ mask[i % kMaskLength];
+ }
+}
+
+void WebsocketConnection::OnSocketWriteError(int error) {
+ LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error;
+ CloseOnError();
+}
+
+} // namespace remoting

Powered by Google App Engine
This is Rietveld 408576698