OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/server/web_socket.h" | 5 #include "net/server/web_socket.h" |
6 | 6 |
7 #include <limits> | 7 #include <limits> |
8 | 8 |
9 #include "base/base64.h" | 9 #include "base/base64.h" |
10 #include "base/rand_util.h" | 10 #include "base/rand_util.h" |
11 #include "base/logging.h" | 11 #include "base/logging.h" |
12 #include "base/md5.h" | 12 #include "base/md5.h" |
13 #include "base/sha1.h" | 13 #include "base/sha1.h" |
14 #include "base/strings/string_number_conversions.h" | 14 #include "base/strings/string_number_conversions.h" |
15 #include "base/strings/stringprintf.h" | 15 #include "base/strings/stringprintf.h" |
16 #include "base/sys_byteorder.h" | 16 #include "base/sys_byteorder.h" |
17 #include "net/server/http_connection.h" | 17 #include "net/server/http_connection.h" |
| 18 #include "net/server/http_server.h" |
18 #include "net/server/http_server_request_info.h" | 19 #include "net/server/http_server_request_info.h" |
19 #include "net/server/http_server_response_info.h" | 20 #include "net/server/http_server_response_info.h" |
20 | 21 |
21 namespace net { | 22 namespace net { |
22 | 23 |
23 namespace { | 24 namespace { |
24 | 25 |
25 static uint32 WebSocketKeyFingerprint(const std::string& str) { | 26 static uint32 WebSocketKeyFingerprint(const std::string& str) { |
26 std::string result; | 27 std::string result; |
27 const char* p_char = str.c_str(); | 28 const char* p_char = str.c_str(); |
28 int length = str.length(); | 29 int length = str.length(); |
29 int spaces = 0; | 30 int spaces = 0; |
30 for (int i = 0; i < length; ++i) { | 31 for (int i = 0; i < length; ++i) { |
31 if (p_char[i] >= '0' && p_char[i] <= '9') | 32 if (p_char[i] >= '0' && p_char[i] <= '9') |
32 result.append(&p_char[i], 1); | 33 result.append(&p_char[i], 1); |
33 else if (p_char[i] == ' ') | 34 else if (p_char[i] == ' ') |
34 spaces++; | 35 spaces++; |
35 } | 36 } |
36 if (spaces == 0) | 37 if (spaces == 0) |
37 return 0; | 38 return 0; |
38 int64 number = 0; | 39 int64 number = 0; |
39 if (!base::StringToInt64(result, &number)) | 40 if (!base::StringToInt64(result, &number)) |
40 return 0; | 41 return 0; |
41 return base::HostToNet32(static_cast<uint32>(number / spaces)); | 42 return base::HostToNet32(static_cast<uint32>(number / spaces)); |
42 } | 43 } |
43 | 44 |
44 class WebSocketHixie76 : public net::WebSocket { | 45 class WebSocketHixie76 : public net::WebSocket { |
45 public: | 46 public: |
46 static net::WebSocket* Create(HttpConnection* connection, | 47 static net::WebSocket* Create(HttpServer* server, |
| 48 HttpConnection* connection, |
47 const HttpServerRequestInfo& request, | 49 const HttpServerRequestInfo& request, |
48 size_t* pos) { | 50 size_t* pos) { |
49 if (connection->recv_data().length() < *pos + kWebSocketHandshakeBodyLen) | 51 if (connection->read_buf()->GetSize() |
| 52 < static_cast<int>(*pos + kWebSocketHandshakeBodyLen)) |
50 return NULL; | 53 return NULL; |
51 return new WebSocketHixie76(connection, request, pos); | 54 return new WebSocketHixie76(server, connection, request, pos); |
52 } | 55 } |
53 | 56 |
54 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { | 57 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { |
55 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); | 58 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); |
56 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); | 59 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); |
57 | 60 |
58 uint32 fp1 = WebSocketKeyFingerprint(key1); | 61 uint32 fp1 = WebSocketKeyFingerprint(key1); |
59 uint32 fp2 = WebSocketKeyFingerprint(key2); | 62 uint32 fp2 = WebSocketKeyFingerprint(key2); |
60 | 63 |
61 char data[16]; | 64 char data[16]; |
62 memcpy(data, &fp1, 4); | 65 memcpy(data, &fp1, 4); |
63 memcpy(data + 4, &fp2, 4); | 66 memcpy(data + 4, &fp2, 4); |
64 memcpy(data + 8, &key3_[0], 8); | 67 memcpy(data + 8, &key3_[0], 8); |
65 | 68 |
66 base::MD5Digest digest; | 69 base::MD5Digest digest; |
67 base::MD5Sum(data, 16, &digest); | 70 base::MD5Sum(data, 16, &digest); |
68 | 71 |
69 std::string origin = request.GetHeaderValue("origin"); | 72 std::string origin = request.GetHeaderValue("origin"); |
70 std::string host = request.GetHeaderValue("host"); | 73 std::string host = request.GetHeaderValue("host"); |
71 std::string location = "ws://" + host + request.path; | 74 std::string location = "ws://" + host + request.path; |
72 connection_->Send(base::StringPrintf( | 75 server_->SendRaw( |
73 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" | 76 connection_->id(), |
74 "Upgrade: WebSocket\r\n" | 77 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
75 "Connection: Upgrade\r\n" | 78 "Upgrade: WebSocket\r\n" |
76 "Sec-WebSocket-Origin: %s\r\n" | 79 "Connection: Upgrade\r\n" |
77 "Sec-WebSocket-Location: %s\r\n" | 80 "Sec-WebSocket-Origin: %s\r\n" |
78 "\r\n", | 81 "Sec-WebSocket-Location: %s\r\n" |
79 origin.c_str(), | 82 "\r\n", |
80 location.c_str())); | 83 origin.c_str(), |
81 connection_->Send(reinterpret_cast<char*>(digest.a), 16); | 84 location.c_str())); |
| 85 server_->SendRaw(connection_->id(), |
| 86 std::string(reinterpret_cast<char*>(digest.a), 16)); |
82 } | 87 } |
83 | 88 |
84 virtual ParseResult Read(std::string* message) OVERRIDE { | 89 virtual ParseResult Read(std::string* message) OVERRIDE { |
85 DCHECK(message); | 90 DCHECK(message); |
86 const std::string& data = connection_->recv_data(); | 91 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
87 if (data[0]) | 92 if (read_buf->StartOfBuffer()[0]) |
88 return FRAME_ERROR; | 93 return FRAME_ERROR; |
89 | 94 |
| 95 base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); |
90 size_t pos = data.find('\377', 1); | 96 size_t pos = data.find('\377', 1); |
91 if (pos == std::string::npos) | 97 if (pos == base::StringPiece::npos) |
92 return FRAME_INCOMPLETE; | 98 return FRAME_INCOMPLETE; |
93 | 99 |
94 std::string buffer(data.begin() + 1, data.begin() + pos); | 100 message->assign(data.data() + 1, pos - 1); |
95 message->swap(buffer); | 101 read_buf->DidConsume(pos + 1); |
96 connection_->Shift(pos + 1); | |
97 | 102 |
98 return FRAME_OK; | 103 return FRAME_OK; |
99 } | 104 } |
100 | 105 |
101 virtual void Send(const std::string& message) OVERRIDE { | 106 virtual void Send(const std::string& message) OVERRIDE { |
102 char message_start = 0; | 107 char message_start = 0; |
103 char message_end = -1; | 108 char message_end = -1; |
104 connection_->Send(&message_start, 1); | 109 server_->SendRaw(connection_->id(), std::string(1, message_start)); |
105 connection_->Send(message); | 110 server_->SendRaw(connection_->id(), message); |
106 connection_->Send(&message_end, 1); | 111 server_->SendRaw(connection_->id(), std::string(1, message_end)); |
107 } | 112 } |
108 | 113 |
109 private: | 114 private: |
110 static const int kWebSocketHandshakeBodyLen; | 115 static const int kWebSocketHandshakeBodyLen; |
111 | 116 |
112 WebSocketHixie76(HttpConnection* connection, | 117 WebSocketHixie76(HttpServer* server, |
| 118 HttpConnection* connection, |
113 const HttpServerRequestInfo& request, | 119 const HttpServerRequestInfo& request, |
114 size_t* pos) : WebSocket(connection) { | 120 size_t* pos) |
| 121 : WebSocket(server, connection) { |
115 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); | 122 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); |
116 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); | 123 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); |
117 | 124 |
118 if (key1.empty()) { | 125 if (key1.empty()) { |
119 connection->Send(HttpServerResponseInfo::CreateFor500( | 126 server->SendResponse( |
120 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " | 127 connection->id(), |
121 "specified.")); | 128 HttpServerResponseInfo::CreateFor500( |
| 129 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " |
| 130 "specified.")); |
122 return; | 131 return; |
123 } | 132 } |
124 | 133 |
125 if (key2.empty()) { | 134 if (key2.empty()) { |
126 connection->Send(HttpServerResponseInfo::CreateFor500( | 135 server->SendResponse( |
127 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " | 136 connection->id(), |
128 "specified.")); | 137 HttpServerResponseInfo::CreateFor500( |
| 138 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " |
| 139 "specified.")); |
129 return; | 140 return; |
130 } | 141 } |
131 | 142 |
132 key3_ = connection->recv_data().substr( | 143 key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, |
133 *pos, | 144 kWebSocketHandshakeBodyLen); |
134 *pos + kWebSocketHandshakeBodyLen); | |
135 *pos += kWebSocketHandshakeBodyLen; | 145 *pos += kWebSocketHandshakeBodyLen; |
136 } | 146 } |
137 | 147 |
138 std::string key3_; | 148 std::string key3_; |
139 | 149 |
140 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); | 150 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); |
141 }; | 151 }; |
142 | 152 |
143 const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; | 153 const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; |
144 | 154 |
(...skipping 17 matching lines...) Expand all Loading... |
162 const unsigned char kMaskBit = 0x80; | 172 const unsigned char kMaskBit = 0x80; |
163 const unsigned char kPayloadLengthMask = 0x7F; | 173 const unsigned char kPayloadLengthMask = 0x7F; |
164 | 174 |
165 const size_t kMaxSingleBytePayloadLength = 125; | 175 const size_t kMaxSingleBytePayloadLength = 125; |
166 const size_t kTwoBytePayloadLengthField = 126; | 176 const size_t kTwoBytePayloadLengthField = 126; |
167 const size_t kEightBytePayloadLengthField = 127; | 177 const size_t kEightBytePayloadLengthField = 127; |
168 const size_t kMaskingKeyWidthInBytes = 4; | 178 const size_t kMaskingKeyWidthInBytes = 4; |
169 | 179 |
170 class WebSocketHybi17 : public WebSocket { | 180 class WebSocketHybi17 : public WebSocket { |
171 public: | 181 public: |
172 static WebSocket* Create(HttpConnection* connection, | 182 static WebSocket* Create(HttpServer* server, |
| 183 HttpConnection* connection, |
173 const HttpServerRequestInfo& request, | 184 const HttpServerRequestInfo& request, |
174 size_t* pos) { | 185 size_t* pos) { |
175 std::string version = request.GetHeaderValue("sec-websocket-version"); | 186 std::string version = request.GetHeaderValue("sec-websocket-version"); |
176 if (version != "8" && version != "13") | 187 if (version != "8" && version != "13") |
177 return NULL; | 188 return NULL; |
178 | 189 |
179 std::string key = request.GetHeaderValue("sec-websocket-key"); | 190 std::string key = request.GetHeaderValue("sec-websocket-key"); |
180 if (key.empty()) { | 191 if (key.empty()) { |
181 connection->Send(HttpServerResponseInfo::CreateFor500( | 192 server->SendResponse( |
182 "Invalid request format. Sec-WebSocket-Key is empty or isn't " | 193 connection->id(), |
183 "specified.")); | 194 HttpServerResponseInfo::CreateFor500( |
| 195 "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
| 196 "specified.")); |
184 return NULL; | 197 return NULL; |
185 } | 198 } |
186 return new WebSocketHybi17(connection, request, pos); | 199 return new WebSocketHybi17(server, connection, request, pos); |
187 } | 200 } |
188 | 201 |
189 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { | 202 virtual void Accept(const HttpServerRequestInfo& request) OVERRIDE { |
190 static const char* const kWebSocketGuid = | 203 static const char* const kWebSocketGuid = |
191 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | 204 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
192 std::string key = request.GetHeaderValue("sec-websocket-key"); | 205 std::string key = request.GetHeaderValue("sec-websocket-key"); |
193 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); | 206 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
194 std::string encoded_hash; | 207 std::string encoded_hash; |
195 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); | 208 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
196 | 209 |
197 std::string response = base::StringPrintf( | 210 server_->SendRaw( |
198 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" | 211 connection_->id(), |
199 "Upgrade: WebSocket\r\n" | 212 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
200 "Connection: Upgrade\r\n" | 213 "Upgrade: WebSocket\r\n" |
201 "Sec-WebSocket-Accept: %s\r\n" | 214 "Connection: Upgrade\r\n" |
202 "\r\n", | 215 "Sec-WebSocket-Accept: %s\r\n" |
203 encoded_hash.c_str()); | 216 "\r\n", |
204 connection_->Send(response); | 217 encoded_hash.c_str())); |
205 } | 218 } |
206 | 219 |
207 virtual ParseResult Read(std::string* message) OVERRIDE { | 220 virtual ParseResult Read(std::string* message) OVERRIDE { |
208 const std::string& frame = connection_->recv_data(); | 221 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
| 222 base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); |
209 int bytes_consumed = 0; | 223 int bytes_consumed = 0; |
210 | |
211 ParseResult result = | 224 ParseResult result = |
212 WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message); | 225 WebSocket::DecodeFrameHybi17(frame, true, &bytes_consumed, message); |
213 if (result == FRAME_OK) | 226 if (result == FRAME_OK) |
214 connection_->Shift(bytes_consumed); | 227 read_buf->DidConsume(bytes_consumed); |
215 if (result == FRAME_CLOSE) | 228 if (result == FRAME_CLOSE) |
216 closed_ = true; | 229 closed_ = true; |
217 return result; | 230 return result; |
218 } | 231 } |
219 | 232 |
220 virtual void Send(const std::string& message) OVERRIDE { | 233 virtual void Send(const std::string& message) OVERRIDE { |
221 if (closed_) | 234 if (closed_) |
222 return; | 235 return; |
223 std::string data = WebSocket::EncodeFrameHybi17(message, 0); | 236 server_->SendRaw(connection_->id(), |
224 connection_->Send(data); | 237 WebSocket::EncodeFrameHybi17(message, 0)); |
225 } | 238 } |
226 | 239 |
227 private: | 240 private: |
228 WebSocketHybi17(HttpConnection* connection, | 241 WebSocketHybi17(HttpServer* server, |
| 242 HttpConnection* connection, |
229 const HttpServerRequestInfo& request, | 243 const HttpServerRequestInfo& request, |
230 size_t* pos) | 244 size_t* pos) |
231 : WebSocket(connection), | 245 : WebSocket(server, connection), |
232 op_code_(0), | 246 op_code_(0), |
233 final_(false), | 247 final_(false), |
234 reserved1_(false), | 248 reserved1_(false), |
235 reserved2_(false), | 249 reserved2_(false), |
236 reserved3_(false), | 250 reserved3_(false), |
237 masked_(false), | 251 masked_(false), |
238 payload_(0), | 252 payload_(0), |
239 payload_length_(0), | 253 payload_length_(0), |
240 frame_end_(0), | 254 frame_end_(0), |
241 closed_(false) { | 255 closed_(false) { |
242 } | 256 } |
243 | 257 |
244 OpCode op_code_; | 258 OpCode op_code_; |
245 bool final_; | 259 bool final_; |
246 bool reserved1_; | 260 bool reserved1_; |
247 bool reserved2_; | 261 bool reserved2_; |
248 bool reserved3_; | 262 bool reserved3_; |
249 bool masked_; | 263 bool masked_; |
250 const char* payload_; | 264 const char* payload_; |
251 size_t payload_length_; | 265 size_t payload_length_; |
252 const char* frame_end_; | 266 const char* frame_end_; |
253 bool closed_; | 267 bool closed_; |
254 | 268 |
255 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); | 269 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); |
256 }; | 270 }; |
257 | 271 |
258 } // anonymous namespace | 272 } // anonymous namespace |
259 | 273 |
260 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, | 274 WebSocket* WebSocket::CreateWebSocket(HttpServer* server, |
| 275 HttpConnection* connection, |
261 const HttpServerRequestInfo& request, | 276 const HttpServerRequestInfo& request, |
262 size_t* pos) { | 277 size_t* pos) { |
263 WebSocket* socket = WebSocketHybi17::Create(connection, request, pos); | 278 WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); |
264 if (socket) | 279 if (socket) |
265 return socket; | 280 return socket; |
266 | 281 |
267 return WebSocketHixie76::Create(connection, request, pos); | 282 return WebSocketHixie76::Create(server, connection, request, pos); |
268 } | 283 } |
269 | 284 |
270 // static | 285 // static |
271 WebSocket::ParseResult WebSocket::DecodeFrameHybi17(const std::string& frame, | 286 WebSocket::ParseResult WebSocket::DecodeFrameHybi17( |
272 bool client_frame, | 287 const base::StringPiece& frame, |
273 int* bytes_consumed, | 288 bool client_frame, |
274 std::string* output) { | 289 int* bytes_consumed, |
| 290 std::string* output) { |
275 size_t data_length = frame.length(); | 291 size_t data_length = frame.length(); |
276 if (data_length < 2) | 292 if (data_length < 2) |
277 return FRAME_INCOMPLETE; | 293 return FRAME_INCOMPLETE; |
278 | 294 |
279 const char* buffer_begin = const_cast<char*>(frame.data()); | 295 const char* buffer_begin = const_cast<char*>(frame.data()); |
280 const char* p = buffer_begin; | 296 const char* p = buffer_begin; |
281 const char* buffer_end = p + data_length; | 297 const char* buffer_end = p + data_length; |
282 | 298 |
283 unsigned char first_byte = *p++; | 299 unsigned char first_byte = *p++; |
284 unsigned char second_byte = *p++; | 300 unsigned char second_byte = *p++; |
(...skipping 57 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
342 if (static_cast<size_t>(buffer_end - p) < total_length) | 358 if (static_cast<size_t>(buffer_end - p) < total_length) |
343 return FRAME_INCOMPLETE; | 359 return FRAME_INCOMPLETE; |
344 | 360 |
345 if (masked) { | 361 if (masked) { |
346 output->resize(payload_length); | 362 output->resize(payload_length); |
347 const char* masking_key = p; | 363 const char* masking_key = p; |
348 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); | 364 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); |
349 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload. | 365 for (size_t i = 0; i < payload_length; ++i) // Unmask the payload. |
350 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; | 366 (*output)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; |
351 } else { | 367 } else { |
352 std::string buffer(p, p + payload_length); | 368 output->assign(p, p + payload_length); |
353 output->swap(buffer); | |
354 } | 369 } |
355 | 370 |
356 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin; | 371 size_t pos = p + actual_masking_key_length + payload_length - buffer_begin; |
357 *bytes_consumed = pos; | 372 *bytes_consumed = pos; |
358 return closed ? FRAME_CLOSE : FRAME_OK; | 373 return closed ? FRAME_CLOSE : FRAME_OK; |
359 } | 374 } |
360 | 375 |
361 // static | 376 // static |
362 std::string WebSocket::EncodeFrameHybi17(const std::string& message, | 377 std::string WebSocket::EncodeFrameHybi17(const std::string& message, |
363 int masking_key) { | 378 int masking_key) { |
(...skipping 29 matching lines...) Expand all Loading... |
393 const char* mask_bytes = reinterpret_cast<char*>(&masking_key); | 408 const char* mask_bytes = reinterpret_cast<char*>(&masking_key); |
394 frame.insert(frame.end(), mask_bytes, mask_bytes + 4); | 409 frame.insert(frame.end(), mask_bytes, mask_bytes + 4); |
395 for (size_t i = 0; i < data_length; ++i) // Mask the payload. | 410 for (size_t i = 0; i < data_length; ++i) // Mask the payload. |
396 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]); | 411 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]); |
397 } else { | 412 } else { |
398 frame.insert(frame.end(), data, data + data_length); | 413 frame.insert(frame.end(), data, data + data_length); |
399 } | 414 } |
400 return std::string(&frame[0], frame.size()); | 415 return std::string(&frame[0], frame.size()); |
401 } | 416 } |
402 | 417 |
403 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { | 418 WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) |
| 419 : server_(server), |
| 420 connection_(connection) { |
404 } | 421 } |
405 | 422 |
406 } // namespace net | 423 } // namespace net |
OLD | NEW |