Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(705)

Side by Side Diff: remoting/host/websocket_connection.cc

Issue 11358190: Add simple WebSocket server implementation. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Created 8 years, 1 month ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
« no previous file with comments | « remoting/host/websocket_connection.h ('k') | remoting/host/websocket_connection_unittest.cc » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
(Empty)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/host/websocket_connection.h"
6
7 #include <map>
8 #include <vector>
9
10 #include "base/base64.h"
11 #include "base/compiler_specific.h"
12 #include "base/location.h"
13 #include "base/sha1.h"
14 #include "base/single_thread_task_runner.h"
15 #include "base/string_split.h"
16 #include "base/sys_byteorder.h"
17 #include "base/thread_task_runner_handle.h"
18 #include "net/base/net_errors.h"
19 #include "net/socket/stream_socket.h"
20
21 namespace remoting {
22
23 namespace {
24
25 const int kReadBufferSize = 1024;
26 const char kLineSeparator[] = "\r\n";
27 const char kHeaderEndMarker[] = "\r\n\r\n";
28 const char kHeaderKeyValueSeparator[] = ": ";
29 const int kMaskLength = 4;
30
31 // Maximum frame length that can be encoded without extended length filed.
32 const uint32 kMaxNotExtendedLength = 125;
33
34 // Maximum frame length that can be encoded in 16 bits.
35 const uint32 kMax16BitLength = 65535;
36
37 // Special values of the length field used to extend frame length to 16 or 64
38 // bits.
39 const uint32 kLength16BitMarker = 126;
40 const uint32 kLength64BitMarker = 127;
41
42 // Fixed value specified in RFC6455. It's used to compute accept token sent to
43 // the client in Sec-WebSocket-Accept key.
44 const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
45
46 } // namespace
47
48 WebSocketConnection::WebSocketConnection()
49 : delegate_(NULL),
50 maximum_message_size_(0),
51 state_(READING_HEADERS),
52 receiving_message_(false),
53 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
54 }
55
56 WebSocketConnection::~WebSocketConnection() {
57 Close();
58 }
59
60 void WebSocketConnection::Start(
61 scoped_ptr<net::StreamSocket> socket,
62 ConnectedCallback connected_callback) {
63 socket_ = socket.Pass();
64 connected_callback_ = connected_callback;
65 reader_.Init(socket_.get(), base::Bind(
66 &WebSocketConnection::OnSocketReadResult, base::Unretained(this)));
67 writer_.Init(socket_.get(), base::Bind(
68 &WebSocketConnection::OnSocketWriteError, base::Unretained(this)));
69 }
70
71 void WebSocketConnection::Accept(Delegate* delegate) {
72 DCHECK_EQ(state_, HEADERS_READ);
73
74 state_ = ACCEPTED;
75 delegate_ = delegate;
76
77 std::string accept_key =
78 base::SHA1HashString(websocket_key_ + kWebsocketKeySalt);
79 std::string accept_key_base64;
80 bool result = base::Base64Encode(accept_key, &accept_key_base64);
81 DCHECK(result);
82
83 std::string handshake;
84 handshake += "HTTP/1.1 101 Switching Protocol";
85 handshake += kLineSeparator;
86 handshake += "Upgrade: websocket";
87 handshake += kLineSeparator;
88 handshake += "Connection: Upgrade";
89 handshake += kLineSeparator;
90 handshake += "Sec-WebSocket-Accept: " + accept_key_base64;
91 handshake += kHeaderEndMarker;
92
93 scoped_refptr<net::IOBufferWithSize> buffer =
94 new net::IOBufferWithSize(handshake.size());
95 memcpy(buffer->data(), handshake.data(), handshake.size());
96 writer_.Write(buffer, base::Closure());
97 }
98
99 void WebSocketConnection::Reject() {
100 DCHECK_EQ(state_, HEADERS_READ);
101
102 state_ = CLOSED;
103 std::string response = "HTTP/1.1 401 Unauthorized";
104 response += kHeaderEndMarker;
105 scoped_refptr<net::IOBufferWithSize> buffer =
106 new net::IOBufferWithSize(response.size());
107 memcpy(buffer->data(), response.data(), response.size());
108 writer_.Write(buffer, base::Closure());
109 }
110
111 void WebSocketConnection::set_maximum_message_size(uint64 size) {
112 maximum_message_size_ = size;
113 }
114
115 void WebSocketConnection::SendText(const std::string& text) {
116 SendFragment(OPCODE_TEXT_FRAME, text);
117 }
118
119 void WebSocketConnection::Close() {
120 switch (state_) {
121 case READING_HEADERS:
122 break;
123
124 case HEADERS_READ:
125 Reject();
126 break;
127
128 case ACCEPTED:
129 SendFragment(OPCODE_CLOSE, std::string());
130 break;
131
132 case CLOSED:
133 break;
134 }
135 state_ = CLOSED;
136 }
137
138 void WebSocketConnection::CloseOnError() {
139 State old_state_ = state_;
140 Close();
141 if (old_state_ == ACCEPTED) {
142 DCHECK(delegate_);
143 delegate_->OnWebSocketClosed();
144 }
145 }
146
147 void WebSocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data,
148 int result) {
149 if (result <= 0) {
150 if (result != 0) {
151 LOG(ERROR) << "Error when trying to read from WebSocket connection: "
152 << result;
153 }
154 CloseOnError();
155 return;
156 }
157
158 switch (state_) {
159 case READING_HEADERS: {
160 headers_.append(data->data(), data->data() + result);
161 size_t header_end_pos = headers_.find(kHeaderEndMarker);
162 if (header_end_pos != std::string::npos) {
163 bool result;
164 if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) {
165 LOG(ERROR) << "WebSocket client tried writing data before handshake "
166 "has finished.";
167 DCHECK(!connected_callback_.is_null());
168 state_ = CLOSED;
169 result = false;
170 } else {
171 // Crop newline symbols from the end.
172 headers_.resize(header_end_pos);
173
174 result = ParseHeaders();
175 if (!result) {
176 state_ = CLOSED;
177 } else {
178 state_ = HEADERS_READ;
179 }
180 }
181 ConnectedCallback cb(connected_callback_);
182 connected_callback_.Reset();
183 cb.Run(result);
184 }
185 break;
186 }
187
188 case HEADERS_READ:
189 LOG(ERROR) << "Received unexpected data before websocket "
190 "connection is accepted.";
191 CloseOnError();
192 break;
193
194 case ACCEPTED:
195 DCHECK(delegate_);
196 received_data_.append(data->data(), data->data() + result);
197 ProcessData();
198
199 case CLOSED:
200 // Ignore anything received after connection is rejected or closed.
201 break;
202 }
203 }
204
205 void WebSocketConnection::ProcessData() {
206 DCHECK_EQ(state_, ACCEPTED);
207
208 if (received_data_.size() < 2) {
209 // Header hasn't been received yet.
210 return;
211 }
212
213 bool fin_bit = (received_data_.data()[0] & 0x80) != 0;
214
215 // 3 bits after FIN are reserved for WebSocket extensions. RFC6455 requires
216 // that endpoint fails connection if any of these bits is set while no
217 // extension that uses these bits was negotiated.
218 int rsv_bits = received_data_.data()[0] & 0x70;
219 if (rsv_bits != 0) {
220 LOG(ERROR) << "Incoming has unsupported RSV bits set.";
221 CloseOnError();
222 return;
223 }
224
225 int opcode = received_data_.data()[0] & 0x0f;
226
227 int mask_bit = received_data_.data()[1] & 0x80;
228 if (mask_bit == 0) {
229 LOG(ERROR) << "Incoming frame is not masked.";
230 CloseOnError();
231 return;
232 }
233
234 // Length field has variable size in each WebSocket frame - it's either 1, 3
235 // or 9 bytes with the first bit always reserved for MASK flag. The first byte
236 // is set to 126 or 127 for 16 and 64 bit extensions respectively. Code below
237 // extracts |length| value and sets |length_field_size| accordingly.
238 int length_field_size = 1;
239 uint64 length = received_data_.data()[1] & 0x7F;
240 if (length == kLength16BitMarker) {
241 if (received_data_.size() < 4) {
242 // Haven't received the whole frame header yet.
243 return;
244 }
245 length_field_size = 3;
246 length = base::NetToHost16(
247 *reinterpret_cast<const uint16*>(received_data_.data() + 2));
248 } else if (length == kLength64BitMarker) {
249 if (received_data_.size() < 10) {
250 // Haven't received the whole frame header yet.
251 return;
252 }
253 length_field_size = 9;
254 length = base::NetToHost64(
255 *reinterpret_cast<const uint64*>(received_data_.data() + 2));
256 }
257
258 int payload_position = 1 + length_field_size + kMaskLength;
259
260 // Check that the size of the frame is below the limit. It needs to be done
261 // before we read the payload to avoid allocating buffer for a bogus frame
262 // that is too big.
263 if (maximum_message_size_ > 0 && length > maximum_message_size_) {
264 LOG(ERROR) << "Client tried to send a fragment that is bigger than "
265 "the maximum message size of " << maximum_message_size_;
266 CloseOnError();
267 return;
268 }
269
270 if (received_data_.size() < payload_position + length) {
271 // Haven't received the whole frame yet.
272 return;
273 }
274
275 // Unmask the payload.
276 if (mask_bit) {
277 const char* mask = received_data_.data() + length_field_size + 1;
278 UnmaskPayload(
279 mask,
280 const_cast<char*>(received_data_.data()) + payload_position, length);
281 }
282
283 const char* payload = received_data_.data() + payload_position;
284
285 if (opcode < 0x8) {
286 if (maximum_message_size_ > 0 &&
287 current_message_.size() + length > maximum_message_size_) {
288 LOG(ERROR) << "Client tried to send a message that is bigger than "
289 "the maximum message size of " << maximum_message_size_;
290 CloseOnError();
291 return;
292 }
293
294 // Non-control message.
295 current_message_.append(payload, payload + length);
296 } else {
297 // Control message.
298 if (!fin_bit) {
299 LOG(ERROR) << "Received fragmented control message.";
300 CloseOnError();
301 return;
302 }
303 if (length > kMaxNotExtendedLength) {
304 LOG(ERROR) << "Received control message that is larger than 125 bytes.";
305 CloseOnError();
306 return;
307 }
308 }
309
310 switch (opcode) {
311 case OPCODE_CONTINUATION:
312 if (!receiving_message_) {
313 LOG(ERROR) << "Received unexpected continuation frame.";
314 CloseOnError();
315 return;
316 }
317 break;
318
319 case OPCODE_TEXT_FRAME:
320 case OPCODE_BINARY_FRAME:
321 if (receiving_message_) {
322 LOG(ERROR) << "Received unexpected new start frame in a middle of "
323 "a message.";
324 CloseOnError();
325 return;
326 }
327 break;
328
329 case OPCODE_CLOSE:
330 Close();
331 delegate_->OnWebSocketClosed();
332 return;
333
334 case OPCODE_PING:
335 SendFragment(OPCODE_PONG, std::string(payload, payload + length));
336 break;
337
338 case OPCODE_PONG:
339 break;
340
341 default:
342 LOG(ERROR) << "Received invalid opcode: " << opcode;
343 CloseOnError();
344 return;
345 }
346
347 // Remove the frame from |received_data_|.
348 received_data_.erase(0, payload_position + length);
349
350 // Post a task to process the data left in the buffer, if any.
351 if (!received_data_.empty()) {
352 base::ThreadTaskRunnerHandle::Get()->PostTask(
353 FROM_HERE, base::Bind(&WebSocketConnection::ProcessData,
354 weak_factory_.GetWeakPtr()));
355 }
356
357 // Handle payload in non-control messages. Delegate can be called only at the
358 // end of this function
359 if (opcode < 0x8) {
360 if (!fin_bit) {
361 receiving_message_ = true;
362 } else {
363 receiving_message_ = false;
364 std::string msg;
365 msg.swap(current_message_);
366 delegate_->OnWebSocketMessage(msg);
367 }
368 }
369 }
370
371 void WebSocketConnection::SendFragment(WebsocketOpcode opcode,
372 const std::string& payload) {
373 DCHECK_EQ(state_, ACCEPTED);
374
375 int length_field_size = 1;
376 if (payload.size() > kMax16BitLength) {
377 length_field_size = 9;
378 } else if (payload.size() > kMaxNotExtendedLength) {
379 length_field_size = 3;
380 }
381
382 scoped_refptr<net::IOBufferWithSize> buffer =
383 new net::IOBufferWithSize(1 + length_field_size + payload.size());
384
385 // Always set FIN flag because we never fragment outgoing messages.
386 buffer->data()[0] = opcode | 0x80;
387
388 if (payload.size() > kMax16BitLength) {
389 uint64 size = base::HostToNet64(payload.size());
390 buffer->data()[1] = kLength64BitMarker;
391 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
392 } else if (payload.size() > kMaxNotExtendedLength) {
393 uint16 size = base::HostToNet16(payload.size());
394 buffer->data()[1] = kLength16BitMarker;
395 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
396 } else {
397 buffer->data()[1] = payload.size();
398 }
399 memcpy(buffer->data() + 1 + length_field_size,
400 payload.data(), payload.size());
401
402 writer_.Write(buffer, base::Closure());
403 }
404
405 bool WebSocketConnection::ParseHeaders() {
406 std::vector<std::string> lines;
407 base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines);
408
409 // Parse request line.
410 std::vector<std::string> request_parts;
411 base::SplitString(lines[0], ' ', &request_parts);
412 if (request_parts.size() != 3 ||
413 request_parts[0] != "GET" ||
414 request_parts[2] != "HTTP/1.1") {
415 LOG(ERROR) << "Invalid Request-Line: " << headers_[0];
416 return false;
417 }
418 request_path_ = request_parts[1];
419
420 std::map<std::string, std::string> headers;
421
422 for (size_t i = 1; i < lines.size(); ++i) {
423 std::string separator(kHeaderKeyValueSeparator);
424 size_t pos = lines[i].find(separator);
425 if (pos == std::string::npos || pos == 0) {
426 LOG(ERROR) << "Invalid header line: " << lines[i];
427 return false;
428 }
429 std::string key = lines[i].substr(0, pos);
430 if (headers.find(key) != headers.end()) {
431 LOG(ERROR) << "Duplicate header value: " << key;
432 return false;
433 }
434 headers[key] = lines[i].substr(pos + separator.size());
435 }
436
437 std::map<std::string, std::string>::iterator it = headers.find("Connection");
438 if (it == headers.end() || it->second != "Upgrade") {
439 LOG(ERROR) << "Connection header is missing or invalid.";
440 return false;
441 }
442
443 it = headers.find("Upgrade");
444 if (it == headers.end() || it->second != "websocket") {
445 LOG(ERROR) << "Upgrade header is missing or invalid.";
446 return false;
447 }
448
449 it = headers.find("Host");
450 if (it == headers.end()) {
451 LOG(ERROR) << "Host header is missing.";
452 return false;
453 }
454 request_host_ = it->second;
455
456 it = headers.find("Sec-WebSocket-Version");
457 if (it == headers.end()) {
458 LOG(ERROR) << "Sec-WebSocket-Version header is missing.";
459 return false;
460 }
461 if (it->second != "13") {
462 LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second;
463 return false;
464 }
465
466 it = headers.find("Origin");
467 if (it == headers.end()) {
468 LOG(ERROR) << "Origin header is missing.";
469 return false;
470 }
471 origin_ = it->second;
472
473 it = headers.find("Sec-WebSocket-Key");
474 if (it == headers.end()) {
475 LOG(ERROR) << "Sec-WebSocket-Key header is missing.";
476 return false;
477 }
478 websocket_key_ = it->second;
479
480 return true;
481 }
482
483 void WebSocketConnection::UnmaskPayload(const char* mask,
484 char* payload, int payload_length) {
485 for (int i = 0; i < payload_length; ++i) {
486 payload[i] = payload[i] ^ mask[i % kMaskLength];
487 }
488 }
489
490 void WebSocketConnection::OnSocketWriteError(int error) {
491 LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error;
492 CloseOnError();
493 }
494
495 } // namespace remoting
OLDNEW
« no previous file with comments | « remoting/host/websocket_connection.h ('k') | remoting/host/websocket_connection_unittest.cc » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698