| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
| 6 | 6 |
| 7 #include <algorithm> | 7 #include <algorithm> |
| 8 #include <iterator> | 8 #include <iterator> |
| 9 #include <set> | 9 #include <set> |
| 10 #include <string> | 10 #include <string> |
| (...skipping 216 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 227 WebSocketExtensionParams* params) { | 227 WebSocketExtensionParams* params) { |
| 228 static const char kClientPrefix[] = "client_"; | 228 static const char kClientPrefix[] = "client_"; |
| 229 static const char kServerPrefix[] = "server_"; | 229 static const char kServerPrefix[] = "server_"; |
| 230 static const char kNoContextTakeover[] = "no_context_takeover"; | 230 static const char kNoContextTakeover[] = "no_context_takeover"; |
| 231 static const char kMaxWindowBits[] = "max_window_bits"; | 231 static const char kMaxWindowBits[] = "max_window_bits"; |
| 232 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | 232 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
| 233 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, | 233 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, |
| 234 the_strings_server_and_client_must_be_the_same_length); | 234 the_strings_server_and_client_must_be_the_same_length); |
| 235 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | 235 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
| 236 | 236 |
| 237 DCHECK(extension.name() == "permessage-deflate"); | 237 DCHECK_EQ("permessage-deflate", extension.name()); |
| 238 const ParameterVector& parameters = extension.parameters(); | 238 const ParameterVector& parameters = extension.parameters(); |
| 239 std::set<std::string> seen_names; | 239 std::set<std::string> seen_names; |
| 240 for (ParameterVector::const_iterator it = parameters.begin(); | 240 for (ParameterVector::const_iterator it = parameters.begin(); |
| 241 it != parameters.end(); ++it) { | 241 it != parameters.end(); ++it) { |
| 242 const std::string& name = it->name(); | 242 const std::string& name = it->name(); |
| 243 if (seen_names.count(name) != 0) { | 243 if (seen_names.count(name) != 0) { |
| 244 *failure_message = | 244 *failure_message = |
| 245 "Received duplicate permessage-deflate extension parameter " + name; | 245 "Received duplicate permessage-deflate extension parameter " + name; |
| 246 return false; | 246 return false; |
| 247 } | 247 } |
| (...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 290 std::string* extensions, | 290 std::string* extensions, |
| 291 std::string* failure_message, | 291 std::string* failure_message, |
| 292 WebSocketExtensionParams* params) { | 292 WebSocketExtensionParams* params) { |
| 293 void* state = NULL; | 293 void* state = NULL; |
| 294 std::string value; | 294 std::string value; |
| 295 std::vector<std::string> accepted_extensions; | 295 std::vector<std::string> accepted_extensions; |
| 296 // TODO(ricea): If adding support for additional extensions, generalise this | 296 // TODO(ricea): If adding support for additional extensions, generalise this |
| 297 // code. | 297 // code. |
| 298 bool seen_permessage_deflate = false; | 298 bool seen_permessage_deflate = false; |
| 299 while (headers->EnumerateHeader( | 299 while (headers->EnumerateHeader( |
| 300 &state, websockets::kSecWebSocketExtensions, &value)) { | 300 &state, websockets::kSecWebSocketExtensions, &value)) { |
| 301 WebSocketExtensionParser parser; | 301 WebSocketExtensionParser parser; |
| 302 parser.Parse(value); | 302 parser.Parse(value); |
| 303 if (parser.has_error()) { | 303 if (parser.has_error()) { |
| 304 // TODO(yhirano) Set appropriate failure message. | 304 // TODO(yhirano) Set appropriate failure message. |
| 305 *failure_message = | 305 *failure_message = |
| 306 "'Sec-WebSocket-Extensions' header value is " | 306 "'Sec-WebSocket-Extensions' header value is " |
| 307 "rejected by the parser: " + | 307 "rejected by the parser: " + |
| 308 value; | 308 value; |
| 309 return false; | 309 return false; |
| 310 } | 310 } |
| 311 if (parser.extension().name() == "permessage-deflate") { | 311 if (parser.extension().name() == "permessage-deflate") { |
| 312 if (seen_permessage_deflate) { | 312 if (seen_permessage_deflate) { |
| 313 *failure_message = "Received duplicate permessage-deflate response"; | 313 *failure_message = "Received duplicate permessage-deflate response"; |
| 314 return false; | 314 return false; |
| 315 } | 315 } |
| 316 seen_permessage_deflate = true; | 316 seen_permessage_deflate = true; |
| 317 if (!ValidatePerMessageDeflateExtension( | 317 if (!ValidatePerMessageDeflateExtension( |
| 318 parser.extension(), failure_message, params)) | 318 parser.extension(), failure_message, params)) |
| 319 return false; | 319 return false; |
| 320 } else { | 320 } else { |
| 321 *failure_message = | 321 *failure_message = |
| 322 "Found an unsupported extension '" + | 322 "Found an unsupported extension '" + |
| 323 parser.extension().name() + | 323 parser.extension().name() + |
| 324 "' in 'Sec-WebSocket-Extensions' header"; | 324 "' in 'Sec-WebSocket-Extensions' header"; |
| 325 return false; | 325 return false; |
| 326 } | 326 } |
| 327 accepted_extensions.push_back(value); | 327 accepted_extensions.push_back(value); |
| 328 } | 328 } |
| (...skipping 80 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 409 // HttpStreamParser uses a weak pointer when reading from the | 409 // HttpStreamParser uses a weak pointer when reading from the |
| 410 // socket, so it won't be called back after being destroyed. The | 410 // socket, so it won't be called back after being destroyed. The |
| 411 // HttpStreamParser is owned by HttpBasicState which is owned by this object, | 411 // HttpStreamParser is owned by HttpBasicState which is owned by this object, |
| 412 // so this use of base::Unretained() is safe. | 412 // so this use of base::Unretained() is safe. |
| 413 int rv = parser()->ReadResponseHeaders( | 413 int rv = parser()->ReadResponseHeaders( |
| 414 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback, | 414 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback, |
| 415 base::Unretained(this), | 415 base::Unretained(this), |
| 416 callback)); | 416 callback)); |
| 417 if (rv == ERR_IO_PENDING) | 417 if (rv == ERR_IO_PENDING) |
| 418 return rv; | 418 return rv; |
| 419 if (rv == OK) | 419 return ValidateResponse(rv); |
| 420 return ValidateResponse(); | |
| 421 OnFinishOpeningHandshake(); | |
| 422 return rv; | |
| 423 } | 420 } |
| 424 | 421 |
| 425 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const { | 422 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const { |
| 426 return parser()->GetResponseInfo(); | 423 return parser()->GetResponseInfo(); |
| 427 } | 424 } |
| 428 | 425 |
| 429 int WebSocketBasicHandshakeStream::ReadResponseBody( | 426 int WebSocketBasicHandshakeStream::ReadResponseBody( |
| 430 IOBuffer* buf, | 427 IOBuffer* buf, |
| 431 int buf_len, | 428 int buf_len, |
| 432 const CompletionCallback& callback) { | 429 const CompletionCallback& callback) { |
| (...skipping 86 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 519 handshake_challenge_for_testing_.reset(new std::string(key)); | 516 handshake_challenge_for_testing_.reset(new std::string(key)); |
| 520 } | 517 } |
| 521 | 518 |
| 522 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { | 519 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { |
| 523 return failure_message_; | 520 return failure_message_; |
| 524 } | 521 } |
| 525 | 522 |
| 526 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( | 523 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( |
| 527 const CompletionCallback& callback, | 524 const CompletionCallback& callback, |
| 528 int result) { | 525 int result) { |
| 529 if (result == OK) | 526 callback.Run(ValidateResponse(result)); |
| 530 result = ValidateResponse(); | |
| 531 else | |
| 532 OnFinishOpeningHandshake(); | |
| 533 callback.Run(result); | |
| 534 } | 527 } |
| 535 | 528 |
| 536 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() { | 529 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() { |
| 537 DCHECK(connect_delegate_); | 530 DCHECK(connect_delegate_); |
| 538 DCHECK(http_response_info_); | 531 DCHECK(http_response_info_); |
| 539 scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers; | 532 scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers; |
| 540 scoped_ptr<WebSocketHandshakeResponseInfo> response( | 533 scoped_ptr<WebSocketHandshakeResponseInfo> response( |
| 541 new WebSocketHandshakeResponseInfo(url_, | 534 new WebSocketHandshakeResponseInfo(url_, |
| 542 headers->response_code(), | 535 headers->response_code(), |
| 543 headers->GetStatusText(), | 536 headers->GetStatusText(), |
| 544 headers, | 537 headers, |
| 545 http_response_info_->response_time)); | 538 http_response_info_->response_time)); |
| 546 connect_delegate_->OnFinishOpeningHandshake(response.Pass()); | 539 connect_delegate_->OnFinishOpeningHandshake(response.Pass()); |
| 547 } | 540 } |
| 548 | 541 |
| 549 int WebSocketBasicHandshakeStream::ValidateResponse() { | 542 int WebSocketBasicHandshakeStream::ValidateResponse(int rv) { |
| 550 DCHECK(http_response_info_); | 543 DCHECK(http_response_info_); |
| 551 const scoped_refptr<HttpResponseHeaders>& headers = | 544 const HttpResponseHeaders* headers = http_response_info_->headers.get(); |
| 552 http_response_info_->headers; | 545 if (rv >= 0) { |
| 546 switch (headers->response_code()) { |
| 547 case HTTP_SWITCHING_PROTOCOLS: |
| 548 OnFinishOpeningHandshake(); |
| 549 return ValidateUpgradeResponse(headers); |
| 553 | 550 |
| 554 switch (headers->response_code()) { | 551 // We need to pass these through for authentication to work. |
| 555 case HTTP_SWITCHING_PROTOCOLS: | 552 case HTTP_UNAUTHORIZED: |
| 556 OnFinishOpeningHandshake(); | 553 case HTTP_PROXY_AUTHENTICATION_REQUIRED: |
| 557 return ValidateUpgradeResponse(headers); | 554 return OK; |
| 558 | 555 |
| 559 // We need to pass these through for authentication to work. | 556 // Other status codes are potentially risky (see the warnings in the |
| 560 case HTTP_UNAUTHORIZED: | 557 // WHATWG WebSocket API spec) and so are dropped by default. |
| 561 case HTTP_PROXY_AUTHENTICATION_REQUIRED: | 558 default: |
| 562 return OK; | 559 failure_message_ = base::StringPrintf( |
| 563 | 560 "Error during WebSocket handshake: Unexpected response code: %d", |
| 564 // Other status codes are potentially risky (see the warnings in the | 561 headers->response_code()); |
| 565 // WHATWG WebSocket API spec) and so are dropped by default. | 562 OnFinishOpeningHandshake(); |
| 566 default: | 563 return ERR_INVALID_RESPONSE; |
| 567 failure_message_ = base::StringPrintf("Unexpected status code: %d", | 564 } |
| 568 headers->response_code()); | 565 } else { |
| 569 OnFinishOpeningHandshake(); | 566 failure_message_ = |
| 570 return ERR_INVALID_RESPONSE; | 567 std::string("Error during WebSocket handshake: ") + ErrorToString(rv); |
| 568 OnFinishOpeningHandshake(); |
| 569 return rv; |
| 571 } | 570 } |
| 572 } | 571 } |
| 573 | 572 |
| 574 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( | 573 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| 575 const scoped_refptr<HttpResponseHeaders>& headers) { | 574 const HttpResponseHeaders* headers) { |
| 576 extension_params_.reset(new WebSocketExtensionParams); | 575 extension_params_.reset(new WebSocketExtensionParams); |
| 577 if (ValidateUpgrade(headers.get(), &failure_message_) && | 576 if (ValidateUpgrade(headers, &failure_message_) && |
| 578 ValidateSecWebSocketAccept(headers.get(), | 577 ValidateSecWebSocketAccept(headers, |
| 579 handshake_challenge_response_, | 578 handshake_challenge_response_, |
| 580 &failure_message_) && | 579 &failure_message_) && |
| 581 ValidateConnection(headers.get(), &failure_message_) && | 580 ValidateConnection(headers, &failure_message_) && |
| 582 ValidateSubProtocol(headers.get(), | 581 ValidateSubProtocol(headers, |
| 583 requested_sub_protocols_, | 582 requested_sub_protocols_, |
| 584 &sub_protocol_, | 583 &sub_protocol_, |
| 585 &failure_message_) && | 584 &failure_message_) && |
| 586 ValidateExtensions(headers.get(), | 585 ValidateExtensions(headers, |
| 587 requested_extensions_, | 586 requested_extensions_, |
| 588 &extensions_, | 587 &extensions_, |
| 589 &failure_message_, | 588 &failure_message_, |
| 590 extension_params_.get())) { | 589 extension_params_.get())) { |
| 591 return OK; | 590 return OK; |
| 592 } | 591 } |
| 593 failure_message_ = "Error during WebSocket handshake: " + failure_message_; | 592 failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
| 594 return ERR_INVALID_RESPONSE; | 593 return ERR_INVALID_RESPONSE; |
| 595 } | 594 } |
| 596 | 595 |
| 597 } // namespace net | 596 } // namespace net |
| OLD | NEW |