Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(124)

Side by Side Diff: remoting/host/websocket_connection.cc

Issue 11358190: Add simple WebSocket server implementation. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Created 8 years, 1 month ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
OLDNEW
(Empty)
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "remoting/host/websocket_connection.h"
6
7 #include <map>
8 #include <vector>
9
10 #include "base/base64.h"
11 #include "base/compiler_specific.h"
12 #include "base/location.h"
13 #include "base/sha1.h"
14 #include "base/single_thread_task_runner.h"
15 #include "base/string_split.h"
16 #include "base/sys_byteorder.h"
17 #include "base/thread_task_runner_handle.h"
18 #include "net/base/net_errors.h"
19 #include "net/socket/stream_socket.h"
20
21 namespace remoting {
22
23 namespace {
24
25 const int kReadBufferSize = 1024;
26 const char kLineSeparator[] = "\r\n";
27 const char kHeaderEndMarker[] = "\r\n\r\n";
28 const char kHeaderKeyValueSeparator[] = ": ";
29 const int kMaskLength = 4;
30
31 // Fixed value specified in RFC6455. It's used to compute accept token sent to
32 // the client in Sec-WebSocket-Accept key.
33 const char kWebsocketKeySalt[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
34
35 } // namespace
36
37 WebsocketConnection::WebsocketConnection()
38 : delegate_(NULL),
39 state_(READING_HEADERS),
40 receiving_message_(false),
41 ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) {
42 }
43
44 WebsocketConnection::~WebsocketConnection() {
45 Close();
46 }
47
48 void WebsocketConnection::Start(
49 scoped_ptr<net::StreamSocket> socket,
50 ConnectedCallback connected_callback) {
51 socket_ = socket.Pass();
52 connected_callback_ = connected_callback;
53 reader_.Init(socket_.get(), base::Bind(
54 &WebsocketConnection::OnSocketReadResult, base::Unretained(this)));
55 writer_.Init(socket_.get(), base::Bind(
56 &WebsocketConnection::OnSocketWriteError, base::Unretained(this)));
57 }
58
59 void WebsocketConnection::Accept(Delegate* delegate) {
60 DCHECK_EQ(state_, HEADERS_READ);
61
62 state_ = ACCEPTED;
63 delegate_ = delegate;
64
65 std::string accept_key =
66 base::SHA1HashString(websocket_key_ + kWebsocketKeySalt);
67 std::string accept_key_base64;
68 bool result = base::Base64Encode(accept_key, &accept_key_base64);
69 DCHECK(result);
70
71 std::string handshake;
72 handshake += "HTTP/1.1 101 Switching Protocol";
73 handshake += kLineSeparator;
74 handshake += "Upgrade: websocket";
75 handshake += kLineSeparator;
76 handshake += "Connection: Upgrade";
77 handshake += kLineSeparator;
78 handshake += "Sec-WebSocket-Accept: " + accept_key_base64;
79 handshake += kHeaderEndMarker;
80
81 scoped_refptr<net::IOBufferWithSize> buffer =
82 new net::IOBufferWithSize(handshake.size());
83 memcpy(buffer->data(), handshake.data(), handshake.size());
84 writer_.Write(buffer, base::Closure());
85 }
86
87 void WebsocketConnection::Reject() {
88 DCHECK_EQ(state_, HEADERS_READ);
89
90 state_ = CLOSED;
91 std::string response = "HTTP/1.1 401 Unauthorized";
92 response += kHeaderEndMarker;
93 scoped_refptr<net::IOBufferWithSize> buffer =
94 new net::IOBufferWithSize(response.size());
95 memcpy(buffer->data(), response.data(), response.size());
96 writer_.Write(buffer, base::Closure());
97 }
98
99 void WebsocketConnection::SendText(const std::string& text) {
100 SendFragment(OPCODE_TEXT_FRAME, text.data(), text.size());
101 }
102
103 void WebsocketConnection::Close() {
104 switch (state_) {
105 case READING_HEADERS:
106 break;
107
108 case HEADERS_READ:
109 Reject();
110 break;
111
112 case ACCEPTED:
113 SendFragment(OPCODE_CLOSE, NULL, 0);
114 break;
115
116 case CLOSED:
117 break;
118 }
119 state_ = CLOSED;
120 }
121
122 void WebsocketConnection::CloseOnError() {
123 State old_state_ = state_;
124 Close();
125 if (old_state_ == ACCEPTED) {
126 DCHECK(delegate_);
127 delegate_->OnWebsocketClosed();
128 }
129 }
130
131 void WebsocketConnection::OnSocketReadResult(scoped_refptr<net::IOBuffer> data,
132 int result) {
133 if (result <= 0) {
134 if (result != 0) {
135 LOG(ERROR) << "Error when trying to read from WebSocket connection: "
136 << result;
137 }
138 CloseOnError();
139 return;
140 }
141
142 switch (state_) {
143 case READING_HEADERS: {
144 headers_.append(data->data(), data->data() + result);
145 size_t header_end_pos = headers_.find(kHeaderEndMarker);
146 if (header_end_pos != std::string::npos) {
147 bool result;
148 if (header_end_pos != headers_.size() - strlen(kHeaderEndMarker)) {
149 LOG(ERROR) << "WebSocket client tried writing data before handshake "
150 "has finished.";
151 DCHECK(!connected_callback_.is_null());
152 state_ = CLOSED;
153 result = false;
154 } else {
155 // Crop newline symbols from the end.
156 headers_.resize(header_end_pos);
157
158 result = ParseHeaders();
159 if (!result) {
160 state_ = CLOSED;
161 } else {
162 state_ = HEADERS_READ;
163 }
164 }
165 ConnectedCallback cb(connected_callback_);
166 connected_callback_.Reset();
167 cb.Run(result);
168 }
169 break;
170 }
171
172 case HEADERS_READ:
173 LOG(ERROR) << "Received unexpected data before websocket "
174 "connection is accepted.";
175 CloseOnError();
176 break;
177
178 case ACCEPTED:
179 DCHECK(delegate_);
180 current_packet_.append(data->data(), data->data() + result);
181 ProcessData();
182
183 case CLOSED:
184 // Ignore anything received after connection is rejected or closed.
185 break;
186 }
187 }
188
189 void WebsocketConnection::ProcessData() {
190 DCHECK_EQ(state_, ACCEPTED);
191
192 if (current_packet_.size() < 2) {
193 // Header hasn't been received yet.
194 return;
195 }
196
197 bool fin_bit = (current_packet_.data()[0] & 0x80) != 0;
198
199 int rsv_bits = current_packet_.data()[0] & 0x70;
200 if (rsv_bits != 0) {
201 LOG(ERROR) << "Incoming has unsupported RSV bits set.";
202 CloseOnError();
203 return;
204 }
205
206 int opcode = current_packet_.data()[0] & 0x0f;
207
208 int mask_bit = current_packet_.data()[1] & 0x80;
209 if (mask_bit == 0) {
210 LOG(ERROR) << "Incoming frame is not masked.";
211 CloseOnError();
212 return;
213 }
214
215 int length_field_size = 1;
216 uint64 length = current_packet_.data()[1] & 0x7F;
217 if (length == 126) {
218 if (current_packet_.size() < 4) {
219 // Haven't received the whole frame yet.
220 return;
221 }
222 length_field_size = 3;
223 length = base::NetToHost16(
224 *reinterpret_cast<const uint16*>(current_packet_.data() + 2));
225 } else if (length == 127) {
226 if (current_packet_.size() < 10) {
227 // Haven't received the whole frame yet.
228 return;
229 }
230 length_field_size = 9;
231 length = base::NetToHost64(
232 *reinterpret_cast<const uint64*>(current_packet_.data() + 2));
233 }
234
235 int payload_position = 1 + length_field_size + kMaskLength;
236
237 if (current_packet_.size() < payload_position + length) {
238 // Haven't received the whole frame yet.
239 return;
240 }
241
242 if (mask_bit) {
243 const char* mask = current_packet_.data() + length_field_size + 1;
244 UnmaskPayload(
245 mask,
246 const_cast<char*>(current_packet_.data()) + payload_position, length);
247 }
248
249 if (opcode < 0x8) {
250 // Non-control message.
251 current_message_.append(
252 current_packet_.data() + payload_position,
253 current_packet_.data() + payload_position + length);
254 } else {
255 // Control message.
256 if (!fin_bit) {
257 LOG(ERROR) << "Received fragmented control message.";
258 CloseOnError();
259 return;
260 }
261 if (length > 125) {
262 LOG(ERROR) << "Received control message that is larger than 125 bytes.";
263 CloseOnError();
264 return;
265 }
266 }
267
268 switch (opcode) {
269 case OPCODE_CONTINUATION:
270 if (!receiving_message_) {
271 LOG(ERROR) << "Received unexpected continuation frame.";
272 CloseOnError();
273 return;
274 }
275 break;
276
277 case OPCODE_TEXT_FRAME:
278 case OPCODE_BINARY_FRAME:
279 if (receiving_message_) {
280 LOG(ERROR) << "Received unexpected new start frame in a middle of "
281 "a message.";
282 CloseOnError();
283 return;
284 }
285 break;
286
287 case OPCODE_CLOSE:
288 Close();
289 delegate_->OnWebsocketClosed();
290 return;
291
292 case OPCODE_PING:
293 SendFragment(
294 OPCODE_PONG, current_packet_.data() + payload_position, length);
295 break;
296
297 case OPCODE_PONG:
298 break;
299
300 default:
301 LOG(ERROR) << "Received invalid opcode: " << opcode;
302 CloseOnError();
303 return;
304 }
305
306 // Remove the frame from |current_packet_|.
307 current_packet_.erase(0, payload_position + length);
308
309 // Post a task to process the data we have left in the buffer if any.
310 if (!current_packet_.empty()) {
311 base::ThreadTaskRunnerHandle::Get()->PostTask(
312 FROM_HERE, base::Bind(&WebsocketConnection::ProcessData,
313 weak_factory_.GetWeakPtr()));
314 }
315
316 // Handle payload in non-control messages. Delegate can be called only at the
317 // end of this function
318 if (opcode < 0x8) {
319 if (!fin_bit) {
320 receiving_message_ = true;
321 } else {
322 receiving_message_ = false;
323 std::string msg;
324 msg.swap(current_message_);
325 delegate_->OnWebsocketMessage(msg);
326 }
327 }
328 }
329
330 void WebsocketConnection::SendFragment(
331 WebsocketOpcode opcode,
332 const char* payload, int payload_length) {
333 DCHECK_EQ(state_, ACCEPTED);
334
335 int length_field_size = 1;
336 if (payload_length > 65535) {
337 length_field_size = 9;
338 } else if (payload_length > 125) {
339 length_field_size = 3;
340 }
341
342 scoped_refptr<net::IOBufferWithSize> buffer =
343 new net::IOBufferWithSize(1 + length_field_size + payload_length);
344
345 // Always set FIN flag because we never fragment outgoing messages.
346 buffer->data()[0] = opcode | 0x80;
347
348 if (payload_length > 65535) {
349 uint64 size = base::HostToNet64(payload_length);
350 buffer->data()[1] = 127;
351 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
352 } else if (payload_length > 125) {
353 uint16 size = base::HostToNet16(payload_length);
354 buffer->data()[1] = 126;
355 memcpy(buffer->data() + 2, reinterpret_cast<char*>(&size), sizeof(size));
356 } else {
357 buffer->data()[1] = payload_length;
358 }
359 memcpy(buffer->data() + 1 + length_field_size, payload, payload_length);
360
361 writer_.Write(buffer, base::Closure());
362 }
363
364 bool WebsocketConnection::ParseHeaders() {
365 std::vector<std::string> lines;
366 base::SplitStringUsingSubstr(headers_, kLineSeparator, &lines);
367
368 // Parse request line.
369 std::vector<std::string> request_parts;
370 base::SplitString(lines[0], ' ', &request_parts);
371 if (request_parts.size() != 3 ||
372 request_parts[0] != "GET" ||
373 request_parts[2] != "HTTP/1.1") {
374 LOG(ERROR) << "Invalid Request-Line: " << headers_[0];
375 return false;
376 }
377 request_path_ = request_parts[1];
378
379 std::map<std::string, std::string> headers;
380
381 for (size_t i = 1; i < lines.size(); ++i) {
382 std::string separator(kHeaderKeyValueSeparator);
383 size_t pos = lines[i].find(separator);
384 if (pos == std::string::npos || pos == 0) {
385 LOG(ERROR) << "Invalid header line: " << lines[i];
386 return false;
387 }
388 std::string key = lines[i].substr(0, pos);
389 if (headers.find(key) != headers.end()) {
390 LOG(ERROR) << "Duplicate header value: " << key;
391 return false;
392 }
393 headers[key] = lines[i].substr(pos + separator.size());
394 }
395
396 std::map<std::string, std::string>::iterator it = headers.find("Connection");
397 if (it == headers.end() || it->second != "Upgrade") {
398 LOG(ERROR) << "Connection header is missing or invalid.";
399 return false;
400 }
401
402 it = headers.find("Upgrade");
403 if (it == headers.end() || it->second != "websocket") {
404 LOG(ERROR) << "Upgrade header is missing or invalid.";
405 return false;
406 }
407
408 it = headers.find("Host");
409 if (it == headers.end()) {
410 LOG(ERROR) << "Host header is missing.";
411 return false;
412 }
413 request_host_ = it->second;
414
415 it = headers.find("Sec-WebSocket-Version");
416 if (it == headers.end()) {
417 LOG(ERROR) << "Sec-WebSocket-Version header is missing.";
418 return false;
419 }
420 if (it->second != "13") {
421 LOG(ERROR) << "Unsupported WebSocket protocol version: " << it->second;
422 return false;
423 }
424
425 it = headers.find("Origin");
426 if (it == headers.end()) {
427 LOG(ERROR) << "Origin header is missing.";
428 return false;
429 }
430 origin_ = it->second;
431
432 it = headers.find("Sec-WebSocket-Key");
433 if (it == headers.end()) {
434 LOG(ERROR) << "Sec-WebSocket-Key header is missing.";
435 return false;
436 }
437 websocket_key_ = it->second;
438
439 return true;
440 }
441
442 void WebsocketConnection::UnmaskPayload(const char* mask,
443 char* payload, int payload_length) {
444 for (int i = 0; i < payload_length; ++i) {
445 payload[i] = payload[i] ^ mask[i % kMaskLength];
446 }
447 }
448
449 void WebsocketConnection::OnSocketWriteError(int error) {
450 LOG(ERROR) << "Failed to write to a WebSocket. Error: " << error;
451 CloseOnError();
452 }
453
454 } // namespace remoting
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698