| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
| 6 | 6 |
| 7 #include <algorithm> | 7 #include <algorithm> |
| 8 #include <iterator> | 8 #include <iterator> |
| 9 #include <set> |
| 9 #include <string> | 10 #include <string> |
| 10 #include <vector> | 11 #include <vector> |
| 11 | 12 |
| 12 #include "base/base64.h" | 13 #include "base/base64.h" |
| 13 #include "base/basictypes.h" | 14 #include "base/basictypes.h" |
| 14 #include "base/bind.h" | 15 #include "base/bind.h" |
| 15 #include "base/containers/hash_tables.h" | 16 #include "base/containers/hash_tables.h" |
| 16 #include "base/stl_util.h" | 17 #include "base/stl_util.h" |
| 18 #include "base/strings/string_number_conversions.h" |
| 17 #include "base/strings/string_util.h" | 19 #include "base/strings/string_util.h" |
| 18 #include "base/strings/stringprintf.h" | 20 #include "base/strings/stringprintf.h" |
| 19 #include "base/time/time.h" | 21 #include "base/time/time.h" |
| 20 #include "crypto/random.h" | 22 #include "crypto/random.h" |
| 21 #include "net/http/http_request_headers.h" | 23 #include "net/http/http_request_headers.h" |
| 22 #include "net/http/http_request_info.h" | 24 #include "net/http/http_request_info.h" |
| 23 #include "net/http/http_response_body_drainer.h" | 25 #include "net/http/http_response_body_drainer.h" |
| 24 #include "net/http/http_response_headers.h" | 26 #include "net/http/http_response_headers.h" |
| 25 #include "net/http/http_status_code.h" | 27 #include "net/http/http_status_code.h" |
| 26 #include "net/http/http_stream_parser.h" | 28 #include "net/http/http_stream_parser.h" |
| 27 #include "net/socket/client_socket_handle.h" | 29 #include "net/socket/client_socket_handle.h" |
| 28 #include "net/websockets/websocket_basic_stream.h" | 30 #include "net/websockets/websocket_basic_stream.h" |
| 31 #include "net/websockets/websocket_deflate_predictor.h" |
| 32 #include "net/websockets/websocket_deflate_predictor_impl.h" |
| 33 #include "net/websockets/websocket_deflate_stream.h" |
| 34 #include "net/websockets/websocket_deflater.h" |
| 29 #include "net/websockets/websocket_extension_parser.h" | 35 #include "net/websockets/websocket_extension_parser.h" |
| 30 #include "net/websockets/websocket_handshake_constants.h" | 36 #include "net/websockets/websocket_handshake_constants.h" |
| 31 #include "net/websockets/websocket_handshake_handler.h" | 37 #include "net/websockets/websocket_handshake_handler.h" |
| 32 #include "net/websockets/websocket_handshake_request_info.h" | 38 #include "net/websockets/websocket_handshake_request_info.h" |
| 33 #include "net/websockets/websocket_handshake_response_info.h" | 39 #include "net/websockets/websocket_handshake_response_info.h" |
| 34 #include "net/websockets/websocket_stream.h" | 40 #include "net/websockets/websocket_stream.h" |
| 35 | 41 |
| 36 namespace net { | 42 namespace net { |
| 43 |
| 44 // TODO(ricea): If more extensions are added, replace this with a more general |
| 45 // mechanism. |
| 46 struct WebSocketExtensionParams { |
| 47 WebSocketExtensionParams() |
| 48 : deflate_enabled(false), |
| 49 client_window_bits(15), |
| 50 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} |
| 51 |
| 52 bool deflate_enabled; |
| 53 int client_window_bits; |
| 54 WebSocketDeflater::ContextTakeOverMode deflate_mode; |
| 55 }; |
| 56 |
| 37 namespace { | 57 namespace { |
| 38 | 58 |
| 39 enum GetHeaderResult { | 59 enum GetHeaderResult { |
| 40 GET_HEADER_OK, | 60 GET_HEADER_OK, |
| 41 GET_HEADER_MISSING, | 61 GET_HEADER_MISSING, |
| 42 GET_HEADER_MULTIPLE, | 62 GET_HEADER_MULTIPLE, |
| 43 }; | 63 }; |
| 44 | 64 |
| 45 std::string MissingHeaderMessage(const std::string& header_name) { | 65 std::string MissingHeaderMessage(const std::string& header_name) { |
| 46 return std::string("'") + header_name + "' header is missing"; | 66 return std::string("'") + header_name + "' header is missing"; |
| (...skipping 148 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 195 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 215 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
| 196 *failure_message = | 216 *failure_message = |
| 197 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 217 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
| 198 "but no response was received"; | 218 "but no response was received"; |
| 199 return false; | 219 return false; |
| 200 } | 220 } |
| 201 *sub_protocol = value; | 221 *sub_protocol = value; |
| 202 return true; | 222 return true; |
| 203 } | 223 } |
| 204 | 224 |
| 225 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, |
| 226 std::string* failure_message, |
| 227 WebSocketExtensionParams* params) { |
| 228 static const char kClientPrefix[] = "client_"; |
| 229 static const char kServerPrefix[] = "server_"; |
| 230 static const char kNoContextTakeover[] = "no_context_takeover"; |
| 231 static const char kMaxWindowBits[] = "max_window_bits"; |
| 232 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
| 233 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, |
| 234 the_strings_server_and_client_must_be_the_same_length); |
| 235 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
| 236 |
| 237 DCHECK(extension.name() == "permessage-deflate"); |
| 238 const ParameterVector& parameters = extension.parameters(); |
| 239 std::set<std::string> seen_names; |
| 240 for (ParameterVector::const_iterator it = parameters.begin(); |
| 241 it != parameters.end(); ++it) { |
| 242 const std::string& name = it->name(); |
| 243 if (seen_names.count(name) != 0) { |
| 244 *failure_message = |
| 245 "Received duplicate permessage-deflate extension parameter " + name; |
| 246 return false; |
| 247 } |
| 248 seen_names.insert(name); |
| 249 const std::string client_or_server(name, 0, kPrefixLen); |
| 250 const bool is_client = (client_or_server == kClientPrefix); |
| 251 if (!is_client && client_or_server != kServerPrefix) { |
| 252 *failure_message = |
| 253 "Received an unexpected permessage-deflate extension parameter"; |
| 254 return false; |
| 255 } |
| 256 const std::string rest(name, kPrefixLen); |
| 257 if (rest == kNoContextTakeover) { |
| 258 if (it->HasValue()) { |
| 259 *failure_message = "Received invalid " + name + " parameter"; |
| 260 return false; |
| 261 } |
| 262 if (is_client) |
| 263 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; |
| 264 } else if (rest == kMaxWindowBits) { |
| 265 if (!it->HasValue()) { |
| 266 *failure_message = name + " must have value"; |
| 267 return false; |
| 268 } |
| 269 int bits = 0; |
| 270 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || |
| 271 it->value()[0] == '0' || |
| 272 it->value().find_first_not_of("0123456789") != std::string::npos) { |
| 273 *failure_message = "Received invalid " + name + " parameter"; |
| 274 return false; |
| 275 } |
| 276 if (is_client) |
| 277 params->client_window_bits = bits; |
| 278 } else { |
| 279 *failure_message = |
| 280 "Received an unexpected permessage-deflate extension parameter"; |
| 281 return false; |
| 282 } |
| 283 } |
| 284 params->deflate_enabled = true; |
| 285 return true; |
| 286 } |
| 287 |
| 205 bool ValidateExtensions(const HttpResponseHeaders* headers, | 288 bool ValidateExtensions(const HttpResponseHeaders* headers, |
| 206 const std::vector<std::string>& requested_extensions, | 289 const std::vector<std::string>& requested_extensions, |
| 207 std::string* extensions, | 290 std::string* extensions, |
| 208 std::string* failure_message) { | 291 std::string* failure_message, |
| 292 WebSocketExtensionParams* params) { |
| 209 void* state = NULL; | 293 void* state = NULL; |
| 210 std::string value; | 294 std::string value; |
| 295 std::vector<std::string> accepted_extensions; |
| 296 // TODO(ricea): If adding support for additional extensions, generalise this |
| 297 // code. |
| 298 bool seen_permessage_deflate = false; |
| 211 while (headers->EnumerateHeader( | 299 while (headers->EnumerateHeader( |
| 212 &state, websockets::kSecWebSocketExtensions, &value)) { | 300 &state, websockets::kSecWebSocketExtensions, &value)) { |
| 213 WebSocketExtensionParser parser; | 301 WebSocketExtensionParser parser; |
| 214 parser.Parse(value); | 302 parser.Parse(value); |
| 215 if (parser.has_error()) { | 303 if (parser.has_error()) { |
| 216 // TODO(yhirano) Set appropriate failure message. | 304 // TODO(yhirano) Set appropriate failure message. |
| 217 *failure_message = | 305 *failure_message = |
| 218 "'Sec-WebSocket-Extensions' header value is " | 306 "'Sec-WebSocket-Extensions' header value is " |
| 219 "rejected by the parser: " + | 307 "rejected by the parser: " + |
| 220 value; | 308 value; |
| 221 return false; | 309 return false; |
| 222 } | 310 } |
| 223 // TODO(ricea): Accept permessage-deflate with valid parameters. | 311 if (parser.extension().name() == "permessage-deflate") { |
| 224 *failure_message = | 312 if (seen_permessage_deflate) { |
| 225 "Found an unsupported extension '" + | 313 *failure_message = "Received duplicate permessage-deflate response"; |
| 226 parser.extension().name() + | 314 return false; |
| 227 "' in 'Sec-WebSocket-Extensions' header"; | 315 } |
| 228 return false; | 316 seen_permessage_deflate = true; |
| 317 if (!ValidatePerMessageDeflateExtension( |
| 318 parser.extension(), failure_message, params)) |
| 319 return false; |
| 320 } else { |
| 321 *failure_message = |
| 322 "Found an unsupported extension '" + |
| 323 parser.extension().name() + |
| 324 "' in 'Sec-WebSocket-Extensions' header"; |
| 325 return false; |
| 326 } |
| 327 accepted_extensions.push_back(value); |
| 229 } | 328 } |
| 329 *extensions = JoinString(accepted_extensions, ", "); |
| 230 return true; | 330 return true; |
| 231 } | 331 } |
| 232 | 332 |
| 233 } // namespace | 333 } // namespace |
| 234 | 334 |
| 235 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( | 335 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( |
| 236 scoped_ptr<ClientSocketHandle> connection, | 336 scoped_ptr<ClientSocketHandle> connection, |
| 237 WebSocketStream::ConnectDelegate* connect_delegate, | 337 WebSocketStream::ConnectDelegate* connect_delegate, |
| 238 bool using_proxy, | 338 bool using_proxy, |
| 239 std::vector<std::string> requested_sub_protocols, | 339 std::vector<std::string> requested_sub_protocols, |
| (...skipping 37 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 277 enriched_headers.CopyFrom(headers); | 377 enriched_headers.CopyFrom(headers); |
| 278 std::string handshake_challenge; | 378 std::string handshake_challenge; |
| 279 if (handshake_challenge_for_testing_) { | 379 if (handshake_challenge_for_testing_) { |
| 280 handshake_challenge = *handshake_challenge_for_testing_; | 380 handshake_challenge = *handshake_challenge_for_testing_; |
| 281 handshake_challenge_for_testing_.reset(); | 381 handshake_challenge_for_testing_.reset(); |
| 282 } else { | 382 } else { |
| 283 handshake_challenge = GenerateHandshakeChallenge(); | 383 handshake_challenge = GenerateHandshakeChallenge(); |
| 284 } | 384 } |
| 285 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); | 385 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); |
| 286 | 386 |
| 387 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, |
| 388 requested_extensions_, |
| 389 &enriched_headers); |
| 287 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, | 390 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, |
| 288 requested_sub_protocols_, | 391 requested_sub_protocols_, |
| 289 &enriched_headers); | 392 &enriched_headers); |
| 290 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, | |
| 291 requested_extensions_, | |
| 292 &enriched_headers); | |
| 293 | 393 |
| 294 ComputeSecWebSocketAccept(handshake_challenge, | 394 ComputeSecWebSocketAccept(handshake_challenge, |
| 295 &handshake_challenge_response_); | 395 &handshake_challenge_response_); |
| 296 | 396 |
| 297 DCHECK(connect_delegate_); | 397 DCHECK(connect_delegate_); |
| 298 scoped_ptr<WebSocketHandshakeRequestInfo> request( | 398 scoped_ptr<WebSocketHandshakeRequestInfo> request( |
| 299 new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); | 399 new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); |
| 300 request->headers.CopyFrom(enriched_headers); | 400 request->headers.CopyFrom(enriched_headers); |
| 301 connect_delegate_->OnStartOpeningHandshake(request.Pass()); | 401 connect_delegate_->OnStartOpeningHandshake(request.Pass()); |
| 302 | 402 |
| (...skipping 83 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 386 drainer->Start(session); | 486 drainer->Start(session); |
| 387 // |drainer| will delete itself. | 487 // |drainer| will delete itself. |
| 388 } | 488 } |
| 389 | 489 |
| 390 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { | 490 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { |
| 391 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is | 491 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is |
| 392 // gone, then copy whatever has happened there over here. | 492 // gone, then copy whatever has happened there over here. |
| 393 } | 493 } |
| 394 | 494 |
| 395 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { | 495 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { |
| 396 // TODO(ricea): Add deflate support. | |
| 397 | |
| 398 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make | 496 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make |
| 399 // sure it does not touch it again before it is destroyed. | 497 // sure it does not touch it again before it is destroyed. |
| 400 state_.DeleteParser(); | 498 state_.DeleteParser(); |
| 401 return scoped_ptr<WebSocketStream>( | 499 scoped_ptr<WebSocketStream> basic_stream( |
| 402 new WebSocketBasicStream(state_.ReleaseConnection(), | 500 new WebSocketBasicStream(state_.ReleaseConnection(), |
| 403 state_.read_buf(), | 501 state_.read_buf(), |
| 404 sub_protocol_, | 502 sub_protocol_, |
| 405 extensions_)); | 503 extensions_)); |
| 504 DCHECK(extension_params_.get()); |
| 505 if (extension_params_->deflate_enabled) { |
| 506 return scoped_ptr<WebSocketStream>( |
| 507 new WebSocketDeflateStream(basic_stream.Pass(), |
| 508 extension_params_->deflate_mode, |
| 509 extension_params_->client_window_bits, |
| 510 scoped_ptr<WebSocketDeflatePredictor>( |
| 511 new WebSocketDeflatePredictorImpl))); |
| 512 } else { |
| 513 return basic_stream.Pass(); |
| 514 } |
| 406 } | 515 } |
| 407 | 516 |
| 408 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( | 517 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
| 409 const std::string& key) { | 518 const std::string& key) { |
| 410 handshake_challenge_for_testing_.reset(new std::string(key)); | 519 handshake_challenge_for_testing_.reset(new std::string(key)); |
| 411 } | 520 } |
| 412 | 521 |
| 413 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { | 522 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { |
| 414 return failure_message_; | 523 return failure_message_; |
| 415 } | 524 } |
| (...skipping 41 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 457 default: | 566 default: |
| 458 failure_message_ = base::StringPrintf("Unexpected status code: %d", | 567 failure_message_ = base::StringPrintf("Unexpected status code: %d", |
| 459 headers->response_code()); | 568 headers->response_code()); |
| 460 OnFinishOpeningHandshake(); | 569 OnFinishOpeningHandshake(); |
| 461 return ERR_INVALID_RESPONSE; | 570 return ERR_INVALID_RESPONSE; |
| 462 } | 571 } |
| 463 } | 572 } |
| 464 | 573 |
| 465 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( | 574 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| 466 const scoped_refptr<HttpResponseHeaders>& headers) { | 575 const scoped_refptr<HttpResponseHeaders>& headers) { |
| 576 extension_params_.reset(new WebSocketExtensionParams); |
| 467 if (ValidateUpgrade(headers.get(), &failure_message_) && | 577 if (ValidateUpgrade(headers.get(), &failure_message_) && |
| 468 ValidateSecWebSocketAccept(headers.get(), | 578 ValidateSecWebSocketAccept(headers.get(), |
| 469 handshake_challenge_response_, | 579 handshake_challenge_response_, |
| 470 &failure_message_) && | 580 &failure_message_) && |
| 471 ValidateConnection(headers.get(), &failure_message_) && | 581 ValidateConnection(headers.get(), &failure_message_) && |
| 472 ValidateSubProtocol(headers.get(), | 582 ValidateSubProtocol(headers.get(), |
| 473 requested_sub_protocols_, | 583 requested_sub_protocols_, |
| 474 &sub_protocol_, | 584 &sub_protocol_, |
| 475 &failure_message_) && | 585 &failure_message_) && |
| 476 ValidateExtensions(headers.get(), | 586 ValidateExtensions(headers.get(), |
| 477 requested_extensions_, | 587 requested_extensions_, |
| 478 &extensions_, | 588 &extensions_, |
| 479 &failure_message_)) { | 589 &failure_message_, |
| 590 extension_params_.get())) { |
| 480 return OK; | 591 return OK; |
| 481 } | 592 } |
| 482 failure_message_ = "Error during WebSocket handshake: " + failure_message_; | 593 failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
| 483 return ERR_INVALID_RESPONSE; | 594 return ERR_INVALID_RESPONSE; |
| 484 } | 595 } |
| 485 | 596 |
| 486 } // namespace net | 597 } // namespace net |
| OLD | NEW |