Chromium Code Reviews| OLD | NEW |
|---|---|
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
| 6 | 6 |
| 7 #include <algorithm> | 7 #include <algorithm> |
| 8 #include <iterator> | 8 #include <iterator> |
| 9 #include <set> | |
| 9 #include <string> | 10 #include <string> |
| 10 #include <vector> | 11 #include <vector> |
| 11 | 12 |
| 12 #include "base/base64.h" | 13 #include "base/base64.h" |
| 13 #include "base/basictypes.h" | 14 #include "base/basictypes.h" |
| 14 #include "base/bind.h" | 15 #include "base/bind.h" |
| 15 #include "base/containers/hash_tables.h" | 16 #include "base/containers/hash_tables.h" |
| 16 #include "base/stl_util.h" | 17 #include "base/stl_util.h" |
| 18 #include "base/strings/string_number_conversions.h" | |
| 17 #include "base/strings/string_util.h" | 19 #include "base/strings/string_util.h" |
| 18 #include "base/strings/stringprintf.h" | 20 #include "base/strings/stringprintf.h" |
| 19 #include "base/time/time.h" | 21 #include "base/time/time.h" |
| 20 #include "crypto/random.h" | 22 #include "crypto/random.h" |
| 21 #include "net/http/http_request_headers.h" | 23 #include "net/http/http_request_headers.h" |
| 22 #include "net/http/http_request_info.h" | 24 #include "net/http/http_request_info.h" |
| 23 #include "net/http/http_response_body_drainer.h" | 25 #include "net/http/http_response_body_drainer.h" |
| 24 #include "net/http/http_response_headers.h" | 26 #include "net/http/http_response_headers.h" |
| 25 #include "net/http/http_status_code.h" | 27 #include "net/http/http_status_code.h" |
| 26 #include "net/http/http_stream_parser.h" | 28 #include "net/http/http_stream_parser.h" |
| 27 #include "net/socket/client_socket_handle.h" | 29 #include "net/socket/client_socket_handle.h" |
| 28 #include "net/websockets/websocket_basic_stream.h" | 30 #include "net/websockets/websocket_basic_stream.h" |
| 31 #include "net/websockets/websocket_deflate_predictor.h" | |
| 32 #include "net/websockets/websocket_deflate_predictor_impl.h" | |
| 33 #include "net/websockets/websocket_deflate_stream.h" | |
| 34 #include "net/websockets/websocket_deflater.h" | |
| 29 #include "net/websockets/websocket_extension_parser.h" | 35 #include "net/websockets/websocket_extension_parser.h" |
| 30 #include "net/websockets/websocket_handshake_constants.h" | 36 #include "net/websockets/websocket_handshake_constants.h" |
| 31 #include "net/websockets/websocket_handshake_handler.h" | 37 #include "net/websockets/websocket_handshake_handler.h" |
| 32 #include "net/websockets/websocket_handshake_request_info.h" | 38 #include "net/websockets/websocket_handshake_request_info.h" |
| 33 #include "net/websockets/websocket_handshake_response_info.h" | 39 #include "net/websockets/websocket_handshake_response_info.h" |
| 34 #include "net/websockets/websocket_stream.h" | 40 #include "net/websockets/websocket_stream.h" |
| 35 | 41 |
| 36 namespace net { | 42 namespace net { |
| 43 | |
| 44 // TODO(ricea): If more extensions are added, replace this with a more general | |
| 45 // mechanism. | |
| 46 struct WebSocketExtensionParams { | |
| 47 WebSocketExtensionParams() | |
| 48 : deflate_enabled(false), | |
| 49 client_window_bits(15), | |
| 50 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} | |
| 51 | |
| 52 bool deflate_enabled; | |
|
yhirano
2014/01/23 12:18:26
When do you turn this flag on?
Adam Rice
2014/01/24 02:04:36
Nowhere. Fixed, and added a test.
| |
| 53 int client_window_bits; | |
| 54 WebSocketDeflater::ContextTakeOverMode deflate_mode; | |
| 55 }; | |
| 56 | |
| 37 namespace { | 57 namespace { |
| 38 | 58 |
| 39 enum GetHeaderResult { | 59 enum GetHeaderResult { |
| 40 GET_HEADER_OK, | 60 GET_HEADER_OK, |
| 41 GET_HEADER_MISSING, | 61 GET_HEADER_MISSING, |
| 42 GET_HEADER_MULTIPLE, | 62 GET_HEADER_MULTIPLE, |
| 43 }; | 63 }; |
| 44 | 64 |
| 45 std::string MissingHeaderMessage(const std::string& header_name) { | 65 std::string MissingHeaderMessage(const std::string& header_name) { |
| 46 return std::string("'") + header_name + "' header is missing"; | 66 return std::string("'") + header_name + "' header is missing"; |
| (...skipping 148 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
| 195 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 215 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
| 196 *failure_message = | 216 *failure_message = |
| 197 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 217 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
| 198 "but no response was received"; | 218 "but no response was received"; |
| 199 return false; | 219 return false; |
| 200 } | 220 } |
| 201 *sub_protocol = value; | 221 *sub_protocol = value; |
| 202 return true; | 222 return true; |
| 203 } | 223 } |
| 204 | 224 |
| 225 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, | |
| 226 std::string* failure_message, | |
| 227 WebSocketExtensionParams* params) { | |
| 228 static const char kClientPrefix[] = "client_"; | |
| 229 static const char kServerPrefix[] = "server_"; | |
| 230 static const char kNoContentTakeover[] = "no_content_takeover"; | |
|
yhirano
2014/01/23 12:18:26
ContentTakeOver should be ContextTakeOver.
Same fo
Adam Rice
2014/01/24 02:04:36
Oops. Fixed.
| |
| 231 static const char kMaxWindowBits[] = "max_window_bits"; | |
| 232 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | |
| 233 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, | |
| 234 the_strings_server_and_client_must_be_the_same_length); | |
| 235 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | |
| 236 | |
| 237 DCHECK(extension.name() == "permessage-deflate"); | |
| 238 const ParameterVector& parameters = extension.parameters(); | |
| 239 std::set<std::string> seen_names; | |
| 240 for (ParameterVector::const_iterator it = parameters.begin(); | |
| 241 it != parameters.end(); ++it) { | |
| 242 const std::string& name = it->name(); | |
| 243 if (seen_names.count(name) != 0) { | |
| 244 *failure_message = | |
| 245 "Received duplicate permessage-deflate extension parameter " + name; | |
| 246 return false; | |
| 247 } | |
| 248 seen_names.insert(name); | |
| 249 const std::string client_or_server(name, 0, kPrefixLen); | |
| 250 const bool is_client = (client_or_server == kClientPrefix); | |
| 251 if (!is_client && client_or_server != kServerPrefix) { | |
| 252 *failure_message = | |
| 253 "Received an unexpected permessage-deflate extension parameter"; | |
| 254 return false; | |
| 255 } | |
| 256 const std::string rest(name, kPrefixLen); | |
| 257 if (rest == kNoContentTakeover) { | |
| 258 if (it->HasValue()) { | |
| 259 *failure_message = "Received invalid " + name + " parameter"; | |
| 260 return false; | |
| 261 } | |
| 262 if (is_client) | |
| 263 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; | |
| 264 } else if (rest == kMaxWindowBits) { | |
| 265 if (!it->HasValue()) { | |
| 266 *failure_message = name + " must have value"; | |
| 267 return false; | |
| 268 } | |
| 269 int bits = 0; | |
| 270 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || | |
| 271 it->value()[0] == '0' || | |
| 272 it->value().find_first_not_of("0123456789") != std::string::npos) { | |
| 273 *failure_message = "Received invalid " + name + " parameter"; | |
| 274 return false; | |
| 275 } | |
| 276 if (is_client) | |
| 277 params->client_window_bits = bits; | |
| 278 } else { | |
| 279 *failure_message = | |
| 280 "Received an unexpected permessage-deflate extension parameter"; | |
| 281 return false; | |
| 282 } | |
| 283 } | |
| 284 return true; | |
| 285 } | |
| 286 | |
| 205 bool ValidateExtensions(const HttpResponseHeaders* headers, | 287 bool ValidateExtensions(const HttpResponseHeaders* headers, |
| 206 const std::vector<std::string>& requested_extensions, | 288 const std::vector<std::string>& requested_extensions, |
| 207 std::string* extensions, | 289 std::string* extensions, |
| 208 std::string* failure_message) { | 290 std::string* failure_message, |
| 291 WebSocketExtensionParams* params) { | |
| 209 void* state = NULL; | 292 void* state = NULL; |
| 210 std::string value; | 293 std::string value; |
| 294 std::vector<std::string> accepted_extensions; | |
| 295 // TODO(ricea): If adding support for additional extensions, generalise this | |
| 296 // code. | |
| 297 bool seen_permessage_deflate = false; | |
| 211 while (headers->EnumerateHeader( | 298 while (headers->EnumerateHeader( |
| 212 &state, websockets::kSecWebSocketExtensions, &value)) { | 299 &state, websockets::kSecWebSocketExtensions, &value)) { |
| 213 WebSocketExtensionParser parser; | 300 WebSocketExtensionParser parser; |
| 214 parser.Parse(value); | 301 parser.Parse(value); |
| 215 if (parser.has_error()) { | 302 if (parser.has_error()) { |
| 216 // TODO(yhirano) Set appropriate failure message. | 303 // TODO(yhirano) Set appropriate failure message. |
| 217 *failure_message = | 304 *failure_message = |
| 218 "'Sec-WebSocket-Extensions' header value is " | 305 "'Sec-WebSocket-Extensions' header value is " |
| 219 "rejected by the parser: " + | 306 "rejected by the parser: " + |
| 220 value; | 307 value; |
| 221 return false; | 308 return false; |
| 222 } | 309 } |
| 223 // TODO(ricea): Accept permessage-deflate with valid parameters. | 310 if (parser.extension().name() == "permessage-deflate") { |
| 224 *failure_message = | 311 if (seen_permessage_deflate) { |
| 225 "Found an unsupported extension '" + | 312 *failure_message = "Received duplicate permessage-deflate response"; |
| 226 parser.extension().name() + | 313 return false; |
| 227 "' in 'Sec-WebSocket-Extensions' header"; | 314 } |
| 228 return false; | 315 seen_permessage_deflate = true; |
| 316 if (!ValidatePerMessageDeflateExtension( | |
| 317 parser.extension(), failure_message, params)) | |
| 318 return false; | |
| 319 } else { | |
| 320 *failure_message = | |
| 321 "Found an unsupported extension '" + | |
| 322 parser.extension().name() + | |
| 323 "' in 'Sec-WebSocket-Extensions' header"; | |
| 324 return false; | |
| 325 } | |
| 326 accepted_extensions.push_back(value); | |
| 229 } | 327 } |
| 328 *extensions = JoinString(accepted_extensions, ", "); | |
| 230 return true; | 329 return true; |
| 231 } | 330 } |
| 232 | 331 |
| 233 } // namespace | 332 } // namespace |
| 234 | 333 |
| 235 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( | 334 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( |
| 236 scoped_ptr<ClientSocketHandle> connection, | 335 scoped_ptr<ClientSocketHandle> connection, |
| 237 WebSocketStream::ConnectDelegate* connect_delegate, | 336 WebSocketStream::ConnectDelegate* connect_delegate, |
| 238 bool using_proxy, | 337 bool using_proxy, |
| 239 std::vector<std::string> requested_sub_protocols, | 338 std::vector<std::string> requested_sub_protocols, |
| (...skipping 37 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
| 277 enriched_headers.CopyFrom(headers); | 376 enriched_headers.CopyFrom(headers); |
| 278 std::string handshake_challenge; | 377 std::string handshake_challenge; |
| 279 if (handshake_challenge_for_testing_) { | 378 if (handshake_challenge_for_testing_) { |
| 280 handshake_challenge = *handshake_challenge_for_testing_; | 379 handshake_challenge = *handshake_challenge_for_testing_; |
| 281 handshake_challenge_for_testing_.reset(); | 380 handshake_challenge_for_testing_.reset(); |
| 282 } else { | 381 } else { |
| 283 handshake_challenge = GenerateHandshakeChallenge(); | 382 handshake_challenge = GenerateHandshakeChallenge(); |
| 284 } | 383 } |
| 285 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); | 384 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); |
| 286 | 385 |
| 386 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, | |
| 387 requested_extensions_, | |
| 388 &enriched_headers); | |
| 287 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, | 389 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, |
| 288 requested_sub_protocols_, | 390 requested_sub_protocols_, |
| 289 &enriched_headers); | 391 &enriched_headers); |
| 290 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, | |
| 291 requested_extensions_, | |
| 292 &enriched_headers); | |
| 293 | 392 |
| 294 ComputeSecWebSocketAccept(handshake_challenge, | 393 ComputeSecWebSocketAccept(handshake_challenge, |
| 295 &handshake_challenge_response_); | 394 &handshake_challenge_response_); |
| 296 | 395 |
| 297 DCHECK(connect_delegate_); | 396 DCHECK(connect_delegate_); |
| 298 scoped_ptr<WebSocketHandshakeRequestInfo> request( | 397 scoped_ptr<WebSocketHandshakeRequestInfo> request( |
| 299 new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); | 398 new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); |
| 300 request->headers.CopyFrom(enriched_headers); | 399 request->headers.CopyFrom(enriched_headers); |
| 301 connect_delegate_->OnStartOpeningHandshake(request.Pass()); | 400 connect_delegate_->OnStartOpeningHandshake(request.Pass()); |
| 302 | 401 |
| (...skipping 83 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
| 386 drainer->Start(session); | 485 drainer->Start(session); |
| 387 // |drainer| will delete itself. | 486 // |drainer| will delete itself. |
| 388 } | 487 } |
| 389 | 488 |
| 390 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { | 489 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { |
| 391 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is | 490 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is |
| 392 // gone, then copy whatever has happened there over here. | 491 // gone, then copy whatever has happened there over here. |
| 393 } | 492 } |
| 394 | 493 |
| 395 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { | 494 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { |
| 396 // TODO(ricea): Add deflate support. | |
| 397 | |
| 398 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make | 495 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make |
| 399 // sure it does not touch it again before it is destroyed. | 496 // sure it does not touch it again before it is destroyed. |
| 400 state_.DeleteParser(); | 497 state_.DeleteParser(); |
| 401 return scoped_ptr<WebSocketStream>( | 498 scoped_ptr<WebSocketStream> basic_stream( |
| 402 new WebSocketBasicStream(state_.ReleaseConnection(), | 499 new WebSocketBasicStream(state_.ReleaseConnection(), |
| 403 state_.read_buf(), | 500 state_.read_buf(), |
| 404 sub_protocol_, | 501 sub_protocol_, |
| 405 extensions_)); | 502 extensions_)); |
| 503 | |
| 504 if (extension_params_->deflate_enabled) { | |
| 505 return scoped_ptr<WebSocketStream>( | |
| 506 new WebSocketDeflateStream(basic_stream.Pass(), | |
| 507 extension_params_->deflate_mode, | |
| 508 extension_params_->client_window_bits, | |
| 509 scoped_ptr<WebSocketDeflatePredictor>( | |
| 510 new WebSocketDeflatePredictorImpl))); | |
| 511 } else { | |
| 512 return basic_stream.Pass(); | |
| 513 } | |
| 406 } | 514 } |
| 407 | 515 |
| 408 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( | 516 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
| 409 const std::string& key) { | 517 const std::string& key) { |
| 410 handshake_challenge_for_testing_.reset(new std::string(key)); | 518 handshake_challenge_for_testing_.reset(new std::string(key)); |
| 411 } | 519 } |
| 412 | 520 |
| 413 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { | 521 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { |
| 414 return failure_message_; | 522 return failure_message_; |
| 415 } | 523 } |
| (...skipping 41 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
| 457 default: | 565 default: |
| 458 failure_message_ = base::StringPrintf("Unexpected status code: %d", | 566 failure_message_ = base::StringPrintf("Unexpected status code: %d", |
| 459 headers->response_code()); | 567 headers->response_code()); |
| 460 OnFinishOpeningHandshake(); | 568 OnFinishOpeningHandshake(); |
| 461 return ERR_INVALID_RESPONSE; | 569 return ERR_INVALID_RESPONSE; |
| 462 } | 570 } |
| 463 } | 571 } |
| 464 | 572 |
| 465 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( | 573 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| 466 const scoped_refptr<HttpResponseHeaders>& headers) { | 574 const scoped_refptr<HttpResponseHeaders>& headers) { |
| 575 extension_params_.reset(new WebSocketExtensionParams); | |
| 467 if (ValidateUpgrade(headers.get(), &failure_message_) && | 576 if (ValidateUpgrade(headers.get(), &failure_message_) && |
| 468 ValidateSecWebSocketAccept(headers.get(), | 577 ValidateSecWebSocketAccept(headers.get(), |
| 469 handshake_challenge_response_, | 578 handshake_challenge_response_, |
| 470 &failure_message_) && | 579 &failure_message_) && |
| 471 ValidateConnection(headers.get(), &failure_message_) && | 580 ValidateConnection(headers.get(), &failure_message_) && |
| 472 ValidateSubProtocol(headers.get(), | 581 ValidateSubProtocol(headers.get(), |
| 473 requested_sub_protocols_, | 582 requested_sub_protocols_, |
| 474 &sub_protocol_, | 583 &sub_protocol_, |
| 475 &failure_message_) && | 584 &failure_message_) && |
| 476 ValidateExtensions(headers.get(), | 585 ValidateExtensions(headers.get(), |
| 477 requested_extensions_, | 586 requested_extensions_, |
| 478 &extensions_, | 587 &extensions_, |
| 479 &failure_message_)) { | 588 &failure_message_, |
| 589 extension_params_.get())) { | |
| 480 return OK; | 590 return OK; |
| 481 } | 591 } |
| 482 failure_message_ = "Error during WebSocket handshake: " + failure_message_; | 592 failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
| 483 return ERR_INVALID_RESPONSE; | 593 return ERR_INVALID_RESPONSE; |
| 484 } | 594 } |
| 485 | 595 |
| 486 } // namespace net | 596 } // namespace net |
| OLD | NEW |