Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(342)

Side by Side Diff: remoting/host/websocket_connection.cc

Issue 11358190: Add simple WebSocket server implementation. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Created 8 years, 1 month ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
OLDNEW
(Empty)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/host/websocket_connection.h"
6
7 #include <map>
8 #include <vector>
9
10 #include "base/base64.h"
11 #include "base/compiler_specific.h"
12 #include "base/location.h"
13 #include "base/sha1.h"
14 #include "base/single_thread_task_runner.h"
15 #include "base/string_split.h"
16 #include "base/sys_byteorder.h"
17 #include "base/thread_task_runner_handle.h"
18 #include "net/base/net_errors.h"
19 #include "net/socket/stream_socket.h"
20
21 namespace remoting {
22
23 namespace {
24
25 const int kReadBufferSize = 1024;
26 const char kLineSeparator[] = "\r\n";
27 const char kHeaderEndMarker[] = "\r\n\r\n";
28 const char kHeaderKeyValueSeparator[] = ": ";
29 const int kMaskLength = 4;
30
31 // Fixed value specified in RFC6455. It's used to compute accept token sent to
32 // the client in Sec-WebSocket-Accept key.
33 const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
34
35 } // namespace
36
37 WebsocketConnection::WebsocketConnection()
38 : delegate_(NULL),
39 maximum_message_size_(0),
40 state_(READING_HEADERS),
41 receiving_message_(false),
42 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
43 }
44
45 WebsocketConnection::~WebsocketConnection() {
46 Close();
47 }
48
49 void WebsocketConnection::Start(
50 scoped_ptr<net::StreamSocket> socket,
51 ConnectedCallback connected_callback) {
52 socket_ = socket.Pass();
53 connected_callback_ = connected_callback;
54 reader_.Init(socket_.get(), base::Bind(
55 &WebsocketConnection::OnSocketReadResult, base::Unretained(this)));
56 writer_.Init(socket_.get(), base::Bind(
57 &WebsocketConnection::OnSocketWriteError, base::Unretained(this)));
58 }
59
60 void WebsocketConnection::Accept(Delegate* delegate) {
61 DCHECK_EQ(state_, HEADERS_READ);
62
63 state_ = ACCEPTED;
64 delegate_ = delegate;
65
66 std::string accept_key =
67 base::SHA1HashString(websocket_key_ + kWebsocketKeySalt);
68 std::string accept_key_base64;
69 bool result = base::Base64Encode(accept_key, &accept_key_base64);
70 DCHECK(result);
71
72 std::string handshake;
73 handshake += "HTTP/1.1 101 Switching Protocol";
74 handshake += kLineSeparator;
75 handshake += "Upgrade: websocket";
76 handshake += kLineSeparator;
77 handshake += "Connection: Upgrade";
78 handshake += kLineSeparator;
79 handshake += "Sec-WebSocket-Accept: " + accept_key_base64;
80 handshake += kHeaderEndMarker;
81
82 scoped_refptr<net::IOBufferWithSize> buffer =
83 new net::IOBufferWithSize(handshake.size());
84 memcpy(buffer->data(), handshake.data(), handshake.size());
85 writer_.Write(buffer, base::Closure());
86 }
87
88 void WebsocketConnection::Reject() {
89 DCHECK_EQ(state_, HEADERS_READ);
90
91 state_ = CLOSED;
92 std::string response = "HTTP/1.1 401 Unauthorized";
93 response += kHeaderEndMarker;
94 scoped_refptr<net::IOBufferWithSize> buffer =
95 new net::IOBufferWithSize(response.size());
96 memcpy(buffer->data(), response.data(), response.size());
97 writer_.Write(buffer, base::Closure());
98 }
99
100 void WebsocketConnection::set_maximum_message_size(uint64 size) {
101 maximum_message_size_ = size;
102 }
103
104 void WebsocketConnection::SendText(const std::string& text) {
105 SendFragment(OPCODE_TEXT_FRAME, text.data(), text.size());
106 }
107
108 void WebsocketConnection::Close() {
109 switch (state_) {
110 case READING_HEADERS:
111 break;
112
113 case HEADERS_READ:
114 Reject();
115 break;
116
117 case ACCEPTED:
118 SendFragment(OPCODE_CLOSE, NULL, 0);
119 break;
120
121 case CLOSED:
122 break;
123 }
124 state_ = CLOSED;
125 }
126
127 void WebsocketConnection::CloseOnError() {
128 State old_state_ = state_;
129 Close();
130 if (old_state_ == ACCEPTED) {
131 DCHECK(delegate_);
132 delegate_->OnWebsocketClosed();
133 }
134 }
135
136 void WebsocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data,
137 int result) {
138 if (result <= 0) {
139 if (result != 0) {
140 LOG(ERROR) << "Error when trying to read from WebSocket connection: "
141 << result;
142 }
143 CloseOnError();
144 return;
145 }
146
147 switch (state_) {
148 case READING_HEADERS: {
149 headers_.append(data->data(), data->data() + result);
150 size_t header_end_pos = headers_.find(kHeaderEndMarker);
151 if (header_end_pos != std::string::npos) {
152 bool result;
153 if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) {
154 LOG(ERROR) << "WebSocket client tried writing data before handshake "
155 "has finished.";
156 DCHECK(!connected_callback_.is_null());
157 state_ = CLOSED;
158 result = false;
159 } else {
160 // Crop newline symbols from the end.
161 headers_.resize(header_end_pos);
162
163 result = ParseHeaders();
164 if (!result) {
165 state_ = CLOSED;
166 } else {
167 state_ = HEADERS_READ;
168 }
169 }
170 ConnectedCallback cb(connected_callback_);
171 connected_callback_.Reset();
172 cb.Run(result);
173 }
174 break;
175 }
176
177 case HEADERS_READ:
178 LOG(ERROR) << "Received unexpected data before websocket "
179 "connection is accepted.";
180 CloseOnError();
181 break;
182
183 case ACCEPTED:
184 DCHECK(delegate_);
185 received_data_.append(data->data(), data->data() + result);
186 ProcessData();
187
188 case CLOSED:
189 // Ignore anything received after connection is rejected or closed.
190 break;
191 }
192 }
193
194 void WebsocketConnection::ProcessData() {
195 DCHECK_EQ(state_, ACCEPTED);
196
197 if (received_data_.size() < 2) {
198 // Header hasn't been received yet.
199 return;
200 }
201
202 bool fin_bit = (received_data_.data()[0] & 0x80) != 0;
203
204 int rsv_bits = received_data_.data()[0] & 0x70;
Wez 2012/11/20 05:44:09 nit: Add a comment summarizing what these bits are
Sergey Ulanov 2012/11/21 01:40:24 Done.
205 if (rsv_bits != 0) {
206 LOG(ERROR) << "Incoming has unsupported RSV bits set.";
207 CloseOnError();
208 return;
209 }
210
211 int opcode = received_data_.data()[0] & 0x0f;
212
213 int mask_bit = received_data_.data()[1] & 0x80;
214 if (mask_bit == 0) {
215 LOG(ERROR) << "Incoming frame is not masked.";
216 CloseOnError();
217 return;
218 }
219
220 int length_field_size = 1;
Wez 2012/11/20 05:44:09 Please add a comment summarizing this length proce
Sergey Ulanov 2012/11/21 01:40:24 Done.
221 uint64 length = received_data_.data()[1] & 0x7F;
222 if (length == 126) {
223 if (received_data_.size() < 4) {
224 // Haven't received the whole frame yet.
Wez 2012/11/20 05:44:09 nit: "Haven't received the whole frame header yet"
Sergey Ulanov 2012/11/21 01:40:24 Done.
225 return;
226 }
227 length_field_size = 3;
228 length = base::NetToHost16(
229 *reinterpret_cast<const uint16*>(received_data_.data() + 2));
230 } else if (length == 127) {
231 if (received_data_.size() < 10) {
232 // Haven't received the whole frame yet.
233 return;
234 }
235 length_field_size = 9;
236 length = base::NetToHost64(
237 *reinterpret_cast<const uint64*>(received_data_.data() + 2));
238 }
239
240 int payload_position = 1 + length_field_size + kMaskLength;
241
242 if (maximum_message_size_ > 0 && length > maximum_message_size_) {
Wez 2012/11/20 05:44:09 nit: Add a comment explaining why we need this che
Sergey Ulanov 2012/11/21 01:40:24 Done.
243 LOG(ERROR) << "Client tried to send a fragment that is bigger than "
244 "the maximum message size of " << maximum_message_size_;
245 CloseOnError();
246 return;
247 }
248
249 if (received_data_.size() < payload_position + length) {
250 // Haven't received the whole frame yet.
251 return;
252 }
253
254 if (mask_bit) {
Wez 2012/11/20 05:44:09 nit: Add a comment to the effect of "un-mask the m
Sergey Ulanov 2012/11/21 01:40:24 Done.
255 const char* mask = received_data_.data() + length_field_size + 1;
256 UnmaskPayload(
257 mask,
258 const_cast<char*>(received_data_.data()) + payload_position, length);
259 }
260
261 if (opcode < 0x8) {
262 // Non-control message.
263 current_message_.append(
264 received_data_.data() + payload_position,
265 received_data_.data() + payload_position + length);
266
267 if (maximum_message_size_ > 0 &&
268 current_message_.size() > maximum_message_size_) {
Wez 2012/11/20 05:44:09 It's too late to check this here; the total size o
Sergey Ulanov 2012/11/21 01:40:24 Done, but it doesn't really matter much.
269 LOG(ERROR) << "Client tried to send a message that is bigger than "
270 "the maximum message size of " << maximum_message_size_;
271 CloseOnError();
272 return;
273 }
274 } else {
275 // Control message.
276 if (!fin_bit) {
277 LOG(ERROR) << "Received fragmented control message.";
278 CloseOnError();
279 return;
280 }
281 if (length > 125) {
282 LOG(ERROR) << "Received control message that is larger than 125 bytes.";
283 CloseOnError();
284 return;
285 }
286 }
287
288 switch (opcode) {
289 case OPCODE_CONTINUATION:
290 if (!receiving_message_) {
291 LOG(ERROR) << "Received unexpected continuation frame.";
292 CloseOnError();
293 return;
294 }
295 break;
296
297 case OPCODE_TEXT_FRAME:
298 case OPCODE_BINARY_FRAME:
299 if (receiving_message_) {
300 LOG(ERROR) << "Received unexpected new start frame in a middle of "
301 "a message.";
302 CloseOnError();
303 return;
304 }
305 break;
306
307 case OPCODE_CLOSE:
308 Close();
309 delegate_->OnWebsocketClosed();
310 return;
311
312 case OPCODE_PING:
313 SendFragment(
314 OPCODE_PONG, received_data_.data() + payload_position, length);
315 break;
316
317 case OPCODE_PONG:
318 break;
319
320 default:
321 LOG(ERROR) << "Received invalid opcode: " << opcode;
322 CloseOnError();
323 return;
324 }
325
326 // Remove the frame from |received_data_|.
327 received_data_.erase(0, payload_position + length);
328
329 // Post a task to process the data we have left in the buffer if any.
Wez 2012/11/20 05:44:09 nit: "... left in the buffer, if any."
Sergey Ulanov 2012/11/21 01:40:24 Done.
330 if (!received_data_.empty()) {
331 base::ThreadTaskRunnerHandle::Get()->PostTask(
332 FROM_HERE, base::Bind(&WebsocketConnection::ProcessData,
333 weak_factory_.GetWeakPtr()));
334 }
335
336 // Handle payload in non-control messages. Delegate can be called only at the
337 // end of this function
338 if (opcode < 0x8) {
339 if (!fin_bit) {
340 receiving_message_ = true;
341 } else {
342 receiving_message_ = false;
343 std::string msg;
344 msg.swap(current_message_);
345 delegate_->OnWebsocketMessage(msg);
346 }
347 }
348 }
349
350 void WebsocketConnection::SendFragment(
351 WebsocketOpcode opcode,
352 const char* payload, int payload_length) {
353 DCHECK_EQ(state_, ACCEPTED);
354
355 int length_field_size = 1;
356 if (payload_length > 65535) {
357 length_field_size = 9;
358 } else if (payload_length > 125) {
359 length_field_size = 3;
360 }
361
362 scoped_refptr<net::IOBufferWithSize> buffer =
363 new net::IOBufferWithSize(1 + length_field_size + payload_length);
364
365 // Always set FIN flag because we never fragment outgoing messages.
366 buffer->data()[0] = opcode | 0x80;
367
368 if (payload_length > 65535) {
369 uint64 size = base::HostToNet64(payload_length);
370 buffer->data()[1] = 127;
371 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
372 } else if (payload_length > 125) {
373 uint16 size = base::HostToNet16(payload_length);
374 buffer->data()[1] = 126;
375 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
376 } else {
377 buffer->data()[1] = payload_length;
378 }
379 memcpy(buffer->data() + 1 + length_field_size, payload, payload_length);
380
381 writer_.Write(buffer, base::Closure());
382 }
383
384 bool WebsocketConnection::ParseHeaders() {
385 std::vector<std::string> lines;
386 base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines);
387
388 // Parse request line.
389 std::vector<std::string> request_parts;
390 base::SplitString(lines[0], ' ', &request_parts);
391 if (request_parts.size() != 3 ||
392 request_parts[0] != "GET" ||
393 request_parts[2] != "HTTP/1.1") {
394 LOG(ERROR) << "Invalid Request-Line: " << headers_[0];
395 return false;
396 }
397 request_path_ = request_parts[1];
398
399 std::map<std::string, std::string> headers;
400
401 for (size_t i = 1; i < lines.size(); ++i) {
402 std::string separator(kHeaderKeyValueSeparator);
403 size_t pos = lines[i].find(separator);
404 if (pos == std::string::npos || pos == 0) {
405 LOG(ERROR) << "Invalid header line: " << lines[i];
406 return false;
407 }
408 std::string key = lines[i].substr(0, pos);
409 if (headers.find(key) != headers.end()) {
410 LOG(ERROR) << "Duplicate header value: " << key;
411 return false;
412 }
413 headers[key] = lines[i].substr(pos + separator.size());
414 }
415
416 std::map<std::string, std::string>::iterator it = headers.find("Connection");
417 if (it == headers.end() || it->second != "Upgrade") {
418 LOG(ERROR) << "Connection header is missing or invalid.";
419 return false;
420 }
421
422 it = headers.find("Upgrade");
423 if (it == headers.end() || it->second != "websocket") {
424 LOG(ERROR) << "Upgrade header is missing or invalid.";
425 return false;
426 }
427
428 it = headers.find("Host");
429 if (it == headers.end()) {
430 LOG(ERROR) << "Host header is missing.";
431 return false;
432 }
433 request_host_ = it->second;
434
435 it = headers.find("Sec-WebSocket-Version");
436 if (it == headers.end()) {
437 LOG(ERROR) << "Sec-WebSocket-Version header is missing.";
438 return false;
439 }
440 if (it->second != "13") {
441 LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second;
442 return false;
443 }
444
445 it = headers.find("Origin");
446 if (it == headers.end()) {
447 LOG(ERROR) << "Origin header is missing.";
448 return false;
449 }
450 origin_ = it->second;
451
452 it = headers.find("Sec-WebSocket-Key");
453 if (it == headers.end()) {
454 LOG(ERROR) << "Sec-WebSocket-Key header is missing.";
455 return false;
456 }
457 websocket_key_ = it->second;
458
459 return true;
460 }
461
462 void WebsocketConnection::UnmaskPayload(const char* mask,
463 char* payload, int payload_length) {
464 for (int i = 0; i < payload_length; ++i) {
465 payload[i] = payload[i] ^ mask[i % kMaskLength];
466 }
467 }
468
469 void WebsocketConnection::OnSocketWriteError(int error) {
470 LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error;
471 CloseOnError();
472 }
473
474 } // namespace remoting
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698