| OLD | NEW |
| 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/server/web_socket.h" | 5 #include "net/server/web_socket.h" |
| 6 | 6 |
| 7 #include "base/base64.h" | 7 #include "base/base64.h" |
| 8 #include "base/rand_util.h" | 8 #include "base/rand_util.h" |
| 9 #include "base/logging.h" | 9 #include "base/logging.h" |
| 10 #include "base/md5.h" | 10 #include "base/md5.h" |
| (...skipping 151 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 162 const unsigned char kReserved3Bit = 0x10; | 162 const unsigned char kReserved3Bit = 0x10; |
| 163 const unsigned char kOpCodeMask = 0xF; | 163 const unsigned char kOpCodeMask = 0xF; |
| 164 const unsigned char kMaskBit = 0x80; | 164 const unsigned char kMaskBit = 0x80; |
| 165 const unsigned char kPayloadLengthMask = 0x7F; | 165 const unsigned char kPayloadLengthMask = 0x7F; |
| 166 | 166 |
| 167 const size_t kMaxSingleBytePayloadLength = 125; | 167 const size_t kMaxSingleBytePayloadLength = 125; |
| 168 const size_t kTwoBytePayloadLengthField = 126; | 168 const size_t kTwoBytePayloadLengthField = 126; |
| 169 const size_t kEightBytePayloadLengthField = 127; | 169 const size_t kEightBytePayloadLengthField = 127; |
| 170 const size_t kMaskingKeyWidthInBytes = 4; | 170 const size_t kMaskingKeyWidthInBytes = 4; |
| 171 | 171 |
| 172 class WebSocketHybi10 : public WebSocket { | 172 class WebSocketHybi17 : public WebSocket { |
| 173 public: | 173 public: |
| 174 static WebSocket* Create(HttpConnection* connection, | 174 static WebSocket* Create(HttpConnection* connection, |
| 175 const HttpServerRequestInfo& request, | 175 const HttpServerRequestInfo& request, |
| 176 size_t* pos) { | 176 size_t* pos) { |
| 177 std::string version = request.GetHeaderValue("Sec-WebSocket-Version"); | 177 std::string version = request.GetHeaderValue("Sec-WebSocket-Version"); |
| 178 if (version != "8" && version != "13") | 178 if (version != "8" && version != "13") |
| 179 return NULL; | 179 return NULL; |
| 180 | 180 |
| 181 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); | 181 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); |
| 182 if (key.empty()) { | 182 if (key.empty()) { |
| 183 connection->Send500("Invalid request format. " | 183 connection->Send500("Invalid request format. " |
| 184 "Sec-WebSocket-Key is empty or isn't specified."); | 184 "Sec-WebSocket-Key is empty or isn't specified."); |
| 185 return NULL; | 185 return NULL; |
| 186 } | 186 } |
| 187 return new WebSocketHybi10(connection, request, pos); | 187 return new WebSocketHybi17(connection, request, pos); |
| 188 } | 188 } |
| 189 | 189 |
| 190 virtual void Accept(const HttpServerRequestInfo& request) { | 190 virtual void Accept(const HttpServerRequestInfo& request) { |
| 191 static const char* const kWebSocketGuid = | 191 static const char* const kWebSocketGuid = |
| 192 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | 192 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
| 193 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); | 193 std::string key = request.GetHeaderValue("Sec-WebSocket-Key"); |
| 194 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); | 194 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
| 195 std::string encoded_hash; | 195 std::string encoded_hash; |
| 196 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); | 196 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
| 197 | 197 |
| (...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 230 case kOpCodeText: | 230 case kOpCodeText: |
| 231 break; | 231 break; |
| 232 case kOpCodeBinary: // We don't support binary frames yet. | 232 case kOpCodeBinary: // We don't support binary frames yet. |
| 233 case kOpCodeContinuation: // We don't support binary frames yet. | 233 case kOpCodeContinuation: // We don't support binary frames yet. |
| 234 case kOpCodePing: // We don't support binary frames yet. | 234 case kOpCodePing: // We don't support binary frames yet. |
| 235 case kOpCodePong: // We don't support binary frames yet. | 235 case kOpCodePong: // We don't support binary frames yet. |
| 236 default: | 236 default: |
| 237 return FRAME_ERROR; | 237 return FRAME_ERROR; |
| 238 } | 238 } |
| 239 | 239 |
| 240 if (!masked_) // According to Hybi-17 spec client MUST mask his frame. |
| 241 return FRAME_ERROR; |
| 242 |
| 240 uint64 payload_length64 = second_byte & kPayloadLengthMask; | 243 uint64 payload_length64 = second_byte & kPayloadLengthMask; |
| 241 if (payload_length64 > kMaxSingleBytePayloadLength) { | 244 if (payload_length64 > kMaxSingleBytePayloadLength) { |
| 242 int extended_payload_length_size; | 245 int extended_payload_length_size; |
| 243 if (payload_length64 == kTwoBytePayloadLengthField) | 246 if (payload_length64 == kTwoBytePayloadLengthField) |
| 244 extended_payload_length_size = 2; | 247 extended_payload_length_size = 2; |
| 245 else { | 248 else { |
| 246 DCHECK(payload_length64 == kEightBytePayloadLengthField); | 249 DCHECK(payload_length64 == kEightBytePayloadLengthField); |
| 247 extended_payload_length_size = 8; | 250 extended_payload_length_size = 8; |
| 248 } | 251 } |
| 249 if (buffer_end - p < extended_payload_length_size) | 252 if (buffer_end - p < extended_payload_length_size) |
| 250 return FRAME_INCOMPLETE; | 253 return FRAME_INCOMPLETE; |
| 251 payload_length64 = 0; | 254 payload_length64 = 0; |
| 252 for (int i = 0; i < extended_payload_length_size; ++i) { | 255 for (int i = 0; i < extended_payload_length_size; ++i) { |
| 253 payload_length64 <<= 8; | 256 payload_length64 <<= 8; |
| 254 payload_length64 |= static_cast<unsigned char>(*p++); | 257 payload_length64 |= static_cast<unsigned char>(*p++); |
| 255 } | 258 } |
| 256 } | 259 } |
| 257 | 260 |
| 258 static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; | 261 static const uint64 max_payload_length = 0x7FFFFFFFFFFFFFFFull; |
| 259 size_t masking_key_length = masked_ ? kMaskingKeyWidthInBytes : 0; | |
| 260 static size_t max_length = std::numeric_limits<size_t>::max(); | 262 static size_t max_length = std::numeric_limits<size_t>::max(); |
| 261 if (payload_length64 > max_payload_length || | 263 if (payload_length64 > max_payload_length || |
| 262 payload_length64 + masking_key_length > max_length) { | 264 payload_length64 + kMaskingKeyWidthInBytes > max_length) { |
| 263 // WebSocket frame length too large. | 265 // WebSocket frame length too large. |
| 264 return FRAME_ERROR; | 266 return FRAME_ERROR; |
| 265 } | 267 } |
| 266 payload_length_ = static_cast<size_t>(payload_length64); | 268 payload_length_ = static_cast<size_t>(payload_length64); |
| 267 | 269 |
| 268 size_t total_length = masking_key_length + payload_length_; | 270 size_t total_length = kMaskingKeyWidthInBytes + payload_length_; |
| 269 if (static_cast<size_t>(buffer_end - p) < total_length) | 271 if (static_cast<size_t>(buffer_end - p) < total_length) |
| 270 return FRAME_INCOMPLETE; | 272 return FRAME_INCOMPLETE; |
| 271 | 273 |
| 272 if (masked_) { | 274 if (masked_) { |
| 273 message->resize(payload_length_); | 275 message->resize(payload_length_); |
| 274 const char* masking_key = p; | 276 const char* masking_key = p; |
| 275 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); | 277 char* payload = const_cast<char*>(p + kMaskingKeyWidthInBytes); |
| 276 for (size_t i = 0; i < payload_length_; ++i) // Unmask the payload. | 278 for (size_t i = 0; i < payload_length_; ++i) // Unmask the payload. |
| 277 (*message)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; | 279 (*message)[i] = payload[i] ^ masking_key[i % kMaskingKeyWidthInBytes]; |
| 278 } else { | 280 } else { |
| 279 std::string buffer(p, p + payload_length_); | 281 std::string buffer(p, p + payload_length_); |
| 280 message->swap(buffer); | 282 message->swap(buffer); |
| 281 } | 283 } |
| 282 | 284 |
| 283 size_t pos = p + masking_key_length + payload_length_ - | 285 size_t pos = p + kMaskingKeyWidthInBytes + payload_length_ - |
| 284 connection_->recv_data().c_str(); | 286 connection_->recv_data().c_str(); |
| 285 connection_->Shift(pos); | 287 connection_->Shift(pos); |
| 286 | 288 |
| 287 return closed_ ? FRAME_CLOSE : FRAME_OK; | 289 return closed_ ? FRAME_CLOSE : FRAME_OK; |
| 288 } | 290 } |
| 289 | 291 |
| 290 virtual void Send(const std::string& message) { | 292 virtual void Send(const std::string& message) { |
| 291 if (closed_) | 293 if (closed_) |
| 292 return; | 294 return; |
| 293 | 295 |
| 294 std::vector<char> frame; | 296 std::vector<char> frame; |
| 295 OpCode op_code = kOpCodeText; | 297 OpCode op_code = kOpCodeText; |
| 296 size_t data_length = message.length(); | 298 size_t data_length = message.length(); |
| 297 | 299 |
| 298 frame.push_back(kFinalBit | op_code); | 300 frame.push_back(kFinalBit | op_code); |
| 299 if (data_length <= kMaxSingleBytePayloadLength) | 301 if (data_length <= kMaxSingleBytePayloadLength) |
| 300 frame.push_back(kMaskBit | data_length); | 302 frame.push_back(data_length); |
| 301 else if (data_length <= 0xFFFF) { | 303 else if (data_length <= 0xFFFF) { |
| 302 frame.push_back(kMaskBit | kTwoBytePayloadLengthField); | 304 frame.push_back(kTwoBytePayloadLengthField); |
| 303 frame.push_back((data_length & 0xFF00) >> 8); | 305 frame.push_back((data_length & 0xFF00) >> 8); |
| 304 frame.push_back(data_length & 0xFF); | 306 frame.push_back(data_length & 0xFF); |
| 305 } else { | 307 } else { |
| 306 frame.push_back(kMaskBit | kEightBytePayloadLengthField); | 308 frame.push_back(kEightBytePayloadLengthField); |
| 307 char extended_payload_length[8]; | 309 char extended_payload_length[8]; |
| 308 size_t remaining = data_length; | 310 size_t remaining = data_length; |
| 309 // Fill the length into extended_payload_length in the network byte order. | 311 // Fill the length into extended_payload_length in the network byte order. |
| 310 for (int i = 0; i < 8; ++i) { | 312 for (int i = 0; i < 8; ++i) { |
| 311 extended_payload_length[7 - i] = remaining & 0xFF; | 313 extended_payload_length[7 - i] = remaining & 0xFF; |
| 312 remaining >>= 8; | 314 remaining >>= 8; |
| 313 } | 315 } |
| 314 frame.insert(frame.end(), | 316 frame.insert(frame.end(), |
| 315 extended_payload_length, | 317 extended_payload_length, |
| 316 extended_payload_length + 8); | 318 extended_payload_length + 8); |
| 317 DCHECK(!remaining); | 319 DCHECK(!remaining); |
| 318 } | 320 } |
| 319 | 321 |
| 320 // Mask the frame. | |
| 321 size_t masking_key_start = frame.size(); | |
| 322 // Add placeholder for masking key. Will be overwritten. | |
| 323 frame.resize(frame.size() + kMaskingKeyWidthInBytes); | |
| 324 size_t payload_start = frame.size(); | |
| 325 const char* data = message.c_str(); | 322 const char* data = message.c_str(); |
| 326 frame.insert(frame.end(), data, data + data_length); | 323 frame.insert(frame.end(), data, data + data_length); |
| 327 | 324 |
| 328 base::RandBytes(&frame[0] + masking_key_start, | |
| 329 kMaskingKeyWidthInBytes); | |
| 330 for (size_t i = 0; i < data_length; ++i) { | |
| 331 frame[payload_start + i] ^= | |
| 332 frame[masking_key_start + i % kMaskingKeyWidthInBytes]; | |
| 333 } | |
| 334 connection_->Send(&frame[0], frame.size()); | 325 connection_->Send(&frame[0], frame.size()); |
| 335 } | 326 } |
| 336 | 327 |
| 337 private: | 328 private: |
| 338 WebSocketHybi10(HttpConnection* connection, | 329 WebSocketHybi17(HttpConnection* connection, |
| 339 const HttpServerRequestInfo& request, | 330 const HttpServerRequestInfo& request, |
| 340 size_t* pos) | 331 size_t* pos) |
| 341 : WebSocket(connection), | 332 : WebSocket(connection), |
| 342 op_code_(0), | 333 op_code_(0), |
| 343 final_(false), | 334 final_(false), |
| 344 reserved1_(false), | 335 reserved1_(false), |
| 345 reserved2_(false), | 336 reserved2_(false), |
| 346 reserved3_(false), | 337 reserved3_(false), |
| 347 masked_(false), | 338 masked_(false), |
| 348 payload_(0), | 339 payload_(0), |
| 349 payload_length_(0), | 340 payload_length_(0), |
| 350 frame_end_(0), | 341 frame_end_(0), |
| 351 closed_(false) { | 342 closed_(false) { |
| 352 } | 343 } |
| 353 | 344 |
| 354 OpCode op_code_; | 345 OpCode op_code_; |
| 355 bool final_; | 346 bool final_; |
| 356 bool reserved1_; | 347 bool reserved1_; |
| 357 bool reserved2_; | 348 bool reserved2_; |
| 358 bool reserved3_; | 349 bool reserved3_; |
| 359 bool masked_; | 350 bool masked_; |
| 360 const char* payload_; | 351 const char* payload_; |
| 361 size_t payload_length_; | 352 size_t payload_length_; |
| 362 const char* frame_end_; | 353 const char* frame_end_; |
| 363 bool closed_; | 354 bool closed_; |
| 364 | 355 |
| 365 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi10); | 356 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); |
| 366 }; | 357 }; |
| 367 | 358 |
| 368 } // anonymous namespace | 359 } // anonymous namespace |
| 369 | 360 |
| 370 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, | 361 WebSocket* WebSocket::CreateWebSocket(HttpConnection* connection, |
| 371 const HttpServerRequestInfo& request, | 362 const HttpServerRequestInfo& request, |
| 372 size_t* pos) { | 363 size_t* pos) { |
| 373 WebSocket* socket = WebSocketHybi10::Create(connection, request, pos); | 364 WebSocket* socket = WebSocketHybi17::Create(connection, request, pos); |
| 374 if (socket) | 365 if (socket) |
| 375 return socket; | 366 return socket; |
| 376 | 367 |
| 377 return WebSocketHixie76::Create(connection, request, pos); | 368 return WebSocketHixie76::Create(connection, request, pos); |
| 378 } | 369 } |
| 379 | 370 |
| 380 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { | 371 WebSocket::WebSocket(HttpConnection* connection) : connection_(connection) { |
| 381 } | 372 } |
| 382 | 373 |
| 383 } // namespace net | 374 } // namespace net |
| OLD | NEW |