OLD | NEW |
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/server/web_socket.h" | 5 #include "net/server/web_socket.h" |
6 | 6 |
7 #include "base/base64.h" | 7 #include "base/base64.h" |
8 #include "base/rand_util.h" | 8 #include "base/rand_util.h" |
9 #include "base/logging.h" | 9 #include "base/logging.h" |
10 #include "base/md5.h" | 10 #include "base/md5.h" |
(...skipping 151 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
162 const unsigned char kReserved3Bit = 0x10; | 162 const unsigned char kReserved3Bit = 0x10; |
163 const unsigned char kOpCodeMask = 0xF; | 163 const unsigned char kOpCodeMask = 0xF; |
164 const unsigned char kMaskBit = 0x80; | 164 const unsigned char kMaskBit = 0x80; |
165 const unsigned char kPayloadLengthMask = 0x7F; | 165 const unsigned char kPayloadLengthMask = 0x7F; |
166 | 166 |
167 const size_t kMaxSingleBytePayloadLength = 125; | 167 const size_t kMaxSingleBytePayloadLength = 125; |
168 const size_t kTwoBytePayloadLengthField = 126; | 168 const size_t kTwoBytePayloadLengthField = 126; |
169 const size_t kEightBytePayloadLengthField = 127; | 169 const size_t kEightBytePayloadLengthField = 127; |
170 const size_t kMaskingKeyWidthInBytes = 4; | 170 const size_t kMaskingKeyWidthInBytes = 4; |
171 | 171 |
172 class WebSocketHybi10 : public WebSocket { | 172 class WebSocketHybi17 : public WebSocket { |
173 public: | 173 public: |
174 static WebSocket* Create(HttpConnection* connection, | 174 static WebSocket* Create(HttpConnection* connection, |
175 const HttpServerRequestInfo& request, | 175 const HttpServerRequestInfo& request, |
176 size_t* pos) { | 176 size_t* pos) { |
177 std::string version = request.GetHeaderValue("Sec-WebSocket-Version"); | 177 std::string version = request.GetHeaderValue("Sec-WebSocket-Version"); |
178 if (version != "8" && version != "13") | 178 if (version != "8" && version != "13") |
179 return NULL; | 179 return NULL; |
180 | 180 |
181 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); | 181 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); |
182 if (key.empty()) { | 182 if (key.empty()) { |
183 connection->Send500("Invalid request format. " | 183 connection->Send500("Invalid request format. " |
184 "Sec-WebSocket-Key is empty or isn't specified."); | 184 "Sec-WebSocket-Key is empty or isn't specified."); |
185 return NULL; | 185 return NULL; |
186 } | 186 } |
187 return new WebSocketHybi10(connection, request, pos); | 187 return new WebSocketHybi17(connection, request, pos); |
188 } | 188 } |
189 | 189 |
190 virtual void Accept(const HttpServerRequestInfo& request) { | 190 virtual void Accept(const HttpServerRequestInfo& request) { |
191 static const char* const kWebSocketGuid = | 191 static const char* const kWebSocketGuid = |
192 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | 192 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
193 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); | 193 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); |
194 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); | 194 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
195 std::string encoded_hash; | 195 std::string encoded_hash; |
196 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); | 196 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
197 | 197 |
(...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
230 case kOpCodeText: | 230 case kOpCodeText: |
231 break; | 231 break; |
232 case kOpCodeBinary: // We don't support binary frames yet. | 232 case kOpCodeBinary: // We don't support binary frames yet. |
233 case kOpCodeContinuation: // We don't support binary frames yet. | 233 case kOpCodeContinuation: // We don't support binary frames yet. |
234 case kOpCodePing: // We don't support binary frames yet. | 234 case kOpCodePing: // We don't support binary frames yet. |
235 case kOpCodePong: // We don't support binary frames yet. | 235 case kOpCodePong: // We don't support binary frames yet. |
236 default: | 236 default: |
237 return FRAME_ERROR; | 237 return FRAME_ERROR; |
238 } | 238 } |
239 | 239 |
| 240 if (!masked_) // According to Hybi-17 spec client MUST mask his frame. |
| 241 return FRAME_ERROR; |
| 242 |
240 uint64 payload_length64 = second_byte & kPayloadLengthMask; | 243 uint64 payload_length64 = second_byte & kPayloadLengthMask; |
241 if (payload_length64 > kMaxSingleBytePayloadLength) { | 244 if (payload_length64 > kMaxSingleBytePayloadLength) { |
242 int extended_payload_length_size; | 245 int extended_payload_length_size; |
243 if (payload_length64 == kTwoBytePayloadLengthField) | 246 if (payload_length64 == kTwoBytePayloadLengthField) |
244 extended_payload_length_size = 2; | 247 extended_payload_length_size = 2; |
245 else { | 248 else { |
246 DCHECK(payload_length64 == kEightBytePayloadLengthField); | 249 DCHECK(payload_length64 == kEightBytePayloadLengthField); |
247 extended_payload_length_size = 8; | 250 extended_payload_length_size = 8; |
248 } | 251 } |
249 if (buffer_end - p < extended_payload_length_size) | 252 if (buffer_end - p < extended_payload_length_size) |
250 return FRAME_INCOMPLETE; | 253 return FRAME_INCOMPLETE; |
251 payload_length64 = 0; | 254 payload_length64 = 0; |
252 for (int i = 0; i < extended_payload_length_size; ++i) { | 255 for (int i = 0; i < extended_payload_length_size; ++i) { |
253 payload_length64 <<= 8; | 256 payload_length64 <<= 8; |
254 payload_length64 |= static_cast<unsigned char>(*p++); | 257 payload_length64 |= static_cast<unsigned char>(*p++); |
255 } | 258 } |
256 } | 259 } |
257 | 260 |
258 static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; | 261 static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; |
259 size_t masking_key_length = masked_ ? kMaskingKeyWidthInBytes : 0; | |
260 static size_t max_length = std::numeric_limits<size_t>::max(); | 262 static size_t max_length = std::numeric_limits<size_t>::max(); |
261 if (payload_length64 > max_payload_length || | 263 if (payload_length64 > max_payload_length || |
262 payload_length64 + masking_key_length > max_length) { | 264 payload_length64 + kMaskingKeyWidthInBytes > max_length) { |
263 // WebSocket frame length too large. | 265 // WebSocket frame length too large. |
264 return FRAME_ERROR; | 266 return FRAME_ERROR; |
265 } | 267 } |
266 payload_length_ = static_cast<size_t>(payload_length64); | 268 payload_length_ = static_cast<size_t>(payload_length64); |
267 | 269 |
268 size_t total_length = masking_key_length + payload_length_; | 270 size_t total_length = kMaskingKeyWidthInBytes + payload_length_; |
269 if (static_cast<size_t>(buffer_end - p) < total_length) | 271 if (static_cast<size_t>(buffer_end - p) < total_length) |
270 return FRAME_INCOMPLETE; | 272 return FRAME_INCOMPLETE; |
271 | 273 |
272 if (masked_) { | 274 if (masked_) { |
273 message->resize(payload_length_); | 275 message->resize(payload_length_); |
274 const char* masking_key = p; | 276 const char* masking_key = p; |
275 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); | 277 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); |
276 for (size_t i = 0; i < payload_length_; ++i) // Unmask the payload. | 278 for (size_t i = 0; i < payload_length_; ++i) // Unmask the payload. |
277 (*message)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; | 279 (*message)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; |
278 } else { | 280 } else { |
279 std::string buffer(p, p + payload_length_); | 281 std::string buffer(p, p + payload_length_); |
280 message->swap(buffer); | 282 message->swap(buffer); |
281 } | 283 } |
282 | 284 |
283 size_t pos = p + masking_key_length + payload_length_ - | 285 size_t pos = p + kMaskingKeyWidthInBytes + payload_length_ - |
284 connection_->recv_data().c_str(); | 286 connection_->recv_data().c_str(); |
285 connection_->Shift(pos); | 287 connection_->Shift(pos); |
286 | 288 |
287 return closed_ ? FRAME_CLOSE : FRAME_OK; | 289 return closed_ ? FRAME_CLOSE : FRAME_OK; |
288 } | 290 } |
289 | 291 |
290 virtual void Send(const std::string& message) { | 292 virtual void Send(const std::string& message) { |
291 if (closed_) | 293 if (closed_) |
292 return; | 294 return; |
293 | 295 |
294 std::vector<char> frame; | 296 std::vector<char> frame; |
295 OpCode op_code = kOpCodeText; | 297 OpCode op_code = kOpCodeText; |
296 size_t data_length = message.length(); | 298 size_t data_length = message.length(); |
297 | 299 |
298 frame.push_back(kFinalBit | op_code); | 300 frame.push_back(kFinalBit | op_code); |
299 if (data_length <= kMaxSingleBytePayloadLength) | 301 if (data_length <= kMaxSingleBytePayloadLength) |
300 frame.push_back(kMaskBit | data_length); | 302 frame.push_back(data_length); |
301 else if (data_length <= 0xFFFF) { | 303 else if (data_length <= 0xFFFF) { |
302 frame.push_back(kMaskBit | kTwoBytePayloadLengthField); | 304 frame.push_back(kTwoBytePayloadLengthField); |
303 frame.push_back((data_length & 0xFF00) >> 8); | 305 frame.push_back((data_length & 0xFF00) >> 8); |
304 frame.push_back(data_length & 0xFF); | 306 frame.push_back(data_length & 0xFF); |
305 } else { | 307 } else { |
306 frame.push_back(kMaskBit | kEightBytePayloadLengthField); | 308 frame.push_back(kEightBytePayloadLengthField); |
307 char extended_payload_length[8]; | 309 char extended_payload_length[8]; |
308 size_t remaining = data_length; | 310 size_t remaining = data_length; |
309 // Fill the length into extended_payload_length in the network byte order. | 311 // Fill the length into extended_payload_length in the network byte order. |
310 for (int i = 0; i < 8; ++i) { | 312 for (int i = 0; i < 8; ++i) { |
311 extended_payload_length[7 - i] = remaining & 0xFF; | 313 extended_payload_length[7 - i] = remaining & 0xFF; |
312 remaining >>= 8; | 314 remaining >>= 8; |
313 } | 315 } |
314 frame.insert(frame.end(), | 316 frame.insert(frame.end(), |
315 extended_payload_length, | 317 extended_payload_length, |
316 extended_payload_length + 8); | 318 extended_payload_length + 8); |
317 DCHECK(!remaining); | 319 DCHECK(!remaining); |
318 } | 320 } |
319 | 321 |
320 // Mask the frame. | |
321 size_t masking_key_start = frame.size(); | |
322 // Add placeholder for masking key. Will be overwritten. | |
323 frame.resize(frame.size() + kMaskingKeyWidthInBytes); | |
324 size_t payload_start = frame.size(); | |
325 const char* data = message.c_str(); | 322 const char* data = message.c_str(); |
326 frame.insert(frame.end(), data, data + data_length); | 323 frame.insert(frame.end(), data, data + data_length); |
327 | 324 |
328 base::RandBytes(&frame[0] + masking_key_start, | |
329 kMaskingKeyWidthInBytes); | |
330 for (size_t i = 0; i < data_length; ++i) { | |
331 frame[payload_start + i] ^= | |
332 frame[masking_key_start + i % kMaskingKeyWidthInBytes]; | |
333 } | |
334 connection_->Send(&frame[0], frame.size()); | 325 connection_->Send(&frame[0], frame.size()); |
335 } | 326 } |
336 | 327 |
337 private: | 328 private: |
338 WebSocketHybi10(HttpConnection* connection, | 329 WebSocketHybi17(HttpConnection* connection, |
339 const HttpServerRequestInfo& request, | 330 const HttpServerRequestInfo& request, |
340 size_t* pos) | 331 size_t* pos) |
341 : WebSocket(connection), | 332 : WebSocket(connection), |
342 op_code_(0), | 333 op_code_(0), |
343 final_(false), | 334 final_(false), |
344 reserved1_(false), | 335 reserved1_(false), |
345 reserved2_(false), | 336 reserved2_(false), |
346 reserved3_(false), | 337 reserved3_(false), |
347 masked_(false), | 338 masked_(false), |
348 payload_(0), | 339 payload_(0), |
349 payload_length_(0), | 340 payload_length_(0), |
350 frame_end_(0), | 341 frame_end_(0), |
351 closed_(false) { | 342 closed_(false) { |
352 } | 343 } |
353 | 344 |
354 OpCode op_code_; | 345 OpCode op_code_; |
355 bool final_; | 346 bool final_; |
356 bool reserved1_; | 347 bool reserved1_; |
357 bool reserved2_; | 348 bool reserved2_; |
358 bool reserved3_; | 349 bool reserved3_; |
359 bool masked_; | 350 bool masked_; |
360 const char* payload_; | 351 const char* payload_; |
361 size_t payload_length_; | 352 size_t payload_length_; |
362 const char* frame_end_; | 353 const char* frame_end_; |
363 bool closed_; | 354 bool closed_; |
364 | 355 |
365 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi10); | 356 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); |
366 }; | 357 }; |
367 | 358 |
368 } // anonymous namespace | 359 } // anonymous namespace |
369 | 360 |
370 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, | 361 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, |
371 const HttpServerRequestInfo& request, | 362 const HttpServerRequestInfo& request, |
372 size_t* pos) { | 363 size_t* pos) { |
373 WebSocket* socket = WebSocketHybi10::Create(connection, request, pos); | 364 WebSocket* socket = WebSocketHybi17::Create(connection, request, pos); |
374 if (socket) | 365 if (socket) |
375 return socket; | 366 return socket; |
376 | 367 |
377 return WebSocketHixie76::Create(connection, request, pos); | 368 return WebSocketHixie76::Create(connection, request, pos); |
378 } | 369 } |
379 | 370 |
380 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { | 371 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { |
381 } | 372 } |
382 | 373 |
383 } // namespace net | 374 } // namespace net |
OLD | NEW |